Skip to content

Commit

Permalink
attempt at launching small fast in CI, add tqdm_loggable (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Sep 13, 2024
1 parent 2645efb commit 5fc4084
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 28 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/docker-base-image.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
name: Build and Push Docker TPU Images

on:
push:
branches:
- main
workflow_run:
workflows: ["Run Tests"]
types:
- completed
branches: [main]
workflow_dispatch:

jobs:
build:
Expand Down
72 changes: 72 additions & 0 deletions .github/workflows/launch_small_fast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
name: Launch Llama 2 Small Fast

on:
workflow_run:
workflows: ["Build and Push Docker TPU Images"]
types:
- completed
branches: [main, "experiment/*"]
# pull_request:
workflow_dispatch:

jobs:
test:
if: (github.event.pull_request.head.repo.full_name == github.repository)
runs-on: ubuntu-latest
env:
TPU_ZONE: "us-central2-b"
TPU_TYPE: "v4-32"

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@v1
with:
project_id: ${{ secrets.GCP_PROJECT_ID }}

- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v1
with:
credentials_json: ${{ secrets.GCP_SA_KEY }}

- name: Configure Google Cloud
run: |
gcloud config set project ${{ secrets.GCP_PROJECT_ID }}
REGION=${TPU_ZONE%-*}
echo "$REGION"
gcloud auth configure-docker $REGION-docker.pkg.dev
- name: Install locally
run: |
python -m pip install --upgrade pip
pip install -e .[test] "jax[cpu]==0.4.30"
- name: Launch Small Fast TPU Train LM job
run: |
export TPU_NAME=small-fast-${{ github.run_id }}
export WANDB_API_KEY=${{ secrets.WANDB_API_KEY }}
export RUN_ID=small_fast_${{ github.run_id }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
cat > .config <<EOF
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
WANDB_ENTITY: stanford-mercury
WANDB_PROJECT: levanter
HF_TOKEN: ${{ secrets.HF_TOKEN }}
EOF
python infra/launch.py -e CI 1 --foreground --tpu_name ${TPU_NAME} --run_id $RUN_ID --zone ${TPU_ZONE} --tpu_type ${TPU_TYPE} --preemptible -- \
python -m levanter.main.train_lm \
--config_path config/llama_small_fast.yaml \
--trainer.checkpointer.base_path gs://levanter-checkpoints/llama-itest/ \
--trainer.checkpointer.save_interval 10m
--trainer.num_train_steps 10000
- name: Cleanup
if: ${{ always() }}
run: |
export TPU_NAME=small-fast-${{ github.run_id }}
gcloud compute tpus queued-resources delete $TPU_NAME --zone ${TPU_ZONE} --quiet --force
45 changes: 31 additions & 14 deletions infra/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def main():
else:
raise ValueError(f"Unknown docker registry: {registry}")

failure = None

for i in range(retries + 1):
try:
launch_job(
Expand All @@ -128,24 +130,39 @@ def main():
print(f"Error running command {e.cmd}")
if i < retries - 1:
print("Retrying... %d/%d" % (i + 1, retries))
else:
print("Retries exhausted. Raising error.")
print(f"Error running command {e.cmd}")
print(f"Output: {e.output}")
failure = e
else:
print("Job finished with no error.")
break

if autodelete:
print("Autodelete is set to True. Tearing down machine...")
levanter.infra.tpus.run_command(
"gcloud",
"alpha",
"compute",
"tpus",
"queued-resources",
"delete",
tpu_name,
"--quiet",
f"--zone={zone}",
"--force",
)
try:
if autodelete:
print("Autodelete is set to True. Tearing down machine...")
levanter.infra.tpus.run_command(
"gcloud",
"alpha",
"compute",
"tpus",
"queued-resources",
"delete",
tpu_name,
"--quiet",
f"--zone={zone}",
"--force",
)
except Exception as e:
print(f"Error tearing down TPU {tpu_name}: {e}")
if failure:
raise failure
else:
raise e

if failure:
raise failure


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"filelock~=3.13",
# "ai2-olmo",
"async-lru~=2.0",
"tqdm-loggable>=0.2"
]

[project.urls]
Expand Down
19 changes: 18 additions & 1 deletion src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import threading
import time
import warnings
from datetime import timedelta
from typing import Callable, Optional

import humanfriendly
import jax
from tqdm import tqdm
from tqdm_loggable import tqdm_logging
from tqdm_loggable.auto import tqdm

import levanter.tracker
from levanter.data import DataLoader
Expand All @@ -39,7 +41,9 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n
else:
desc = "eval"

_tqdm_logging_one_time_setup()
pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches)

iter_ = iter(pbar)
while True:
time_in = time.time()
Expand Down Expand Up @@ -186,6 +190,8 @@ def pbar_logger(iterable=None, desc="train", **tqdm_mkwargs):
kwargs["desc"] = desc
if "iterable" not in kwargs:
kwargs["iterable"] = iterable

_tqdm_logging_one_time_setup()
pbar = tqdm(**kwargs)

def update_pbar(step: StepInfo):
Expand Down Expand Up @@ -359,3 +365,14 @@ def compute_and_viz_log_probs(step: StepInfo):
wandb.log({"log_probs": wandb.Html(path)}, step=step.step)

return compute_and_viz_log_probs


_did_tqdm_logging_one_time_setup = False


def _tqdm_logging_one_time_setup():
global _did_tqdm_logging_one_time_setup
if _did_tqdm_logging_one_time_setup:
return
_did_tqdm_logging_one_time_setup = True
tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60))
4 changes: 2 additions & 2 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import jax.numpy as jnp
import jmp
import numpy as np
import tqdm
from jax.sharding import Mesh
from tqdm_loggable.auto import tqdm

