diff options
Diffstat (limited to 'third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc')
-rw-r--r-- | third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc | 60 |
1 files changed, 29 insertions, 31 deletions
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index c78a6b4f057..ded3ca4b6e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -18,11 +18,9 @@ limitations under the License. #include <optional> #include <vector> -#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,10 +29,12 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -52,16 +52,13 @@ namespace xla { namespace gpu { namespace { +namespace ma = mlir::arith; + using llvm::SmallVector; using mlir::Location; using mlir::OpBuilder; using mlir::Value; using mlir::ValueRange; -using mlir::arith::AddIOp; -using mlir::arith::AndIOp; -using mlir::arith::CmpIOp; -using mlir::arith::CmpIPredicate; -using mlir::arith::ConstantIndexOp; using mlir::func::ReturnOp; using mlir::tensor::InsertOp; using mlir_converter::ApplyAffineMap; @@ -208,7 +205,6 @@ absl::Status MlirScatterFusion::EmitEntryFunction( b.setInsertionPointToStart(entry_function.addEntryBlock()); SmallVector<Value> result_tensors{entry_function.getArguments().back()}; - auto c0 = b.create<ConstantIndexOp>(0); auto scatter_result = EmitThreadLoopNest( b, result_tensors, thread_id_to_update_map, @@ -224,39 +220,41 @@ absl::Status MlirScatterFusion::EmitEntryFunction( // Extract slice offsets from scatter_indices operand, compute if the // whole slice of scatter_update operand will fit into the output. - mlir::Value is_in_bounds = - b.create<mlir::arith::ConstantIntOp>(1, b.getI1Type()); + mlir::Value in_bounds = b.create<ma::ConstantIntOp>(1, b.getI1Type()); SmallVector<Value, 4> indices{ llvm::ArrayRef(update_tensor_indices).drop_front()}; - for (int i = 0; i < scatter_operand->shape().rank(); ++i) { - Value extracted_index = c0; - if (i < scatter_indices->shape().dimensions(1)) { - SmallVector<Value, 4> indices_tensor_indices = { - update_tensor_indices.front(), b.create<ConstantIndexOp>(i)}; - extracted_index = ProvideParameter( - root_computation, scatter, kScatterIndicesIndex, - indices_tensor_indices, call_targets, entry_function, b); - if (extracted_index.getType() != b.getIndexType()) { - extracted_index = b.create<mlir::arith::IndexCastOp>( - b.getIndexType(), extracted_index); - } + for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) { + SmallVector<Value, 4> indices_tensor_indices = { + update_tensor_indices.front(), b.create<ma::ConstantIndexOp>(i)}; + auto index = ProvideParameter( + root_computation, scatter, kScatterIndicesIndex, + indices_tensor_indices, call_targets, entry_function, b); + auto index_ty = mlir::cast<mlir::IntegerType>(index.getType()); + if (index_ty.isUnsigned()) { + auto int_ty = b.getIntegerType(index_ty.getWidth()); + index = b.create<mlir::UnrealizedConversionCastOp>(int_ty, index) + .getResult(0); + index = b.create<ma::IndexCastUIOp>(b.getIndexType(), index); + } else { + index = b.create<ma::IndexCastOp>(b.getIndexType(), index); + auto c0 = b.create<ma::ConstantIndexOp>(0); + in_bounds = b.create<ma::AndIOp>( + in_bounds, + b.create<ma::CmpIOp>(ma::CmpIPredicate::sge, index, c0)); } - is_in_bounds = b.create<AndIOp>( - is_in_bounds, - b.create<CmpIOp>(CmpIPredicate::sge, extracted_index, c0)); - Value ub = b.create<ConstantIndexOp>( + Value ub = b.create<ma::ConstantIndexOp>( scatter_operand->shape().dimensions(i) - scatter_update->shape().dimensions(i + 1)); - is_in_bounds = b.create<AndIOp>( - is_in_bounds, - b.create<CmpIOp>(CmpIPredicate::sle, extracted_index, ub)); - indices[i] = b.create<AddIOp>(extracted_index, indices[i]); + in_bounds = b.create<ma::AndIOp>( + in_bounds, + b.create<ma::CmpIOp>(ma::CmpIPredicate::sle, index, ub)); + indices[i] = b.create<ma::AddIOp>(index, indices[i]); } // Call scatter's computation if is_in_bounds. Value output_tensor = output_tensors.front(); Value predicated_update = b.create<scf::IfOp>( - is_in_bounds, + in_bounds, [&](OpBuilder& then_builder, Location then_loc) -> void { Value updated_output = EmitScatterComputation( scatter, indices, update_elem, output_tensor, |