Skip to content

Commit

Permalink
get max_mem in all node (#62853)
Browse files Browse the repository at this point in the history
  • Loading branch information
Difers authored Apr 17, 2024
1 parent e082e68 commit 4da4005
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/paddle/distributed/launch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,32 @@ def launch():
if tuner_cfg['metric_cfg']['name'] not in cur_cfg:
cur_cfg[tuner_cfg['metric_cfg']['name']] = None

path = f"auto_tuner/mem/{job_id}/{ip}"
if nnodes > 1:
while not client.put(
path, str(cur_cfg["max_mem_usage"]).encode('latin-1')
):
time.sleep(1)
result = list(client.get_prefix(f"auto_tuner/mem/{job_id}"))
size = len(result)
while size != nnodes:
time.sleep(1)
result = list(
client.get_prefix(f"auto_tuner/mem/{job_id}/")
)
size = len(result)
mem_allnodes = [i[0].decode() for i in result]

for mem in mem_allnodes:
if mem is None:
continue
if mem == "OOM":
cur_cfg["max_mem_usage"] = mem
break
cur_cfg["max_mem_usage"] = max(
int(mem), int(cur_cfg["max_mem_usage"])
)

# if need accurate peak memory
if os.environ.get("FLAGS_log_memory_stats", False):
max_peak_memory = None
Expand Down

0 comments on commit 4da4005

Please sign in to comment.