diff --git a/.github/workflows/docker-base-image.yaml b/.github/workflows/docker-base-image.yaml index a5ada69c3..a5e6c3724 100644 --- a/.github/workflows/docker-base-image.yaml +++ b/.github/workflows/docker-base-image.yaml @@ -1,9 +1,12 @@ name: Build and Push Docker TPU Images on: - push: - branches: - - main + workflow_run: + workflows: ["Run Tests"] + types: + - completed + branches: [main] + workflow_dispatch: jobs: build: diff --git a/.github/workflows/launch_small_fast.yaml b/.github/workflows/launch_small_fast.yaml new file mode 100644 index 000000000..15f423674 --- /dev/null +++ b/.github/workflows/launch_small_fast.yaml @@ -0,0 +1,72 @@ +name: Launch Llama 2 Small Fast + +on: + workflow_run: + workflows: ["Build and Push Docker TPU Images"] + types: + - completed + branches: [main, "experiment/*"] +# pull_request: + workflow_dispatch: + +jobs: + test: + if: (github.event.pull_request.head.repo.full_name == github.repository) + runs-on: ubuntu-latest + env: + TPU_ZONE: "us-central2-b" + TPU_TYPE: "v4-32" + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Google Cloud SDK + uses: google-github-actions/setup-gcloud@v1 + with: + project_id: ${{ secrets.GCP_PROJECT_ID }} + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v1 + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + + - name: Configure Google Cloud + run: | + gcloud config set project ${{ secrets.GCP_PROJECT_ID }} + REGION=${TPU_ZONE%-*} + echo "$REGION" + gcloud auth configure-docker $REGION-docker.pkg.dev + + - name: Install locally + run: | + python -m pip install --upgrade pip + pip install -e .[test] "jax[cpu]==0.4.30" + + - name: Launch Small Fast TPU Train LM job + run: | + export TPU_NAME=small-fast-${{ github.run_id }} + export WANDB_API_KEY=${{ secrets.WANDB_API_KEY }} + export RUN_ID=small_fast_${{ github.run_id }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} + + cat > .config <=0.2" ] [project.urls] diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 21aaf5faa..e03add43d 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -8,11 +8,13 @@ import threading import time import warnings +from datetime import timedelta from typing import Callable, Optional import humanfriendly import jax -from tqdm import tqdm +from tqdm_loggable import tqdm_logging +from tqdm_loggable.auto import tqdm import levanter.tracker from levanter.data import DataLoader @@ -39,7 +41,9 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n else: desc = "eval" + _tqdm_logging_one_time_setup() pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) + iter_ = iter(pbar) while True: time_in = time.time() @@ -186,6 +190,8 @@ def pbar_logger(iterable=None, desc="train", **tqdm_mkwargs): kwargs["desc"] = desc if "iterable" not in kwargs: kwargs["iterable"] = iterable + + _tqdm_logging_one_time_setup() pbar = tqdm(**kwargs) def update_pbar(step: StepInfo): @@ -359,3 +365,14 @@ def compute_and_viz_log_probs(step: StepInfo): wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs + + +_did_tqdm_logging_one_time_setup = False + + +def _tqdm_logging_one_time_setup(): + global _did_tqdm_logging_one_time_setup + if _did_tqdm_logging_one_time_setup: + return + _did_tqdm_logging_one_time_setup = True + tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60)) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 48fcb426c..2aa9b7ff3 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -8,8 +8,8 @@ import jax.numpy as jnp import jmp import numpy as np -import tqdm from jax.sharding import Mesh +from tqdm_loggable.auto import tqdm import haliax as hax from haliax.partitioning import ResourceMapping @@ -300,7 +300,7 @@ def evaluate(self, m: LmHeadModel): iterator = LoadingTimeTrackerIterator(self.loader) n = 0 - for batch, tags in tqdm.tqdm(iterator, "eval"): + for batch, tags in tqdm(iterator, "eval"): state = self.accum_for_batch(m, state, batch, tags) n += 1 diff --git a/src/levanter/infra/docker.py b/src/levanter/infra/docker.py index 63a51ae2f..39f3b1325 100644 --- a/src/levanter/infra/docker.py +++ b/src/levanter/infra/docker.py @@ -65,7 +65,12 @@ def read(fd): return b"".join(output) else: - return subprocess.check_output(argv, stderr=subprocess.STDOUT) + try: + return subprocess.check_output(argv, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + # print the output if the command failed, reraising the exception + print(e.output.decode()) + raise e def configure_gcp_docker(project_id, region, repository): diff --git a/src/levanter/infra/tpus.py b/src/levanter/infra/tpus.py index 7e630f069..b8a8df9e0 100644 --- a/src/levanter/infra/tpus.py +++ b/src/levanter/infra/tpus.py @@ -49,6 +49,7 @@ def list_tpus(zone): "list", f"--zone={zone}", "--format=json(name.basename(), state)", + "--quiet", ] ) ) @@ -68,6 +69,7 @@ def describe_tpu(tpu_name, zone): tpu_name, f"--zone={zone}", "--format=json(name.basename(), state)", + "--quiet", ], stderr=subprocess.DEVNULL, ) @@ -77,6 +79,8 @@ def describe_tpu(tpu_name, zone): def start_tpu_vm(tpu_name, *, tpu_type, capacity_type, version, zone, node_count): + # ensure alpha is enabled + run_command("gcloud", "components", "install", "alpha", "--quiet") if version is None: version = "tpu-ubuntu2204-base" tpu_stat = describe_tpu(tpu_name, zone) @@ -196,17 +200,31 @@ def run_command(*args, **kwargs): def add_ssh_key(ssh_key_filename): # format 3072 SHA256:... key-name (RSA) - key_hash = subprocess.check_output(["ssh-keygen", "-lf", ssh_key_filename]).decode("utf-8").split()[1] - existing_keys = subprocess.check_output(["ssh-add", "-l"]).decode("utf-8").split("\n") - for key in existing_keys: - if key_hash in key: - return + try: + key_hash = ( + subprocess.check_output(["ssh-keygen", "-lf", ssh_key_filename], stderr=subprocess.STDOUT) + .decode("utf-8") + .split()[1] + ) + existing_keys = ( + subprocess.check_output(["ssh-add", "-l"], stderr=subprocess.STDOUT).decode("utf-8").split("\n") + ) + for key in existing_keys: + if key_hash in key: + return - subprocess.check_call(["ssh-add", ssh_key_filename]) + subprocess.check_call(["ssh-add", ssh_key_filename]) + except subprocess.CalledProcessError: + raise def tpu_ssh(tpu_name, zone, node_count, *args, ignore_failure=False): - add_ssh_key(os.path.expanduser("~/.ssh/google_compute_engine")) + try: + add_ssh_key(os.path.expanduser("~/.ssh/google_compute_engine")) + except subprocess.CalledProcessError as e: + print("Failed to add ssh key. This may lead to problems.", e) + pass + try: if node_count > 1: return _tpu_ssh_multislice(tpu_name, zone, node_count, *args, ignore_failure=ignore_failure) @@ -219,6 +237,7 @@ def tpu_ssh(tpu_name, zone, node_count, *args, ignore_failure=False): "tpu-vm", "ssh", tpu_name, + "--quiet", "--worker=all", f"--zone={zone}", "--command=%s" % " ".join(args), @@ -243,6 +262,7 @@ def _tpu_ssh_multislice(tpu_name, zone, node_count, *args, ignore_failure=False) "ssh", f"{tpu_name}-{i}", "--worker=all", + "--quiet", f"--zone={zone}", "--command=%s" % " ".join(args), )