Skip to content

Commit

Permalink
Merge branch 'main' into memory_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenying-liu authored Nov 26, 2024
2 parents 956e8d2 + f29bf3a commit 7ec6086
Show file tree
Hide file tree
Showing 96 changed files with 4,282 additions and 711 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
* @gobbleturk @jonb377 @khatwanimohit @bvandermoon @vipannalla
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla
28 changes: 28 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:
- why is this change being made,
- the problem being solved and any relevant context,
- why this is a good solution,
- some information about the specific implementation,
- shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

# Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

# Checklist

Before submitting this PR, please make sure (put X in square brackets):
- [ ] I have performed a self-review of my code.
- [ ] I have necessary comments in my code, particularly in hard-to-understand areas.
- [ ] I have run end-to-end tests tests and provided workload links above if applicable.
- [ ] I have made or will make corresponding changes to the doc if needed.
1 change: 0 additions & 1 deletion .github/workflows/AddLabel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ jobs:
// This list should match with CODEOWNERS
let requiredReviewers = {
gobbleturk: "",
jonb377: "",
khatwanimohit: "",
bvandermoon: "",
vipannalla: "",
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,16 @@ jobs:
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test decode.py
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1
- name: Test int8_decode
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1 quantization=int8 quantize_kvcache=True
- name: Test decode.py with per_device_batch_size < 1
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=.25
- name: Test int8_training
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test fp8_training
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test train.py with dropout
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} max_target_length=128 per_device_batch_size=1 dropout_rate=0.02
- name: Test generate_param_only_checkpoint
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }}
- name: Test generate_param_only_checkpoint with int8 quantization
Expand Down
15 changes: 10 additions & 5 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_nightly MODE=nightly DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_nightly
- name: build jax stable stack image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable_stack_0.4.33 MODE=stable_stack DEVICE=TPU PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack_0.4.33 BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable_stack_0.4.35 MODE=stable_stack DEVICE=TPU PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack_0.4.35 BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
- name: build image with stable stack nightly jax
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_stable_stack_nightly_jax MODE=stable_stack DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
gpu:
strategy:
fail-fast: false
Expand All @@ -54,10 +57,12 @@ jobs:
run: docker system prune --all --force
- name: build jax stable image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable MODE=stable DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_stable
- name: build jax nightly image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_nightly MODE=nightly DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_nightly
- name: build jax pinned image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_pinned MODE=pinned DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_pinned
- name: build jax stable stack image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable_stack_0.4.35 MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack_0.4.35 BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu:jax0.4.35-cuda_dl24.10-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
- name: build image with stable stack nightly jax
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_stable_stack_nightly_jax MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
11 changes: 11 additions & 0 deletions .github/workflows/require-checklist.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Require Checklist
on:
pull_request:
types: [opened, edited, synchronize]
jobs:
check_pr_body:
runs-on: ubuntu-latest
steps:
- uses: mheap/require-checklist-action@v2
with:
requireChecklist: true # If this is true and there are no checklists detected, the action will fail
19 changes: 14 additions & 5 deletions MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,28 @@ class SystemCharacteristics:
platform: str
topology_name: Optional[str]
chip_config_name: Optional[str] # 'megacore' or 'default'
chips_per_host_bounds: Optional[tuple]
chips_per_host_bounds: Optional[tuple] # number of chips on each host in each dimension.
devices_per_slice: int
wrap: Optional[tuple]
wrap: Optional[tuple] # wrap around for each dimension (i.e., locus type)