import haliax as hax
from haliax.partitioning import ResourceMapping
Expand Down Expand Up @@ -300,7 +300,7 @@ def evaluate(self, m: LmHeadModel):
iterator = LoadingTimeTrackerIterator(self.loader)
n = 0

for batch, tags in tqdm.tqdm(iterator, "eval"):
for batch, tags in tqdm(iterator, "eval"):
state = self.accum_for_batch(m, state, batch, tags)
n += 1

Expand Down
7 changes: 6 additions & 1 deletion src/levanter/infra/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def read(fd):

return b"".join(output)
else:
return subprocess.check_output(argv, stderr=subprocess.STDOUT)
try:
return subprocess.check_output(argv, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
# print the output if the command failed, reraising the exception
print(e.output.decode())
raise e


def configure_gcp_docker(project_id, region, repository):
Expand Down
34 changes: 27 additions & 7 deletions src/levanter/infra/tpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def list_tpus(zone):
"list",
f"--zone={zone}",
"--format=json(name.basename(), state)",
"--quiet",
]
)
)
Expand All @@ -68,6 +69,7 @@ def describe_tpu(tpu_name, zone):
tpu_name,
f"--zone={zone}",
"--format=json(name.basename(), state)",
"--quiet",
],
stderr=subprocess.DEVNULL,
)
Expand All @@ -77,6 +79,8 @@ def describe_tpu(tpu_name, zone):


def start_tpu_vm(tpu_name, *, tpu_type, capacity_type, version, zone, node_count):
# ensure alpha is enabled
run_command("gcloud", "components", "install", "alpha", "--quiet")
if version is None:
version = "tpu-ubuntu2204-base"
tpu_stat = describe_tpu(tpu_name, zone)
Expand Down Expand Up @@ -196,17 +200,31 @@ def run_command(*args, **kwargs):

def add_ssh_key(ssh_key_filename):
# format 3072 SHA256:... key-name (RSA)
key_hash = subprocess.check_output(["ssh-keygen", "-lf", ssh_key_filename]).decode("utf-8").split()[1]
existing_keys = subprocess.check_output(["ssh-add", "-l"]).decode("utf-8").split("\n")
for key in existing_keys:
if key_hash in key:
return
try:
key_hash = (
subprocess.check_output(["ssh-keygen", "-lf", ssh_key_filename], stderr=subprocess.STDOUT)
.decode("utf-8")
.split()[1]
)
existing_keys = (
subprocess.check_output(["ssh-add", "-l"], stderr=subprocess.STDOUT).decode("utf-8").split("\n")
)
for key in existing_keys:
if key_hash in key:
return

subprocess.check_call(["ssh-add", ssh_key_filename])
subprocess.check_call(["ssh-add", ssh_key_filename])
except subprocess.CalledProcessError:
raise


def tpu_ssh(tpu_name, zone, node_count, *args, ignore_failure=False):
add_ssh_key(os.path.expanduser("~/.ssh/google_compute_engine"))
try:
add_ssh_key(os.path.expanduser("~/.ssh/google_compute_engine"))
except subprocess.CalledProcessError as e:
print("Failed to add ssh key. This may lead to problems.", e)
pass

try:
if node_count > 1:
return _tpu_ssh_multislice(tpu_name, zone, node_count, *args, ignore_failure=ignore_failure)
Expand All @@ -219,6 +237,7 @@ def tpu_ssh(tpu_name, zone, node_count, *args, ignore_failure=False):
"tpu-vm",
"ssh",
tpu_name,
"--quiet",
"--worker=all",
f"--zone={zone}",
"--command=%s" % " ".join(args),
Expand All @@ -243,6 +262,7 @@ def _tpu_ssh_multislice(tpu_name, zone, node_count, *args, ignore_failure=False)
"ssh",
f"{tpu_name}-{i}",
"--worker=all",
"--quiet",
f"--zone={zone}",
"--command=%s" % " ".join(args),
)
Expand Down

0 comments on commit 5fc4084

Please sign in to comment.