Skip to content

Commit

Permalink
add axonn as an option, lower lr of optimizer, change wandb logging, …
Browse files Browse the repository at this point in the history
…add multi-node support
  • Loading branch information
siddharth9820 committed Jul 3, 2024
1 parent 43145e8 commit 56e4e1e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
39 changes: 27 additions & 12 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def setup(
out_dir: Path = Path("out/finetune/full"),
precision: Optional[str] = None,
devices: Union[int, str] = 1,
num_nodes: Union[int, str] = 1,
resume: Union[bool, Literal["auto"], Path] = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
Expand All @@ -57,6 +58,7 @@ def setup(
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
strategy: Literal["axonn", "fsdp"] = "fsdp"
) -> None:
"""Finetune a model.
Expand All @@ -65,7 +67,8 @@ def setup(
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs/<job-name>/share.
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
devices: How many devices/GPUs to use
devices: How many devices/GPUs per node to use
nodes: How many nodes to use
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
Expand All @@ -75,6 +78,7 @@ def setup(
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
strategy: Parallel strategy to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
Expand All @@ -87,21 +91,31 @@ def setup(

precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name, out_dir, name=f"finetune-{config.name}", resume=bool(resume), log_interval=train.log_interval
logger_name, out_dir, name=f"finetune-{config.name}-{strategy}-clip", resume=bool(resume), log_interval=train.log_interval,
project="test-litgpt"
)

if devices > 1:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
if strategy == "fsdp":
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
elif strategy == "axonn":
try:
from mpi4py import MPI
except ImportError:
pass
from axonn.lightning import AxonnStrategy
strategy = AxonnStrategy(G_intra_d=num_nodes * devices)
else:
strategy = "auto"

fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger)
fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger)
devices = devices * num_nodes
fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer)


Expand Down Expand Up @@ -138,7 +152,7 @@ def main(

model = fabric.setup(model)

optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = instantiate_torch_optimizer(optimizer, model.parameters(), lr=3e-5)
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0}
Expand Down Expand Up @@ -239,11 +253,12 @@ def fit(
logits = model(input_ids)
# shift the targets such that output n predicts token n+1
loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))
fabric.backward(loss / train.gradient_accumulation_iters(devices), model=model)

running_loss.update(loss.detach())

if not is_accumulating:
fabric.clip_gradients(model, optimizer, 1.0)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
Expand Down
4 changes: 2 additions & 2 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def parse_devices(devices: Union[str, int]) -> int:
def choose_logger(
logger_name: Literal["csv", "tensorboard", "wandb"],
out_dir: Path,
name: str,
project: str,
log_interval: int = 1,
resume: Optional[bool] = None,
**kwargs: Any,
Expand All @@ -503,7 +503,7 @@ def choose_logger(
if logger_name == "tensorboard":
return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs)
if logger_name == "wandb":
return WandbLogger(project=name, resume=resume, **kwargs)
return WandbLogger(project=project, resume=resume, **kwargs)
raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")


Expand Down

0 comments on commit 56e4e1e

Please sign in to comment.