aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Szaday <twelve@google.com>2023-01-18 11:44:11 -0800
committerunda <4316856+fcoUnda@users.noreply.github.com>2023-02-03 21:41:55 +0000
commitc7f0b8e334061ab86c3d45d46ea397b17c34b3e4 (patch)
tree700e3baba63fd9850010d0f039d0ebce05ac4c35
parent5968b6b37ee986ad563d9bae2a995aae8c9f6bea (diff)
downloadtensorflow-c7f0b8e334061ab86c3d45d46ea397b17c34b3e4.tar.gz
Add bounds-checking to `TPUPartitionedInput` and `TPUPartitionedOutput` ops.
PiperOrigin-RevId: 502936861
-rw-r--r--tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc104
-rw-r--r--tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc79
2 files changed, 177 insertions, 6 deletions
diff --git a/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc b/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc
index 1a185ad2107..c905664cac9 100644
--- a/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc
@@ -45,7 +45,30 @@ REGISTER_OP("TPUPartitionedInput")
int partition_dim;
TF_RETURN_IF_ERROR(c->GetAttr("partition_dim", &partition_dim));
+ if (c->num_inputs() == 0) {
+ return errors::InvalidArgument(
+ "Expected at least one input to TPUPartitionedInput.");
+ }
+
ShapeHandle cur = c->input(c->num_inputs() - 1);
+ int rank = InferenceContext::kUnknownRank;
+ if (dtype == DT_RESOURCE) {
+ auto* shapes_and_types =
+ c->input_handle_shapes_and_types(c->num_inputs() - 1);
+ if (shapes_and_types) {
+ ShapeHandle shape_handle = shapes_and_types->at(0).shape;
+ rank = InferenceContext::Rank(shape_handle);
+ }
+ } else {
+ rank = InferenceContext::Rank(cur);
+ }
+
+ // limitation: can only validate rank when it is known
+ if ((rank != InferenceContext::kUnknownRank && partition_dim >= rank) ||
+ (partition_dim < -1))
+ return errors::InvalidArgument("Cannot partition dim ", partition_dim,
+ " of rank ", rank, " tensor.");
+
for (int i = c->num_inputs() - 2; i >= 0; --i) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
"From merging shape ", i,
@@ -101,4 +124,85 @@ REGISTER_OP("TPUPartitionedInput")
return OkStatus();
});
+REGISTER_OP("TPUPartitionedInputV2")
+ .Input("inputs: N * T")
+ .Output("output: T")
+ .Attr("N: int >= 1")
+ .Attr("T: type")
+ .Attr("partition_dims: list(int)")
+ .Attr("is_packed: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ DataType dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
+ std::vector<int> partition_dims;
+ TF_RETURN_IF_ERROR(c->GetAttr("partition_dims", &partition_dims));
+ bool is_packed;
+ TF_RETURN_IF_ERROR(c->GetAttr("is_packed", &is_packed));
+
+ int num_partitions = 1;
+ for (const int& partition_dim : partition_dims) {
+ num_partitions *= partition_dim;
+ }
+
+ bool replicated = partition_dims.empty();
+ int num_inputs_expected = is_packed ? 1 : num_partitions;
+ if (!((replicated && !is_packed) ||
+ (c->num_inputs() == num_inputs_expected))) {
+ // we cannot validate the number of inputs for replicated, unpacked ops
+ // since we cannot infer the number of partitions from partition_dims
+ return errors::InvalidArgument("Expected ", num_inputs_expected,
+ " inputs, got ", c->num_inputs(), ".");
+ } else if (c->num_inputs() == 0) {
+ return errors::InvalidArgument(
+ "Expected at least one input to TPUPartitionedInputV2.");
+ }
+
+ ShapeHandle output_shape;
+ if (dtype == DT_RESOURCE) {
+ ShapeHandle previous_shape_handle;
+ const std::vector<shape_inference::ShapeAndType>* shapes_and_types =
+ nullptr;
+ for (int i = c->num_inputs() - 1; i >= 0; --i) {
+ shapes_and_types = c->input_handle_shapes_and_types(i);
+ if (shapes_and_types) {
+ ShapeHandle shape_handle = shapes_and_types->at(0).shape;
+ if (!c->FullyDefined(shape_handle)) {
+ return errors::InvalidArgument("Inputs must have static shape,",
+ "input[", i,
+ "] has unknown dimension.");
+ }
+
+ if (i != c->num_inputs() - 1) {
+ ShapeHandle tmp;
+ if (!c->Merge(shape_handle, previous_shape_handle, &tmp).ok()) {
+ return errors::InvalidArgument(
+ "Inputs must have the same shape.");
+ }
+ } else {
+ previous_shape_handle = shape_handle;
+ }
+ }
+ }
+
+ if (shapes_and_types) {
+ TF_ASSIGN_OR_RETURN(
+ output_shape,
+ _ComputeOutputShape(c, previous_shape_handle, partition_dims));
+ std::vector<shape_inference::ShapeAndType> output_shapes_and_types;
+ output_shapes_and_types.push_back(shape_inference::ShapeAndType(
+ output_shape, shapes_and_types->at(0).dtype));
+ c->set_output_handle_shapes_and_types(0, output_shapes_and_types);
+ }
+ }
+
+ if (!c->FullyDefined(output_shape)) {
+ TF_ASSIGN_OR_RETURN(
+ output_shape, _ComputeOutputShape(c, c->input(0), partition_dims));
+ }
+
+ c->set_output(0, output_shape);
+
+ return OkStatus();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc b/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc
index 9ab183ad6d3..da9fd60923c 100644
--- a/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc
@@ -38,19 +38,86 @@ REGISTER_OP("TPUPartitionedOutput")
TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits));
if (dtype == DT_RESOURCE) {
return errors::Unimplemented("Not implemented.");
+ } else if (c->num_inputs() == 0) {
+ return errors::InvalidArgument(
+ "Expected at least one input to TPUPartitionedOutput.");
}
ShapeHandle input = c->input(0);
+ int rank = InferenceContext::Rank(input);
+ // limitation: can only validate rank when it is known
+ if ((rank != InferenceContext::kUnknownRank && partition_dim >= rank) ||
+ (partition_dim < -1))
+ return errors::InvalidArgument("Cannot partition dim ", partition_dim,
+ " of rank ", rank, " tensor.");
+
ShapeHandle newoutput0;
- shape_inference::DimensionHandle new_dim;
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- c->Divide(c->Dim(input, partition_dim), num_splits,
- true /* evenly_divisible */, &new_dim),
- "Number of ways to split should evenly divide the split dimension");
- TF_CHECK_OK(c->ReplaceDim(input, partition_dim, new_dim, &newoutput0));
+ if (partition_dim == -1) {
+ newoutput0 = input; // replicated input/output share shapes
+ } else {
+ shape_inference::DimensionHandle new_dim;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ c->Divide(c->Dim(input, partition_dim), num_splits,
+ true /* evenly_divisible */, &new_dim),
+ "Number of ways to split should evenly divide the split dimension");
+ TF_CHECK_OK(c->ReplaceDim(input, partition_dim, new_dim, &newoutput0));
+ }
+
for (int i = num_splits - 1; i >= 0; --i) {
c->set_output(i, newoutput0);
}
+
+ return OkStatus();
+ });
+
+REGISTER_OP("TPUPartitionedOutputV2")
+ .Input("inputs: T")
+ .Output("output: num_splits * T")
+ .Attr("T: type")
+ .Attr("num_splits: int >= 1")
+ .Attr("partition_dims: list(int)")
+ .SetShapeFn([](InferenceContext* c) {
+ DataType dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
+ std::vector<int> partition_dims;
+ TF_RETURN_IF_ERROR(c->GetAttr("partition_dims", &partition_dims));
+ int num_splits;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits));
+ if (dtype == DT_RESOURCE) {
+ return errors::Unimplemented("Not implemented.");
+ } else if (c->num_inputs() == 0) {
+ return errors::InvalidArgument(
+ "Expected at least one input to TPUPartitionedOutputV2.");
+ }
+
+ ShapeHandle handle = c->input(0);
+ int rank = InferenceContext::Rank(handle);
+ int num_cores_per_replica = 1;
+ for (const int& partition_dim : partition_dims) {
+ num_cores_per_replica *= partition_dim;
+ }
+
+ if (num_splits != num_cores_per_replica) {
+ return errors::InvalidArgument("Expected ", num_cores_per_replica,
+ " splits.");
+ } else if (rank > (int)partition_dims.size()) {
+ return errors::InvalidArgument("Expected at least ", rank,
+ " partition dimensions.");
+ }
+
+ for (int i = 0; i < rank; ++i) {
+ shape_inference::DimensionHandle dim;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ c->Divide(c->Dim(handle, i), partition_dims[i],
+ true /* evenly_divisible */, &dim),
+ "Number of ways to split should evenly divide the split dimension");
+ TF_CHECK_OK(c->ReplaceDim(handle, i, dim, &handle));
+ }
+
+ for (int i = num_splits - 1; i >= 0; --i) {
+ c->set_output(i, handle);
+ }
+
return OkStatus();
});