UserFacingNameToSystemCharacteristics = {
# v5e
# v6e: one core per chip with 32 GB HBM
"v6e-1": SystemCharacteristics("tpu", "v6e:1x1", "default", (1, 1, 1), 1, (False, False, False)),
"v6e-4": SystemCharacteristics("tpu", "v6e:2x2", "default", (2, 2, 1), 4, (False, False, False)),
"v6e-8": SystemCharacteristics("tpu", "v6e:2x4", "default", (2, 2, 1), 8, (False, False, False)),
"v6e-16": SystemCharacteristics("tpu", "v6e:4x4", "default", (2, 2, 1), 16, (False, False, False)),
"v6e-32": SystemCharacteristics("tpu", "v6e:4x8", "default", (2, 2, 1), 32, (False, False, False)),
"v6e-64": SystemCharacteristics("tpu", "v6e:8x8", "default", (2, 2, 1), 64, (False, False, False)),
"v6e-128": SystemCharacteristics("tpu", "v6e:8x16", "default", (2, 2, 1), 128, (False, True, False)),
"v6e-256": SystemCharacteristics("tpu", "v6e:16x16", "default", (2, 2, 1), 256, (True, True, False)),
# v5e: one core per chip with 16 GB HBM
"v5e-16": SystemCharacteristics("tpu", "v5e:4x4", "default", (2, 2, 1), 16, (False, False, False)),
"v5e-32": SystemCharacteristics("tpu", "v5e:4x8", "default", (2, 2, 1), 32, (False, False, False)),
"v5e-64": SystemCharacteristics("tpu", "v5e:8x8", "default", (2, 2, 1), 64, (False, False, False)),
"v5e-128": SystemCharacteristics("tpu", "v5e:8x16", "default", (2, 2, 1), 128, (False, True, False)),
"v5e-256": SystemCharacteristics("tpu", "v5e:16x16", "default", (2, 2, 1), 256, (True, True, False)),
# v4
# v4: one megacore per chip with 32 GB HBM
"v4-8": SystemCharacteristics("tpu", "v4:2x2x1", "megacore", (2, 2, 1), 4, (False, False, False)),
"v4-16": SystemCharacteristics("tpu", "v4:2x2x2", "megacore", (2, 2, 1), 8, (False, False, False)),
"v4-32": SystemCharacteristics("tpu", "v4:2x2x4", "megacore", (2, 2, 1), 16, (False, False, False)),
Expand All @@ -54,7 +63,7 @@ class SystemCharacteristics:
"v4-1536": SystemCharacteristics("tpu", "v4:8x8x12", "megacore", (2, 2, 1), 768, (True, True, True)),
"v4-2048": SystemCharacteristics("tpu", "v4:8x8x16", "megacore", (2, 2, 1), 1024, (True, True, True)),
"v4-4096": SystemCharacteristics("tpu", "v4:8x16x16", "megacore", (2, 2, 1), 2048, (True, True, True)),
# v5p
# v5p: one megacore per chip with 96 GB HBM
"v5p-8": SystemCharacteristics("tpu", "v5:2x2x1", "megacore", (2, 2, 1), 4, (False, False, False)),
"v5p-16": SystemCharacteristics("tpu", "v5:2x2x2", "megacore", (2, 2, 1), 8, (False, False, False)),
"v5p-32": SystemCharacteristics("tpu", "v5:2x2x4", "megacore", (2, 2, 1), 16, (False, False, False)),
Expand Down
29 changes: 10 additions & 19 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@

abstract_logger = ocp.logging.abstract_logger
cloud_logger = ocp.logging.cloud_logger
composite_logger = ocp.logging.composite_logger
standard_logger = ocp.logging.standard_logger


def create_orbax_checkpoint_manager(
Expand Down Expand Up @@ -107,7 +105,6 @@ def create_orbax_emergency_checkpoint_manager(
global_mesh=global_mesh,
abstract_state=abstract_state,
options=options,
local_state_handler=emergency_checkpoint_manager.local_checkpoint_handler(),
logger=orbax_logger,
)

Expand Down Expand Up @@ -218,8 +215,11 @@ def map_to_pspec(data):
),
None,
)

