Skip to content

Commit

Permalink
revert some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Oct 22, 2024
1 parent ecf1e45 commit 02ace46
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 76 deletions.
37 changes: 5 additions & 32 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import base64
import os
import subprocess
from typing import Any, Dict, List, Optional
from typing import Optional

import yaml
from google.cloud import storage
Expand Down Expand Up @@ -59,37 +59,6 @@ def get_git_commit():
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()


class DockerRunCommand:
def __init__(self, image_id: str, command: List[str], *, foreground: bool, env: Dict[str, Any], name="levanter"):
self.base_part = [
"docker",
"run",
"-t" if foreground else "-d",
f"--name={name}",
"--privileged",
"--shm-size=32gb",
"--net=host",
"--init",
"--mount",
"type=volume,source=levanter,target=/home/levanter",
"-v",
"/tmp:/tmp",
]

self.env_part: List[str] = []
self.add_env(env)

self.cmd_part = [image_id, *command]

def add_env(self, env: Dict[str, Any]):
for k, v in env.items():
self.env_part.extend(["-e", k + f"={str(v)}"])

@property
def full_cmd(self):
return self.base_part + self.env_part + self.cmd_part


def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"):
docker_command = [
"docker",
Expand All @@ -106,6 +75,10 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
"/tmp:/tmp",
]

# optionally add multislice env vars (if set by ray runtime env vars)
for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
docker_command.extend(["-e", v])

for k, v in env.items():
docker_command.extend(["-e", k + f"={str(v)}"])

Expand Down
70 changes: 26 additions & 44 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError
from ray.remote_function import RemoteFunction

from levanter.infra.cli_helpers import DockerRunCommand
from levanter.infra.cli_helpers import make_docker_run_command
from levanter.utils.ray_utils import ser_exc_info


Expand Down Expand Up @@ -63,28 +63,22 @@ class TpuRunError(_TpuRunResult):
error: Exception


def run_on_pod(docker_cmd: DockerRunCommand, name: str, tpu_type: str) -> ray.ObjectRef:
def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.ObjectRef:
"""
Run a remote function on a TPU pod.
Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
def do_run(docker_cmd: DockerRunCommand, name: str) -> _TpuRunResult:
def do_run(remote_fn) -> _TpuRunResult:
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4

def _run_docker():
run_docker(docker_cmd=docker_cmd.full_cmd, name=name)

remote_fn = ray.remote(_run_docker)

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
Expand All @@ -101,16 +95,15 @@ def _run_docker():
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)

return do_run.remote(docker_cmd, name)
return do_run.remote(remote_fn)


def run_on_pod_multislice(docker_cmd: DockerRunCommand, name: str, tpu_type: str, num_slices: int) -> ray.ObjectRef:
def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef:
"""
Run a remote function on multiple TPU slices.
Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
Expand All @@ -128,7 +121,7 @@ def __init__(self):
def get_slice_info(self):
return self.pod_name, self.num_hosts, self.ip

def do_run(self, docker_cmd, name, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
port = 8081
mxla_env = {
"MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
Expand All @@ -137,13 +130,6 @@ def do_run(self, docker_cmd, name, coordinator_ip, slice_id, num_slices) -> _Tpu
"MEGASCALE_SLICE_ID": str(slice_id),
}

docker_cmd.add_env(mxla_env)

def _run_docker():
run_docker(docker_cmd=docker_cmd.full_cmd, name=name)

remote_fn = ray.remote(_run_docker)

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
Expand Down Expand Up @@ -178,7 +164,7 @@ def _run_docker():

coordinator_ip = slice_infos[0][2]

return [actor.do_run.remote(docker_cmd, name, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]
return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
Expand All @@ -205,13 +191,12 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
return remote_fn, tpu_name


def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6, max_retries_failure=10):
def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_retries_failure=10):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails
Expand All @@ -230,7 +215,7 @@ def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6,
attempt += 1
problem = None
try:
out = ray.get(run_on_pod(docker_cmd, name, tpu_type))
out = ray.get(run_on_pod(remote_fn, tpu_type))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
Expand Down Expand Up @@ -275,14 +260,13 @@ def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6,


def run_on_pod_multislice_resumable(
docker_cmd, name, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
max_retries_preemption: The maximum number of times to retry if the job is preempted
Expand All @@ -302,7 +286,7 @@ def run_on_pod_multislice_resumable(
attempt += 1
problem = None
try:
outs = ray.get(run_on_pod_multislice(docker_cmd, name, tpu_type, num_slices))
outs = ray.get(run_on_pod_multislice(remote_fn, tpu_type, num_slices))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
Expand Down Expand Up @@ -362,30 +346,28 @@ def _run_command(*args, **kwargs):
return subprocess.check_call(args, **kwargs)


def run_docker(docker_cmd, name="levanter"):
_kill_old_container(name)
try:
return _run_command(*docker_cmd)
except subprocess.CalledProcessError as e:
logger.exception("Failed to run docker command")
raise e


def run_docker_on_pod(
image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10
):
env = _massage_env(env)

docker_cmd = DockerRunCommand(image_id, command, env=env, foreground=True, name=name)
docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name)

def run_docker():
_kill_old_container(name)
try:
return _run_command(*docker_cmd)
except subprocess.CalledProcessError as e:
logger.exception("Failed to run docker command")
raise e

if num_slices == 1:
run_on_pod_resumable(
docker_cmd, name=name, tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
else:
run_on_pod_multislice_resumable(
docker_cmd,
name=name,
ray.remote(run_docker),
tpu_type=tpu_type,
num_slices=num_slices,
max_retries_failure=retries,
Expand Down

0 comments on commit 02ace46

Please sign in to comment.