aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTayo Oguntebi <tayo@google.com>2021-02-17 15:39:17 -0800
committerTensorFlower Gardener <gardener@tensorflow.org>2021-02-17 15:46:32 -0800
commit24f8591d6ed36a61470d4132e949b282243b0d23 (patch)
tree4df0b48b942f03fdc253036cc6495522ed23c40f
parent607ffbc56cf054a02b86d05e232cb6640e11519d (diff)
downloadtensorflow-24f8591d6ed36a61470d4132e949b282243b0d23.tar.gz
Sets num_tasks value for use in TPUReplicate rewrite.
PiperOrigin-RevId: 358048719 Change-Id: Ic396a70ab1bd14d760fa120548bb01dc3a2dbe56
-rw-r--r--tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc1
1 files changed, 1 insertions, 0 deletions
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index aed0add2be3..a183c3dc522 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -3934,6 +3934,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
device_set, tpu_compilation_device,
&num_tpus_per_task, &tpu_devices));
+ *num_tasks = tpu_devices.size();
string topology;
TF_RETURN_IF_ERROR(