if dataset_type == "grain" and data_iterator is not None:
if (
dataset_type == "grain"
and data_iterator is not None
and (checkpoint_manager.directory / str(latest_step) / "iter").exists()
):
return (
checkpoint_manager.restore(
latest_step,
Expand Down Expand Up @@ -262,32 +262,23 @@ def map_to_pspec(data):
return None, None


def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None:
def setup_checkpoint_logger(config) -> cloud_logger.CloudLogger | None:
"""Setup checkpoint logger.
Args:
config
Returns:
CompositeLogger
CloudLogger
"""
orbax_cloud_logger = None
orbax_standard_logger = None
max_logging.log("Setting up checkpoint logger...")
if config.enable_checkpoint_cloud_logger:
logger_name = f"checkpoint_{config.run_name}"
logger_name = f"goodput_{config.run_name}"
options = cloud_logger.CloudLoggerOptions(job_name=config.run_name, logger_name=logger_name)
orbax_cloud_logger = cloud_logger.CloudLogger(options=options)
max_logging.log("Successfully set up checkpoint cloud logger.")
return orbax_cloud_logger

if config.enable_checkpoint_standard_logger:
orbax_standard_logger = standard_logger.StandardLogger()
max_logging.log("Successfully set up checkpoint standard logger.")

orbax_logger = None
if orbax_cloud_logger is not None and orbax_standard_logger is not None:
orbax_logger = composite_logger.CompositeLogger(orbax_cloud_logger, orbax_standard_logger)
max_logging.log("Successfully set up checkpoint composite logger.")

return orbax_logger
return orbax_cloud_logger


def load_params_from_path(load_parameters_from_path, abstract_unboxed_params):
Expand Down
2 changes: 2 additions & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
LENGTH = "activation_length"
EMBED = "activation_embed"
HEAD = "activation_heads"
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
KV_BATCH = "activation_kv_batch"
KV_HEAD = "activation_kv_heads"
KV_HEAD_DIM = "activation_kv_head_dim"
D_KV = "activation_kv"
CACHE_BATCH_PREFILL = "cache_batch_prefill"
CACHE_BATCH = "cache_batch"
CACHE_SEQUENCE = "cache_sequence"
CACHE_HEADS = "cache_heads"
Expand Down
56 changes: 56 additions & 0 deletions MaxText/configs/a3/llama_3.1_405b/128vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
echo "Running 128vm.sh"
# Example command to invoke this script via XPK, assume you've installed xpk
# COMMAND="bash MaxText/configs/a3/llama_3.1_405b/128vm.sh"
# COMMAND='export LD_LIBRARY_PATH=/usr/local/cuda-12.6/compat:$LD_LIBRARY_PATH;'"${COMMAND}";
#
# xpk workload create --project=${PROJECT}--cluster=${CLUSTER_NAME} --zone=${ZONE} \
# --workload=${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type=${DEVICE_TYPE} --num-nodes=2 --priority=high \
# --command="$COMMAND" --env=XLA_FLAGS=$XLA_FLAGS

# Stop execution if any command exits with error
set -e

export OUTPUT_PATH="gs://maxtext-experiments-multipod"
export RUN_NAME="llama-31-128vm-$(date +%Y-%m-%d-%H-%M)"
export EXECUTABLE="train.py"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0
--xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

# 128 nodes
python MaxText/$EXECUTABLE MaxText/configs/models/llama3.1_405b.yml run_name=$RUN_NAME \
base_config=base.yml \
run_name=gpu_train_test \
hardware=gpu \
steps=10 \
model_name=llama3.1-405b \
enable_checkpointing: False \
attention=cudnn_flash_te \
remat_policy=full \
use_iota_embed=True \
scan_layers=True \
dataset_type=synthetic \
async_checkpointing=False \
logits_dot_in_fp32=False \
per_device_batch_size=1.0 \
max_target_length=8192 \
dcn_fsdp_parallelism=128 \
ici_fsdp_parallelism=8 \
base_output_directory=$OUTPUT_PATH \
profiler=xplane

Loading

0 comments on commit 7ec6086

Please sign in to comment.