aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChandra Devarakonda <chandrasekhard@google.com>2024-02-22 12:58:18 -0800
committerTensorFlow Release Automation <jenkins@tensorflow.org>2024-02-28 22:32:56 +0000
commit1a19b6e0f2507099c4c003785371f7e4f63f5c70 (patch)
tree1a2f72e2c3659c2a591cc1a95471250263bbed44
parent4bdc149ac84738b06a592473595c1c9c2bd2a9a3 (diff)
downloadtensorflow-upstream-r2.16-5e39a976964.tar.gz
Update grpc_tpu_worker.py fileupstream-r2.16-5e39a976964
PiperOrigin-RevId: 609469881
-rw-r--r--tensorflow/python/tools/grpc_tpu_worker.py40
1 files changed, 17 insertions, 23 deletions
diff --git a/tensorflow/python/tools/grpc_tpu_worker.py b/tensorflow/python/tools/grpc_tpu_worker.py
index 154fcaeacf0..e8cf2a59c5b 100644
--- a/tensorflow/python/tools/grpc_tpu_worker.py
+++ b/tensorflow/python/tools/grpc_tpu_worker.py
@@ -61,7 +61,11 @@ def setup_env_vars():
os.environ['TPU_STDERR_LOG_LEVEL'] = '0'
os.environ['CLOUD_TPU_TASK_ID'] = worker_id
os.environ['TPU_LOCK_DEVICE'] = 'true'
- os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
+ os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = (
+ worker_network_endpoints.split(',')[0].split(':')[2] + ':8476'
+ )
+ os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'
+
accelerator_type_to_host_bounds = {
# v2
'v2-8': '1,1,1',
@@ -78,29 +82,16 @@ def setup_env_vars():
'v3-512': '8,8,1',
'v3-1024': '8,16,1',
'v3-2048': '16,16,1',
- # v4
- 'v4-8': '1,1,1',
- 'v4-16': '1,1,2',
- 'v4-32': '1,1,4',
- 'v4-64': '1,2,4',
- 'v4-128': '2,2,4',
- 'v4-256': '2,2,8',
- 'v4-512': '2,4,8',
- 'v4-1024': '4,4,8',
- 'v4-2048': '4,4,16',
- 'v4-4096': '4,8,16',
- }
- os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
- accelerator_type]
- os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split(
- ',')[0].split(':')[2] + ':8476'
- os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'
-
- os.environ['TPU_STDERR_LOG_LEVEL'] = '0'
+ }
- if accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']:
- os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true'
+ # If v4 TPU don't set any topology related flags,
+ # libtpu will set these values.
+ if not (accelerator_type.startswith('v4-') or
+ accelerator_type.startswith('v5')):
+ os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
+ os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
+ accelerator_type]
# Set the hostname override.
os.environ['TPU_HOSTNAME_OVERRIDE'] = get_host_ip()
@@ -111,7 +102,10 @@ def main(unused_args):
server_def = tensorflow_server_pb2.ServerDef(protocol='grpc')
job_def = server_def.cluster.job.add()
job_def.name = 'tpu_worker'
- job_def.tasks[0] = 'localhost:8470'
+ tpu_task_port = os.getenv('TPU_TASK_PORT')
+ if tpu_task_port is None or not tpu_task_port:
+ tpu_task_port = '8470' # If TPU task port is not available, use 8470.
+ job_def.tasks[0] = 'localhost:' + tpu_task_port
server_def.job_name = 'tpu_worker'
server_def.task_index = 0