Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] full finetune / qlora + ac/offload/optm in bwd #21

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 118 additions & 65 deletions transformer_nuggets/llama/finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""
Used to train a model from scratch on big dense blocks of text data using causal attention.
# full finetuning on single gpu without FSDP
python transformer_nuggets/llama/finetune.py --profile --model 7B --enable_ac --full_finetune --optim_in_bwd

# qlora on 2 gpus with FSDP
python transformer_nuggets/llama/finetune.py --profile --model 7B --fsdp_num_gpus 2 --use_fsdp2 --enable_ac --register_nf4_param --cpu_offload
"""
import argparse
import functools
Expand All @@ -9,6 +13,8 @@
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from torch.distributed._composable import checkpoint
from torch.distributed.optim import _apply_optimizer_in_backward

import numpy as np
import torch
Expand Down Expand Up @@ -41,9 +47,49 @@ class Hyperparameters(transformer_nuggets.llama.train.Hyperparameters):

@dataclass
class TrainingConfig(transformer_nuggets.llama.train.TrainingConfig):
log_interval: int = 10
log_interval: int = 1
track_max_memory: bool = False
use_fsdp2: bool = False
enable_ac: bool = False
qlora_debug: bool = False
full_finetune: bool = False
optim_in_bwd: bool = False
cpu_offload: bool = False
register_nf4_param: bool = False


def get_profile_context(
hyper_params: Hyperparameters, train_config: TrainingConfig, rank: int = 0
):
"""Returns a context manager that can be used to profile the model."""

def trace_handler(prof):
fp8_linear_type = hyper_params.fp8_linear_type

dtype_str = fp8_linear_type if fp8_linear_type else "bf16"
output_str = str(train_config.log_dir / f"trace_llama_7b_hf_{dtype_str}_rank_{rank}.json")
prof.export_chrome_trace(output_str)
logging.info(f"Wrote profile to: {output_str}")

if train_config.profile:
context = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=100, warmup=1, active=2, repeat=1),
record_shapes=True,
with_stack=True,
on_trace_ready=trace_handler,
with_flops=True,
profile_memory=True,
experimental_config=torch.profiler._ExperimentalConfig(
enable_cuda_sync_events=True,
),
)
return context
else:
return nullcontext()


def main(
Expand All @@ -64,24 +110,30 @@ def main(

# Setup Model
model_args = ModelArgs.from_name(training_config.model_name)
if rank == 0:
logging.info(f"Initializing model: {training_config.model_name}")
# if rank == 0:
# logging.info(f"Initializing model: {training_config.model_name}")
with training_config.device:
model = Transformer(model_args).to(torch.bfloat16)
torch.set_default_dtype(torch.bfloat16)
model = Transformer(model_args)
model.init_parameters()

qlora_config = qlora.QloraConfig(
hyper_params.lora_r,
hyper_params.lora_alpha,
hyper_params.lora_dropout,
training_config.register_nf4_param,
)
qlora.swap_for_qlora(model, qlora_config, torch.bfloat16)

if training_config.qlora_debug:
qlora.swap_for_qlora_debug(model, qlora_config, torch.bfloat16)
elif not training_config.full_finetune:
qlora.swap_for_qlora(model, qlora_config, torch.bfloat16)
model.setup_caches(
hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device
)

if rank == 0:
logging.info("Setting up the dataloaders")
# if rank == 0:
# logging.info("Setting up the dataloaders")
train_data, val_data = load_datasets(hyper_params, training_config, rank, world_size)
train_dataloader = DataLoader(
train_data,
Expand All @@ -90,7 +142,11 @@ def main(
)
val_dataloader = DataLoader(val_data, batch_size=hyper_params.micro_batch_size, num_workers=2)

log_num_params(model)
# log_num_params(model)

if training_config.enable_ac:
for layer in model.layers:
checkpoint(layer)

if world_size > 1:
if training_config.use_fsdp2:
Expand All @@ -101,29 +157,39 @@ def main(
fully_shard,
reshard_after_forward=True,
)
fsdp_kwargs = {}
if training_config.cpu_offload:
from torch.distributed._composable.fsdp import OffloadPolicy
fsdp_kwargs["offload_policy"] = OffloadPolicy("cpu")
for layer in model.layers:
fully_shard_fn(layer)
fully_shard_fn(model)
fully_shard_fn(layer, **fsdp_kwargs)
fully_shard_fn(model, **fsdp_kwargs)
else:
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=ModuleWrapPolicy([TransformerBlock]),
)

torch.cuda.reset_peak_memory_stats()

if training_config.compile:
model = torch.compile(model)

if rank == 0:
logging.info(model)

optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=hyper_params.learning_rate,
weight_decay=hyper_params.weight_decay,
betas=(hyper_params.beta1, hyper_params.beta2),
foreach=hyper_params.foreach_optimizer,
)
# if rank == 0:
# logging.info(model)

if training_config.optim_in_bwd:
optimizer = None
_apply_optimizer_in_backward(optimizer_class=torch.optim.SGD, params=[p for p in model.parameters() if p.requires_grad], optimizer_kwargs={"lr": 2e-5})
else:
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=hyper_params.learning_rate,
weight_decay=hyper_params.weight_decay,
betas=(hyper_params.beta1, hyper_params.beta2),
foreach=hyper_params.foreach_optimizer,
)

train(
model,
Expand All @@ -141,6 +207,12 @@ def entrypoint(
profile: bool = False,
use_fsdp2: bool = False,
model_name: str = "7B",
enable_ac: bool = False,
qlora_debug: bool = False,
full_finetune: bool = False,
optim_in_bwd: bool = False,
cpu_offload: bool = False,
register_nf4_param: bool = False,
rank: int = 0,
world_size: int = 1,
):
Expand All @@ -154,6 +226,12 @@ def entrypoint(
device=torch.device(f"cuda:{rank}"),
use_fsdp2=use_fsdp2,
model_name=model_name,
enable_ac=enable_ac,
qlora_debug=qlora_debug,
full_finetune=full_finetune,
optim_in_bwd=optim_in_bwd,
cpu_offload=cpu_offload,
register_nf4_param=register_nf4_param,
)
main(hyper_params, training_config, rank, world_size)

Expand Down Expand Up @@ -193,9 +271,9 @@ def train(
training_config.log_dir
/ f"qlora_train_loss_{dtype_str}_overfit_{training_config.overfit}_compile_{training_config.compile}_{rank}.csv"
)
if rank == 0:
logging.info(f"val_loss_file: {val_loss_file}")
logging.info(f"train_loss_file: {train_loss_file}")
# if rank == 0:
# logging.info(f"val_loss_file: {val_loss_file}")
# logging.info(f"train_loss_file: {train_loss_file}")

this_batch_loss = torch.tensor(0.0, device=training_config.device)
this_batch_n = 0
Expand All @@ -204,8 +282,9 @@ def train(
with profile_context as p:
for iter_num in range(hyper_params.max_iters):
lr = get_lr(iter_num, hyper_params)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
if optimizer is not None:
for param_group in optimizer.param_groups:
param_group["lr"] = lr

input_ids, targets = next(train_iter)
input_ids = input_ids.pin_memory().to(training_config.device)
Expand All @@ -229,7 +308,7 @@ def train(
# Scale the loss by grad_accumulation iters
(loss / hyper_params.gradient_accumulation_iters).backward()

if not is_accumulating:
if not is_accumulating and optimizer is not None:
optimizer.step()
optimizer.zero_grad()
step_count += 1
Expand Down Expand Up @@ -269,10 +348,10 @@ def train(

if rank == 0:
mem_stats = torch.cuda.memory_stats()
peak_active = mem_stats["active_bytes.all.peak"] / (1024 * 1024)
peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1024 * 1024)
peak_active = mem_stats["active_bytes.all.peak"] / (1000 * 1000)
peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1000 * 1000)
logging.info(
f"iter={iter_num} max_iters={hyper_params.max_iters} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB"
f"iter={iter_num} max_iters={hyper_params.max_iters} step_count={step_count} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB"
)

if training_config.profile:
Expand All @@ -288,6 +367,8 @@ def train(
torch.cuda.memory._dump_snapshot(f"{memory_trace_path}")
torch.cuda.memory._record_memory_history(enabled=None)
logging.info(f"Wrote memory traces to: {memory_trace_path}")
if iter_num == 260:
break


class Dataset(IterableDataset):
Expand Down Expand Up @@ -366,10 +447,16 @@ def load_datasets(
help="if specified, runs FSDP with this many GPUs on a single host",
)
parser.add_argument("--use_fsdp2", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--enable_ac", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--model", default="7B")
parser.add_argument("--qlora_debug", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--full_finetune", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--optim_in_bwd", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--cpu_offload", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--register_nf4_param", action=argparse.BooleanOptionalAction, default=False)
args = parser.parse_args()
fsdp_num_gpus = args.fsdp_num_gpus
inner_args = (args.profile, args.use_fsdp2, args.model)
inner_args = (args.profile, args.use_fsdp2, args.model, args.enable_ac, args.qlora_debug, args.full_finetune, args.optim_in_bwd, args.cpu_offload, args.register_nf4_param)

if fsdp_num_gpus is None or fsdp_num_gpus == 1:
entrypoint(*inner_args)
Expand Down Expand Up @@ -413,37 +500,3 @@ def validate(
model.train()
write_loss_to_file(loss_file, training_iter, loss.item())
return val_loss.item()


def get_profile_context(
hyper_params: Hyperparameters, train_config: TrainingConfig, rank: int = 0
):
"""Returns a context manager that can be used to profile the model."""

def trace_handler(prof):
fp8_linear_type = hyper_params.fp8_linear_type

dtype_str = fp8_linear_type if fp8_linear_type else "bf16"
output_str = str(train_config.log_dir / f"trace_llama_7b_hf_{dtype_str}_rank_{rank}.json")
prof.export_chrome_trace(output_str)
logging.info(f"Wrote profile to: {output_str}")

if train_config.profile:
context = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=100, warmup=1, active=2, repeat=1),
record_shapes=True,
with_stack=True,
on_trace_ready=trace_handler,
with_flops=True,
profile_memory=True,
experimental_config=torch.profiler._ExperimentalConfig(
enable_cuda_sync_events=True,
),
)
return context
else:
return nullcontext()
1 change: 1 addition & 0 deletions transformer_nuggets/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_name(cls, name: str):
"CodeLlama-7b-Python-hf": dict(
block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
),
"mini": dict(n_layer=2, n_head=32, dim=4096),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
"30B": dict(n_layer=60, n_head=52, dim=6656),
Expand Down
Loading