aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp248
1 files changed, 202 insertions, 46 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8860674ef847..a28b90b1d95c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -84,6 +84,195 @@ static LogicalResult isContraction(Operation *op) {
hasMultiplyAddBody(genericOp.region()));
}
+static bool hasOnlyScalarElementwiseOp(Region &r) {
+ if (!llvm::hasSingleElement(r))
+ return false;
+ for (Operation &op : r.front()) {
+ if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+ op.hasTrait<OpTrait::ElementwiseMappable>()) ||
+ llvm::any_of(op.getResultTypes(),
+ [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+ return false;
+ }
+ return true;
+}
+
+// Return true if the op is an element-wise linalg op.
+static bool isElementwise(Operation *op) {
+ auto genericOp = dyn_cast<linalg::GenericOp>(op);
+ if (!genericOp)
+ return false;
+ if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
+ return false;
+ // TODO: relax the restrictions on indexing map.
+ for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
+ if (!genericOp.getOutputIndexingMap(i).isIdentity())
+ return false;
+ }
+ // Currently limit the input indexing map to minor identity as other
+ // permutations might require adding transpose ops to convert the vector read
+ // to the right shape.
+ for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
+ if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
+ return false;
+ }
+ return hasOnlyScalarElementwiseOp(genericOp.getRegion());
+}
+
+static VectorType extractVectorTypeFromScalarView(Value v) {
+ MemRefType mt = v.getType().cast<MemRefType>();
+ return mt.getShape().empty()
+ ? VectorType()
+ : VectorType::get(mt.getShape(), mt.getElementType());
+}
+
+static Value transferReadVector(OpBuilder &builder, Value memref) {
+ edsc::ScopedContext scope(builder);
+ auto memrefType = memref.getType().cast<MemRefType>();
+ if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
+ SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+ return vector_transfer_read(vectorType, memref, indices);
+ }
+ return std_load(memref);
+}
+
+static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
+ edsc::ScopedContext scope(builder);
+ auto memrefType = memref.getType().cast<MemRefType>();
+ if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
+ SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+ if (vectorType != value.getType())
+ value = vector_broadcast(vectorType, value);
+ vector_transfer_write(value, memref, indices);
+ } else {
+ std_store(value, memref);
+ }
+}
+
+namespace {
+// Transforms scalar operations into their vectorized counterparts,
+// while using the provided generic op to map:
+// * Its arguments to transfer reads from the views of the generic op.
+// * linalg.yield ops to transfer writes to the views of the generic op.
+class GenericVectorizer {
+public:
+ GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic)
+ : builder(builder), generic(generic) {}
+
+ // Takes a scalar operation and builds its vectorized counterpart or
+ // counterparts using the underlying builder.
+ // If operands of the scalar operation are referring to previously vectorized
+ // operations, then in their vectorized form these operands will be referring
+ // to previous vectorization results.
+ void vectorize(Operation &scalarOp) {
+ auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
+ if (yieldOp) {
+ for (auto outputAndMemref :
+ llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
+ Value vectorValue = vectorize(std::get<0>(outputAndMemref));
+ transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
+ }
+ return;
+ }
+ Operation *vectorOp = uncachedVectorize(scalarOp);
+ assert(scalarOp.getNumResults() == vectorOp->getNumResults());
+ for (auto result :
+ llvm::zip(scalarOp.getResults(), vectorOp->getResults())) {
+ valueCache[std::get<0>(result)] = std::get<1>(result);
+ }
+ }
+
+private:
+ // Transforms a scalar value into its vectorized counterpart, recursively
+ // vectorizing operations as necessary using the underlying builder.
+ // Keeps track of previously vectorized values and reuses vectorization
+ // results if these values come up again.
+ Value vectorize(Value scalarValue) {
+ // Don't vectorize values coming from outside the region.
+ if (scalarValue.getParentRegion() != &generic.region())
+ return scalarValue;
+ auto vectorValueIt = valueCache.find(scalarValue);
+ if (vectorValueIt != valueCache.end())
+ return vectorValueIt->second;
+
+ // If the value is from the region but not in the cache it means it is a
+ // block argument.
+ auto scalarArg = scalarValue.cast<BlockArgument>();
+ assert(scalarArg.getOwner() == &generic.region().front());
+ Value vector_arg =
+ generic.getInputsAndOutputBuffers()[scalarArg.getArgNumber()];
+ Value vectorResult = transferReadVector(builder, vector_arg);
+ valueCache[scalarArg] = vectorResult;
+ return vectorResult;
+ }
+
+ // Return the largest shape of all the given values. Return an empty
+ // SmallVector if there are no vector value.
+ static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) {
+ SmallVector<int64_t, 4> largestShape;
+ int64_t maxSize = 1;
+ for (Value value : values) {
+ auto vecType = value.getType().dyn_cast<VectorType>();
+ if (!vecType)
+ continue;
+ if (maxSize < vecType.getNumElements()) {
+ largestShape.assign(vecType.getShape().begin(),
+ vecType.getShape().end());
+ }
+ }
+ return largestShape;
+ }
+
+ // If the value's type doesn't have the given shape broadcast it.
+ Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) {
+ auto vecType = value.getType().dyn_cast<VectorType>();
+ if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
+ return value;
+ auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
+ : value.getType());
+ return builder.create<vector::BroadcastOp>(
+ builder.getInsertionPoint()->getLoc(), newVecType, value);
+ }
+
+ // Takes a scalar operation and builds its vectorized counterpart or
+ // counterparts using underlying builder without involving any caches.
+ Operation *uncachedVectorize(Operation &base_scalarOp) {
+ SmallVector<Value, 4> vectorizedOperands;
+ for (Value operand : base_scalarOp.getOperands()) {
+ vectorizedOperands.push_back(vectorize(operand));
+ }
+ SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands);
+ for (Value &operand : vectorizedOperands)
+ operand = broadcastIfNeeded(operand, shape);
+ OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName());
+ state.addAttributes(base_scalarOp.getAttrs());
+ state.addOperands(vectorizedOperands);
+ if (shape.empty()) {
+ state.addTypes(base_scalarOp.getResultTypes());
+ } else {
+ SmallVector<VectorType, 4> vectorizedTypes;
+ for (auto Type : base_scalarOp.getResultTypes())
+ vectorizedTypes.push_back(VectorType::get(shape, Type));
+ state.addTypes(vectorizedTypes);
+ }
+ return builder.createOperation(state);
+ }
+
+ OpBuilder &builder;
+ linalg::GenericOp generic;
+ llvm::DenseMap<Value, Value> valueCache;
+};
+} // namespace
+
+// Replaces elementwise linalg.generic ops with their bodies with scalar
+// operations from these bodies promoted to vector operations.
+static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
+ GenericVectorizer vectorizer(builder, op);
+ for (Operation &scalarOp : op.region().front()) {
+ vectorizer.vectorize(scalarOp);
+ }
+}
+
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
@@ -96,7 +285,8 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
if (isa<linalg::FillOp, linalg::CopyOp>(op))
return success();
-
+ if (isElementwise(op))
+ return success();
return isContraction(op);
}
@@ -108,28 +298,11 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
- auto extractVectorTypeFromScalarView = [](Value v) {
- MemRefType mt = v.getType().cast<MemRefType>();
- return mt.getShape().empty()
- ? VectorType()
- : VectorType::get(mt.getShape(), mt.getElementType());
- };
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
- Value viewOutput = fillOp.output();
- if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
- auto vecType =
- VectorType::get(fillOp.getOutputBufferType(0).getShape(),
- fillOp.getOutputBufferType(0).getElementType());
- Value vector = vector_broadcast(vecType, fillOp.value());
- Value zero = std_constant_index(0);
- SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
- vector_transfer_write(vector, viewOutput, indicesOutput);
- } else {
- std_store(fillOp.value(), viewOutput);
- }
+ transferWriteVector(builder, fillOp.value(), fillOp.output());
return;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
@@ -138,36 +311,19 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
- Value zero = std_constant_index(0);
- Value viewInput = copyOp.input();
- Value viewOutput = copyOp.output();
- Value vector;
- if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) {
- SmallVector<Value, 4> indicesInput(inputType.getRank(), zero);
- if (copyOp.inputPermutation())
- vector = vector_transfer_read(
- extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput,
- copyOp.inputPermutation().getValue());
- else
- vector =
- vector_transfer_read(extractVectorTypeFromScalarView(viewInput),
- viewInput, indicesInput);
- } else {
- vector = std_load(viewInput).value;
- }
- if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
- SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
- if (copyOp.outputPermutation())
- vector_transfer_write(vector, viewOutput, indicesOutput,
- copyOp.outputPermutation().getValue());
- else
- vector_transfer_write(vector, viewOutput, indicesOutput);
- } else {
- std_store(vector, viewOutput);
- }
+ Value vector = transferReadVector(builder, copyOp.input());
+ transferWriteVector(builder, vector, copyOp.output());
return;
}
+ if (isElementwise(op)) {
+ LLVM_DEBUG(dbgs() << dbgPref
+ << "Rewrite linalg op as vector.transfer_read + "
+ "vector_op + vector.transfer_write: "
+ << *op);
+ return vectorizeElementwise(cast<linalg::GenericOp>(op), builder);
+ }
+
assert(succeeded(isContraction(op)) && "Expected contraction");
// Vectorize other ops as vector contraction.