aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSandeep Dasgupta <sdasgup@google.com>2023-04-28 18:46:13 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2023-04-28 18:50:46 -0700
commita64c0d35b540f514717251a2397f3e238a16f403 (patch)
tree0baf5d0b37b81603a850870d6b22da58957ceca3
parentd2029b77a9209eaa863ff9231141faf7ae68b3d3 (diff)
downloadtensorflow-a64c0d35b540f514717251a2397f3e238a16f403.tar.gz
Integrate StableHLO at openxla/stablehlo@43d81c6
PiperOrigin-RevId: 528037738
-rw-r--r--third_party/stablehlo/temporary.patch85
-rw-r--r--third_party/stablehlo/workspace.bzl4
2 files changed, 2 insertions, 87 deletions
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
index 7bdfe5ef756..18aa6f251a0 100644
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -1,61 +1,3 @@
-diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp
---- stablehlo/stablehlo/dialect/AssemblyFormat.cpp
-+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp
-@@ -203,7 +203,7 @@
- return parser.emitError(loc, "expected tensor with complex element type");
-
- // Assign LHS and RHS to inferred type
-- Type realType = createRealType(type);
-+ Type realType = createRealType(shapedType);
- lhs = rhs = realType;
- result = type;
- return success();
-diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.cpp b/stablehlo/stablehlo/dialect/ChloOps.cpp
---- stablehlo/stablehlo/dialect/ChloOps.cpp
-+++ stablehlo/stablehlo/dialect/ChloOps.cpp
-@@ -190,7 +190,7 @@
- ValueShapeRange operands, DictionaryAttr attributes,
- RegionRange /*regions*/,
- SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
-- ShapedType lhsType = operands[0].getType();
-+ ShapedType lhsType = operands[0].getType().cast<ShapedType>();
- Type elementType = ComplexType::get(lhsType.getElementType());
- return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
- attributes, elementType,
-diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
---- stablehlo/stablehlo/dialect/StablehloOps.cpp
-+++ stablehlo/stablehlo/dialect/StablehloOps.cpp
-@@ -228,7 +228,7 @@
- Attribute value) {
- ShapedType type;
- if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
-- type = elemAttr.getType();
-+ type = cast<ShapedType>(elemAttr.getType());
- } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
- // All XLA types must be tensor types. In the build() method, we want to
- // provide more flexibility by allowing attributes of scalar types. But we
-@@ -2518,7 +2518,7 @@
- void StablehloDialect::printAttribute(Attribute attr,
- DialectAsmPrinter& os) const {
- if (auto type_extensions = attr.dyn_cast<TypeExtensionsAttr>()) {
-- hlo::printTypeExtensions(attr, os);
-+ hlo::printTypeExtensions(hlo::BoundedAttrInterface(attr), os);
- return;
- }
- LogicalResult result = generatedAttributePrinter(attr, os);
-diff --ruN a/stablehlo/stablehlo/dialect/VhloOps.cpp b/stablehlo/stablehlo/dialect/VhloOps.cpp
---- stablehlo/stablehlo/dialect/VhloOps.cpp
-+++ stablehlo/stablehlo/dialect/VhloOps.cpp
-@@ -181,7 +181,8 @@
- void TensorV1Attr::print(mlir::AsmPrinter& p) const {
- p << '<'
- << DenseIntOrFPElementsAttr::getFromRawBuffer(
-- convertTypeToBuiltinForPrint(getType()), getData())
-+ convertTypeToBuiltinForPrint(getType()).cast<ShapedType>(),
-+ getData())
- << '>';
- }
-
diff --ruN a/stablehlo/stablehlo/integrations/python/mlir/dialects/stablehlo.py b/stablehlo/stablehlo/integrations/python/mlir/dialects/stablehlo.py
--- stablehlo/stablehlo/integrations/python/mlir/dialects/stablehlo.py
+++ stablehlo/stablehlo/integrations/python/mlir/dialects/stablehlo.py
@@ -70,31 +12,4 @@ diff --ruN a/stablehlo/stablehlo/integrations/python/mlir/dialects/stablehlo.py
+ is still forward compatible with.
+ """
+ return "0.9.0"
-diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/reference/Ops.cpp
---- stablehlo/stablehlo/reference/Ops.cpp
-+++ stablehlo/stablehlo/reference/Ops.cpp
-@@ -619,7 +619,8 @@
- resultIt != inputs[0].index_end(); ++resultIt) {
- SmallVector<Tensor> args;
- for (size_t i = 0; i < inputs.size(); ++i) {
-- auto tensor = Tensor(computation.getArgument(i).getType());
-+ auto tensor =
-+ Tensor(cast<ShapedType>(computation.getArgument(i).getType()));
- tensor.set({}, inputs[i].get(*resultIt));
- args.push_back(tensor);
- }
-diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
-+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
-@@ -144,8 +144,8 @@
- if (auto attr = vhloAttr.dyn_cast<vhlo::TensorV1Attr>()) {
- auto builtinType = typeConverter->convertType(attr.getType());
- if (!builtinType) return {};
-- return DenseIntOrFPElementsAttr::getFromRawBuffer(builtinType,
-- attr.getData());
-+ return DenseIntOrFPElementsAttr::getFromRawBuffer(
-+ cast<ShapedType>(builtinType), attr.getData());
- }
- if (auto attr = vhloAttr.dyn_cast<vhlo::TransposeV1Attr>()) {
- RETURN_CONVERTED_ENUM_ATTR(Transpose, V1);
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index 7ad4ab5ee33..d6d2f3b10a1 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
# LINT.IfChange
- STABLEHLO_COMMIT = "45a85ebd8afcc67429d7158c25af2381e80f74f9"
- STABLEHLO_SHA256 = "33163b1aeac7495532b212378b7909d61b30afda99a78a57771a1761413451b9"
+ STABLEHLO_COMMIT = "43d81c6883ade82052920bd367c61f9e52f09954"
+ STABLEHLO_SHA256 = "57a8a93e51211f990d760631f2bfdbba5257b22dda3d60e35a186bba988a2ace"
# LINT.ThenChange(Google-internal path)
tf_http_archive(