diff options
author | Peter Hawkins <phawkins@google.com> | 2022-08-22 08:15:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2022-08-22 08:18:43 -0700 |
commit | 971538e7eef7ba1b4f11b3549693f70ec8c2da27 (patch) | |
tree | 2318a65526d70929ac52782d7fd65c9b62691cfa | |
parent | e3233d514a89fbb70a69f696ef8770ccb2469d88 (diff) | |
download | tensorflow-971538e7eef7ba1b4f11b3549693f70ec8c2da27.tar.gz |
Break some dependencies of tensorflow/compiler/xla/mlir:mlir_hlo_to_hlo on TensorFlow, since MHLO->HLO conversion is used by non-TensorFlow users.
* Change the type signature of ConvertMlirHloToHlo() to avoid TensorFlow types in its shape representation function. Add XLA-typed variants of these instead. Clone the relevant parts of tf2xla/layout_util that this conversion depends on and change them to work with XLA types instead of TensorFlow types.
* Move a utility function in type_to_shape into its only caller in TensorFlow.
PiperOrigin-RevId: 469189741
-rw-r--r-- | tensorflow/compiler/mlir/tensorflow/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc | 54 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/layout_util.cc | 111 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/layout_util.h | 84 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc | 74 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h | 18 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/type_to_shape.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/type_to_shape.h | 13 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/type_to_shape_test.cc | 39 | ||||
-rw-r--r-- | tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_helpers.h | 23 | ||||
-rw-r--r-- | tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc | 5 |
15 files changed, 299 insertions, 175 deletions
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index e030a83d6a0..8ea5f5383c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -2101,6 +2101,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "//tensorflow/compiler/xla/mlir_hlo:sink_constants_to_control_flow", + "//tensorflow/compiler/mlir/xla:layout_util", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:tf_xla_passes", "//tensorflow/compiler/mlir/xla:xla_passes", @@ -2116,6 +2117,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core/platform:errors", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_internal", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 6c9e9308f79..5971bc2cde6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -51,12 +51,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/xla/layout_util.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/adjust_layout.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" @@ -66,6 +68,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/error_payloads.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" @@ -185,12 +188,17 @@ Status GetOutputInfo( xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs, std::vector<XlaResourceUpdate>* resource_updates) { auto shape_representation_fn_no_fast_memory = - [shape_determination_fns](const TensorShape& shape, DataType dtype) { - auto layout_preference = shape_determination_fns.layout_preference_fn( - shape, dtype, std::nullopt); - return shape_determination_fns.shape_representation_fn( - shape, dtype, /*use_fast_memory=*/false, layout_preference); - }; + [shape_determination_fns]( + const xla::Shape& xla_shape) -> StatusOr<xla::Shape> { + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape.element_type())); + auto layout_preference = shape_determination_fns.layout_preference_fn( + shape, dtype, std::nullopt); + return shape_determination_fns.shape_representation_fn( + shape, dtype, /*use_fast_memory=*/false, layout_preference); + }; mlir::func::FuncOp main_func = module.lookupSymbol<mlir::func::FuncOp>("main"); @@ -233,9 +241,12 @@ Status GetOutputInfo( } } } - TF_ASSIGN_OR_RETURN( - xla::Shape shape, - xla::TypeToShape(buffer_ty, shape_representation_fn_no_fast_memory)); + + xla::Shape shape = xla::TypeToShape(buffer_ty); + if (shape.element_type() == xla::PRIMITIVE_TYPE_INVALID) { + return errors::InvalidArgument("XLA conversion failed for MLIR type."); + } + TF_ASSIGN_OR_RETURN(shape, shape_representation_fn_no_fast_memory(shape)); if (!result_ty.hasStaticShape()) { int64_t rank = result_ty.getRank(); @@ -512,10 +523,29 @@ Status ConvertMLIRToXlaComputation( TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type, prefer_tf2xla, custom_legalization_passes)); + mlir::MlirToHloConversionOptions options; + options.layout_preference_fn = + [&](const xla::Shape& xla_shape) -> StatusOr<mlir::XlaLayoutPreference> { + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape.element_type())); + return shape_determination_fns.layout_preference_fn(shape, dtype, + std::nullopt); + }; + options.shape_representation_fn = + [&](const xla::Shape& xla_shape, bool fast_mem, + mlir::XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> { + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape.element_type())); + return shape_determination_fns.shape_representation_fn( + shape, dtype, fast_mem, layout_preference); + }; xla::HloProto hlo_proto; - TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto, - use_tuple_args, return_tuple, - shape_determination_fns)); + TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo( + module_op, &hlo_proto, use_tuple_args, return_tuple, options)); *xla_computation = xla::XlaComputation(hlo_proto.hlo_module()); return OkStatus(); } diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 59f2463fc73..182df21da24 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -465,8 +465,6 @@ cc_library( srcs = ["type_to_shape.cc"], hdrs = ["type_to_shape.h"], deps = [ - "//tensorflow/compiler/mlir/tensorflow:convert_tensor", - "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -497,6 +495,22 @@ tf_cc_test( ) cc_library( + name = "layout_util", + srcs = ["layout_util.cc"], + hdrs = ["layout_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( name = "mlir_hlo_to_hlo", srcs = [ "mlir_hlo_to_hlo.cc", @@ -505,14 +519,11 @@ cc_library( hdrs = ["mlir_hlo_to_hlo.h"], deps = [ ":attribute_exporter", + ":layout_util", ":type_to_shape", ":xla_passes", "//tensorflow/compiler/mlir:name_utils", - "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:layout_util", - "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/mlir/xla/layout_util.cc b/tensorflow/compiler/mlir/xla/layout_util.cc new file mode 100644 index 00000000000..f240ac19826 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/layout_util.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/layout_util.h" + +namespace mlir { + +// Rewrites the layout of xla_shape if there is tiled sharding. +xla::Status RewriteLayoutWithShardedShape( + const std::optional<xla::HloSharding>& sharding, bool use_fast_memory, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + xla::Shape* xla_shape) { + if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) { + // After sharding, per core shape might have different layout. For example, + // before sharding, a shape [128, 128] will be assigned default + // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, + // the sharded shapes will have minor-to-major {0, 1}. + // + // As a result, for sharded shapes, we set their layout to per core shape's + // layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64_t device = *sharding->tile_assignment().begin(); + std::vector<int64_t> offset = + sharding->TileOffsetForDevice(*xla_shape, device); + std::vector<int64_t> limit = + sharding->TileLimitForDevice(*xla_shape, device); + std::vector<int64_t> dimensions(xla_shape->rank()); + for (int64_t i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TF_ASSIGN_OR_RETURN(auto layout_preference, + layout_preference_fn + ? layout_preference_fn(per_device_xla_shape) + : XlaLayoutPreference::kNoPreference); + TF_ASSIGN_OR_RETURN( + per_device_xla_shape, + shape_representation_fn + ? shape_representation_fn(per_device_xla_shape, use_fast_memory, + layout_preference) + : per_device_xla_shape); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return xla::Status::OK(); +} + +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + std::optional<xla::OpSharding> sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector<xla::XlaOp> elements; + for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN( + auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), layout_preference_fn, + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TF_ASSIGN_OR_RETURN(auto layout_preference, + layout_preference_fn + ? layout_preference_fn(original_shape) + : XlaLayoutPreference::kNoPreference); + TF_ASSIGN_OR_RETURN( + auto to_shape, + shape_representation_fn + ? shape_representation_fn(original_shape, fast_mem, layout_preference) + : original_shape); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, layout_preference_fn, shape_representation_fn, + &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64_t i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/layout_util.h b/tensorflow/compiler/mlir/xla/layout_util.h new file mode 100644 index 00000000000..8828074d40a --- /dev/null +++ b/tensorflow/compiler/mlir/xla/layout_util.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA layout and shapes. + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_LAYOUT_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_LAYOUT_UTIL_H_ + +#include <functional> +#include <vector> + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace mlir { + +// XLA Layout preferences. Currently, when it comes to TPU, there are two +// primary layout choices for any XLA argumetns (parameter or resource): (1) +// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU +// layout while Linear is native host (CPU) layout. +// This enum allows the caller of XLA to progogate layout preference to the XLA +// compiler. +// kNoPreference: the generic layout where the XLA compiler has the freedom +// to assign any layout. +// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU. +// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may +// insert transformation TPU kernels. +// As the layout of any argument will change from a native host layout to a +// native TPU layout either on host or on device, XLA compiler and TPU runtime +// must be in coordination to transform the parameters in a consistent way. +enum class XlaLayoutPreference { + kNoPreference = 0, + kTpuPreferCompactChunkPaddedLayout = 1, + kTpuPreferLinearLayout = 2 +}; + +// The following defines the layout preference of an xla tensor. +// The return value of LayoutPreferenceFn can be used in +// ShapeRepresentationFn. +typedef std::function<xla::StatusOr<XlaLayoutPreference>( + const xla::Shape& shape)> + LayoutPreferenceFn; + +typedef std::function<xla::StatusOr<xla::Shape>( + const xla::Shape& shape, bool fast_mem, + XlaLayoutPreference layout_preference)> + ShapeRepresentationFn; + +// Return a LayoutPreferenceFn that always uses kNoPreference layout. +LayoutPreferenceFn UseNoPreferenceLayoutFn(); + +// Rewrites the layout of xla_shape if there is tiled sharding. +xla::Status RewriteLayoutWithShardedShape( + const std::optional<xla::HloSharding>& sharding, bool use_fast_memory, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + xla::Shape* xla_shape); + +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + std::optional<xla::OpSharding> sharding, bool fast_mem); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 6aad9aa3283..26aa184e91e 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -47,13 +47,10 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/attribute_exporter.h" #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" -#include "tensorflow/compiler/tf2xla/layout_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/quantize.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -569,17 +566,14 @@ class ConvertToHloModule { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. - explicit ConvertToHloModule( - mlir::ModuleOp module, xla::XlaBuilder& module_builder, - bool use_tuple_args, bool return_tuple, - tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns - shape_determination_fns, - MlirToHloConversionOptions options) + explicit ConvertToHloModule(mlir::ModuleOp module, + xla::XlaBuilder& module_builder, + bool use_tuple_args, bool return_tuple, + MlirToHloConversionOptions options) : module_(module), module_builder_(module_builder), use_tuple_args_(use_tuple_args), return_tuple_(return_tuple), - shape_determination_fns_(shape_determination_fns), options_(options) {} // Perform the lowering to XLA. This function returns failure if an error was @@ -679,11 +673,6 @@ class ConvertToHloModule { // Whether to always return a tuple. bool return_tuple_; - // Shape determination functions to determine entry function argument and - // result shapes. - tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns - shape_determination_fns_; - // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; @@ -2187,8 +2176,9 @@ LogicalResult ConvertToHloModule::Lower( xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); StatusOr<xla::XlaOp> reshape = - tensorflow::ReshapeWithCorrectRepresentationAndSharding( - builder, returns[index], return_shape, shape_determination_fns_, + ReshapeWithCorrectRepresentationAndSharding( + builder, returns[index], return_shape, + options_.layout_preference_fn, options_.shape_representation_fn, ret_shardings[index], /*fast_mem=*/false); if (!reshape.ok()) return inst->emitError() << reshape.status().error_message(); @@ -2336,24 +2326,18 @@ LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication( for (BlockArgument& arg : block->getArguments()) { arg_shapes->push_back(xla::TypeToShape(arg.getType())); xla::Shape& arg_shape = arg_shapes->back(); - tensorflow::TensorShape arg_tensor_shape; - auto status = - tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape); - if (!status.ok()) - return block->getParentOp()->emitError() << status.error_message(); - - tensorflow::DataType arg_dtype; - status = tensorflow::ConvertToDataType(arg.getType(), &arg_dtype); - if (!status.ok()) - return block->getParentOp()->emitError() << status.error_message(); - - CHECK(shape_determination_fns_.layout_preference_fn && // Crash OK - shape_determination_fns_.shape_representation_fn); - auto layout_preference = shape_determination_fns_.layout_preference_fn( - arg_tensor_shape, arg_dtype, std::nullopt); - auto arg_shape_status = shape_determination_fns_.shape_representation_fn( - arg_tensor_shape, arg_dtype, /*use_fast_memory=*/false, - layout_preference); + auto layout_preference_status = + options_.layout_preference_fn ? options_.layout_preference_fn(arg_shape) + : XlaLayoutPreference::kNoPreference; + if (!layout_preference_status.ok()) + return block->getParentOp()->emitError() + << layout_preference_status.status().error_message(); + + auto arg_shape_status = options_.shape_representation_fn + ? options_.shape_representation_fn( + arg_shape, /*use_fast_memory=*/false, + layout_preference_status.ValueOrDie()) + : arg_shape; if (!arg_shape_status.ok()) return block->getParentOp()->emitError() << arg_shape_status.status().error_message(); @@ -2382,9 +2366,10 @@ LogicalResult ConvertToHloModule::SetEntryTupleShardings( return block->getParentOp()->emitError() << hlo_sharding.status().error_message(); - auto status = tensorflow::RewriteLayoutWithShardedShape( + auto status = RewriteLayoutWithShardedShape( hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false, - shape_determination_fns_, &(*arg_shapes)[arg_sharding.index()]); + options_.layout_preference_fn, options_.shape_representation_fn, + &(*arg_shapes)[arg_sharding.index()]); if (!status.ok()) return block->getParentOp()->emitError() << status.error_message(); @@ -2568,24 +2553,21 @@ Status ConvertRegionToComputation(mlir::Region* region, MlirToHloConversionOptions options) { mlir::ModuleOp module; xla::XlaBuilder module_builder("main"); - ConvertToHloModule converter(module, module_builder, true, true, {}, options); + ConvertToHloModule converter(module, module_builder, true, true, options); if (failed(converter.LowerRegionAsComputation(region, func))) return tensorflow::errors::Internal( "failed to convert region to computation"); return ::tensorflow::OkStatus(); } -Status ConvertMlirHloToHlo( - mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, - bool return_tuple, - const tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns - shape_determination_fns, - MlirToHloConversionOptions options) { +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args, bool return_tuple, + MlirToHloConversionOptions options) { TF_RETURN_IF_ERROR(PrepareForExport(module)); mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); xla::XlaBuilder module_builder("main"); ConvertToHloModule converter(module, module_builder, use_tuple_args, - return_tuple, shape_determination_fns, options); + return_tuple, options); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); StringRef module_name = module.getName() ? *module.getName() : "main"; @@ -2602,7 +2584,7 @@ Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, TF_RETURN_IF_ERROR(PrepareForExport(module)); ConvertToHloModule converter(module, builder, /*use_tuple_args=*/false, /*return_tuple=*/false, - /*shape_determination_fns=*/{}, options); + options); ConvertToHloModule::ValueLoweringMap lowering; // xla_params should only include non-constant parameters the block arguments diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index c781a739d1c..4a29a84ccf2 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -19,11 +19,9 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/tf2xla/layout_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/mlir/xla/layout_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/framework/tensor_shape.h" namespace mlir { @@ -45,6 +43,9 @@ struct MlirToHloConversionOptions { // Legalize names to be compatible with TensorFlow. bool legalize_node_names = true; + + LayoutPreferenceFn layout_preference_fn; + ShapeRepresentationFn shape_representation_fn; }; // Converts a MLIR module in HLO dialect into a HloModuleProto. If @@ -54,14 +55,9 @@ struct MlirToHloConversionOptions { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. -// -// TODO(timshen): move other options into `options`. -Status ConvertMlirHloToHlo( - mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, - bool return_tuple, - const tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns - shape_determination_fns = {}, - MlirToHloConversionOptions options = {}); +Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, + bool use_tuple_args, bool return_tuple, + MlirToHloConversionOptions options = {}); // Transforms a Block into HLO, where the HLO is represented as calls into an // XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index b8f53f30517..8423f354e62 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -202,8 +202,7 @@ class XlaHloToLhloPass TF_RETURN_WITH_CONTEXT_IF_ERROR( ConvertMlirHloToHlo(module, &hlo_proto, /*use_tuple_args=*/false, - /*return_tuple=*/false, - /*shape_determination_fns=*/{}), + /*return_tuple=*/false), "conversion to XLA HLO proto failed"); auto statusOrHloModule = HloModuleFromProto(hlo_proto); diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 834cf94b849..45eef142156 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -23,13 +23,10 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -83,23 +80,6 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { return PrimitiveType::PRIMITIVE_TYPE_INVALID; } -StatusOr<Shape> TypeToShape( - mlir::Type type, CustomShapeRepresentationFn shape_representation_fn) { - tensorflow::PartialTensorShape partial_tensor_shape = - tensorflow::ConvertTypeToTensorShape(type); - - tensorflow::TensorShape fully_defined_tensor_shape; - if (!partial_tensor_shape.AsTensorShape(&fully_defined_tensor_shape)) { - return tensorflow::errors::InvalidArgument( - "XLA HLO only allows fully-defined shape"); - } - - tensorflow::DataType dtype; - TF_RETURN_IF_ERROR(tensorflow::ConvertToDataType(type, &dtype)); - - return shape_representation_fn(fully_defined_tensor_shape, dtype); -} - Shape TypeToShape(mlir::Type type) { PrimitiveType ptype = TypeToPrimitiveType(type); if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.h b/tensorflow/compiler/mlir/xla/type_to_shape.h index 647fb56bb26..59482731254 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.h +++ b/tensorflow/compiler/mlir/xla/type_to_shape.h @@ -20,25 +20,12 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" namespace xla { // Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. Shape TypeToShape(mlir::Type type); -// Type of a custom function that converts a TensorFlow type and shape into an -// XLA shape with optional layout info. -typedef llvm::function_ref<xla::StatusOr<xla::Shape>( - const tensorflow::TensorShape&, tensorflow::DataType)> - CustomShapeRepresentationFn; - -// Compute an XLA shape based in given MLIR type and an -// CustomShapeRepresentationFn, which allows setting custom layout in returned -// XLA shape. -StatusOr<Shape> TypeToShape( - mlir::Type type, CustomShapeRepresentationFn shape_representation_fn); - // Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a // primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID. PrimitiveType TypeToPrimitiveType(mlir::Type type); diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index 9f70fe0e268..0c9af79e2de 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -151,45 +151,6 @@ TEST(TypeToShapeTest, ConvertTensorTypeToTypes) { EqualsProto(Shape().ToProto())); } -TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) { - tensorflow::DataType captured_dtype; - tensorflow::TensorShape captured_tensor_shape; - - // A dummy shape representation function that does nothing other than - // capturing arguments passed to it. - auto test_shape_representation_fn = [&](const tensorflow::TensorShape& shape, - tensorflow::DataType dtype) { - captured_tensor_shape = shape; - captured_dtype = dtype; - return xla::Shape(); - }; - - MLIRContext context; - Builder b(&context); - StatusOr<Shape> status_or_shape; - - // Non-fully-defined shape. - status_or_shape = - TypeToShape(RankedTensorType::get({-1, 2, 3}, b.getF32Type()), - test_shape_representation_fn); - EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(status_or_shape.status())); - - // Scalar Int32 Tensor, using fast memory. - status_or_shape = - TypeToShape(b.getIntegerType(32), test_shape_representation_fn); - EXPECT_TRUE(status_or_shape.ok()); - EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_INT32); - EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape()); - - // Ranked Float32 Tensor, not using fast memory. - status_or_shape = - TypeToShape(RankedTensorType::get({1, 2, 3}, b.getF32Type()), - test_shape_representation_fn); - EXPECT_TRUE(status_or_shape.ok()); - EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_FLOAT); - EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3})); -} - TEST(TypeToShapeTest, ConvertMemRefToShape) { Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30}, {2, 0, 1}); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc index d38ea2c390c..b534cd3179b 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc @@ -149,8 +149,7 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( via_builder ? ConvertMlirHloToHloViaBuilder(module, &hloProto, options) : mlir::ConvertMlirHloToHlo(module, &hloProto, emit_use_tuple_arg, - emit_return_tuple, - /*shape_determination_fns=*/{}, options); + emit_return_tuple, options); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5c9a4e3c592..8e50f2c3af4 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -436,6 +436,7 @@ cc_library( "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/mlir:array_container_utils", + "//tensorflow/compiler/mlir/xla:layout_util", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/xla/client:value_inference", "//tensorflow/compiler/xla/service:computation_placer_hdr", @@ -610,6 +611,7 @@ cc_library( deps = [ ":common", ":host_compute_metadata_proto_cc", + "//tensorflow/compiler/mlir/xla:layout_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 847ca0bac7e..400dd96b8dd 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" +#include "tensorflow/compiler/mlir/xla/layout_util.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" @@ -30,25 +31,7 @@ limitations under the License. namespace tensorflow { -// XLA Layout preferences. Currently, when it comes to TPU, there are two -// primary layout choices for any XLA argumetns (parameter or resource): (1) -// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU -// layout while Linear is native host (CPU) layout. -// This enum allows the caller of XLA to progogate layout preference to the XLA -// compiler. -// kNoPreference: the generic layout where the XLA compiler has the freedom -// to assign any layout. -// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU. -// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may -// insert transformation TPU kernels. -// As the layout of any argument will change from a native host layout to a -// native TPU layout either on host or on device, XLA compiler and TPU runtime -// must be in coordination to transform the parameters in a consistent way. -enum class XlaLayoutPreference { - kNoPreference = 0, - kTpuPreferCompactChunkPaddedLayout = 1, - kTpuPreferLinearLayout = 2 -}; +using XlaLayoutPreference = mlir::XlaLayoutPreference; // Helper methods for building XLA computations. class XlaHelpers { @@ -106,8 +89,6 @@ class XlaHelpers { // Creates an identity shape representation function. XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn(); - - struct XlaOutputDescription { // Type and shape of the output. The shape is the unflattened shape. // When `type` is DT_RESOURCE, `shape` is the shape of the resource diff --git a/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc b/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc index bd7b19be405..35aa5b83a0e 100644 --- a/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc +++ b/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc @@ -61,9 +61,8 @@ Status MlirToXlaComputation(mlir::ModuleOp module, mlir::MlirToHloConversionOptions options; // We don't want the conversion to muck with our operator names. options.legalize_node_names = false; - TF_RETURN_IF_ERROR( - ConvertMlirHloToHlo(module, &proto, use_tuple_args, return_tuple, - /*shape_determination_fns=*/{}, options)); + TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &proto, use_tuple_args, + return_tuple, options)); xla_computation = XlaComputation(std::move(*proto.mutable_hlo_module())); return OkStatus(); |