Skip to content

Commit

Permalink
Merge pull request #6 from jkbhagatio/dev
Browse files Browse the repository at this point in the history
Dev post ddp
  • Loading branch information
jkbhagatio authored May 1, 2024
2 parents 0b9d1e7 + 8d7d81d commit bcdfd51
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 78 deletions.
21 changes: 0 additions & 21 deletions ddp.slurm

This file was deleted.

109 changes: 53 additions & 56 deletions ddp.py → ddp_and_fsdp/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse # noqa: I001
import os
import sys
import time
from itertools import product
from pathlib import Path
Expand All @@ -17,28 +18,25 @@
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.data.distributed import DistributedSampler

# Import nanogpt from relative directory.
nanogpt_dir = Path.cwd().parent
sys.path.append(str(nanogpt_dir))
from nanogpt import NanoGPT, build_dataset

# Hyperparameters for model setup.
LR_SET = [5e-2, 1e-3, 1e-4] # learning rate set
OPTIM_SET = [Adam, AdamW, NAdam] # optimizer set
ARCH_SET = [ # model architecture set
{"ctx_len": 256, "emb_dim": 256, "n_heads": 8, "head_sz": 32, "n_blocks": 8},
{"ctx_len": 2048, "emb_dim": 768, "n_heads": 12, "head_sz": 64, "n_blocks": 12},
{"ctx_len": 2048, "emb_dim": 1024, "n_heads": 16, "head_sz": 64, "n_blocks": 12},
{"ctx_len": 2048, "emb_dim": 1024, "n_heads": 20, "head_sz": 80, "n_blocks": 12},
]

def setup(
rank: int, # rank of current process
world_size: int, # number of processes
master_addr: str, # master machine address (IP or hostname)
master_port: str, # master machine port
):
def setup(backend: str):
"""Sets up the DDP environment."""
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
# Create distributed process group.
init_process_group(backend="gloo", rank=rank, world_size=world_size)
# Create distributed process group and set cuda device according to torchrun LOCAL_RANK env var.
init_process_group(backend=backend)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def cleanup():
"""Cleans up and kills DDP environment."""
Expand All @@ -51,9 +49,10 @@ def train(
val_loader: DataLoader, # batched dataset for validation
optimizer: optim, # optimizer
loss_fn: nn.modules.loss, # loss function
rank: int, # rank of current process
global_rank: int, # rank of current process across all nodes
local_rank: int, # rank of current process within node
max_epochs: int = 5, # max n training epochs
max_batches: int = 500, # max n batches to train
max_batches: int = 1000, # max n batches to train
val_chk_interval: int = 200, # check val loss every `val_chk_interval` batches & print losses
val_iter: int = 5, # number of batches on val_loader to run and avg when computing val loss
patience_thresh: int = 1e9, # consecutive batches without val loss decrease for early stopping
Expand All @@ -69,8 +68,8 @@ def estimate_losses(
"""Estimate losses on val_loader, and return val loss and train loss avg."""
model.eval()
for val_i, (x_val, y_val) in enumerate(val_loader):
logits = model(x_val.to(rank))
val_loss = loss_fn(logits.view(-1, n_tokens), y_val.to(rank).view(-1))
logits = model(x_val.to(local_rank))
val_loss = loss_fn(logits.view(-1, n_tokens), y_val.to(local_rank).view(-1))
val_losses.append(val_loss.item())
if val_i >= (val_iter - 1):
break
Expand Down Expand Up @@ -101,7 +100,7 @@ def apply_gradient_centralization(optimizer):
train_losses, val_losses, train_losses_avg, val_losses_avg = [], [], [], []
init_loss, best_val_loss = float("inf"), float("inf")
patience_ct = 0
if rank == 0:
if global_rank == 0:
wandb.log({"expected_total_batches": batch_lim})
# /s>

Expand All @@ -111,9 +110,9 @@ def apply_gradient_centralization(optimizer):
for batch_i, (x_train, y_train) in enumerate(train_loader):
# <ss Model training.
optimizer.zero_grad()
logits = model(x_train.to(rank)) # -> [batch_sz, ctx_len, n_tokens], but...
logits = model(x_train.to(local_rank)) # -> [batch_sz, ctx_len, n_tokens], but...
# must reshape to compare against batch_sz vector of targets for cross-entropy loss
loss = loss_fn(logits.view(-1, n_tokens), y_train.to(rank).view(-1))
loss = loss_fn(logits.view(-1, n_tokens), y_train.to(local_rank).view(-1))
loss.backward()
apply_gradient_centralization(optimizer)
optimizer.step()
Expand All @@ -125,29 +124,29 @@ def apply_gradient_centralization(optimizer):
estimate_losses(
model, val_loader, val_losses, val_losses_avg, train_losses, train_losses_avg
)
if rank == 0:
if global_rank == 0:
wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]})
# Return if patience check reached (early stopping).
patience_ct = (
0 if val_losses_avg[-1] < best_val_loss else patience_ct + val_chk_interval
)
best_val_loss = min(best_val_loss, val_losses_avg[-1])
if patience_ct >= patience_thresh:
if rank == 0:
if global_rank == 0:
wandb.log(
{"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}
)
return loss, train_losses_avg, val_losses_avg
# Return if max_batches reached.
if (batch_i + 1) * (epoch + 1) >= max_batches:
if rank == 0:
if global_rank == 0:
wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]})
return loss, train_losses_avg, val_losses_avg
# Save checkpoint check.
if (
Path(save_chkpt_dir).exists()
and (init_loss - loss.item()) > save_chkpt_thresh
and rank == 0
and global_rank == 0
):
torch.save(
model.module.state_dict(),
Expand All @@ -156,7 +155,7 @@ def apply_gradient_centralization(optimizer):
init_loss = loss.item()
# /ss>
# <ss Progress metrics.
if rank == 0:
if global_rank == 0:
n_comp_batches = epoch * n_batches + batch_i + 1
elapsed_t = time.time() - start_t
avg_batch_t = elapsed_t / n_comp_batches
Expand All @@ -170,7 +169,7 @@ def apply_gradient_centralization(optimizer):
)
# /ss> /s>
# Return after max_epochs reached.
if rank == 0:
if global_rank == 0:
wandb.log(
{
"train_loss": train_losses_avg[-1],
Expand All @@ -179,18 +178,17 @@ def apply_gradient_centralization(optimizer):
"estimated_time_remaining": est_remaining_t
}
)
if Path(save_chkpt_dir).exists() and rank == 0:
if Path(save_chkpt_dir).exists() and local_rank == 0:
torch.save(
model.module.state_dict(),
Path(save_chkpt_dir) / f"model_chkpt_loss{loss.item():.3f}.pth"
)
return loss, train_losses_avg, val_losses_avg

