diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index f4e62a99c8..832af00dc3 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -42,6 +42,7 @@ def setup( out_dir: Path = Path("out/finetune/full"), precision: Optional[str] = None, devices: Union[int, str] = 1, + num_nodes: Union[int, str] = 1, resume: Union[bool, Literal["auto"], Path] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( @@ -57,6 +58,7 @@ def setup( optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, + strategy: Literal["axonn", "fsdp"] = "fsdp" ) -> None: """Finetune a model. @@ -65,7 +67,8 @@ def setup( out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". - devices: How many devices/GPUs to use + devices: How many devices/GPUs per node to use + nodes: How many nodes to use resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. @@ -75,6 +78,7 @@ def setup( optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. + strategy: Parallel strategy to use. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) @@ -87,21 +91,31 @@ def setup( precision = precision or get_default_supported_precision(training=True) logger = choose_logger( - logger_name, out_dir, name=f"finetune-{config.name}", resume=bool(resume), log_interval=train.log_interval + logger_name, out_dir, name=f"finetune-{config.name}-{strategy}-clip", resume=bool(resume), log_interval=train.log_interval, + project="test-litgpt" ) if devices > 1: - strategy = FSDPStrategy( - auto_wrap_policy={Block}, - activation_checkpointing_policy={Block}, - state_dict_type="full", - limit_all_gathers=True, - cpu_offload=False, - ) + if strategy == "fsdp": + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + elif strategy == "axonn": + try: + from mpi4py import MPI + except ImportError: + pass + from axonn.lightning import AxonnStrategy + strategy = AxonnStrategy(G_intra_d=num_nodes * devices) else: strategy = "auto" - fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) + fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger) + devices = devices * num_nodes fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) @@ -138,7 +152,7 @@ def main( model = fabric.setup(model) - optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) + optimizer = instantiate_torch_optimizer(optimizer, model.parameters(), lr=3e-5) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0} @@ -239,11 +253,12 @@ def fit( logits = model(input_ids) # shift the targets such that output n predicts token n+1 loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:]) - fabric.backward(loss / train.gradient_accumulation_iters(devices)) + fabric.backward(loss / train.gradient_accumulation_iters(devices), model=model) running_loss.update(loss.detach()) if not is_accumulating: + fabric.clip_gradients(model, optimizer, 1.0) optimizer.step() optimizer.zero_grad() scheduler.step() diff --git a/litgpt/utils.py b/litgpt/utils.py index db4e54d9b2..a3e3adce97 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -493,7 +493,7 @@ def parse_devices(devices: Union[str, int]) -> int: def choose_logger( logger_name: Literal["csv", "tensorboard", "wandb"], out_dir: Path, - name: str, + project: str, log_interval: int = 1, resume: Optional[bool] = None, **kwargs: Any, @@ -503,7 +503,7 @@ def choose_logger( if logger_name == "tensorboard": return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs) if logger_name == "wandb": - return WandbLogger(project=name, resume=resume, **kwargs) + return WandbLogger(project=project, resume=resume, **kwargs) raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")