diff options
author | Justin Szaday <twelve@google.com> | 2023-01-18 11:44:11 -0800 |
---|---|---|
committer | unda <4316856+fcoUnda@users.noreply.github.com> | 2023-02-03 21:41:55 +0000 |
commit | c7f0b8e334061ab86c3d45d46ea397b17c34b3e4 (patch) | |
tree | 700e3baba63fd9850010d0f039d0ebce05ac4c35 | |
parent | 5968b6b37ee986ad563d9bae2a995aae8c9f6bea (diff) | |
download | tensorflow-c7f0b8e334061ab86c3d45d46ea397b17c34b3e4.tar.gz |
Add bounds-checking to `TPUPartitionedInput` and `TPUPartitionedOutput` ops.
PiperOrigin-RevId: 502936861
-rw-r--r-- | tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc | 104 | ||||
-rw-r--r-- | tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc | 79 |
2 files changed, 177 insertions, 6 deletions
diff --git a/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc b/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc index 1a185ad2107..c905664cac9 100644 --- a/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc +++ b/tensorflow/core/tpu/ops/tpu_partitioned_input_op.cc @@ -45,7 +45,30 @@ REGISTER_OP("TPUPartitionedInput") int partition_dim; TF_RETURN_IF_ERROR(c->GetAttr("partition_dim", &partition_dim)); + if (c->num_inputs() == 0) { + return errors::InvalidArgument( + "Expected at least one input to TPUPartitionedInput."); + } + ShapeHandle cur = c->input(c->num_inputs() - 1); + int rank = InferenceContext::kUnknownRank; + if (dtype == DT_RESOURCE) { + auto* shapes_and_types = + c->input_handle_shapes_and_types(c->num_inputs() - 1); + if (shapes_and_types) { + ShapeHandle shape_handle = shapes_and_types->at(0).shape; + rank = InferenceContext::Rank(shape_handle); + } + } else { + rank = InferenceContext::Rank(cur); + } + + // limitation: can only validate rank when it is known + if ((rank != InferenceContext::kUnknownRank && partition_dim >= rank) || + (partition_dim < -1)) + return errors::InvalidArgument("Cannot partition dim ", partition_dim, + " of rank ", rank, " tensor."); + for (int i = c->num_inputs() - 2; i >= 0; --i) { TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), "From merging shape ", i, @@ -101,4 +124,85 @@ REGISTER_OP("TPUPartitionedInput") return OkStatus(); }); +REGISTER_OP("TPUPartitionedInputV2") + .Input("inputs: N * T") + .Output("output: T") + .Attr("N: int >= 1") + .Attr("T: type") + .Attr("partition_dims: list(int)") + .Attr("is_packed: bool = false") + .SetShapeFn([](InferenceContext* c) { + DataType dtype; + TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); + std::vector<int> partition_dims; + TF_RETURN_IF_ERROR(c->GetAttr("partition_dims", &partition_dims)); + bool is_packed; + TF_RETURN_IF_ERROR(c->GetAttr("is_packed", &is_packed)); + + int num_partitions = 1; + for (const int& partition_dim : partition_dims) { + num_partitions *= partition_dim; + } + + bool replicated = partition_dims.empty(); + int num_inputs_expected = is_packed ? 1 : num_partitions; + if (!((replicated && !is_packed) || + (c->num_inputs() == num_inputs_expected))) { + // we cannot validate the number of inputs for replicated, unpacked ops + // since we cannot infer the number of partitions from partition_dims + return errors::InvalidArgument("Expected ", num_inputs_expected, + " inputs, got ", c->num_inputs(), "."); + } else if (c->num_inputs() == 0) { + return errors::InvalidArgument( + "Expected at least one input to TPUPartitionedInputV2."); + } + + ShapeHandle output_shape; + if (dtype == DT_RESOURCE) { + ShapeHandle previous_shape_handle; + const std::vector<shape_inference::ShapeAndType>* shapes_and_types = + nullptr; + for (int i = c->num_inputs() - 1; i >= 0; --i) { + shapes_and_types = c->input_handle_shapes_and_types(i); + if (shapes_and_types) { + ShapeHandle shape_handle = shapes_and_types->at(0).shape; + if (!c->FullyDefined(shape_handle)) { + return errors::InvalidArgument("Inputs must have static shape,", + "input[", i, + "] has unknown dimension."); + } + + if (i != c->num_inputs() - 1) { + ShapeHandle tmp; + if (!c->Merge(shape_handle, previous_shape_handle, &tmp).ok()) { + return errors::InvalidArgument( + "Inputs must have the same shape."); + } + } else { + previous_shape_handle = shape_handle; + } + } + } + + if (shapes_and_types) { + TF_ASSIGN_OR_RETURN( + output_shape, + _ComputeOutputShape(c, previous_shape_handle, partition_dims)); + std::vector<shape_inference::ShapeAndType> output_shapes_and_types; + output_shapes_and_types.push_back(shape_inference::ShapeAndType( + output_shape, shapes_and_types->at(0).dtype)); + c->set_output_handle_shapes_and_types(0, output_shapes_and_types); + } + } + + if (!c->FullyDefined(output_shape)) { + TF_ASSIGN_OR_RETURN( + output_shape, _ComputeOutputShape(c, c->input(0), partition_dims)); + } + + c->set_output(0, output_shape); + + return OkStatus(); + }); + } // namespace tensorflow diff --git a/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc b/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc index 9ab183ad6d3..da9fd60923c 100644 --- a/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc +++ b/tensorflow/core/tpu/ops/tpu_partitioned_output_op.cc @@ -38,19 +38,86 @@ REGISTER_OP("TPUPartitionedOutput") TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits)); if (dtype == DT_RESOURCE) { return errors::Unimplemented("Not implemented."); + } else if (c->num_inputs() == 0) { + return errors::InvalidArgument( + "Expected at least one input to TPUPartitionedOutput."); } ShapeHandle input = c->input(0); + int rank = InferenceContext::Rank(input); + // limitation: can only validate rank when it is known + if ((rank != InferenceContext::kUnknownRank && partition_dim >= rank) || + (partition_dim < -1)) + return errors::InvalidArgument("Cannot partition dim ", partition_dim, + " of rank ", rank, " tensor."); + ShapeHandle newoutput0; - shape_inference::DimensionHandle new_dim; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - c->Divide(c->Dim(input, partition_dim), num_splits, - true /* evenly_divisible */, &new_dim), - "Number of ways to split should evenly divide the split dimension"); - TF_CHECK_OK(c->ReplaceDim(input, partition_dim, new_dim, &newoutput0)); + if (partition_dim == -1) { + newoutput0 = input; // replicated input/output share shapes + } else { + shape_inference::DimensionHandle new_dim; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->Divide(c->Dim(input, partition_dim), num_splits, + true /* evenly_divisible */, &new_dim), + "Number of ways to split should evenly divide the split dimension"); + TF_CHECK_OK(c->ReplaceDim(input, partition_dim, new_dim, &newoutput0)); + } + for (int i = num_splits - 1; i >= 0; --i) { c->set_output(i, newoutput0); } + + return OkStatus(); + }); + +REGISTER_OP("TPUPartitionedOutputV2") + .Input("inputs: T") + .Output("output: num_splits * T") + .Attr("T: type") + .Attr("num_splits: int >= 1") + .Attr("partition_dims: list(int)") + .SetShapeFn([](InferenceContext* c) { + DataType dtype; + TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); + std::vector<int> partition_dims; + TF_RETURN_IF_ERROR(c->GetAttr("partition_dims", &partition_dims)); + int num_splits; + TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits)); + if (dtype == DT_RESOURCE) { + return errors::Unimplemented("Not implemented."); + } else if (c->num_inputs() == 0) { + return errors::InvalidArgument( + "Expected at least one input to TPUPartitionedOutputV2."); + } + + ShapeHandle handle = c->input(0); + int rank = InferenceContext::Rank(handle); + int num_cores_per_replica = 1; + for (const int& partition_dim : partition_dims) { + num_cores_per_replica *= partition_dim; + } + + if (num_splits != num_cores_per_replica) { + return errors::InvalidArgument("Expected ", num_cores_per_replica, + " splits."); + } else if (rank > (int)partition_dims.size()) { + return errors::InvalidArgument("Expected at least ", rank, + " partition dimensions."); + } + + for (int i = 0; i < rank; ++i) { + shape_inference::DimensionHandle dim; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->Divide(c->Dim(handle, i), partition_dims[i], + true /* evenly_divisible */, &dim), + "Number of ways to split should evenly divide the split dimension"); + TF_CHECK_OK(c->ReplaceDim(handle, i, dim, &handle)); + } + + for (int i = num_splits - 1; i >= 0; --i) { + c->set_output(i, handle); + } + return OkStatus(); }); |