From 53fccb429a53b52ed03de4efcda45afe1cee924e Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Wed, 1 Feb 2023 15:35:07 -0800 Subject: [Tensorflow] Fix security vulnerability with TensorListSplitOp PiperOrigin-RevId: 506441188 --- tensorflow/compiler/tests/tensor_list_ops_test.py | 11 +++++++++++ tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc | 2 ++ 2 files changed, 13 insertions(+) diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 659d9f41e8d..3c9b29b1835 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -236,6 +236,17 @@ class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase): self.assertAllEqual(z.shape.as_list(), [None]) self.assertAllEqual(z, [0.0, 0.0]) + def testInvalidSplitLength(self): + with self.session(), self.test_scope(): + tensor_list_split = list_ops.tensor_list_split( + tensor=[1], element_shape=[-1], lengths=[0] + ) + with self.assertRaisesRegex( + errors.UnimplementedError, "All lengths must be positive" + ): + self.evaluate(tensor_list_split) + + if __name__ == "__main__": os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + os.environ.get("TF_XLA_FLAGS", "")) diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 980ca07e117..5d299fde600 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel { OP_REQUIRES(ctx, len == length, errors::Unimplemented("All lengths have to be the same")); } + OP_REQUIRES(ctx, length, + errors::Unimplemented("All lengths must be positive")); OP_REQUIRES( ctx, element_dims[0] % length == 0, errors::Unimplemented("Buffer size has to be a multiple of length")); -- cgit v1.2.3