diff options
author | Michael Delorimier <mdel@google.com> | 2021-02-17 23:16:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2021-02-17 23:20:28 -0800 |
commit | 0d9cb3a9594456bc0c8a67d37759fbe1080176d0 (patch) | |
tree | 0ecce3aac134bac9cc17db2c92fa009cf4e03139 | |
parent | 618a81433c49da843b0fc9e92d26927ee3cd003d (diff) | |
download | tensorflow-0d9cb3a9594456bc0c8a67d37759fbe1080176d0.tar.gz |
Added pass xla-legalize-tf-types. This pass converts quantized types to non-quantized (e.g. qint8 to i8).
PiperOrigin-RevId: 358111887
Change-Id: I16e67c7bd1d7d5f383236583f109fe6da82b0541
6 files changed, 266 insertions, 0 deletions
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 59c647e598e..f10aca20b47 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -307,6 +307,7 @@ void CreateConvertMlirToXlaHloPipeline( // inside PromoteResourcesToArgs. pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); + pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass()); pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass( /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, /*tf2xla_fallback_device_type=*/device_type)); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index b2d1e15b53c..63be2fa8d60 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -72,6 +72,7 @@ gentbl( cc_library( name = "xla_passes", srcs = [ + "transforms/legalize_tf_types.cc", "transforms/passes_detail.h", "transforms/prepare_for_export.cc", ], @@ -81,10 +82,12 @@ cc_library( deps = [ ":xla_passes_inc_gen", "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -109,6 +112,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo:convert_op_folder", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir new file mode 100644 index 00000000000..56d903be892 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir @@ -0,0 +1,54 @@ +// RUN: tf-opt "-xla-legalize-tf-types" %s | FILECHECK_OPTS="" FileCheck %s + +func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> { + // CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> { + // CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8> + %0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> + return %0: tensor<1x!tf.qint8> +} + +func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> { + // CHECK: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8> + // CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( { + // CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> () + // CHECK-NEXT: }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8> + // CHECK-NEXT: return %0 : tensor<1xi8> + %0 = "tf.IfRegion"(%arg0) ( { + "tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> () + }, { + "tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> () + }) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf.qint8> + return %0 : tensor<1x!tf.qint8> +} + +func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> { + // CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> { + // CHECK-NEXT: return %arg0 : tensor<1xi8> + return %arg0: tensor<1x!tf.qint8> +} + +func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> { + // CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> { + // CHECK-NEXT: return %arg0 : tensor<1xi16> + return %arg0: tensor<1x!tf.qint16> +} + +func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> { + // CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + return %arg0: tensor<1x!tf.qint32> +} + +func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> { + // CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> { + // CHECK-NEXT: return %arg0 : tensor<1xui8> + return %arg0: tensor<1x!tf.quint8> +} + +func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> { + // CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> { + // CHECK-NEXT: return %arg0 : tensor<1xui16> + return %arg0: tensor<1x!tf.quint16> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc new file mode 100644 index 00000000000..c1ce7d4aab9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc @@ -0,0 +1,185 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The TF dialect uses some TF types that are illegal in the MHLO dialect and +// some generic types that are legal in MHLO. This pass legalizes TF types into +// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8. +// Rewrites here should run before TF to MHLO op legalizations are run. +// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass +// rather than its own pass. + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h" + +#define DEBUG_TYPE "xla-legalize-tf-types" + +namespace mlir { +namespace mhlo { +namespace { + +bool isIllegalElementType(Type type) { + return type + .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type, + mlir::TF::Quint8Type, mlir::TF::Quint16Type>(); +} + +Type replaceElementType(Type type) { + return TypeSwitch<Type, Type>(type) + .Case<mlir::TF::Qint8Type>([&type](Type) { + return mlir::IntegerType::get(type.getContext(), 8); + }) + .Case<mlir::TF::Qint16Type>([&type](Type) { + return mlir::IntegerType::get(type.getContext(), 16); + }) + .Case<mlir::TF::Qint32Type>([&type](Type) { + return mlir::IntegerType::get(type.getContext(), 32); + }) + .Case<mlir::TF::Quint8Type>([&type](Type) { + return mlir::IntegerType::get( + type.getContext(), 8, + mlir::IntegerType::SignednessSemantics::Unsigned); + }) + .Case<mlir::TF::Quint16Type>([&type](Type) { + return mlir::IntegerType::get( + type.getContext(), 16, + mlir::IntegerType::SignednessSemantics::Unsigned); + }) + .Default([&type](Type) { return type; }); +} + +// TODO(b/180234863): What's below this line is generic so convert it to a +// utility. + +bool isIllegalType(Type type) { + if (isIllegalElementType(type)) return true; + if (auto shaped = type.dyn_cast<ShapedType>()) + return isIllegalType(shaped.getElementType()); + return false; +} + +Type replaceType(Type type) { + if (isIllegalElementType(type)) return replaceElementType(type); + if (auto shaped = type.dyn_cast<ShapedType>()) { + Type elem = shaped.getElementType(); + if (isIllegalType(elem)) return shaped.clone(replaceType(elem)); + } + return type; +} + +// An Op is illegal iff it contains an illegalType. +class TfTypeConversionTarget : public ConversionTarget { + public: + explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + markUnknownOpDynamicallyLegal(); + } + + protected: + bool isDynamicallyLegal(Operation *op) const override { + // The FuncOp type can contain types that the op's operand and result types + // do not contain. + if (auto func = dyn_cast<FuncOp>(op)) { + if (llvm::any_of(func.getType().getInputs(), isIllegalType) || + llvm::any_of(func.getType().getResults(), isIllegalType)) + return false; + } + if (llvm::any_of(op->getOperandTypes(), isIllegalType) || + llvm::any_of(op->getResultTypes(), isIllegalType)) + return false; + return true; + } +}; + +class TfTypeConverter : public TypeConverter { + public: + TfTypeConverter() { + addConversion([](Type type) -> Type { + if (isIllegalType(type)) + return replaceType(type); + else + return type; + }); + } +}; + +class TfTypePattern : public ConversionPattern { + public: + TfTypePattern(MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(1, converter, MatchAnyOpTypeTag()) {} + + // The dialect conversion framework will call this matchAndRewrite on each + // Operation in the IR tree. This call matchAndRewrite needs to update the + // Operation's results and child regions. + LogicalResult matchAndRewrite( + Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + // Update the results. + llvm::SmallVector<Type, 4> new_results; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + new_results))) + return failure(); + + // Update the regions. The dialect conversion framework wants new regions to + // be created and updated, rather than updating the old op. Thus we use an + // OperationState so we can add regions to the new up. + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + new_results, op->getAttrs(), op->getSuccessors()); + for (Region ®ion : op->getRegions()) { + Region &new_region = *state.addRegion(); + rewriter.inlineRegionBefore(region, new_region, new_region.begin()); + if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter()))) + return failure(); + } + rewriter.replaceOp(op, rewriter.createOperation(state)->getResults()); + + return success(); + } +}; + +struct LegalizeTfTypesPass + : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> { + void runOnOperation() override; +}; + +void LegalizeTfTypesPass::runOnOperation() { + TfTypeConverter converter; + OwningRewritePatternList patterns; + patterns.insert<TfTypePattern>(&getContext(), converter); + populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); + TfTypeConversionTarget target(getContext()); + if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) + return signalPassFailure(); +} + +static PassRegistration<LegalizeTfTypesPass> registration( + "xla-legalize-tf-types", + "Replace TensorFlow types with types that are legal in the MHLO dialect"); + +} // namespace + +std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() { + return std::make_unique<LegalizeTfTypesPass>(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index b5398f15089..77ba879e000 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -49,6 +49,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass( std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass( llvm::StringRef device_type); +/// Replaces types that do not exist in MHLO with equivalent types that do +/// exist. +std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTfTypesPass(); + /// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list. void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, OwningRewritePatternList& patterns); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td b/tensorflow/compiler/mlir/xla/transforms/xla_passes.td index 602740cbfb9..e93634b63b9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td +++ b/tensorflow/compiler/mlir/xla/transforms/xla_passes.td @@ -15,6 +15,24 @@ limitations under the License. include "mlir/Pass/PassBase.td" +def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> { + let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect"; + + let description = [{ +The TF dialect uses some TF types that are illegal in the MHLO dialect and +some generic types that are legal in MHLO. This pass legalizes TF types into +types that are legal in MHLO. Rewrites here should run before TF to MHLO op +legalizations are run. + +Specifically, this pass replaces each quantized integer type with the +corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8` +everywhere it occurs. Types that are replaced are `TF::Qint8Type`, +`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`. + }]; + + let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()"; +} + def PrepareForExportPass : FunctionPass<"xla-prepare-for-export"> { let summary = "Prepare for XLA export"; |