diff --git a/get_rank_from_slurm.sh b/get_rank_from_slurm.sh new file mode 100755 index 0000000000..881bdc4a0e --- /dev/null +++ b/get_rank_from_slurm.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# select_gpu_device wrapper script +export RANK=${SLURM_PROCID} +exec $* diff --git a/litgpt/__main__.py b/litgpt/__main__.py index dba2fe4c06..5aeca8e3d4 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from mpi4py import MPI import torch diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index d5b2f047fb..c831fc4e54 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -1,5 +1,10 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +try: + from mpi4py import MPI +except ImportError: + pass +from axonn.lightning import AxonnStrategy import sys import time from pathlib import Path @@ -142,23 +147,26 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens): while True: - try: - if not multiline: - prompt = input(">> Prompt: ") - else: - print(">> Prompt: (Type '!submit' on a new line to end input).") - prompt_lines = [] - while True: - line = input() - if line.strip().lower() in ("!submit", "!quit", "!exit"): - break - prompt_lines.append(line) - prompt = "\n".join(prompt_lines) - - except KeyboardInterrupt: - break - - prompt = prompt.lower().strip() + prompt = None + if fabric.global_rank == 0: + try: + if not multiline: + prompt = input(">> Prompt: ") + else: + print(">> Prompt: (Type '!submit' on a new line to end input).") + prompt_lines = [] + while True: + line = input() + if line.strip().lower() in ("!submit", "!quit", "!exit"): + break + prompt_lines.append(line) + prompt = "\n".join(prompt_lines) + + except KeyboardInterrupt: + break + + prompt = prompt.lower().strip() + prompt = fabric.broadcast(prompt, src=0) if not prompt or prompt in ("!quit", "!exit"): break @@ -219,7 +227,9 @@ def main( plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None - fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + strategy = AxonnStrategy(G_intra_r=4) + fabric = L.Fabric(devices=4, precision=precision, plugins=plugins, strategy=strategy) + fabric.launch() checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) @@ -233,9 +243,12 @@ def main( config = Config.from_file(checkpoint_dir / "model_config.yaml") with fabric.init_module(empty_init=True): - model = GPT(config) + model = GPT(config, use_axonn_linear=True) # enable the kv cache + + with fabric.init_tensor(): model.set_kv_cache(batch_size=1) + load_checkpoint(fabric, model, checkpoint_path) model.eval() diff --git a/litgpt/data/base.py b/litgpt/data/base.py index 36ef33fb8a..cc97ec17ae 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -121,3 +121,5 @@ def _sft_collate_fn( batched[key] = batched[key][:, :max_seq_length] return batched + + diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index f4e62a99c8..8ed3551f0f 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -1,4 +1,6 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from mpi4py import MPI +MPI.Init() import dataclasses import math import os @@ -36,12 +38,26 @@ save_hyperparameters, ) +@torch.no_grad() +def global_collate(input_ids, targets, pad_idx=0, ignore_idx=-100): + tensors = [input_ids, targets] + paddings = [pad_idx, ignore_idx] + padded_tensors = [] + for tensor, padding in zip(tensors, paddings): + local_sq = torch.tensor(tensor.shape[1], device="cuda") + torch.distributed.all_reduce(local_sq, torch.distributed.ReduceOp.MAX) + global_sq = local_sq + padded_tensor = torch.full((tensor.shape[0], global_sq), fill_value=padding, device="cuda") + padded_tensor[:, :tensor.shape[1]].copy_(tensor) + padded_tensors.append(padded_tensor) + return padded_tensors def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/full"), precision: Optional[str] = None, devices: Union[int, str] = 1, + num_nodes: Union[int, str] = 1, resume: Union[bool, Literal["auto"], Path] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( @@ -55,8 +71,9 @@ def setup( ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), optimizer: Union[str, Dict] = "AdamW", - logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", + logger_name: Literal["wandb", "tensorboard", "csv"] = "wandb", seed: int = 1337, + strategy: Literal["axonn", "fsdp"] = "fsdp" ) -> None: """Finetune a model. @@ -65,7 +82,8 @@ def setup( out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". - devices: How many devices/GPUs to use + devices: How many devices/GPUs per node to use + num_nodes: How many nodes to use resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. @@ -75,6 +93,7 @@ def setup( optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. + strategy: Parallel strategy to use. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) @@ -87,21 +106,29 @@ def setup( precision = precision or get_default_supported_precision(training=True) logger = choose_logger( - logger_name, out_dir, name=f"finetune-{config.name}", resume=bool(resume), log_interval=train.log_interval + logger_name, out_dir, name=f"finetune-{config.name}-{strategy}-clip", resume=bool(resume), log_interval=train.log_interval, + project="test-litgpt" ) if devices > 1: - strategy = FSDPStrategy( - auto_wrap_policy={Block}, - activation_checkpointing_policy={Block}, - state_dict_type="full", - limit_all_gathers=True, - cpu_offload=False, - ) + if strategy == "fsdp": + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + #activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + elif strategy == "axonn": + from axonn.lightning import AxonnStrategy + strategy = AxonnStrategy(G_intra_r=num_nodes * devices, + #activation_checkpointing_policy={Block}, + overlap_communication=True) else: strategy = "auto" - fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) + fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger) + devices = devices * num_nodes fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) @@ -132,13 +159,14 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): + config.use_axonn_linear=False model = GPT(config) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") model = fabric.setup(model) - optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) + optimizer = instantiate_torch_optimizer(optimizer, model.parameters(), lr=3e-5) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0} @@ -233,17 +261,19 @@ def fit( iter_t0 = time.perf_counter() batch = next(train_iterator) input_ids, targets = batch["input_ids"], batch["labels"] + input_ids, targets = global_collate(input_ids, targets) is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) # shift the targets such that output n predicts token n+1 loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:]) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices), model=model) running_loss.update(loss.detach()) if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=1.0) optimizer.step() optimizer.zero_grad() scheduler.step() @@ -306,6 +336,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva if k >= eval.max_iters: break input_ids, targets = batch["input_ids"], batch["labels"] + input_ids, targets = global_collate(input_ids, targets) logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 37a2794d42..49725a0455 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -150,7 +150,7 @@ def generate( @torch.inference_mode() def main( checkpoint_dir: Path, - prompt: str = "What food do llamas eat?", + prompt: str = "Write a matpotlib program to draw a stacked bar chart", *, num_samples: int = 1, max_new_tokens: int = 50, @@ -212,7 +212,10 @@ def main( plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None - fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + from axonn.lightning import AxonnStrategy + strategy = AxonnStrategy(G_intra_r=4) + fabric = L.Fabric(devices=4, precision=precision, plugins=plugins, strategy=strategy) + fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") @@ -233,7 +236,7 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): - model = GPT(config) + model = GPT(config, use_axonn_linear=True) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need diff --git a/litgpt/model.py b/litgpt/model.py index fe71c60b80..6962c33e23 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -15,14 +15,15 @@ from litgpt.config import Config +from axonn.intra_layer import Linear class GPT(nn.Module): - def __init__(self, config: Config) -> None: + def __init__(self, config: Config, use_axonn_linear=False) -> None: super().__init__() assert config.padded_vocab_size is not None + config.use_axonn_linear = use_axonn_linear self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -30,6 +31,12 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) + + if use_axonn_linear: + self.lm_head = Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + else: + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None @@ -193,14 +200,31 @@ def __init__(self, config: Config) -> None: super().__init__() shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + if config.use_axonn_linear: + self.attn = Linear(config.n_embd, shape, bias=config.bias, expert_mode=False) + else: + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` - self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) + if config.use_axonn_linear: + self.proj = Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias, expert_mode=False, transpose=True) + else: + self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None self.config = config + if config.use_axonn_linear: + # adjust number of heads + from copy import deepcopy + from axonn import axonn as ax + self.config = deepcopy(self.config) + attention_world_size = ax.config.G_intra_r + assert self.config.n_head % attention_world_size == 0 + self.config.n_head //= attention_world_size + assert self.config.n_query_groups % attention_world_size == 0 + self.config.n_query_groups //= attention_world_size + def forward( self, @@ -287,8 +311,12 @@ def build_kv_cache( class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() - self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + if config.use_axonn_linear: + self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias, expert_mode=False) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias, expert_mode=False, transpose=True) + else: + self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) self.config = config @@ -301,9 +329,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LLaMAMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() - self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + if config.use_axonn_linear: + self.fc_1 = Linear(config.n_embd, config.intermediate_size, bias=config.bias, expert_mode=False) + self.fc_2 = Linear(config.n_embd, config.intermediate_size, bias=config.bias, expert_mode=False) + self.proj = Linear(config.intermediate_size, config.n_embd, bias=config.bias, expert_mode=False, transpose=True) + else: + self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) self.config = config diff --git a/litgpt/utils.py b/litgpt/utils.py index db4e54d9b2..a3e3adce97 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -493,7 +493,7 @@ def parse_devices(devices: Union[str, int]) -> int: def choose_logger( logger_name: Literal["csv", "tensorboard", "wandb"], out_dir: Path, - name: str, + project: str, log_interval: int = 1, resume: Optional[bool] = None, **kwargs: Any, @@ -503,7 +503,7 @@ def choose_logger( if logger_name == "tensorboard": return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs) if logger_name == "wandb": - return WandbLogger(project=name, resume=resume, **kwargs) + return WandbLogger(project=project, resume=resume, **kwargs) raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.") diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000..d0f3aad0f4 --- /dev/null +++ b/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +module load pytorch/2.3.1 +. /global/common/software/m4641/venv-2.3.1/bin/activate + + + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +NNODES=$SLURM_JOB_NUM_NODES +GPUS=$(( NNODES * 4 )) + +export WORLD_SIZE=$GPUS +export MASTER_ADDR=$(hostname) +export MASTER_PORT=29500 +export CUDA_VISIBLE_DEVICES=3,2,1,0 +export NCCL_NET_GDR_LEVEL=PHB +export NCCL_CROSS_NIC=1 +export NCCL_SOCKET_IFNAME=hsn +# these are specific to perlmutter's slingshot-11 network +# +export NCCL_NET="AWS Libfabric" +export FI_CXI_RDZV_THRESHOLD=0 +export FI_CXI_RDZV_GET_MIN=0 +export FI_CXI_OFLOW_BUF_SIZE=1073741824 +export FI_CXI_OFLOW_BUF_COUNT=1 + +export MPICH_GPU_SUPPORT_ENABLED=0 + +#MODEL="tiiuae/falcon-7b" +#MODEL="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +MODEL="google/gemma-7b" +STRAT="axonn" #fsdp/axonn + +SCRIPT="python -u -m litgpt finetune_full $MODEL --data Alpaca --devices 4 --train.global_batch_size 32 --train.micro_batch_size 2 --num_nodes $NNODES --strategy $STRAT" + +run_cmd="srun -C gpu -N $NNODES -n $GPUS -c 32 --cpu-bind=cores --gpus-per-node=4 --ntasks-per-node=4 ./get_rank_from_slurm.sh $SCRIPT" + +echo $run_cmd +eval $run_cmd +set +x