diff options
author | Jian Cai <jiancai@google.com> | 2023-02-01 15:35:07 -0800 |
---|---|---|
committer | TensorFlow Release Automation <jenkins@tensorflow.org> | 2023-02-06 22:45:27 +0000 |
commit | 53fccb429a53b52ed03de4efcda45afe1cee924e (patch) | |
tree | 550f99b5b31c40ff8fc5b22005895193cc9a2a45 | |
parent | 10e2e7cf0d13ef026ccbd90ea283ffbf33159703 (diff) | |
download | tensorflow-upstream-r2.11-728113a3be6.tar.gz |
[Tensorflow] Fix security vulnerability with TensorListSplitOpupstream-r2.11-728113a3be6
PiperOrigin-RevId: 506441188
-rw-r--r-- | tensorflow/compiler/tests/tensor_list_ops_test.py | 11 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc | 2 |
2 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 659d9f41e8d..3c9b29b1835 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -236,6 +236,17 @@ class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase): self.assertAllEqual(z.shape.as_list(), [None]) self.assertAllEqual(z, [0.0, 0.0]) + def testInvalidSplitLength(self): + with self.session(), self.test_scope(): + tensor_list_split = list_ops.tensor_list_split( + tensor=[1], element_shape=[-1], lengths=[0] + ) + with self.assertRaisesRegex( + errors.UnimplementedError, "All lengths must be positive" + ): + self.evaluate(tensor_list_split) + + if __name__ == "__main__": os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + os.environ.get("TF_XLA_FLAGS", "")) diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 980ca07e117..5d299fde600 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel { OP_REQUIRES(ctx, len == length, errors::Unimplemented("All lengths have to be the same")); } + OP_REQUIRES(ctx, length, + errors::Unimplemented("All lengths must be positive")); OP_REQUIRES( ctx, element_dims[0] % length == 0, errors::Unimplemented("Buffer size has to be a multiple of length")); |