aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMihai Maruseac <mihaimaruseac@google.com>2022-08-19 13:15:38 -0700
committerGitHub <noreply@github.com>2022-08-19 13:15:38 -0700
commitbe8b86aff1e2233af7323ae04dba521930dbd41c (patch)
tree5d7394fefa415425115e0ef264e64ae28f8048f3
parent62a552b32ee74dff337d3e366f046834f2ae9129 (diff)
parentda380c7d3281808c9076c7ca6e917b50feba99cb (diff)
downloadtensorflow-be8b86aff1e2233af7323ae04dba521930dbd41c.tar.gz
Merge pull request #57282 from tensorflow/r2.7-72180be0344
r2.7 cherry-pick: 72180be0344 "Fix tensor shape dtype bug in parameterized_truncated_normal."
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.cc13
-rw-r--r--tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py23
2 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
index 24b7e3f4ebd..a007d37c4e2 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
@@ -630,20 +631,18 @@ class ParameterizedTruncatedNormalOp : public OpKernel {
OP_REQUIRES(ctx, shape_tensor.NumElements() > 0,
errors::InvalidArgument("Shape tensor must not be empty, got ",
shape_tensor.DebugString()));
- int32_t num_batches = shape_tensor.flat<int32>()(0);
+ TensorShape tensor_shape;
+ OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &tensor_shape));
+ int32_t num_batches = tensor_shape.dim_size(0);
int32_t samples_per_batch = 1;
- const int32_t num_dims = shape_tensor.dim_size(0);
+ const int32_t num_dims = tensor_shape.dims();
for (int32_t i = 1; i < num_dims; i++) {
- samples_per_batch *= shape_tensor.flat<int32>()(i);
+ samples_per_batch *= tensor_shape.dim_size(i);
}
const int32_t num_elements = num_batches * samples_per_batch;
// Allocate the output before fudging num_batches and samples_per_batch.
- auto shape_vec = shape_tensor.flat<int32>();
- TensorShape tensor_shape;
- OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
- shape_vec.data(), shape_vec.size(), &tensor_shape));
Tensor* samples_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));
diff --git a/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py
index 5ec054f6bae..865c1d7c5fa 100644
--- a/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py
+++ b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py
@@ -307,6 +307,29 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
self.assertAllGreater(samples, 0.)
self.assertAllGreater(samples_stateless, 0.)
+ def testShapeTypes(self):
+ for shape_dtype in [np.int32, np.int64]:
+ shape = np.array([1000], dtype=shape_dtype)
+ sample_op = random_ops.parameterized_truncated_normal(
+ shape=shape, means=0.0, stddevs=0.1, minvals=-1., maxvals=1.)
+ new_seed = random_ops.random_uniform([2],
+ seed=1234,
+ minval=0,
+ maxval=(2**31 - 1),
+ dtype=np.int32)
+ sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
+ shape=shape,
+ seed=new_seed,
+ means=0.0,
+ stddevs=0.1,
+ minvals=-1.,
+ maxvals=1.)
+
+ samples = self.evaluate(sample_op)
+ stateless_samples = self.evaluate(sample_op_stateless)
+ self.assertAllEqual(samples.shape, shape)
+ self.assertAllEqual(stateless_samples.shape, shape)
+
def testStatelessParameterizedTruncatedNormalHasGrads(self):
mean = variables.Variable(0.01)
stddev = variables.Variable(1.)