aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Delorimier <mdel@google.com>2021-02-17 23:16:12 -0800
committerTensorFlower Gardener <gardener@tensorflow.org>2021-02-17 23:20:28 -0800
commit0d9cb3a9594456bc0c8a67d37759fbe1080176d0 (patch)
tree0ecce3aac134bac9cc17db2c92fa009cf4e03139
parent618a81433c49da843b0fc9e92d26927ee3cd003d (diff)
downloadtensorflow-0d9cb3a9594456bc0c8a67d37759fbe1080176d0.tar.gz
Added pass xla-legalize-tf-types. This pass converts quantized types to non-quantized (e.g. qint8 to i8).
PiperOrigin-RevId: 358111887 Change-Id: I16e67c7bd1d7d5f383236583f109fe6da82b0541
-rw-r--r--tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc1
-rw-r--r--tensorflow/compiler/mlir/xla/BUILD4
-rw-r--r--tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir54
-rw-r--r--tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc185
-rw-r--r--tensorflow/compiler/mlir/xla/transforms/passes.h4
-rw-r--r--tensorflow/compiler/mlir/xla/transforms/xla_passes.td18
6 files changed, 266 insertions, 0 deletions
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 59c647e598e..f10aca20b47 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -307,6 +307,7 @@ void CreateConvertMlirToXlaHloPipeline(
// inside PromoteResourcesToArgs.
pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
+ pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index b2d1e15b53c..63be2fa8d60 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -72,6 +72,7 @@ gentbl(
cc_library(
name = "xla_passes",
srcs = [
+ "transforms/legalize_tf_types.cc",
"transforms/passes_detail.h",
"transforms/prepare_for_export.cc",
],
@@ -81,10 +82,12 @@ cc_library(
deps = [
":xla_passes_inc_gen",
"//tensorflow/compiler/mlir/hlo",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
@@ -109,6 +112,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir
new file mode 100644
index 00000000000..56d903be892
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir
@@ -0,0 +1,54 @@
+// RUN: tf-opt "-xla-legalize-tf-types" %s | FILECHECK_OPTS="" FileCheck %s
+
+func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
+ // CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
+ %0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8>
+ return %0: tensor<1x!tf.qint8>
+}
+
+func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
+ // CHECK: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8>
+ // CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( {
+ // CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> ()
+ // CHECK-NEXT: }, {
+ // CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> ()
+ // CHECK-NEXT: }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8>
+ // CHECK-NEXT: return %0 : tensor<1xi8>
+ %0 = "tf.IfRegion"(%arg0) ( {
+ "tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> ()
+ }, {
+ "tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> ()
+ }) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf.qint8>
+ return %0 : tensor<1x!tf.qint8>
+}
+
+func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
+ // CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK-NEXT: return %arg0 : tensor<1xi8>
+ return %arg0: tensor<1x!tf.qint8>
+}
+
+func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> {
+ // CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> {
+ // CHECK-NEXT: return %arg0 : tensor<1xi16>
+ return %arg0: tensor<1x!tf.qint16>
+}
+
+func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> {
+ // CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK-NEXT: return %arg0 : tensor<1xi32>
+ return %arg0: tensor<1x!tf.qint32>
+}
+
+func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> {
+ // CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> {
+ // CHECK-NEXT: return %arg0 : tensor<1xui8>
+ return %arg0: tensor<1x!tf.quint8>
+}
+
+func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> {
+ // CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> {
+ // CHECK-NEXT: return %arg0 : tensor<1xui16>
+ return %arg0: tensor<1x!tf.quint16>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc
new file mode 100644
index 00000000000..c1ce7d4aab9
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc
@@ -0,0 +1,185 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The TF dialect uses some TF types that are illegal in the MHLO dialect and
+// some generic types that are legal in MHLO. This pass legalizes TF types into
+// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
+// Rewrites here should run before TF to MHLO op legalizations are run.
+// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
+// rather than its own pass.
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
+
+#define DEBUG_TYPE "xla-legalize-tf-types"
+
+namespace mlir {
+namespace mhlo {
+namespace {
+
+bool isIllegalElementType(Type type) {
+ return type
+ .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
+ mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
+}
+
+Type replaceElementType(Type type) {
+ return TypeSwitch<Type, Type>(type)
+ .Case<mlir::TF::Qint8Type>([&type](Type) {
+ return mlir::IntegerType::get(type.getContext(), 8);
+ })
+ .Case<mlir::TF::Qint16Type>([&type](Type) {
+ return mlir::IntegerType::get(type.getContext(), 16);
+ })
+ .Case<mlir::TF::Qint32Type>([&type](Type) {
+ return mlir::IntegerType::get(type.getContext(), 32);
+ })
+ .Case<mlir::TF::Quint8Type>([&type](Type) {
+ return mlir::IntegerType::get(
+ type.getContext(), 8,
+ mlir::IntegerType::SignednessSemantics::Unsigned);
+ })
+ .Case<mlir::TF::Quint16Type>([&type](Type) {
+ return mlir::IntegerType::get(
+ type.getContext(), 16,
+ mlir::IntegerType::SignednessSemantics::Unsigned);
+ })
+ .Default([&type](Type) { return type; });
+}
+
+// TODO(b/180234863): What's below this line is generic so convert it to a
+// utility.
+
+bool isIllegalType(Type type) {
+ if (isIllegalElementType(type)) return true;
+ if (auto shaped = type.dyn_cast<ShapedType>())
+ return isIllegalType(shaped.getElementType());
+ return false;
+}
+
+Type replaceType(Type type) {
+ if (isIllegalElementType(type)) return replaceElementType(type);
+ if (auto shaped = type.dyn_cast<ShapedType>()) {
+ Type elem = shaped.getElementType();
+ if (isIllegalType(elem)) return shaped.clone(replaceType(elem));
+ }
+ return type;
+}
+
+// An Op is illegal iff it contains an illegalType.
+class TfTypeConversionTarget : public ConversionTarget {
+ public:
+ explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
+ markUnknownOpDynamicallyLegal();
+ }
+
+ protected:
+ bool isDynamicallyLegal(Operation *op) const override {
+ // The FuncOp type can contain types that the op's operand and result types
+ // do not contain.
+ if (auto func = dyn_cast<FuncOp>(op)) {
+ if (llvm::any_of(func.getType().getInputs(), isIllegalType) ||
+ llvm::any_of(func.getType().getResults(), isIllegalType))
+ return false;
+ }
+ if (llvm::any_of(op->getOperandTypes(), isIllegalType) ||
+ llvm::any_of(op->getResultTypes(), isIllegalType))
+ return false;
+ return true;
+ }
+};
+
+class TfTypeConverter : public TypeConverter {
+ public:
+ TfTypeConverter() {
+ addConversion([](Type type) -> Type {
+ if (isIllegalType(type))
+ return replaceType(type);
+ else
+ return type;
+ });
+ }
+};
+
+class TfTypePattern : public ConversionPattern {
+ public:
+ TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
+ : ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
+
+ // The dialect conversion framework will call this matchAndRewrite on each
+ // Operation in the IR tree. This call matchAndRewrite needs to update the
+ // Operation's results and child regions.
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Update the results.
+ llvm::SmallVector<Type, 4> new_results;
+ if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
+ new_results)))
+ return failure();
+
+ // Update the regions. The dialect conversion framework wants new regions to
+ // be created and updated, rather than updating the old op. Thus we use an
+ // OperationState so we can add regions to the new up.
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ new_results, op->getAttrs(), op->getSuccessors());
+ for (Region &region : op->getRegions()) {
+ Region &new_region = *state.addRegion();
+ rewriter.inlineRegionBefore(region, new_region, new_region.begin());
+ if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
+ return failure();
+ }
+ rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
+
+ return success();
+ }
+};
+
+struct LegalizeTfTypesPass
+ : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
+ void runOnOperation() override;
+};
+
+void LegalizeTfTypesPass::runOnOperation() {
+ TfTypeConverter converter;
+ OwningRewritePatternList patterns;
+ patterns.insert<TfTypePattern>(&getContext(), converter);
+ populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
+ TfTypeConversionTarget target(getContext());
+ if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
+ return signalPassFailure();
+}
+
+static PassRegistration<LegalizeTfTypesPass> registration(
+ "xla-legalize-tf-types",
+ "Replace TensorFlow types with types that are legal in the MHLO dialect");
+
+} // namespace
+
+std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
+ return std::make_unique<LegalizeTfTypesPass>();
+}
+
+} // namespace mhlo
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index b5398f15089..77ba879e000 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -49,6 +49,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type);
+/// Replaces types that do not exist in MHLO with equivalent types that do
+/// exist.
+std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTfTypesPass();
+
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns);
diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td b/tensorflow/compiler/mlir/xla/transforms/xla_passes.td
index 602740cbfb9..e93634b63b9 100644
--- a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td
+++ b/tensorflow/compiler/mlir/xla/transforms/xla_passes.td
@@ -15,6 +15,24 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
+def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> {
+ let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect";
+
+ let description = [{
+The TF dialect uses some TF types that are illegal in the MHLO dialect and
+some generic types that are legal in MHLO. This pass legalizes TF types into
+types that are legal in MHLO. Rewrites here should run before TF to MHLO op
+legalizations are run.
+
+Specifically, this pass replaces each quantized integer type with the
+corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8`
+everywhere it occurs. Types that are replaced are `TF::Qint8Type`,
+`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`.
+ }];
+
+ let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()";
+}
+
def PrepareForExportPass : FunctionPass<"xla-prepare-for-export"> {
let summary = "Prepare for XLA export";