diff options
author | Mihai Maruseac <mihaimaruseac@google.com> | 2022-08-19 13:15:38 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-19 13:15:38 -0700 |
commit | be8b86aff1e2233af7323ae04dba521930dbd41c (patch) | |
tree | 5d7394fefa415425115e0ef264e64ae28f8048f3 | |
parent | 62a552b32ee74dff337d3e366f046834f2ae9129 (diff) | |
parent | da380c7d3281808c9076c7ca6e917b50feba99cb (diff) | |
download | tensorflow-be8b86aff1e2233af7323ae04dba521930dbd41c.tar.gz |
Merge pull request #57282 from tensorflow/r2.7-72180be0344
r2.7 cherry-pick: 72180be0344 "Fix tensor shape dtype bug in parameterized_truncated_normal."
-rw-r--r-- | tensorflow/core/kernels/parameterized_truncated_normal_op.cc | 13 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py | 23 |
2 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc index 24b7e3f4ebd..a007d37c4e2 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/stateless_random_ops.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/logging.h" @@ -630,20 +631,18 @@ class ParameterizedTruncatedNormalOp : public OpKernel { OP_REQUIRES(ctx, shape_tensor.NumElements() > 0, errors::InvalidArgument("Shape tensor must not be empty, got ", shape_tensor.DebugString())); - int32_t num_batches = shape_tensor.flat<int32>()(0); + TensorShape tensor_shape; + OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &tensor_shape)); + int32_t num_batches = tensor_shape.dim_size(0); int32_t samples_per_batch = 1; - const int32_t num_dims = shape_tensor.dim_size(0); + const int32_t num_dims = tensor_shape.dims(); for (int32_t i = 1; i < num_dims; i++) { - samples_per_batch *= shape_tensor.flat<int32>()(i); + samples_per_batch *= tensor_shape.dim_size(i); } const int32_t num_elements = num_batches * samples_per_batch; // Allocate the output before fudging num_batches and samples_per_batch. - auto shape_vec = shape_tensor.flat<int32>(); - TensorShape tensor_shape; - OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( - shape_vec.data(), shape_vec.size(), &tensor_shape)); Tensor* samples_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); diff --git a/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py index 5ec054f6bae..865c1d7c5fa 100644 --- a/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py +++ b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py @@ -307,6 +307,29 @@ class ParameterizedTruncatedNormalTest(test.TestCase): self.assertAllGreater(samples, 0.) self.assertAllGreater(samples_stateless, 0.) + def testShapeTypes(self): + for shape_dtype in [np.int32, np.int64]: + shape = np.array([1000], dtype=shape_dtype) + sample_op = random_ops.parameterized_truncated_normal( + shape=shape, means=0.0, stddevs=0.1, minvals=-1., maxvals=1.) + new_seed = random_ops.random_uniform([2], + seed=1234, + minval=0, + maxval=(2**31 - 1), + dtype=np.int32) + sample_op_stateless = stateless.stateless_parameterized_truncated_normal( + shape=shape, + seed=new_seed, + means=0.0, + stddevs=0.1, + minvals=-1., + maxvals=1.) + + samples = self.evaluate(sample_op) + stateless_samples = self.evaluate(sample_op_stateless) + self.assertAllEqual(samples.shape, shape) + self.assertAllEqual(stateless_samples.shape, shape) + def testStatelessParameterizedTruncatedNormalHasGrads(self): mean = variables.Variable(0.01) stddev = variables.Variable(1.) |