def main(
rank: int, # rank of current process
world_size: int, # number of processes
master_addr: str, # master machine address (IP or hostname)
master_port: str, # master machine port
backend: str, # DDP backend to use
global_rank: int, # rank of current process across all nodes
local_rank: int, # rank of current process within node
text_file: str, # path to text file to train on
train_config: tuple[float, optim.Optimizer, list[dict]], # lr, optimizer, model config
):
Expand All @@ -199,7 +197,7 @@ def main(
Sets up DDP env, creates dataset from text file, creates and trains model, cleans up DDP env.
"""
# Set up DDP environment.
setup(rank, world_size, master_addr, master_port)
setup(backend)
# Set up dataset.
with open(text_file) as f:
text = f.read()
Expand All @@ -215,15 +213,15 @@ def main(
)
# Set up model.
model = NanoGPT(n_tokens=len(tokens), **train_config[2])
model = DDP(model.to(rank), device_ids=[rank])
model = DDP(model.to(local_rank), device_ids=[local_rank])
# Initialize wandb config and run.
param_bytes = 4 # 32-bit floats
bytes_in_gb = 1024**3
n_tot_params = sum(p.numel() for p in model.parameters())
n_tot_params_b = round(n_tot_params / 1e9, 3)
tot_sz_gb = n_tot_params * param_bytes / bytes_in_gb
run_name = f"{train_config[1].__name__}-{train_config[0]}_{n_tot_params_b}B"
if rank == 0:
if global_rank == 0:
wandb_config = {
"n_params_bil": n_tot_params_b,
"sz_gb": tot_sz_gb,
Expand All @@ -240,7 +238,16 @@ def main(
optimizer = train_config[1](model.parameters(), lr=train_config[0])
loss_fn = nn.CrossEntropyLoss()
save_chkpt_dir = Path.home() / "nanogpt_ddp_runs" / "chkpts" / run_name
train(model, train_loader, val_loader, optimizer, loss_fn, rank, save_chkpt_dir=save_chkpt_dir)
train(
model,
train_loader,
val_loader,
optimizer,
loss_fn,
global_rank,
local_rank,
save_chkpt_dir=save_chkpt_dir
)
# Clean up DDP environment.
cleanup()

Expand All @@ -250,39 +257,29 @@ def main(
# Parse args.
parser = argparse.ArgumentParser(description="Run DDP distributed training of NanoGPTs.")
parser.add_argument(
"--train-config-idx",
"--ddp_backend",
type=str,
default="nccl",
help="DDP backend to use (typically 'nccl' on Unix-like system, 'gloo' on Windows)."
)
parser.add_argument(
"--train_config_idx",
type=int,
required=True,
help="Index of train config to run. (See `train_configs` var)"
)
parser.add_argument(
"--world-size", type=int, required=True, help="Number of processes to use for DDP."
)
#parser.add_argument("--rank", type=int, required=True, help="Rank of current process.")
parser.add_argument(
"--master-addr", type=str, required=True, help="Master address (or hostname) for DDP."
)
parser.add_argument("--master-port", type=str, default="4444", help="Master port for DDP.")
parser.add_argument(
"--text-file",
"--text_file",
type=str,
default=(Path.cwd() / "data/tiny_austen.txt"),
default=(Path.cwd().parent / "data/tiny_austen.txt"),
help="Path to text file to train on."
)
args = parser.parse_args()
# Get ranks from torchrun env vars.
global_rank = int(os.environ["RANK"]) # rank of current process across all nodes
local_rank = int(os.environ["LOCAL_RANK"]) # rank of current process within node
# Set training config.
train_configs = list(product(LR_SET, OPTIM_SET, ARCH_SET))
train_config = train_configs[args.train_config_idx]
# Run DDP training.
mp.spawn( # passes `rank` to `main` as first arg automatically
main,
args=(
args.world_size,
args.master_addr,
args.master_port,
args.text_file,
train_config,
),
nprocs=args.world_size,
join=True,
)
main(args.ddp_backend, global_rank, local_rank, args.text_file, train_config)
37 changes: 37 additions & 0 deletions ddp_and_fsdp/ddp.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash
#SBATCH --job-name=ddp-training
#SBATCH --partition=a100
#SBATCH --nodes=1
#SBATCH --mem=128G
#SBATCH --ntasks=2 # processes per job
#SBATCH --gres=gpu:2 # gpus total across nodes
#SBATCH --array=0-26%3 # jobs, % max in parallel (27 unique models, given hyperparemeter configurations)
#SBATCH --output=/nfs/nhome/live/jbhagat/nanogpt_ddp_runs/job_%j.out
#SBATCH --error=/nfs/nhome/live/jbhagat/nanogpt_ddp_runs/job_%j.err

# Set first node as the master
HEAD_NODE_HOSTNAME=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
HEAD_NODE_IP=$(nslookup $HEAD_NODE_HOSTNAME | grep 'Address:' | awk 'NR==2 {print $2}')

# Dynamically calculate number of processes per node, based on number of nodes assigned for this job
PROCS_PER_NODE=$(($SLURM_NTASKS / $SLURM_JOB_NUM_NODES))

# Echo vars to .out file
echo "HEAD_NODE_HOSTNAME: $HEAD_NODE_HOSTNAME, HEAD_NODE_IP: $HEAD_NODE_IP, PROCS_PER_NODE: $PROCS_PER_NODE"

# Activate env
source /nfs/nhome/live/jbhagat/mambaforge/etc/profile.d/conda.sh
conda activate nanogpt

# Run ddp
srun torchrun \
--standalone \
--nnodes=${SLURM_JOB_NUM_NODES} \
--nproc_per_node=${PROCS_PER_NODE} \
/nfs/nhome/live/jbhagat/nanoGPT/ddp_and_fsdp/ddp.py \
--train_config_idx="$SLURM_ARRAY_TASK_ID"

# rdzv args for multinode
#--rdzv_id=4444 \
#--rdzv_backend="c10d" \
#--rdzv_endpoint="$HEAD_NODE_IP:44444"
9 changes: 8 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Multi-head self-attention is implemented "from scratch", at the level of pytorch

While the overall architecture is similar, this nanoGPT makes departures from Karpathy's nanoGPT in: naming conventions, data loading and training configuration, projecting embedding dimensions to attention heads, the format of operations in self-attention units and transformer blocks, output model generation (by adding parameters such as `temp` and `top_k`), and more.

Additionally, examples of distributed training of models across multiple GPUs using PyTorch
Distributed Data Parallel (DDP) and Fully Sharded Data Parallel (FSDP) via Slurm can be found in the `ddp_and_fsdp` directory.

## Examples

### nanoGPT-Shakespeare
Expand Down Expand Up @@ -41,6 +44,10 @@ Output generated from models trained after approximately 320000 (top), 640000 (m

- `tests/` contains tests that can be run via pytest for verifying components of nanoGPT work as expected.

- `ddp_and_fsdp/` contains python modules and slurm scripts for:
- 1: speeding up training of a single model across multiple GPUs via model copying and distributed batching using DDP.
- 2: training a single large model across multiple GPUs via sharding using FSDP.

- `.github/workflows/` contains a github actions workflow for building the python environment, running tests, and uploading the results to codecov.

## Usage
Expand Down Expand Up @@ -75,7 +82,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import nanogpt
nanogpt_dir = Path.cwd()
sys.path.append(nanogpt_dir)
sys.path.append(str(nanogpt_dir))
import nanogpt

# Load in text file to train on and build dataloaders
Expand Down

0 comments on commit bcdfd51

Please sign in to comment.