diff options
author | Chandra Devarakonda <chandrasekhard@google.com> | 2024-02-22 12:58:18 -0800 |
---|---|---|
committer | TensorFlow Release Automation <jenkins@tensorflow.org> | 2024-02-28 22:32:56 +0000 |
commit | 1a19b6e0f2507099c4c003785371f7e4f63f5c70 (patch) | |
tree | 1a2f72e2c3659c2a591cc1a95471250263bbed44 | |
parent | 4bdc149ac84738b06a592473595c1c9c2bd2a9a3 (diff) | |
download | tensorflow-upstream-r2.16-5e39a976964.tar.gz |
Update grpc_tpu_worker.py fileupstream-r2.16-5e39a976964
PiperOrigin-RevId: 609469881
-rw-r--r-- | tensorflow/python/tools/grpc_tpu_worker.py | 40 |
1 files changed, 17 insertions, 23 deletions
diff --git a/tensorflow/python/tools/grpc_tpu_worker.py b/tensorflow/python/tools/grpc_tpu_worker.py index 154fcaeacf0..e8cf2a59c5b 100644 --- a/tensorflow/python/tools/grpc_tpu_worker.py +++ b/tensorflow/python/tools/grpc_tpu_worker.py @@ -61,7 +61,11 @@ def setup_env_vars(): os.environ['TPU_STDERR_LOG_LEVEL'] = '0' os.environ['CLOUD_TPU_TASK_ID'] = worker_id os.environ['TPU_LOCK_DEVICE'] = 'true' - os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' + os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = ( + worker_network_endpoints.split(',')[0].split(':')[2] + ':8476' + ) + os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476' + accelerator_type_to_host_bounds = { # v2 'v2-8': '1,1,1', @@ -78,29 +82,16 @@ def setup_env_vars(): 'v3-512': '8,8,1', 'v3-1024': '8,16,1', 'v3-2048': '16,16,1', - # v4 - 'v4-8': '1,1,1', - 'v4-16': '1,1,2', - 'v4-32': '1,1,4', - 'v4-64': '1,2,4', - 'v4-128': '2,2,4', - 'v4-256': '2,2,8', - 'v4-512': '2,4,8', - 'v4-1024': '4,4,8', - 'v4-2048': '4,4,16', - 'v4-4096': '4,8,16', - } - os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ - accelerator_type] - os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split( - ',')[0].split(':')[2] + ':8476' - os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476' - - os.environ['TPU_STDERR_LOG_LEVEL'] = '0' + } - if accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']: - os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true' + # If v4 TPU don't set any topology related flags, + # libtpu will set these values. + if not (accelerator_type.startswith('v4-') or + accelerator_type.startswith('v5')): + os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' + os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ + accelerator_type] # Set the hostname override. os.environ['TPU_HOSTNAME_OVERRIDE'] = get_host_ip() @@ -111,7 +102,10 @@ def main(unused_args): server_def = tensorflow_server_pb2.ServerDef(protocol='grpc') job_def = server_def.cluster.job.add() job_def.name = 'tpu_worker' - job_def.tasks[0] = 'localhost:8470' + tpu_task_port = os.getenv('TPU_TASK_PORT') + if tpu_task_port is None or not tpu_task_port: + tpu_task_port = '8470' # If TPU task port is not available, use 8470. + job_def.tasks[0] = 'localhost:' + tpu_task_port server_def.job_name = 'tpu_worker' server_def.task_index = 0 |