diff --git a/transformer_nuggets/llama/finetune.py b/transformer_nuggets/llama/finetune.py index 8c52d53..6ba85ad 100644 --- a/transformer_nuggets/llama/finetune.py +++ b/transformer_nuggets/llama/finetune.py @@ -1,5 +1,9 @@ """ -Used to train a model from scratch on big dense blocks of text data using causal attention. +# full finetuning on single gpu without FSDP +python transformer_nuggets/llama/finetune.py --profile --model 7B --enable_ac --full_finetune --optim_in_bwd + +# qlora on 2 gpus with FSDP +python transformer_nuggets/llama/finetune.py --profile --model 7B --fsdp_num_gpus 2 --use_fsdp2 --enable_ac --register_nf4_param --cpu_offload """ import argparse import functools @@ -9,6 +13,8 @@ from contextlib import nullcontext from dataclasses import dataclass from pathlib import Path +from torch.distributed._composable import checkpoint +from torch.distributed.optim import _apply_optimizer_in_backward import numpy as np import torch @@ -41,9 +47,49 @@ class Hyperparameters(transformer_nuggets.llama.train.Hyperparameters): @dataclass class TrainingConfig(transformer_nuggets.llama.train.TrainingConfig): - log_interval: int = 10 + log_interval: int = 1 track_max_memory: bool = False use_fsdp2: bool = False + enable_ac: bool = False + qlora_debug: bool = False + full_finetune: bool = False + optim_in_bwd: bool = False + cpu_offload: bool = False + register_nf4_param: bool = False + + +def get_profile_context( + hyper_params: Hyperparameters, train_config: TrainingConfig, rank: int = 0 +): + """Returns a context manager that can be used to profile the model.""" + + def trace_handler(prof): + fp8_linear_type = hyper_params.fp8_linear_type + + dtype_str = fp8_linear_type if fp8_linear_type else "bf16" + output_str = str(train_config.log_dir / f"trace_llama_7b_hf_{dtype_str}_rank_{rank}.json") + prof.export_chrome_trace(output_str) + logging.info(f"Wrote profile to: {output_str}") + + if train_config.profile: + context = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=100, warmup=1, active=2, repeat=1), + record_shapes=True, + with_stack=True, + on_trace_ready=trace_handler, + with_flops=True, + profile_memory=True, + experimental_config=torch.profiler._ExperimentalConfig( + enable_cuda_sync_events=True, + ), + ) + return context + else: + return nullcontext() def main( @@ -64,24 +110,30 @@ def main( # Setup Model model_args = ModelArgs.from_name(training_config.model_name) - if rank == 0: - logging.info(f"Initializing model: {training_config.model_name}") + # if rank == 0: + # logging.info(f"Initializing model: {training_config.model_name}") with training_config.device: - model = Transformer(model_args).to(torch.bfloat16) + torch.set_default_dtype(torch.bfloat16) + model = Transformer(model_args) model.init_parameters() qlora_config = qlora.QloraConfig( hyper_params.lora_r, hyper_params.lora_alpha, hyper_params.lora_dropout, + training_config.register_nf4_param, ) - qlora.swap_for_qlora(model, qlora_config, torch.bfloat16) + + if training_config.qlora_debug: + qlora.swap_for_qlora_debug(model, qlora_config, torch.bfloat16) + elif not training_config.full_finetune: + qlora.swap_for_qlora(model, qlora_config, torch.bfloat16) model.setup_caches( hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device ) - if rank == 0: - logging.info("Setting up the dataloaders") + # if rank == 0: + # logging.info("Setting up the dataloaders") train_data, val_data = load_datasets(hyper_params, training_config, rank, world_size) train_dataloader = DataLoader( train_data, @@ -90,7 +142,11 @@ def main( ) val_dataloader = DataLoader(val_data, batch_size=hyper_params.micro_batch_size, num_workers=2) - log_num_params(model) + # log_num_params(model) + + if training_config.enable_ac: + for layer in model.layers: + checkpoint(layer) if world_size > 1: if training_config.use_fsdp2: @@ -101,9 +157,13 @@ def main( fully_shard, reshard_after_forward=True, ) + fsdp_kwargs = {} + if training_config.cpu_offload: + from torch.distributed._composable.fsdp import OffloadPolicy + fsdp_kwargs["offload_policy"] = OffloadPolicy("cpu") for layer in model.layers: - fully_shard_fn(layer) - fully_shard_fn(model) + fully_shard_fn(layer, **fsdp_kwargs) + fully_shard_fn(model, **fsdp_kwargs) else: model = FSDP( model, @@ -111,19 +171,25 @@ def main( auto_wrap_policy=ModuleWrapPolicy([TransformerBlock]), ) + torch.cuda.reset_peak_memory_stats() + if training_config.compile: model = torch.compile(model) - if rank == 0: - logging.info(model) - - optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - lr=hyper_params.learning_rate, - weight_decay=hyper_params.weight_decay, - betas=(hyper_params.beta1, hyper_params.beta2), - foreach=hyper_params.foreach_optimizer, - ) + # if rank == 0: + # logging.info(model) + + if training_config.optim_in_bwd: + optimizer = None + _apply_optimizer_in_backward(optimizer_class=torch.optim.SGD, params=[p for p in model.parameters() if p.requires_grad], optimizer_kwargs={"lr": 2e-5}) + else: + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], + lr=hyper_params.learning_rate, + weight_decay=hyper_params.weight_decay, + betas=(hyper_params.beta1, hyper_params.beta2), + foreach=hyper_params.foreach_optimizer, + ) train( model, @@ -141,6 +207,12 @@ def entrypoint( profile: bool = False, use_fsdp2: bool = False, model_name: str = "7B", + enable_ac: bool = False, + qlora_debug: bool = False, + full_finetune: bool = False, + optim_in_bwd: bool = False, + cpu_offload: bool = False, + register_nf4_param: bool = False, rank: int = 0, world_size: int = 1, ): @@ -154,6 +226,12 @@ def entrypoint( device=torch.device(f"cuda:{rank}"), use_fsdp2=use_fsdp2, model_name=model_name, + enable_ac=enable_ac, + qlora_debug=qlora_debug, + full_finetune=full_finetune, + optim_in_bwd=optim_in_bwd, + cpu_offload=cpu_offload, + register_nf4_param=register_nf4_param, ) main(hyper_params, training_config, rank, world_size) @@ -193,9 +271,9 @@ def train( training_config.log_dir / f"qlora_train_loss_{dtype_str}_overfit_{training_config.overfit}_compile_{training_config.compile}_{rank}.csv" ) - if rank == 0: - logging.info(f"val_loss_file: {val_loss_file}") - logging.info(f"train_loss_file: {train_loss_file}") + # if rank == 0: + # logging.info(f"val_loss_file: {val_loss_file}") + # logging.info(f"train_loss_file: {train_loss_file}") this_batch_loss = torch.tensor(0.0, device=training_config.device) this_batch_n = 0 @@ -204,8 +282,9 @@ def train( with profile_context as p: for iter_num in range(hyper_params.max_iters): lr = get_lr(iter_num, hyper_params) - for param_group in optimizer.param_groups: - param_group["lr"] = lr + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr input_ids, targets = next(train_iter) input_ids = input_ids.pin_memory().to(training_config.device) @@ -229,7 +308,7 @@ def train( # Scale the loss by grad_accumulation iters (loss / hyper_params.gradient_accumulation_iters).backward() - if not is_accumulating: + if not is_accumulating and optimizer is not None: optimizer.step() optimizer.zero_grad() step_count += 1 @@ -269,10 +348,10 @@ def train( if rank == 0: mem_stats = torch.cuda.memory_stats() - peak_active = mem_stats["active_bytes.all.peak"] / (1024 * 1024) - peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1024 * 1024) + peak_active = mem_stats["active_bytes.all.peak"] / (1000 * 1000) + peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1000 * 1000) logging.info( - f"iter={iter_num} max_iters={hyper_params.max_iters} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB" + f"iter={iter_num} max_iters={hyper_params.max_iters} step_count={step_count} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB" ) if training_config.profile: @@ -288,6 +367,8 @@ def train( torch.cuda.memory._dump_snapshot(f"{memory_trace_path}") torch.cuda.memory._record_memory_history(enabled=None) logging.info(f"Wrote memory traces to: {memory_trace_path}") + if iter_num == 260: + break class Dataset(IterableDataset): @@ -366,10 +447,16 @@ def load_datasets( help="if specified, runs FSDP with this many GPUs on a single host", ) parser.add_argument("--use_fsdp2", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--enable_ac", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--model", default="7B") + parser.add_argument("--qlora_debug", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--full_finetune", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--optim_in_bwd", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--cpu_offload", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--register_nf4_param", action=argparse.BooleanOptionalAction, default=False) args = parser.parse_args() fsdp_num_gpus = args.fsdp_num_gpus - inner_args = (args.profile, args.use_fsdp2, args.model) + inner_args = (args.profile, args.use_fsdp2, args.model, args.enable_ac, args.qlora_debug, args.full_finetune, args.optim_in_bwd, args.cpu_offload, args.register_nf4_param) if fsdp_num_gpus is None or fsdp_num_gpus == 1: entrypoint(*inner_args) @@ -413,37 +500,3 @@ def validate( model.train() write_loss_to_file(loss_file, training_iter, loss.item()) return val_loss.item() - - -def get_profile_context( - hyper_params: Hyperparameters, train_config: TrainingConfig, rank: int = 0 -): - """Returns a context manager that can be used to profile the model.""" - - def trace_handler(prof): - fp8_linear_type = hyper_params.fp8_linear_type - - dtype_str = fp8_linear_type if fp8_linear_type else "bf16" - output_str = str(train_config.log_dir / f"trace_llama_7b_hf_{dtype_str}_rank_{rank}.json") - prof.export_chrome_trace(output_str) - logging.info(f"Wrote profile to: {output_str}") - - if train_config.profile: - context = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule(wait=100, warmup=1, active=2, repeat=1), - record_shapes=True, - with_stack=True, - on_trace_ready=trace_handler, - with_flops=True, - profile_memory=True, - experimental_config=torch.profiler._ExperimentalConfig( - enable_cuda_sync_events=True, - ), - ) - return context - else: - return nullcontext() diff --git a/transformer_nuggets/llama/model.py b/transformer_nuggets/llama/model.py index b54d0aa..7cce0bc 100644 --- a/transformer_nuggets/llama/model.py +++ b/transformer_nuggets/llama/model.py @@ -58,6 +58,7 @@ def from_name(cls, name: str): "CodeLlama-7b-Python-hf": dict( block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000 ), + "mini": dict(n_layer=2, n_head=32, dim=4096), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), "30B": dict(n_layer=60, n_head=52, dim=6656), diff --git a/transformer_nuggets/quant/nf4_tensor.py b/transformer_nuggets/quant/nf4_tensor.py index 96018b8..2bfeb80 100644 --- a/transformer_nuggets/quant/nf4_tensor.py +++ b/transformer_nuggets/quant/nf4_tensor.py @@ -1,13 +1,323 @@ import logging from dataclasses import dataclass -from typing import Dict, Tuple - +from typing import Dict, Tuple, Any, Optional, Union +import math import torch logging.basicConfig(level=logging.INFO) bnb_available = False +aten = torch.ops.aten +NF4_OPS_TABLE: Dict[Any, Any] = {} + +def implements(aten_ops): + """Register aten ops to the float8 op table""" + + def decorator(func): + for op in aten_ops: + NF4_OPS_TABLE[op] = func + return func + + return decorator + +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + # nn.Parameter need detach + quantized_scalers = aten_op(args[0].quantized_scalers, *args[1:], **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, *args[1:], **kwargs) + quantized_data = aten_op(args[0].quantized_data, *args[1:], **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, *args[1:], **kwargs) + nf4 = aten_op(args[0].nf4, *args[1:], **kwargs) + tensor_meta = SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ) + +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + # torch.chunk + # TODO: find if there are other args/kwargs in aten.split + assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.split.Tensor with 2 args" + # TODO: assert on dim-0 sharding. how to get dim from torch.chunk? + num_chunks = args[0].size(0) // args[1] + + # TODO: assert numel % num_chunks == 0 + quantized_scalers_chunks = aten_op(args[0].quantized_scalers, args[0].quantized_scalers.numel() // num_chunks, **kwargs) + quantization_factor_chunks = aten_op(args[0].quantization_factor, args[0].quantization_factor.numel() // num_chunks, **kwargs) + quantized_data_chunks = aten_op(args[0].quantized_data, args[0].quantized_data.numel() // num_chunks, **kwargs) + + assert len(args) == 2, "only support 2d because of tensor meta" + return [ + NF4Tensor( + SubclassTensorArgs( + (args[0].size(0) // num_chunks, args[0].size(1)), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers, + quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) for quantized_scalers, quantization_factor, quantized_data in zip( + quantized_scalers_chunks, quantization_factor_chunks, quantized_data_chunks + ) + ] + +@implements( + [ + aten.new_zeros.default, + ] +) +def nf4_new_zeros(aten_op, args, kwargs=None): + assert len(args[0].shape) == 2 and len(args[1]) == 2, "only support new zeros on 2D" + assert args[0].numel() % math.prod(args[1]) == 0 + ratio = args[0].numel() // math.prod(args[1]) + + assert args[0].quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" + quantized_scalers_new_zeros = aten_op(args[0].quantized_scalers, [args[0].quantized_scalers.size(0) // ratio], **kwargs) + + assert args[0].quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" + quantization_factor_new_zeros = aten_op(args[0].quantization_factor, [args[0].quantization_factor.size(0) // ratio], **kwargs) + + assert args[0].quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}" + quantized_data_new_zeros = aten_op(args[0].quantized_data, [args[0].quantized_data.size(0) // ratio], **kwargs) + + return NF4Tensor( + SubclassTensorArgs( + (args[1][0], args[1][1]), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers_new_zeros, + quantization_factor_new_zeros, + args[0].scaler_mean, + quantized_data_new_zeros, + args[0].nf4, + ) + +@implements( + [ + aten.slice.Tensor, + ] +) +def nf4_slice(aten_op, args, kwargs=None): + assert len(args) == 4 + assert args[1] == 0, f"only support dim=0 but got dim={args[1]}" + # TODO: maybe relax? + assert args[2] == 0, f"only support start=0 but got start={args[2]}" + assert args[3] == args[0].size(0), f"only support end == size(0) but got end={args[3]} and size(0)={args[0].size(0)}" + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + args[0].quantized_data, + args[0].nf4, + ) + +@implements( + [ + aten.copy_.default, + ] +) +def nf4_copy_(aten_op, args, kwargs=None): + assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args" + quantized_scalers = aten_op(args[0].quantized_scalers, args[1].quantized_scalers, **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, args[1].quantization_factor, **kwargs) + quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, args[1].scaler_mean, **kwargs) + nf4 = aten_op(args[0].nf4, args[1].nf4, **kwargs) + tensor_meta = SubclassTensorArgs( + args[1].size(), + args[1].stride(), + args[1].storage_offset(), + args[1].dtype, + args[1].device, + args[1].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[1].block_size, + args[1].n_blocks, + args[1].scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ) + +@implements( + [ + aten.view.default, + ] +) +def nf4_view(aten_op, args, kwargs=None): + assert len(args) == 2, args[1] == -1 + quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs) + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + tensor_meta = SubclassTensorArgs( + [args[0].numel()], + (1, ), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers, + quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + +@implements( + [ + aten.as_strided.default, + ] +) +def nf4_as_strided(aten_op, args, kwargs=None): + assert len(args[1]) == 2 and math.prod(args[1]) == args[0].numel(), "only support same numel" + assert args[2] == [args[1][1], 1], f"only support stride {[args[1][1], 1]}" + assert args[0].storage_offset() == args[3], f"only support same storage offset" + return NF4Tensor( + SubclassTensorArgs( + torch.Size(args[1]), + tuple(args[2]), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + args[0].quantized_data, + args[0].nf4, + ) + +@implements( + [ + aten._to_copy.default, + ] +) +def nf4_to_copy(aten_op, args, kwargs=None): + quantized_data_kwargs = kwargs + quantized_data_kwargs['dtype'] = args[0].quantized_data.dtype + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **quantized_data_kwargs) + + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + kwargs['device'], + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + + +@implements( + [ + aten.is_pinned.default, + ] +) +def nf4_is_pinned(aten_op, args, kwargs=None): + return aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + + +@implements( + [ + aten._pin_memory.default, + ] +) +def nf4_pin_memory(aten_op, args, kwargs=None): + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + @dataclass class SubclassTensorArgs: @@ -393,7 +703,37 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - raise NotImplementedError("NF4Tensor does not support torch dispatch") + def allowed_subclasses(type): + return ( + issubclass(cls, type) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Up to the next one to handle") + + if func in NF4_OPS_TABLE: + return NF4_OPS_TABLE[func](func, args, kwargs) + + raise NotImplementedError(f"NF4Tensor does not support torch dispatch {func}") - # Do not force the Float8Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl + + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # # Define a standard `__torch_function__` that propagates state + # kwargs = kwargs or {} + + # def wrap(tensor_meta, block_size, n_blocks, scaler_block_size, quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4, o: Any): + # if isinstance(o, torch.Tensor) and not isinstance(o, cls): + # return cls(tensor_meta, block_size, n_blocks, scaler_block_size, quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4) + # return o + + # with torch._C.DisableTorchFunctionSubclass(): + # if isinstance(args[0], cls): + # out = func(*args, **kwargs) + # return tree_map( + # functools.partial(wrap, args[0].tensor_meta, args[0].block_size, args[0].n_blocks, args[0].scaler_block_size, args[0].quantized_scalers, args[0].quantization_factor, args[0].scaler_mean, args[0].quantized_data, args[0].nf4), out + # ) + # return func(*args, **kwargs) diff --git a/transformer_nuggets/quant/qlora.py b/transformer_nuggets/quant/qlora.py index 1d5f1f1..4af8503 100644 --- a/transformer_nuggets/quant/qlora.py +++ b/transformer_nuggets/quant/qlora.py @@ -1,14 +1,14 @@ import logging import math from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Dict, Any, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from transformer_nuggets.quant.deqaunt_kernel import dequant_nf4_tensor -from transformer_nuggets.quant.nf4_tensor import NF4Tensor +from transformer_nuggets.quant.nf4_tensor import NF4Tensor, SubclassTensorArgs logging.basicConfig(level=logging.INFO) @@ -159,9 +159,13 @@ def __init__( r: int, lora_alpha: int = 1, lora_dropout: float = 0.0, + register_nf4_param: bool = False, ) -> None: super().__init__() - self.weight = NF4Tensor.from_tensor(weight) + if register_nf4_param: + self.weight = nn.Parameter(NF4Tensor.from_tensor(weight), requires_grad=False) + else: + self.weight = NF4Tensor.from_tensor(weight) self.r = r self.lora_alpha = lora_alpha self.in_features = in_features @@ -191,12 +195,77 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return result2 + def fsdp_extensions(self) -> Dict[str, Any]: + from torch.distributed._composable.fsdp import FSDPTensorExtensions + + weight_extensions = FSDPTensorExtensions( + self._fsdp_pre_all_gather, self._fsdp_post_all_gather + ) + return {"weight": weight_extensions} + + def _fsdp_pre_all_gather(self, sharded_param: torch.Tensor): + return ( + sharded_param.quantized_scalers, + sharded_param.quantization_factor, + sharded_param.quantized_data, + ), ( + SubclassTensorArgs( + sharded_param.size(), + sharded_param.stride(), + sharded_param.storage_offset(), + sharded_param.dtype, + sharded_param.device, + sharded_param.requires_grad, + ), + sharded_param.block_size, + sharded_param.n_blocks, + sharded_param.scaler_block_size, + sharded_param.scaler_mean, + sharded_param.nf4, + ) + + def _fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ) -> Union[Tuple[NF4Tensor, Tuple[torch.Tensor, ...]], None]: + (quantized_scalers, quantization_factor, quantized_data) = all_gather_outputs + (tensor_meta, block_size, n_blocks, scaler_block_size, scaler_mean, nf4) = metadata + tensor_meta.original_shape = torch.Size([quantized_data.size(0) * 2]) + if out is not None: + assert isinstance(out, NF4Tensor), f"{type(out)}" + assert ( + quantized_scalers.untyped_storage().data_ptr() + == out.quantized_scalers.untyped_storage().data_ptr() and + quantization_factor.untyped_storage().data_ptr() + == out.quantization_factor.untyped_storage().data_ptr() and + quantized_data.untyped_storage().data_ptr() + == out.quantized_data.untyped_storage().data_ptr() + ), f"Expects out's data to be the all-gather output" + return + + return NF4Tensor( + tensor_meta, + block_size, + n_blocks, + scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ), (quantized_scalers, quantization_factor, quantized_data) + @dataclass class QloraConfig: lora_r: int = 2 lora_alpha: int = 1 lora_dropout: float = 0.0 + register_nf4_param: bool = False class QloraMLP(nn.Module): @@ -215,15 +284,16 @@ def __init__( lora_r = QloraConfig.lora_r lora_alpha = QloraConfig.lora_alpha lora_dropout = QloraConfig.lora_dropout + register_nf4_param = QloraConfig.register_nf4_param self.qlora_w1 = QloraLinear( - weight1.shape[1], weight1.shape[0], weight1, lora_r, lora_alpha, lora_dropout + weight1.shape[1], weight1.shape[0], weight1, lora_r, lora_alpha, lora_dropout, register_nf4_param ) self.qlora_w2 = QloraLinear( - weight2.shape[1], weight2.shape[0], weight2, lora_r, lora_alpha, lora_dropout + weight2.shape[1], weight2.shape[0], weight2, lora_r, lora_alpha, lora_dropout, register_nf4_param ) self.qlora_w3 = QloraLinear( - weight3.shape[1], weight3.shape[0], weight3, lora_r, lora_alpha, lora_dropout + weight3.shape[1], weight3.shape[0], weight3, lora_r, lora_alpha, lora_dropout, register_nf4_param ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -233,13 +303,103 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def swap_for_qlora(model: torch.nn.Module, qlora_config: QloraConfig, dtype) -> None: + # logging.info("Swapping for Qlora...") + for module in model.layers: + feed_forward = module.feed_forward + w1 = feed_forward.w1.weight.to(dtype=dtype) + w2 = feed_forward.w2.weight.to(dtype=dtype) + w3 = feed_forward.w3.weight.to(dtype=dtype) + new_mod = QloraMLP(w1, w2, w3, qlora_config) + module.feed_forward = new_mod + + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + +class QloraLinearDebug(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + weight: torch.Tensor, + r: int, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight.new_zeros((weight.shape[0], int(weight.shape[1]/4))), requires_grad=False) + # self.weight = weight.new_zeros((weight.shape[0], int(weight.shape[1]/4))) + # self.weight.requires_grad = False + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.r = r + self.lora_alpha = lora_alpha + self.in_features = in_features + self.out_features = out_features + self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + result = F.linear(x, self.weight.repeat(1, 4)) + result2 = ( + result + + (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) + * self.scaling + ) + return result2 + +class QloraMLPDebug(nn.Module): + # This very notably doesn't save on backward compute + def __init__( + self, + weight1: torch.Tensor, + weight2: torch.Tensor, + weight3: torch.Tensor, + QloraConfig: QloraConfig = None, + ) -> None: + super().__init__() + if QloraConfig is None: + QloraConfig = QloraConfig() + + lora_r = QloraConfig.lora_r + lora_alpha = QloraConfig.lora_alpha + lora_dropout = QloraConfig.lora_dropout + + self.qlora_w1 = QloraLinearDebug( + weight1.shape[1], weight1.shape[0], weight1, lora_r, lora_alpha, lora_dropout + ) + self.qlora_w2 = QloraLinearDebug( + weight2.shape[1], weight2.shape[0], weight2, lora_r, lora_alpha, lora_dropout + ) + self.qlora_w3 = QloraLinearDebug( + weight3.shape[1], weight3.shape[0], weight3, lora_r, lora_alpha, lora_dropout + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.qlora_w1(x)) * self.qlora_w3(x) + x = self.qlora_w2(x) + return x + +def swap_for_qlora_debug(model: torch.nn.Module, qlora_config: QloraConfig, dtype) -> None: logging.info("Swapping for Qlora...") for module in tqdm(model.layers): feed_forward = module.feed_forward w1 = feed_forward.w1.weight.to(dtype=dtype) w2 = feed_forward.w2.weight.to(dtype=dtype) w3 = feed_forward.w3.weight.to(dtype=dtype) - new_mod = QloraMLP(w1, w2, w3, qlora_config) + new_mod = QloraMLPDebug(w1, w2, w3, qlora_config) module.feed_forward = new_mod for name, param in model.named_parameters():