From 02ace469b6b4fb58c6fe15489a1676f2e07eacf7 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Mon, 21 Oct 2024 20:33:43 -0700 Subject: [PATCH] revert some changes --- src/levanter/infra/cli_helpers.py | 37 +++------------- src/levanter/infra/ray_tpu.py | 70 ++++++++++++------------------- 2 files changed, 31 insertions(+), 76 deletions(-) diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index 1e8828b47..ccd839d89 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -2,7 +2,7 @@ import base64 import os import subprocess -from typing import Any, Dict, List, Optional +from typing import Optional import yaml from google.cloud import storage @@ -59,37 +59,6 @@ def get_git_commit(): return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() -class DockerRunCommand: - def __init__(self, image_id: str, command: List[str], *, foreground: bool, env: Dict[str, Any], name="levanter"): - self.base_part = [ - "docker", - "run", - "-t" if foreground else "-d", - f"--name={name}", - "--privileged", - "--shm-size=32gb", - "--net=host", - "--init", - "--mount", - "type=volume,source=levanter,target=/home/levanter", - "-v", - "/tmp:/tmp", - ] - - self.env_part: List[str] = [] - self.add_env(env) - - self.cmd_part = [image_id, *command] - - def add_env(self, env: Dict[str, Any]): - for k, v in env.items(): - self.env_part.extend(["-e", k + f"={str(v)}"]) - - @property - def full_cmd(self): - return self.base_part + self.env_part + self.cmd_part - - def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"): docker_command = [ "docker", @@ -106,6 +75,10 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante "/tmp:/tmp", ] + # optionally add multislice env vars (if set by ray runtime env vars) + for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: + docker_command.extend(["-e", v]) + for k, v in env.items(): docker_command.extend(["-e", k + f"={str(v)}"]) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 8788820ed..8c4e7a5ee 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -17,7 +17,7 @@ from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError from ray.remote_function import RemoteFunction -from levanter.infra.cli_helpers import DockerRunCommand +from levanter.infra.cli_helpers import make_docker_run_command from levanter.utils.ray_utils import ser_exc_info @@ -63,13 +63,12 @@ class TpuRunError(_TpuRunResult): error: Exception -def run_on_pod(docker_cmd: DockerRunCommand, name: str, tpu_type: str) -> ray.ObjectRef: +def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.ObjectRef: """ Run a remote function on a TPU pod. Args: - docker_cmd: A DockerRunCommand object that holds a docker command to run - name: docker image name + remote_fn: A remote function that takes no arguments tpu_type: The type of TPU to run on, e.g. "v4-32" Returns: @@ -77,14 +76,9 @@ def run_on_pod(docker_cmd: DockerRunCommand, name: str, tpu_type: str) -> ray.Ob """ @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) - def do_run(docker_cmd: DockerRunCommand, name: str) -> _TpuRunResult: + def do_run(remote_fn) -> _TpuRunResult: num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4 - def _run_docker(): - run_docker(docker_cmd=docker_cmd.full_cmd, name=name) - - remote_fn = ray.remote(_run_docker) - remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts) info = _TpuInfo(tpu_name, "ACTIVE", "TPU") @@ -101,16 +95,15 @@ def _run_docker(): logger.exception("Failed to kill job after primary failure") return _handle_ray_error(info, e) - return do_run.remote(docker_cmd, name) + return do_run.remote(remote_fn) -def run_on_pod_multislice(docker_cmd: DockerRunCommand, name: str, tpu_type: str, num_slices: int) -> ray.ObjectRef: +def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef: """ Run a remote function on multiple TPU slices. Args: - docker_cmd: A DockerRunCommand object that holds a docker command to run - name: docker image name + remote_fn: A remote function that takes no arguments tpu_type: The type of TPU to run on, e.g. "v4-32" num_slices: The number of slices to run @@ -128,7 +121,7 @@ def __init__(self): def get_slice_info(self): return self.pod_name, self.num_hosts, self.ip - def do_run(self, docker_cmd, name, coordinator_ip, slice_id, num_slices) -> _TpuRunResult: + def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult: port = 8081 mxla_env = { "MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}", @@ -137,13 +130,6 @@ def do_run(self, docker_cmd, name, coordinator_ip, slice_id, num_slices) -> _Tpu "MEGASCALE_SLICE_ID": str(slice_id), } - docker_cmd.add_env(mxla_env) - - def _run_docker(): - run_docker(docker_cmd=docker_cmd.full_cmd, name=name) - - remote_fn = ray.remote(_run_docker) - remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env) info = _TpuInfo(tpu_name, "ACTIVE", "TPU") @@ -178,7 +164,7 @@ def _run_docker(): coordinator_ip = slice_infos[0][2] - return [actor.do_run.remote(docker_cmd, name, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)] + return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)] def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): @@ -205,13 +191,12 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): return remote_fn, tpu_name -def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6, max_retries_failure=10): +def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_retries_failure=10): """ Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached. Args: - docker_cmd: A DockerRunCommand object that holds a docker command to run - name: docker image name + remote_fn: A remote function that takes no arguments tpu_type: The type of TPU to run on, e.g. "v4-32" max_retries_preemption: The maximum number of times to retry if the job is preempted max_retries_failure: The maximum number of times to retry if the job fails @@ -230,7 +215,7 @@ def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6, attempt += 1 problem = None try: - out = ray.get(run_on_pod(docker_cmd, name, tpu_type)) + out = ray.get(run_on_pod(remote_fn, tpu_type)) except ray.exceptions.RayTaskError as e: problem = e if "preempted" in str(e): @@ -275,14 +260,13 @@ def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6, def run_on_pod_multislice_resumable( - docker_cmd, name, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10 + remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10 ): """ Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached. Args: - docker_cmd: A DockerRunCommand object that holds a docker command to run - name: docker image name + remote_fn: A remote function that takes no arguments tpu_type: The type of TPU to run on, e.g. "v4-32" num_slices: The number of slices to run max_retries_preemption: The maximum number of times to retry if the job is preempted @@ -302,7 +286,7 @@ def run_on_pod_multislice_resumable( attempt += 1 problem = None try: - outs = ray.get(run_on_pod_multislice(docker_cmd, name, tpu_type, num_slices)) + outs = ray.get(run_on_pod_multislice(remote_fn, tpu_type, num_slices)) except ray.exceptions.RayTaskError as e: problem = e if "preempted" in str(e): @@ -362,30 +346,28 @@ def _run_command(*args, **kwargs): return subprocess.check_call(args, **kwargs) -def run_docker(docker_cmd, name="levanter"): - _kill_old_container(name) - try: - return _run_command(*docker_cmd) - except subprocess.CalledProcessError as e: - logger.exception("Failed to run docker command") - raise e - - def run_docker_on_pod( image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10 ): env = _massage_env(env) - docker_cmd = DockerRunCommand(image_id, command, env=env, foreground=True, name=name) + docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name) + + def run_docker(): + _kill_old_container(name) + try: + return _run_command(*docker_cmd) + except subprocess.CalledProcessError as e: + logger.exception("Failed to run docker command") + raise e if num_slices == 1: run_on_pod_resumable( - docker_cmd, name=name, tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 + ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 ) else: run_on_pod_multislice_resumable( - docker_cmd, - name=name, + ray.remote(run_docker), tpu_type=tpu_type, num_slices=num_slices, max_retries_failure=retries,