aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJian Cai <jiancai@google.com>2023-02-01 15:35:07 -0800
committerTensorFlow Release Automation <jenkins@tensorflow.org>2023-02-06 22:45:27 +0000
commit53fccb429a53b52ed03de4efcda45afe1cee924e (patch)
tree550f99b5b31c40ff8fc5b22005895193cc9a2a45
parent10e2e7cf0d13ef026ccbd90ea283ffbf33159703 (diff)
downloadtensorflow-upstream-r2.11-728113a3be6.tar.gz
[Tensorflow] Fix security vulnerability with TensorListSplitOpupstream-r2.11-728113a3be6
PiperOrigin-RevId: 506441188
-rw-r--r--tensorflow/compiler/tests/tensor_list_ops_test.py11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc2
2 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py
index 659d9f41e8d..3c9b29b1835 100644
--- a/tensorflow/compiler/tests/tensor_list_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_list_ops_test.py
@@ -236,6 +236,17 @@ class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase):
self.assertAllEqual(z.shape.as_list(), [None])
self.assertAllEqual(z, [0.0, 0.0])
+ def testInvalidSplitLength(self):
+ with self.session(), self.test_scope():
+ tensor_list_split = list_ops.tensor_list_split(
+ tensor=[1], element_shape=[-1], lengths=[0]
+ )
+ with self.assertRaisesRegex(
+ errors.UnimplementedError, "All lengths must be positive"
+ ):
+ self.evaluate(tensor_list_split)
+
+
if __name__ == "__main__":
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
os.environ.get("TF_XLA_FLAGS", ""))
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
index 980ca07e117..5d299fde600 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel {
OP_REQUIRES(ctx, len == length,
errors::Unimplemented("All lengths have to be the same"));
}
+ OP_REQUIRES(ctx, length,
+ errors::Unimplemented("All lengths must be positive"));
OP_REQUIRES(
ctx, element_dims[0] % length == 0,
errors::Unimplemented("Buffer size has to be a multiple of length"));