aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Hawkins <phawkins@google.com>2022-08-22 08:15:09 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2022-08-22 08:18:43 -0700
commit971538e7eef7ba1b4f11b3549693f70ec8c2da27 (patch)
tree2318a65526d70929ac52782d7fd65c9b62691cfa
parente3233d514a89fbb70a69f696ef8770ccb2469d88 (diff)
downloadtensorflow-971538e7eef7ba1b4f11b3549693f70ec8c2da27.tar.gz
Break some dependencies of tensorflow/compiler/xla/mlir:mlir_hlo_to_hlo on TensorFlow, since MHLO->HLO conversion is used by non-TensorFlow users.
* Change the type signature of ConvertMlirHloToHlo() to avoid TensorFlow types in its shape representation function. Add XLA-typed variants of these instead. Clone the relevant parts of tf2xla/layout_util that this conversion depends on and change them to work with XLA types instead of TensorFlow types. * Move a utility function in type_to_shape into its only caller in TensorFlow. PiperOrigin-RevId: 469189741
-rw-r--r--tensorflow/compiler/mlir/tensorflow/BUILD2
-rw-r--r--tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc54
-rw-r--r--tensorflow/compiler/mlir/xla/BUILD23
-rw-r--r--tensorflow/compiler/mlir/xla/layout_util.cc111
-rw-r--r--tensorflow/compiler/mlir/xla/layout_util.h84
-rw-r--r--tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc74
-rw-r--r--tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h18
-rw-r--r--tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc3
-rw-r--r--tensorflow/compiler/mlir/xla/type_to_shape.cc20
-rw-r--r--tensorflow/compiler/mlir/xla/type_to_shape.h13
-rw-r--r--tensorflow/compiler/mlir/xla/type_to_shape_test.cc39
-rw-r--r--tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc3
-rw-r--r--tensorflow/compiler/tf2xla/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h23
-rw-r--r--tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc5
15 files changed, 299 insertions, 175 deletions
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index e030a83d6a0..8ea5f5383c0 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -2101,6 +2101,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/xla/mlir_hlo",
"//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration",
"//tensorflow/compiler/xla/mlir_hlo:sink_constants_to_control_flow",
+ "//tensorflow/compiler/mlir/xla:layout_util",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:tf_xla_passes",
"//tensorflow/compiler/mlir/xla:xla_passes",
@@ -2116,6 +2117,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core/platform:errors",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:core_cpu_internal",
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 6c9e9308f79..5971bc2cde6 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -51,12 +51,14 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
+#include "tensorflow/compiler/mlir/xla/layout_util.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/transforms/adjust_layout.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/tf2xla/layout_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
@@ -66,6 +68,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/error_payloads.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
#include "tensorflow/core/tpu/tpu_defs.h"
@@ -185,12 +188,17 @@ Status GetOutputInfo(
xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
std::vector<XlaResourceUpdate>* resource_updates) {
auto shape_representation_fn_no_fast_memory =
- [shape_determination_fns](const TensorShape& shape, DataType dtype) {
- auto layout_preference = shape_determination_fns.layout_preference_fn(
- shape, dtype, std::nullopt);
- return shape_determination_fns.shape_representation_fn(
- shape, dtype, /*use_fast_memory=*/false, layout_preference);
- };
+ [shape_determination_fns](
+ const xla::Shape& xla_shape) -> StatusOr<xla::Shape> {
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
+ TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
+ xla_shape.element_type()));
+ auto layout_preference = shape_determination_fns.layout_preference_fn(
+ shape, dtype, std::nullopt);
+ return shape_determination_fns.shape_representation_fn(
+ shape, dtype, /*use_fast_memory=*/false, layout_preference);
+ };
mlir::func::FuncOp main_func =
module.lookupSymbol<mlir::func::FuncOp>("main");
@@ -233,9 +241,12 @@ Status GetOutputInfo(
}
}
}
- TF_ASSIGN_OR_RETURN(
- xla::Shape shape,
- xla::TypeToShape(buffer_ty, shape_representation_fn_no_fast_memory));
+
+ xla::Shape shape = xla::TypeToShape(buffer_ty);
+ if (shape.element_type() == xla::PRIMITIVE_TYPE_INVALID) {
+ return errors::InvalidArgument("XLA conversion failed for MLIR type.");
+ }
+ TF_ASSIGN_OR_RETURN(shape, shape_representation_fn_no_fast_memory(shape));
if (!result_ty.hasStaticShape()) {
int64_t rank = result_ty.getRank();
@@ -512,10 +523,29 @@ Status ConvertMLIRToXlaComputation(
TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type, prefer_tf2xla,
custom_legalization_passes));
+ mlir::MlirToHloConversionOptions options;
+ options.layout_preference_fn =
+ [&](const xla::Shape& xla_shape) -> StatusOr<mlir::XlaLayoutPreference> {
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
+ TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
+ xla_shape.element_type()));
+ return shape_determination_fns.layout_preference_fn(shape, dtype,
+ std::nullopt);
+ };
+ options.shape_representation_fn =
+ [&](const xla::Shape& xla_shape, bool fast_mem,
+ mlir::XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
+ TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
+ xla_shape.element_type()));
+ return shape_determination_fns.shape_representation_fn(
+ shape, dtype, fast_mem, layout_preference);
+ };
xla::HloProto hlo_proto;
- TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
- use_tuple_args, return_tuple,
- shape_determination_fns));
+ TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(
+ module_op, &hlo_proto, use_tuple_args, return_tuple, options));
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
return OkStatus();
}
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 59f2463fc73..182df21da24 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -465,8 +465,6 @@ cc_library(
srcs = ["type_to_shape.cc"],
hdrs = ["type_to_shape.h"],
deps = [
- "//tensorflow/compiler/mlir/tensorflow:convert_tensor",
- "//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
@@ -497,6 +495,22 @@ tf_cc_test(
)
cc_library(
+ name = "layout_util",
+ srcs = ["layout_util.cc"],
+ hdrs = ["layout_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "mlir_hlo_to_hlo",
srcs = [
"mlir_hlo_to_hlo.cc",
@@ -505,14 +519,11 @@ cc_library(
hdrs = ["mlir_hlo_to_hlo.h"],
deps = [
":attribute_exporter",
+ ":layout_util",
":type_to_shape",
":xla_passes",
"//tensorflow/compiler/mlir:name_utils",
- "//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:error_util",
- "//tensorflow/compiler/tf2xla:common",
- "//tensorflow/compiler/tf2xla:layout_util",
- "//tensorflow/compiler/tf2xla:xla_helpers",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/mlir/xla/layout_util.cc b/tensorflow/compiler/mlir/xla/layout_util.cc
new file mode 100644
index 00000000000..f240ac19826
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/layout_util.cc
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/xla/layout_util.h"
+
+namespace mlir {
+
+// Rewrites the layout of xla_shape if there is tiled sharding.
+xla::Status RewriteLayoutWithShardedShape(
+ const std::optional<xla::HloSharding>& sharding, bool use_fast_memory,
+ const LayoutPreferenceFn& layout_preference_fn,
+ const ShapeRepresentationFn& shape_representation_fn,
+ xla::Shape* xla_shape) {
+ if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
+ // After sharding, per core shape might have different layout. For example,
+ // before sharding, a shape [128, 128] will be assigned default
+ // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
+ // the sharded shapes will have minor-to-major {0, 1}.
+ //
+ // As a result, for sharded shapes, we set their layout to per core shape's
+ // layout.
+ //
+ // TODO(endlessroad): for variable input & update, we might have
+ // different layouts which will prevent input output aliasing and
+ // increase memory usage. Investigate such cases.
+ int64_t device = *sharding->tile_assignment().begin();
+ std::vector<int64_t> offset =
+ sharding->TileOffsetForDevice(*xla_shape, device);
+ std::vector<int64_t> limit =
+ sharding->TileLimitForDevice(*xla_shape, device);
+ std::vector<int64_t> dimensions(xla_shape->rank());
+ for (int64_t i = 0; i < xla_shape->rank(); ++i) {
+ dimensions[i] = limit[i] - offset[i];
+ }
+ xla::Shape per_device_xla_shape =
+ xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
+ TF_ASSIGN_OR_RETURN(auto layout_preference,
+ layout_preference_fn
+ ? layout_preference_fn(per_device_xla_shape)
+ : XlaLayoutPreference::kNoPreference);
+ TF_ASSIGN_OR_RETURN(
+ per_device_xla_shape,
+ shape_representation_fn
+ ? shape_representation_fn(per_device_xla_shape, use_fast_memory,
+ layout_preference)
+ : per_device_xla_shape);
+ *xla_shape->mutable_layout() = per_device_xla_shape.layout();
+ }
+ return xla::Status::OK();
+}
+
+// There is a shape_representation_fn or sharding for an output, this function
+// uses a reshape to fix the layout.
+xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
+ xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
+ const LayoutPreferenceFn& layout_preference_fn,
+ const ShapeRepresentationFn& shape_representation_fn,
+ std::optional<xla::OpSharding> sharding, bool fast_mem) {
+ if (original_shape.IsTuple()) {
+ std::vector<xla::XlaOp> elements;
+ for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) {
+ auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
+ TF_ASSIGN_OR_RETURN(
+ auto element,
+ ReshapeWithCorrectRepresentationAndSharding(
+ builder, xla::GetTupleElement(original, i),
+ original_shape.tuple_shapes(i), layout_preference_fn,
+ shape_representation_fn, subsharding, fast_mem));
+ elements.push_back(element);
+ }
+ return xla::Tuple(builder, elements);
+ }
+ if (!original_shape.IsArray()) return original;
+ TF_ASSIGN_OR_RETURN(auto layout_preference,
+ layout_preference_fn
+ ? layout_preference_fn(original_shape)
+ : XlaLayoutPreference::kNoPreference);
+ TF_ASSIGN_OR_RETURN(
+ auto to_shape,
+ shape_representation_fn
+ ? shape_representation_fn(original_shape, fast_mem, layout_preference)
+ : original_shape);
+ if (sharding) {
+ TF_ASSIGN_OR_RETURN(auto hlo_sharding,
+ xla::HloSharding::FromProto(*sharding));
+
+ TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
+ hlo_sharding, fast_mem, layout_preference_fn, shape_representation_fn,
+ &to_shape));
+ }
+ if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
+ for (int64_t i = 0; i < original_shape.rank(); ++i) {
+ to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
+ }
+ }
+ return xla::Reshape(to_shape, original);
+}
+
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/layout_util.h b/tensorflow/compiler/mlir/xla/layout_util.h
new file mode 100644
index 00000000000..8828074d40a
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/layout_util.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utilities for working with XLA layout and shapes.
+
+#ifndef TENSORFLOW_COMPILER_MLIR_XLA_LAYOUT_UTIL_H_
+#define TENSORFLOW_COMPILER_MLIR_XLA_LAYOUT_UTIL_H_
+
+#include <functional>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace mlir {
+
+// XLA Layout preferences. Currently, when it comes to TPU, there are two
+// primary layout choices for any XLA argumetns (parameter or resource): (1)
+// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU
+// layout while Linear is native host (CPU) layout.
+// This enum allows the caller of XLA to progogate layout preference to the XLA
+// compiler.
+// kNoPreference: the generic layout where the XLA compiler has the freedom
+// to assign any layout.
+// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU.
+// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may
+// insert transformation TPU kernels.
+// As the layout of any argument will change from a native host layout to a
+// native TPU layout either on host or on device, XLA compiler and TPU runtime
+// must be in coordination to transform the parameters in a consistent way.
+enum class XlaLayoutPreference {
+ kNoPreference = 0,
+ kTpuPreferCompactChunkPaddedLayout = 1,
+ kTpuPreferLinearLayout = 2
+};
+
+// The following defines the layout preference of an xla tensor.
+// The return value of LayoutPreferenceFn can be used in
+// ShapeRepresentationFn.
+typedef std::function<xla::StatusOr<XlaLayoutPreference>(
+ const xla::Shape& shape)>
+ LayoutPreferenceFn;
+
+typedef std::function<xla::StatusOr<xla::Shape>(
+ const xla::Shape& shape, bool fast_mem,
+ XlaLayoutPreference layout_preference)>
+ ShapeRepresentationFn;
+
+// Return a LayoutPreferenceFn that always uses kNoPreference layout.
+LayoutPreferenceFn UseNoPreferenceLayoutFn();
+
+// Rewrites the layout of xla_shape if there is tiled sharding.
+xla::Status RewriteLayoutWithShardedShape(
+ const std::optional<xla::HloSharding>& sharding, bool use_fast_memory,
+ const LayoutPreferenceFn& layout_preference_fn,
+ const ShapeRepresentationFn& shape_representation_fn,
+ xla::Shape* xla_shape);
+
+// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
+// sharding is present.
+xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
+ xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
+ const LayoutPreferenceFn& layout_preference_fn,
+ const ShapeRepresentationFn& shape_representation_fn,
+ std::optional<xla::OpSharding> sharding, bool fast_mem);
+
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 6aad9aa3283..26aa184e91e 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -47,13 +47,10 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
-#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
-#include "tensorflow/compiler/tf2xla/layout_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
@@ -569,17 +566,14 @@ class ConvertToHloModule {
// are converted to a tuple even when there is only a single return value.
// Multiple return values are always converted to a tuple and returned as a
// single value.
- explicit ConvertToHloModule(
- mlir::ModuleOp module, xla::XlaBuilder& module_builder,
- bool use_tuple_args, bool return_tuple,
- tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns
- shape_determination_fns,
- MlirToHloConversionOptions options)
+ explicit ConvertToHloModule(mlir::ModuleOp module,
+ xla::XlaBuilder& module_builder,
+ bool use_tuple_args, bool return_tuple,
+ MlirToHloConversionOptions options)
: module_(module),
module_builder_(module_builder),
use_tuple_args_(use_tuple_args),
return_tuple_(return_tuple),
- shape_determination_fns_(shape_determination_fns),
options_(options) {}
// Perform the lowering to XLA. This function returns failure if an error was
@@ -679,11 +673,6 @@ class ConvertToHloModule {
// Whether to always return a tuple.
bool return_tuple_;
- // Shape determination functions to determine entry function argument and
- // result shapes.
- tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns
- shape_determination_fns_;
-
// Unique suffix to give to the name of the next lowered region.
size_t region_id_ = 0;
@@ -2187,8 +2176,9 @@ LogicalResult ConvertToHloModule::Lower(
xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
StatusOr<xla::XlaOp> reshape =
- tensorflow::ReshapeWithCorrectRepresentationAndSharding(
- builder, returns[index], return_shape, shape_determination_fns_,
+ ReshapeWithCorrectRepresentationAndSharding(
+ builder, returns[index], return_shape,
+ options_.layout_preference_fn, options_.shape_representation_fn,
ret_shardings[index], /*fast_mem=*/false);
if (!reshape.ok())
return inst->emitError() << reshape.status().error_message();
@@ -2336,24 +2326,18 @@ LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
for (BlockArgument& arg : block->getArguments()) {
arg_shapes->push_back(xla::TypeToShape(arg.getType()));
xla::Shape& arg_shape = arg_shapes->back();
- tensorflow::TensorShape arg_tensor_shape;
- auto status =
- tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
- if (!status.ok())
- return block->getParentOp()->emitError() << status.error_message();
-
- tensorflow::DataType arg_dtype;
- status = tensorflow::ConvertToDataType(arg.getType(), &arg_dtype);
- if (!status.ok())
- return block->getParentOp()->emitError() << status.error_message();
-
- CHECK(shape_determination_fns_.layout_preference_fn && // Crash OK
- shape_determination_fns_.shape_representation_fn);
- auto layout_preference = shape_determination_fns_.layout_preference_fn(
- arg_tensor_shape, arg_dtype, std::nullopt);
- auto arg_shape_status = shape_determination_fns_.shape_representation_fn(
- arg_tensor_shape, arg_dtype, /*use_fast_memory=*/false,
- layout_preference);
+ auto layout_preference_status =
+ options_.layout_preference_fn ? options_.layout_preference_fn(arg_shape)
+ : XlaLayoutPreference::kNoPreference;
+ if (!layout_preference_status.ok())
+ return block->getParentOp()->emitError()
+ << layout_preference_status.status().error_message();
+
+ auto arg_shape_status = options_.shape_representation_fn
+ ? options_.shape_representation_fn(
+ arg_shape, /*use_fast_memory=*/false,
+ layout_preference_status.ValueOrDie())
+ : arg_shape;
if (!arg_shape_status.ok())
return block->getParentOp()->emitError()
<< arg_shape_status.status().error_message();
@@ -2382,9 +2366,10 @@ LogicalResult ConvertToHloModule::SetEntryTupleShardings(
return block->getParentOp()->emitError()
<< hlo_sharding.status().error_message();
- auto status = tensorflow::RewriteLayoutWithShardedShape(
+ auto status = RewriteLayoutWithShardedShape(
hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
- shape_determination_fns_, &(*arg_shapes)[arg_sharding.index()]);
+ options_.layout_preference_fn, options_.shape_representation_fn,
+ &(*arg_shapes)[arg_sharding.index()]);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
@@ -2568,24 +2553,21 @@ Status ConvertRegionToComputation(mlir::Region* region,
MlirToHloConversionOptions options) {
mlir::ModuleOp module;
xla::XlaBuilder module_builder("main");
- ConvertToHloModule converter(module, module_builder, true, true, {}, options);
+ ConvertToHloModule converter(module, module_builder, true, true, options);
if (failed(converter.LowerRegionAsComputation(region, func)))
return tensorflow::errors::Internal(
"failed to convert region to computation");
return ::tensorflow::OkStatus();
}
-Status ConvertMlirHloToHlo(
- mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
- bool return_tuple,
- const tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns
- shape_determination_fns,
- MlirToHloConversionOptions options) {
+Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
+ bool use_tuple_args, bool return_tuple,
+ MlirToHloConversionOptions options) {
TF_RETURN_IF_ERROR(PrepareForExport(module));
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
xla::XlaBuilder module_builder("main");
ConvertToHloModule converter(module, module_builder, use_tuple_args,
- return_tuple, shape_determination_fns, options);
+ return_tuple, options);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
StringRef module_name = module.getName() ? *module.getName() : "main";
@@ -2602,7 +2584,7 @@ Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
TF_RETURN_IF_ERROR(PrepareForExport(module));
ConvertToHloModule converter(module, builder,
/*use_tuple_args=*/false, /*return_tuple=*/false,
- /*shape_determination_fns=*/{}, options);
+ options);
ConvertToHloModule::ValueLoweringMap lowering;
// xla_params should only include non-constant parameters the block arguments
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
index c781a739d1c..4a29a84ccf2 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
@@ -19,11 +19,9 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-#include "tensorflow/compiler/tf2xla/layout_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/mlir/xla/layout_util.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/core/framework/tensor_shape.h"
namespace mlir {
@@ -45,6 +43,9 @@ struct MlirToHloConversionOptions {
// Legalize names to be compatible with TensorFlow.
bool legalize_node_names = true;
+
+ LayoutPreferenceFn layout_preference_fn;
+ ShapeRepresentationFn shape_representation_fn;
};
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
@@ -54,14 +55,9 @@ struct MlirToHloConversionOptions {
// are converted to a tuple even when there is only a single return value.
// Multiple return values are always converted to a tuple and returned as a
// single value.
-//
-// TODO(timshen): move other options into `options`.
-Status ConvertMlirHloToHlo(
- mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args,
- bool return_tuple,
- const tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns
- shape_determination_fns = {},
- MlirToHloConversionOptions options = {});
+Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
+ bool use_tuple_args, bool return_tuple,
+ MlirToHloConversionOptions options = {});
// Transforms a Block into HLO, where the HLO is represented as calls into an
// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp.
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index b8f53f30517..8423f354e62 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -202,8 +202,7 @@ class XlaHloToLhloPass
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirHloToHlo(module, &hlo_proto,
/*use_tuple_args=*/false,
- /*return_tuple=*/false,
- /*shape_determination_fns=*/{}),
+ /*return_tuple=*/false),
"conversion to XLA HLO proto failed");
auto statusOrHloModule = HloModuleFromProto(hlo_proto);
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc
index 834cf94b849..45eef142156 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc
@@ -23,13 +23,10 @@ limitations under the License.
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/Support/DebugStringHelper.h" // from @llvm-project
-#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -83,23 +80,6 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) {
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
}
-StatusOr<Shape> TypeToShape(
- mlir::Type type, CustomShapeRepresentationFn shape_representation_fn) {
- tensorflow::PartialTensorShape partial_tensor_shape =
- tensorflow::ConvertTypeToTensorShape(type);
-
- tensorflow::TensorShape fully_defined_tensor_shape;
- if (!partial_tensor_shape.AsTensorShape(&fully_defined_tensor_shape)) {
- return tensorflow::errors::InvalidArgument(
- "XLA HLO only allows fully-defined shape");
- }
-
- tensorflow::DataType dtype;
- TF_RETURN_IF_ERROR(tensorflow::ConvertToDataType(type, &dtype));
-
- return shape_representation_fn(fully_defined_tensor_shape, dtype);
-}
-
Shape TypeToShape(mlir::Type type) {
PrimitiveType ptype = TypeToPrimitiveType(type);
if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.h b/tensorflow/compiler/mlir/xla/type_to_shape.h
index 647fb56bb26..59482731254 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape.h
+++ b/tensorflow/compiler/mlir/xla/type_to_shape.h
@@ -20,25 +20,12 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/tensor_shape.h"
namespace xla {
// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape.
Shape TypeToShape(mlir::Type type);
-// Type of a custom function that converts a TensorFlow type and shape into an
-// XLA shape with optional layout info.
-typedef llvm::function_ref<xla::StatusOr<xla::Shape>(
- const tensorflow::TensorShape&, tensorflow::DataType)>
- CustomShapeRepresentationFn;
-
-// Compute an XLA shape based in given MLIR type and an
-// CustomShapeRepresentationFn, which allows setting custom layout in returned
-// XLA shape.
-StatusOr<Shape> TypeToShape(
- mlir::Type type, CustomShapeRepresentationFn shape_representation_fn);
-
// Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a
// primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID.
PrimitiveType TypeToPrimitiveType(mlir::Type type);
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
index 9f70fe0e268..0c9af79e2de 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
@@ -151,45 +151,6 @@ TEST(TypeToShapeTest, ConvertTensorTypeToTypes) {
EqualsProto(Shape().ToProto()));
}
-TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) {
- tensorflow::DataType captured_dtype;
- tensorflow::TensorShape captured_tensor_shape;
-
- // A dummy shape representation function that does nothing other than
- // capturing arguments passed to it.
- auto test_shape_representation_fn = [&](const tensorflow::TensorShape& shape,
- tensorflow::DataType dtype) {
- captured_tensor_shape = shape;
- captured_dtype = dtype;
- return xla::Shape();
- };
-
- MLIRContext context;
- Builder b(&context);
- StatusOr<Shape> status_or_shape;
-
- // Non-fully-defined shape.
- status_or_shape =
- TypeToShape(RankedTensorType::get({-1, 2, 3}, b.getF32Type()),
- test_shape_representation_fn);
- EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(status_or_shape.status()));
-
- // Scalar Int32 Tensor, using fast memory.
- status_or_shape =
- TypeToShape(b.getIntegerType(32), test_shape_representation_fn);
- EXPECT_TRUE(status_or_shape.ok());
- EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_INT32);
- EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape());
-
- // Ranked Float32 Tensor, not using fast memory.
- status_or_shape =
- TypeToShape(RankedTensorType::get({1, 2, 3}, b.getF32Type()),
- test_shape_representation_fn);
- EXPECT_TRUE(status_or_shape.ok());
- EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_FLOAT);
- EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3}));
-}
-
TEST(TypeToShapeTest, ConvertMemRefToShape) {
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30},
{2, 0, 1});
diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
index d38ea2c390c..b534cd3179b 100644
--- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
+++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
@@ -149,8 +149,7 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
via_builder
? ConvertMlirHloToHloViaBuilder(module, &hloProto, options)
: mlir::ConvertMlirHloToHlo(module, &hloProto, emit_use_tuple_arg,
- emit_return_tuple,
- /*shape_determination_fns=*/{}, options);
+ emit_return_tuple, options);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 5c9a4e3c592..8e50f2c3af4 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -436,6 +436,7 @@ cc_library(
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/mlir:array_container_utils",
+ "//tensorflow/compiler/mlir/xla:layout_util",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/xla/client:value_inference",
"//tensorflow/compiler/xla/service:computation_placer_hdr",
@@ -610,6 +611,7 @@ cc_library(
deps = [
":common",
":host_compute_metadata_proto_cc",
+ "//tensorflow/compiler/mlir/xla:layout_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index 847ca0bac7e..400dd96b8dd 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/span.h"
+#include "tensorflow/compiler/mlir/xla/layout_util.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
@@ -30,25 +31,7 @@ limitations under the License.
namespace tensorflow {
-// XLA Layout preferences. Currently, when it comes to TPU, there are two
-// primary layout choices for any XLA argumetns (parameter or resource): (1)
-// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU
-// layout while Linear is native host (CPU) layout.
-// This enum allows the caller of XLA to progogate layout preference to the XLA
-// compiler.
-// kNoPreference: the generic layout where the XLA compiler has the freedom
-// to assign any layout.
-// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU.
-// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may
-// insert transformation TPU kernels.
-// As the layout of any argument will change from a native host layout to a
-// native TPU layout either on host or on device, XLA compiler and TPU runtime
-// must be in coordination to transform the parameters in a consistent way.
-enum class XlaLayoutPreference {
- kNoPreference = 0,
- kTpuPreferCompactChunkPaddedLayout = 1,
- kTpuPreferLinearLayout = 2
-};
+using XlaLayoutPreference = mlir::XlaLayoutPreference;
// Helper methods for building XLA computations.
class XlaHelpers {
@@ -106,8 +89,6 @@ class XlaHelpers {
// Creates an identity shape representation function.
XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn();
-
-
struct XlaOutputDescription {
// Type and shape of the output. The shape is the unflattened shape.
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
diff --git a/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc b/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc
index bd7b19be405..35aa5b83a0e 100644
--- a/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc
+++ b/tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc
@@ -61,9 +61,8 @@ Status MlirToXlaComputation(mlir::ModuleOp module,
mlir::MlirToHloConversionOptions options;
// We don't want the conversion to muck with our operator names.
options.legalize_node_names = false;
- TF_RETURN_IF_ERROR(
- ConvertMlirHloToHlo(module, &proto, use_tuple_args, return_tuple,
- /*shape_determination_fns=*/{}, options));
+ TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &proto, use_tuple_args,
+ return_tuple, options));
xla_computation = XlaComputation(std::move(*proto.mutable_hlo_module()));
return OkStatus();