aboutsummaryrefslogtreecommitdiff
path: root/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc')
-rw-r--r--third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc60
1 files changed, 29 insertions, 31 deletions
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
index c78a6b4f057..ded3ca4b6e4 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
@@ -18,11 +18,9 @@ limitations under the License.
#include <optional>
#include <vector>
-#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
@@ -31,10 +29,12 @@ limitations under the License.
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
@@ -52,16 +52,13 @@ namespace xla {
namespace gpu {
namespace {
+namespace ma = mlir::arith;
+
using llvm::SmallVector;
using mlir::Location;
using mlir::OpBuilder;
using mlir::Value;
using mlir::ValueRange;
-using mlir::arith::AddIOp;
-using mlir::arith::AndIOp;
-using mlir::arith::CmpIOp;
-using mlir::arith::CmpIPredicate;
-using mlir::arith::ConstantIndexOp;
using mlir::func::ReturnOp;
using mlir::tensor::InsertOp;
using mlir_converter::ApplyAffineMap;
@@ -208,7 +205,6 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
b.setInsertionPointToStart(entry_function.addEntryBlock());
SmallVector<Value> result_tensors{entry_function.getArguments().back()};
- auto c0 = b.create<ConstantIndexOp>(0);
auto scatter_result = EmitThreadLoopNest(
b, result_tensors, thread_id_to_update_map,
@@ -224,39 +220,41 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
// Extract slice offsets from scatter_indices operand, compute if the
// whole slice of scatter_update operand will fit into the output.
- mlir::Value is_in_bounds =
- b.create<mlir::arith::ConstantIntOp>(1, b.getI1Type());
+ mlir::Value in_bounds = b.create<ma::ConstantIntOp>(1, b.getI1Type());
SmallVector<Value, 4> indices{
llvm::ArrayRef(update_tensor_indices).drop_front()};
- for (int i = 0; i < scatter_operand->shape().rank(); ++i) {
- Value extracted_index = c0;
- if (i < scatter_indices->shape().dimensions(1)) {
- SmallVector<Value, 4> indices_tensor_indices = {
- update_tensor_indices.front(), b.create<ConstantIndexOp>(i)};
- extracted_index = ProvideParameter(
- root_computation, scatter, kScatterIndicesIndex,
- indices_tensor_indices, call_targets, entry_function, b);
- if (extracted_index.getType() != b.getIndexType()) {
- extracted_index = b.create<mlir::arith::IndexCastOp>(
- b.getIndexType(), extracted_index);
- }
+ for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) {
+ SmallVector<Value, 4> indices_tensor_indices = {
+ update_tensor_indices.front(), b.create<ma::ConstantIndexOp>(i)};
+ auto index = ProvideParameter(
+ root_computation, scatter, kScatterIndicesIndex,
+ indices_tensor_indices, call_targets, entry_function, b);
+ auto index_ty = mlir::cast<mlir::IntegerType>(index.getType());
+ if (index_ty.isUnsigned()) {
+ auto int_ty = b.getIntegerType(index_ty.getWidth());
+ index = b.create<mlir::UnrealizedConversionCastOp>(int_ty, index)
+ .getResult(0);
+ index = b.create<ma::IndexCastUIOp>(b.getIndexType(), index);
+ } else {
+ index = b.create<ma::IndexCastOp>(b.getIndexType(), index);
+ auto c0 = b.create<ma::ConstantIndexOp>(0);
+ in_bounds = b.create<ma::AndIOp>(
+ in_bounds,
+ b.create<ma::CmpIOp>(ma::CmpIPredicate::sge, index, c0));
}
- is_in_bounds = b.create<AndIOp>(
- is_in_bounds,
- b.create<CmpIOp>(CmpIPredicate::sge, extracted_index, c0));
- Value ub = b.create<ConstantIndexOp>(
+ Value ub = b.create<ma::ConstantIndexOp>(
scatter_operand->shape().dimensions(i) -
scatter_update->shape().dimensions(i + 1));
- is_in_bounds = b.create<AndIOp>(
- is_in_bounds,
- b.create<CmpIOp>(CmpIPredicate::sle, extracted_index, ub));
- indices[i] = b.create<AddIOp>(extracted_index, indices[i]);
+ in_bounds = b.create<ma::AndIOp>(
+ in_bounds,
+ b.create<ma::CmpIOp>(ma::CmpIPredicate::sle, index, ub));
+ indices[i] = b.create<ma::AddIOp>(index, indices[i]);
}
// Call scatter's computation if is_in_bounds.
Value output_tensor = output_tensors.front();
Value predicated_update =
b.create<scf::IfOp>(
- is_in_bounds,
+ in_bounds,
[&](OpBuilder& then_builder, Location then_loc) -> void {
Value updated_output = EmitScatterComputation(
scatter, indices, update_elem, output_tensor,