From a60f6e8e0f55024620add8ffa644cbafd9eb4ea3 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 22 Feb 2024 14:45:23 -0800 Subject: [PATCH 1/3] full finetune / qlora + ac/offload/optm in bwd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- transformer_nuggets/llama/finetune.py | 149 +++++++++++++++++--------- transformer_nuggets/llama/model.py | 1 + transformer_nuggets/quant/qlora.py | 90 ++++++++++++++++ 3 files changed, 187 insertions(+), 53 deletions(-) diff --git a/transformer_nuggets/llama/finetune.py b/transformer_nuggets/llama/finetune.py index 8c52d53..693f0f5 100644 --- a/transformer_nuggets/llama/finetune.py +++ b/transformer_nuggets/llama/finetune.py @@ -1,5 +1,9 @@ """ -Used to train a model from scratch on big dense blocks of text data using causal attention. +# full finetuning on single gpu without FSDP +python transformer_nuggets/llama/finetune.py --profile --model 7B --enable_ac --full_finetune --optim_in_bwd + +# qlora on 2 gpus with FSDP +python transformer_nuggets/llama/finetune.py --profile --model 7B --fsdp_num_gpus 2 --use_fsdp2 --enable_ac --cpu_offload """ import argparse import functools @@ -9,6 +13,8 @@ from contextlib import nullcontext from dataclasses import dataclass from pathlib import Path +from torch.distributed._composable import checkpoint +from torch.distributed.optim import _apply_optimizer_in_backward import numpy as np import torch @@ -44,6 +50,45 @@ class TrainingConfig(transformer_nuggets.llama.train.TrainingConfig): log_interval: int = 10 track_max_memory: bool = False use_fsdp2: bool = False + enable_ac: bool = False + qlora_debug: bool = False + full_finetune: bool = False + optim_in_bwd: bool = False + cpu_offload: bool = False + + +def get_profile_context( + hyper_params: Hyperparameters, train_config: TrainingConfig, rank: int = 0 +): + """Returns a context manager that can be used to profile the model.""" + + def trace_handler(prof): + fp8_linear_type = hyper_params.fp8_linear_type + + dtype_str = fp8_linear_type if fp8_linear_type else "bf16" + output_str = str(train_config.log_dir / f"trace_llama_7b_hf_{dtype_str}_rank_{rank}.json") + prof.export_chrome_trace(output_str) + logging.info(f"Wrote profile to: {output_str}") + + if train_config.profile: + context = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=100, warmup=1, active=2, repeat=1), + record_shapes=True, + with_stack=True, + on_trace_ready=trace_handler, + with_flops=True, + profile_memory=True, + experimental_config=torch.profiler._ExperimentalConfig( + enable_cuda_sync_events=True, + ), + ) + return context + else: + return nullcontext() def main( @@ -67,7 +112,8 @@ def main( if rank == 0: logging.info(f"Initializing model: {training_config.model_name}") with training_config.device: - model = Transformer(model_args).to(torch.bfloat16) + torch.set_default_dtype(torch.bfloat16) + model = Transformer(model_args) model.init_parameters() qlora_config = qlora.QloraConfig( @@ -75,7 +121,11 @@ def main( hyper_params.lora_alpha, hyper_params.lora_dropout, ) - qlora.swap_for_qlora(model, qlora_config, torch.bfloat16) + + if training_config.qlora_debug: + qlora.swap_for_qlora_debug(model, qlora_config, torch.bfloat16) + elif not training_config.full_finetune: + qlora.swap_for_qlora(model, qlora_config, torch.bfloat16) model.setup_caches( hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device ) @@ -92,18 +142,23 @@ def main( log_num_params(model) + if training_config.enable_ac: + for layer in model.layers: + checkpoint(layer) + if world_size > 1: if training_config.use_fsdp2: # move import to top when fsdp2 is landed - from torch.distributed._composable.fsdp import fully_shard + from torch.distributed._composable.fsdp import fully_shard, OffloadPolicy fully_shard_fn = functools.partial( fully_shard, reshard_after_forward=True, ) + offload_policy = OffloadPolicy("cpu" if training_config.cpu_offload else None) for layer in model.layers: - fully_shard_fn(layer) - fully_shard_fn(model) + fully_shard_fn(layer, offload_policy=offload_policy) + fully_shard_fn(model, offload_policy=offload_policy) else: model = FSDP( model, @@ -111,19 +166,25 @@ def main( auto_wrap_policy=ModuleWrapPolicy([TransformerBlock]), ) + torch.cuda.reset_peak_memory_stats() + if training_config.compile: model = torch.compile(model) if rank == 0: logging.info(model) - optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - lr=hyper_params.learning_rate, - weight_decay=hyper_params.weight_decay, - betas=(hyper_params.beta1, hyper_params.beta2), - foreach=hyper_params.foreach_optimizer, - ) + if training_config.optim_in_bwd: + optimizer = None + _apply_optimizer_in_backward(optimizer_class=torch.optim.SGD, params=[p for p in model.parameters() if p.requires_grad], optimizer_kwargs={"lr": 2e-5}) + else: + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], + lr=hyper_params.learning_rate, + weight_decay=hyper_params.weight_decay, + betas=(hyper_params.beta1, hyper_params.beta2), + foreach=hyper_params.foreach_optimizer, + ) train( model, @@ -141,6 +202,11 @@ def entrypoint( profile: bool = False, use_fsdp2: bool = False, model_name: str = "7B", + enable_ac: bool = False, + qlora_debug: bool = False, + full_finetune: bool = False, + optim_in_bwd: bool = False, + cpu_offload: bool = False, rank: int = 0, world_size: int = 1, ): @@ -154,6 +220,11 @@ def entrypoint( device=torch.device(f"cuda:{rank}"), use_fsdp2=use_fsdp2, model_name=model_name, + enable_ac=enable_ac, + qlora_debug=qlora_debug, + full_finetune=full_finetune, + optim_in_bwd=optim_in_bwd, + cpu_offload=cpu_offload, ) main(hyper_params, training_config, rank, world_size) @@ -204,8 +275,9 @@ def train( with profile_context as p: for iter_num in range(hyper_params.max_iters): lr = get_lr(iter_num, hyper_params) - for param_group in optimizer.param_groups: - param_group["lr"] = lr + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr input_ids, targets = next(train_iter) input_ids = input_ids.pin_memory().to(training_config.device) @@ -229,7 +301,7 @@ def train( # Scale the loss by grad_accumulation iters (loss / hyper_params.gradient_accumulation_iters).backward() - if not is_accumulating: + if not is_accumulating and optimizer is not None: optimizer.step() optimizer.zero_grad() step_count += 1 @@ -269,8 +341,8 @@ def train( if rank == 0: mem_stats = torch.cuda.memory_stats() - peak_active = mem_stats["active_bytes.all.peak"] / (1024 * 1024) - peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1024 * 1024) + peak_active = mem_stats["active_bytes.all.peak"] / (1000 * 1000) + peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1000 * 1000) logging.info( f"iter={iter_num} max_iters={hyper_params.max_iters} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB" ) @@ -366,10 +438,15 @@ def load_datasets( help="if specified, runs FSDP with this many GPUs on a single host", ) parser.add_argument("--use_fsdp2", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--enable_ac", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--model", default="7B") + parser.add_argument("--qlora_debug", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--full_finetune", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--optim_in_bwd", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--cpu_offload", action=argparse.BooleanOptionalAction, default=False) args = parser.parse_args() fsdp_num_gpus = args.fsdp_num_gpus - inner_args = (args.profile, args.use_fsdp2, args.model) + inner_args = (args.profile, args.use_fsdp2, args.model, args.enable_ac, args.qlora_debug, args.full_finetune, args.optim_in_bwd, args.cpu_offload) if fsdp_num_gpus is None or fsdp_num_gpus == 1: entrypoint(*inner_args) @@ -413,37 +490,3 @@ def validate( model.train() write_loss_to_file(loss_file, training_iter, loss.item()) return val_loss.item() - - -def get_profile_context( - hyper_params: Hyperparameters, train_config: TrainingConfig, rank: int = 0 -): - """Returns a context manager that can be used to profile the model.""" - - def trace_handler(prof): - fp8_linear_type = hyper_params.fp8_linear_type - - dtype_str = fp8_linear_type if fp8_linear_type else "bf16" - output_str = str(train_config.log_dir / f"trace_llama_7b_hf_{dtype_str}_rank_{rank}.json") - prof.export_chrome_trace(output_str) - logging.info(f"Wrote profile to: {output_str}") - - if train_config.profile: - context = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule(wait=100, warmup=1, active=2, repeat=1), - record_shapes=True, - with_stack=True, - on_trace_ready=trace_handler, - with_flops=True, - profile_memory=True, - experimental_config=torch.profiler._ExperimentalConfig( - enable_cuda_sync_events=True, - ), - ) - return context - else: - return nullcontext() diff --git a/transformer_nuggets/llama/model.py b/transformer_nuggets/llama/model.py index b54d0aa..7cce0bc 100644 --- a/transformer_nuggets/llama/model.py +++ b/transformer_nuggets/llama/model.py @@ -58,6 +58,7 @@ def from_name(cls, name: str): "CodeLlama-7b-Python-hf": dict( block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000 ), + "mini": dict(n_layer=2, n_head=32, dim=4096), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), "30B": dict(n_layer=60, n_head=52, dim=6656), diff --git a/transformer_nuggets/quant/qlora.py b/transformer_nuggets/quant/qlora.py index 1d5f1f1..a1c407c 100644 --- a/transformer_nuggets/quant/qlora.py +++ b/transformer_nuggets/quant/qlora.py @@ -245,3 +245,93 @@ def swap_for_qlora(model: torch.nn.Module, qlora_config: QloraConfig, dtype) -> for name, param in model.named_parameters(): if "lora_" not in name: param.requires_grad = False + +class QloraLinearDebug(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + weight: torch.Tensor, + r: int, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight.new_zeros((weight.shape[0], int(weight.shape[1]/4))), requires_grad=False) + # self.weight = weight.new_zeros((weight.shape[0], int(weight.shape[1]/4))) + # self.weight.requires_grad = False + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.r = r + self.lora_alpha = lora_alpha + self.in_features = in_features + self.out_features = out_features + self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + result = F.linear(x, self.weight.repeat(1, 4)) + result2 = ( + result + + (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) + * self.scaling + ) + return result2 + +class QloraMLPDebug(nn.Module): + # This very notably doesn't save on backward compute + def __init__( + self, + weight1: torch.Tensor, + weight2: torch.Tensor, + weight3: torch.Tensor, + QloraConfig: QloraConfig = None, + ) -> None: + super().__init__() + if QloraConfig is None: + QloraConfig = QloraConfig() + + lora_r = QloraConfig.lora_r + lora_alpha = QloraConfig.lora_alpha + lora_dropout = QloraConfig.lora_dropout + + self.qlora_w1 = QloraLinearDebug( + weight1.shape[1], weight1.shape[0], weight1, lora_r, lora_alpha, lora_dropout + ) + self.qlora_w2 = QloraLinearDebug( + weight2.shape[1], weight2.shape[0], weight2, lora_r, lora_alpha, lora_dropout + ) + self.qlora_w3 = QloraLinearDebug( + weight3.shape[1], weight3.shape[0], weight3, lora_r, lora_alpha, lora_dropout + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.qlora_w1(x)) * self.qlora_w3(x) + x = self.qlora_w2(x) + return x + +def swap_for_qlora_debug(model: torch.nn.Module, qlora_config: QloraConfig, dtype) -> None: + logging.info("Swapping for Qlora...") + for module in tqdm(model.layers): + feed_forward = module.feed_forward + w1 = feed_forward.w1.weight.to(dtype=dtype) + w2 = feed_forward.w2.weight.to(dtype=dtype) + w3 = feed_forward.w3.weight.to(dtype=dtype) + new_mod = QloraMLPDebug(w1, w2, w3, qlora_config) + module.feed_forward = new_mod + + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False From 0576139d763946200c4dfc40e82a1fc0bc11d1df Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 7 Mar 2024 12:47:42 -0800 Subject: [PATCH 2/3] sharding and cpu offloadin nf4tensor Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- transformer_nuggets/llama/finetune.py | 44 ++-- transformer_nuggets/quant/nf4_tensor.py | 328 +++++++++++++++++++++++- transformer_nuggets/quant/qlora.py | 100 +++++++- 3 files changed, 443 insertions(+), 29 deletions(-) diff --git a/transformer_nuggets/llama/finetune.py b/transformer_nuggets/llama/finetune.py index 693f0f5..fe3b67d 100644 --- a/transformer_nuggets/llama/finetune.py +++ b/transformer_nuggets/llama/finetune.py @@ -47,7 +47,7 @@ class Hyperparameters(transformer_nuggets.llama.train.Hyperparameters): @dataclass class TrainingConfig(transformer_nuggets.llama.train.TrainingConfig): - log_interval: int = 10 + log_interval: int = 1 track_max_memory: bool = False use_fsdp2: bool = False enable_ac: bool = False @@ -55,6 +55,7 @@ class TrainingConfig(transformer_nuggets.llama.train.TrainingConfig): full_finetune: bool = False optim_in_bwd: bool = False cpu_offload: bool = False + register_nf4_param: bool = False def get_profile_context( @@ -109,8 +110,8 @@ def main( # Setup Model model_args = ModelArgs.from_name(training_config.model_name) - if rank == 0: - logging.info(f"Initializing model: {training_config.model_name}") + # if rank == 0: + # logging.info(f"Initializing model: {training_config.model_name}") with training_config.device: torch.set_default_dtype(torch.bfloat16) model = Transformer(model_args) @@ -120,6 +121,7 @@ def main( hyper_params.lora_r, hyper_params.lora_alpha, hyper_params.lora_dropout, + training_config.register_nf4_param, ) if training_config.qlora_debug: @@ -130,8 +132,8 @@ def main( hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device ) - if rank == 0: - logging.info("Setting up the dataloaders") + # if rank == 0: + # logging.info("Setting up the dataloaders") train_data, val_data = load_datasets(hyper_params, training_config, rank, world_size) train_dataloader = DataLoader( train_data, @@ -140,7 +142,7 @@ def main( ) val_dataloader = DataLoader(val_data, batch_size=hyper_params.micro_batch_size, num_workers=2) - log_num_params(model) + # log_num_params(model) if training_config.enable_ac: for layer in model.layers: @@ -149,16 +151,19 @@ def main( if world_size > 1: if training_config.use_fsdp2: # move import to top when fsdp2 is landed - from torch.distributed._composable.fsdp import fully_shard, OffloadPolicy + from torch.distributed._composable.fsdp import fully_shard fully_shard_fn = functools.partial( fully_shard, reshard_after_forward=True, ) - offload_policy = OffloadPolicy("cpu" if training_config.cpu_offload else None) + fsdp_kwargs = {} + if training_config.cpu_offload: + from torch.distributed._composable.fsdp import OffloadPolicy + fsdp_kwargs["offload_policy"] = OffloadPolicy("cpu") for layer in model.layers: - fully_shard_fn(layer, offload_policy=offload_policy) - fully_shard_fn(model, offload_policy=offload_policy) + fully_shard_fn(layer, **fsdp_kwargs) + fully_shard_fn(model, **fsdp_kwargs) else: model = FSDP( model, @@ -171,8 +176,8 @@ def main( if training_config.compile: model = torch.compile(model) - if rank == 0: - logging.info(model) + # if rank == 0: + # logging.info(model) if training_config.optim_in_bwd: optimizer = None @@ -207,6 +212,7 @@ def entrypoint( full_finetune: bool = False, optim_in_bwd: bool = False, cpu_offload: bool = False, + register_nf4_param: bool = False, rank: int = 0, world_size: int = 1, ): @@ -225,6 +231,7 @@ def entrypoint( full_finetune=full_finetune, optim_in_bwd=optim_in_bwd, cpu_offload=cpu_offload, + register_nf4_param=register_nf4_param, ) main(hyper_params, training_config, rank, world_size) @@ -264,9 +271,9 @@ def train( training_config.log_dir / f"qlora_train_loss_{dtype_str}_overfit_{training_config.overfit}_compile_{training_config.compile}_{rank}.csv" ) - if rank == 0: - logging.info(f"val_loss_file: {val_loss_file}") - logging.info(f"train_loss_file: {train_loss_file}") + # if rank == 0: + # logging.info(f"val_loss_file: {val_loss_file}") + # logging.info(f"train_loss_file: {train_loss_file}") this_batch_loss = torch.tensor(0.0, device=training_config.device) this_batch_n = 0 @@ -344,7 +351,7 @@ def train( peak_active = mem_stats["active_bytes.all.peak"] / (1000 * 1000) peak_reserved = mem_stats["reserved_bytes.all.peak"] / (1000 * 1000) logging.info( - f"iter={iter_num} max_iters={hyper_params.max_iters} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB" + f"iter={iter_num} max_iters={hyper_params.max_iters} step_count={step_count} loss={loss_val:.4f} Peak Active Memory: {peak_active} MB, Peak Reserve Memory: {peak_reserved} MB" ) if training_config.profile: @@ -360,6 +367,8 @@ def train( torch.cuda.memory._dump_snapshot(f"{memory_trace_path}") torch.cuda.memory._record_memory_history(enabled=None) logging.info(f"Wrote memory traces to: {memory_trace_path}") + if iter_num == 260: + break class Dataset(IterableDataset): @@ -444,9 +453,10 @@ def load_datasets( parser.add_argument("--full_finetune", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--optim_in_bwd", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--cpu_offload", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument("--register_nf4_param", action=argparse.BooleanOptionalAction, default=False) args = parser.parse_args() fsdp_num_gpus = args.fsdp_num_gpus - inner_args = (args.profile, args.use_fsdp2, args.model, args.enable_ac, args.qlora_debug, args.full_finetune, args.optim_in_bwd, args.cpu_offload) + inner_args = (args.profile, args.use_fsdp2, args.model, args.enable_ac, args.qlora_debug, args.full_finetune, args.optim_in_bwd, args.cpu_offload, args.register_nf4_param) if fsdp_num_gpus is None or fsdp_num_gpus == 1: entrypoint(*inner_args) diff --git a/transformer_nuggets/quant/nf4_tensor.py b/transformer_nuggets/quant/nf4_tensor.py index 96018b8..81268d5 100644 --- a/transformer_nuggets/quant/nf4_tensor.py +++ b/transformer_nuggets/quant/nf4_tensor.py @@ -1,13 +1,303 @@ import logging from dataclasses import dataclass -from typing import Dict, Tuple - +from typing import Dict, Tuple, Any, Optional, Union +import math import torch logging.basicConfig(level=logging.INFO) bnb_available = False +aten = torch.ops.aten +NF4_OPS_TABLE: Dict[Any, Any] = {} + +def implements(aten_ops): + """Register aten ops to the float8 op table""" + + def decorator(func): + for op in aten_ops: + NF4_OPS_TABLE[op] = func + return func + + return decorator + +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + # nn.Parameter need detach + quantized_data = aten_op(args[0].quantized_data, *args[1:], **kwargs) + tensor_meta = SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + # torch.chunk + # TODO: find a better way to derive 2048 + assert args[0].dim() == 2 + assert args[0].quantized_data.numel() * 2 == args[0].numel() + # TODO: assume dim-0 sharding + num_chunks = int(args[0].size(0) / args[1]) + split_size = int(args[0].quantized_data.size(0) / num_chunks) + assert len(args) == 2 + + # TODO figure out n-D tensor + # figure / 2 float + quantized_data_chunks = aten_op(args[0].quantized_data, split_size, **kwargs) + return [ + NF4Tensor( + SubclassTensorArgs( + (int(args[0].size(0) / num_chunks), args[0].size(1)), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) for quantized_data in quantized_data_chunks + ] + +@implements( + [ + aten.new_zeros.default, + ] +) +def nf4_new_zeros(aten_op, args, kwargs=None): + # TODO: find a better way to derive /2 + assert len(args[1]) == 2 + new_zeros = aten_op(args[0].quantized_data, *(([int(math.prod(args[1]) / 2)],) + args[2:]), **kwargs) + return NF4Tensor( + SubclassTensorArgs( + (args[1][0], args[1][1]), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + new_zeros, + args[0].nf4, + ) + +@implements( + [ + aten.slice.Tensor, + ] +) +def nf4_slice(aten_op, args, kwargs=None): + assert len(args) == 4 + assert args[1] == 0, args[2] == 0 + assert args[3] == args[0].size(0) + sliced_data = aten_op(args[0].quantized_data, *(args[1], args[2], int(args[3] * args[0].size(1) / 2)), **kwargs) + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + sliced_data, + args[0].nf4, + ) + +@implements( + [ + aten.copy_.default, + ] +) +def nf4_copy_(aten_op, args, kwargs=None): + assert len(args) == 2 + quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) + tensor_meta = SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + +@implements( + [ + aten.view.default, + ] +) +def nf4_view(aten_op, args, kwargs=None): + assert len(args) == 2, args[1] == -1 + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + tensor_meta = SubclassTensorArgs( + [args[0].numel()], + (1, ), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + +@implements( + [ + aten.as_strided.default, + ] +) +def nf4_as_strided(aten_op, args, kwargs=None): + assert len(args) == 4, len(args[1]) == 2 + assert args[0].size(0) == args[1][0] and args[0].size(1) == args[1][1] + quantized_data_size = [int(math.prod(args[1]) / 2)] + quantized_data_stride = (1,) + quantized_data_offset = 0 + strided_data = aten_op(args[0].quantized_data, *(quantized_data_size, quantized_data_stride, quantized_data_offset), **kwargs) + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + strided_data, + args[0].nf4, + ) + +@implements( + [ + aten._to_copy.default, + ] +) +def nf4_to_copy(aten_op, args, kwargs=None): + quantized_data_kwargs = kwargs + quantized_data_kwargs['dtype'] = args[0].quantized_data.dtype + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **quantized_data_kwargs) + + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + kwargs['device'], + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + + +@implements( + [ + aten.is_pinned.default, + ] +) +def nf4_is_pinned(aten_op, args, kwargs=None): + return aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + + +@implements( + [ + aten._pin_memory.default, + ] +) +def nf4_pin_memory(aten_op, args, kwargs=None): + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) + @dataclass class SubclassTensorArgs: @@ -393,7 +683,37 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - raise NotImplementedError("NF4Tensor does not support torch dispatch") + def allowed_subclasses(type): + return ( + issubclass(cls, type) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type) + ) + + if not all(allowed_subclasses(t) for t in types): + return NotImplemented("Up to the next one to handle") + + if func in NF4_OPS_TABLE: + return NF4_OPS_TABLE[func](func, args, kwargs) + + raise NotImplementedError(f"NF4Tensor does not support torch dispatch {func}") - # Do not force the Float8Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl + + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # # Define a standard `__torch_function__` that propagates state + # kwargs = kwargs or {} + + # def wrap(tensor_meta, block_size, n_blocks, scaler_block_size, quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4, o: Any): + # if isinstance(o, torch.Tensor) and not isinstance(o, cls): + # return cls(tensor_meta, block_size, n_blocks, scaler_block_size, quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4) + # return o + + # with torch._C.DisableTorchFunctionSubclass(): + # if isinstance(args[0], cls): + # out = func(*args, **kwargs) + # return tree_map( + # functools.partial(wrap, args[0].tensor_meta, args[0].block_size, args[0].n_blocks, args[0].scaler_block_size, args[0].quantized_scalers, args[0].quantization_factor, args[0].scaler_mean, args[0].quantized_data, args[0].nf4), out + # ) + # return func(*args, **kwargs) diff --git a/transformer_nuggets/quant/qlora.py b/transformer_nuggets/quant/qlora.py index a1c407c..f54339d 100644 --- a/transformer_nuggets/quant/qlora.py +++ b/transformer_nuggets/quant/qlora.py @@ -1,14 +1,14 @@ import logging import math from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Dict, Any, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from transformer_nuggets.quant.deqaunt_kernel import dequant_nf4_tensor -from transformer_nuggets.quant.nf4_tensor import NF4Tensor +from transformer_nuggets.quant.nf4_tensor import NF4Tensor, SubclassTensorArgs logging.basicConfig(level=logging.INFO) @@ -159,9 +159,13 @@ def __init__( r: int, lora_alpha: int = 1, lora_dropout: float = 0.0, + register_nf4_param: bool = False, ) -> None: super().__init__() - self.weight = NF4Tensor.from_tensor(weight) + if register_nf4_param: + self.weight = nn.Parameter(NF4Tensor.from_tensor(weight), requires_grad=False) + else: + self.weight = NF4Tensor.from_tensor(weight) self.r = r self.lora_alpha = lora_alpha self.in_features = in_features @@ -191,12 +195,91 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return result2 + def fsdp_extensions(self) -> Dict[str, Any]: + from torch.distributed._composable.fsdp import FSDPTensorExtensions + + weight_extensions = FSDPTensorExtensions( + self._fsdp_pre_all_gather, self._fsdp_post_all_gather + ) + return {"weight": weight_extensions} + + def _fsdp_pre_all_gather(self, sharded_param: torch.Tensor): + # TODO: shard Tensor-type params + return (sharded_param.quantized_data, ), ( + SubclassTensorArgs( + sharded_param.size(), + sharded_param.stride(), + sharded_param.storage_offset(), + sharded_param.dtype, + sharded_param.device, + sharded_param.requires_grad, + ), + sharded_param.block_size, + sharded_param.n_blocks, + sharded_param.scaler_block_size, + sharded_param.quantized_scalers, + sharded_param.quantization_factor, + sharded_param.scaler_mean, + sharded_param.nf4, + ) + + # def fsdp_post_all_gather( + # self, + # all_gather_outputs: Tuple[torch.Tensor, ...], + # metadata: Any, + # param_dtype: torch.dtype, + # *, + # out: Optional[torch.Tensor] = None, + # ) -> Union[Tuple[Tuple[torch.Tensor, ...]], None]: + # (quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4) = all_gather_outputs + # (tensor_meta, block_size, n_blocks, scaler_block_size) = metadata + # if out is not None: + # return + # return (quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4), () + + # def _fsdp_pre_all_gather(self, sharded_param: torch.Tensor): + # float8_tensor = self.cast_to_float8_e4m3fn(sharded_param, reduce_amax=True) + # return (float8_tensor._data,), (float8_tensor._scale,) + + def _fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ) -> Union[Tuple[NF4Tensor, Tuple[torch.Tensor, ...]], None]: + (quantized_data, ) = all_gather_outputs + (tensor_meta, block_size, n_blocks, scaler_block_size, quantized_scalers, quantization_factor, scaler_mean, nf4) = metadata + # TODO: figure out x 2 + tensor_meta.original_shape = (tensor_meta.original_shape[0] * 2, tensor_meta.original_shape[1]) + if out is not None: + assert isinstance(out, NF4Tensor), f"{type(out)}" + assert ( + quantized_data.untyped_storage().data_ptr() + == out.quantized_data.untyped_storage().data_ptr() + ), f"Expects out's data to be the all-gather output" + return + + return NF4Tensor( + tensor_meta, + block_size, + n_blocks, + scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ), (quantized_data, ) + @dataclass class QloraConfig: lora_r: int = 2 lora_alpha: int = 1 lora_dropout: float = 0.0 + register_nf4_param: bool = False class QloraMLP(nn.Module): @@ -215,15 +298,16 @@ def __init__( lora_r = QloraConfig.lora_r lora_alpha = QloraConfig.lora_alpha lora_dropout = QloraConfig.lora_dropout + register_nf4_param = QloraConfig.register_nf4_param self.qlora_w1 = QloraLinear( - weight1.shape[1], weight1.shape[0], weight1, lora_r, lora_alpha, lora_dropout + weight1.shape[1], weight1.shape[0], weight1, lora_r, lora_alpha, lora_dropout, register_nf4_param ) self.qlora_w2 = QloraLinear( - weight2.shape[1], weight2.shape[0], weight2, lora_r, lora_alpha, lora_dropout + weight2.shape[1], weight2.shape[0], weight2, lora_r, lora_alpha, lora_dropout, register_nf4_param ) self.qlora_w3 = QloraLinear( - weight3.shape[1], weight3.shape[0], weight3, lora_r, lora_alpha, lora_dropout + weight3.shape[1], weight3.shape[0], weight3, lora_r, lora_alpha, lora_dropout, register_nf4_param ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -233,8 +317,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def swap_for_qlora(model: torch.nn.Module, qlora_config: QloraConfig, dtype) -> None: - logging.info("Swapping for Qlora...") - for module in tqdm(model.layers): + # logging.info("Swapping for Qlora...") + for module in model.layers: feed_forward = module.feed_forward w1 = feed_forward.w1.weight.to(dtype=dtype) w2 = feed_forward.w2.weight.to(dtype=dtype) From 886278921bdb54b43472af7e495dfe36b389b31d Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 2 Apr 2024 11:09:38 -0700 Subject: [PATCH 3/3] avoid hardcoding shapes and gpus in __torch_dispatch__ Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- transformer_nuggets/llama/finetune.py | 2 +- transformer_nuggets/quant/nf4_tensor.py | 128 ++++++++++++++---------- transformer_nuggets/quant/qlora.py | 40 +++----- 3 files changed, 88 insertions(+), 82 deletions(-) diff --git a/transformer_nuggets/llama/finetune.py b/transformer_nuggets/llama/finetune.py index fe3b67d..6ba85ad 100644 --- a/transformer_nuggets/llama/finetune.py +++ b/transformer_nuggets/llama/finetune.py @@ -3,7 +3,7 @@ python transformer_nuggets/llama/finetune.py --profile --model 7B --enable_ac --full_finetune --optim_in_bwd # qlora on 2 gpus with FSDP -python transformer_nuggets/llama/finetune.py --profile --model 7B --fsdp_num_gpus 2 --use_fsdp2 --enable_ac --cpu_offload +python transformer_nuggets/llama/finetune.py --profile --model 7B --fsdp_num_gpus 2 --use_fsdp2 --enable_ac --register_nf4_param --cpu_offload """ import argparse import functools diff --git a/transformer_nuggets/quant/nf4_tensor.py b/transformer_nuggets/quant/nf4_tensor.py index 81268d5..2bfeb80 100644 --- a/transformer_nuggets/quant/nf4_tensor.py +++ b/transformer_nuggets/quant/nf4_tensor.py @@ -28,7 +28,11 @@ def decorator(func): ) def nf4_detach(aten_op, args, kwargs=None): # nn.Parameter need detach + quantized_scalers = aten_op(args[0].quantized_scalers, *args[1:], **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, *args[1:], **kwargs) quantized_data = aten_op(args[0].quantized_data, *args[1:], **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, *args[1:], **kwargs) + nf4 = aten_op(args[0].nf4, *args[1:], **kwargs) tensor_meta = SubclassTensorArgs( args[0].size(), args[0].stride(), @@ -42,11 +46,11 @@ def nf4_detach(aten_op, args, kwargs=None): args[0].block_size, args[0].n_blocks, args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, - args[0].scaler_mean, + quantized_scalers, + quantization_factor, + scaler_mean, quantized_data, - args[0].nf4, + nf4, ) @implements( @@ -56,21 +60,21 @@ def nf4_detach(aten_op, args, kwargs=None): ) def nf4_split(aten_op, args, kwargs=None): # torch.chunk - # TODO: find a better way to derive 2048 - assert args[0].dim() == 2 - assert args[0].quantized_data.numel() * 2 == args[0].numel() - # TODO: assume dim-0 sharding - num_chunks = int(args[0].size(0) / args[1]) - split_size = int(args[0].quantized_data.size(0) / num_chunks) - assert len(args) == 2 - - # TODO figure out n-D tensor - # figure / 2 float - quantized_data_chunks = aten_op(args[0].quantized_data, split_size, **kwargs) + # TODO: find if there are other args/kwargs in aten.split + assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.split.Tensor with 2 args" + # TODO: assert on dim-0 sharding. how to get dim from torch.chunk? + num_chunks = args[0].size(0) // args[1] + + # TODO: assert numel % num_chunks == 0 + quantized_scalers_chunks = aten_op(args[0].quantized_scalers, args[0].quantized_scalers.numel() // num_chunks, **kwargs) + quantization_factor_chunks = aten_op(args[0].quantization_factor, args[0].quantization_factor.numel() // num_chunks, **kwargs) + quantized_data_chunks = aten_op(args[0].quantized_data, args[0].quantized_data.numel() // num_chunks, **kwargs) + + assert len(args) == 2, "only support 2d because of tensor meta" return [ NF4Tensor( SubclassTensorArgs( - (int(args[0].size(0) / num_chunks), args[0].size(1)), + (args[0].size(0) // num_chunks, args[0].size(1)), args[0].stride(), args[0].storage_offset(), args[0].dtype, @@ -80,12 +84,14 @@ def nf4_split(aten_op, args, kwargs=None): args[0].block_size, args[0].n_blocks, args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, + quantized_scalers, + quantization_factor, args[0].scaler_mean, quantized_data, args[0].nf4, - ) for quantized_data in quantized_data_chunks + ) for quantized_scalers, quantization_factor, quantized_data in zip( + quantized_scalers_chunks, quantization_factor_chunks, quantized_data_chunks + ) ] @implements( @@ -94,9 +100,19 @@ def nf4_split(aten_op, args, kwargs=None): ] ) def nf4_new_zeros(aten_op, args, kwargs=None): - # TODO: find a better way to derive /2 - assert len(args[1]) == 2 - new_zeros = aten_op(args[0].quantized_data, *(([int(math.prod(args[1]) / 2)],) + args[2:]), **kwargs) + assert len(args[0].shape) == 2 and len(args[1]) == 2, "only support new zeros on 2D" + assert args[0].numel() % math.prod(args[1]) == 0 + ratio = args[0].numel() // math.prod(args[1]) + + assert args[0].quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" + quantized_scalers_new_zeros = aten_op(args[0].quantized_scalers, [args[0].quantized_scalers.size(0) // ratio], **kwargs) + + assert args[0].quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" + quantization_factor_new_zeros = aten_op(args[0].quantization_factor, [args[0].quantization_factor.size(0) // ratio], **kwargs) + + assert args[0].quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}" + quantized_data_new_zeros = aten_op(args[0].quantized_data, [args[0].quantized_data.size(0) // ratio], **kwargs) + return NF4Tensor( SubclassTensorArgs( (args[1][0], args[1][1]), @@ -109,10 +125,10 @@ def nf4_new_zeros(aten_op, args, kwargs=None): args[0].block_size, args[0].n_blocks, args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, + quantized_scalers_new_zeros, + quantization_factor_new_zeros, args[0].scaler_mean, - new_zeros, + quantized_data_new_zeros, args[0].nf4, ) @@ -123,9 +139,10 @@ def nf4_new_zeros(aten_op, args, kwargs=None): ) def nf4_slice(aten_op, args, kwargs=None): assert len(args) == 4 - assert args[1] == 0, args[2] == 0 - assert args[3] == args[0].size(0) - sliced_data = aten_op(args[0].quantized_data, *(args[1], args[2], int(args[3] * args[0].size(1) / 2)), **kwargs) + assert args[1] == 0, f"only support dim=0 but got dim={args[1]}" + # TODO: maybe relax? + assert args[2] == 0, f"only support start=0 but got start={args[2]}" + assert args[3] == args[0].size(0), f"only support end == size(0) but got end={args[3]} and size(0)={args[0].size(0)}" return NF4Tensor( SubclassTensorArgs( args[0].size(), @@ -141,7 +158,7 @@ def nf4_slice(aten_op, args, kwargs=None): args[0].quantized_scalers, args[0].quantization_factor, args[0].scaler_mean, - sliced_data, + args[0].quantized_data, args[0].nf4, ) @@ -151,26 +168,30 @@ def nf4_slice(aten_op, args, kwargs=None): ] ) def nf4_copy_(aten_op, args, kwargs=None): - assert len(args) == 2 + assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args" + quantized_scalers = aten_op(args[0].quantized_scalers, args[1].quantized_scalers, **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, args[1].quantization_factor, **kwargs) quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, args[1].scaler_mean, **kwargs) + nf4 = aten_op(args[0].nf4, args[1].nf4, **kwargs) tensor_meta = SubclassTensorArgs( - args[0].size(), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, + args[1].size(), + args[1].stride(), + args[1].storage_offset(), + args[1].dtype, + args[1].device, + args[1].requires_grad, ) return NF4Tensor( tensor_meta, - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, - args[0].scaler_mean, + args[1].block_size, + args[1].n_blocks, + args[1].scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, quantized_data, - args[0].nf4, + nf4, ) @implements( @@ -180,6 +201,8 @@ def nf4_copy_(aten_op, args, kwargs=None): ) def nf4_view(aten_op, args, kwargs=None): assert len(args) == 2, args[1] == -1 + quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs) quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) tensor_meta = SubclassTensorArgs( [args[0].numel()], @@ -194,8 +217,8 @@ def nf4_view(aten_op, args, kwargs=None): args[0].block_size, args[0].n_blocks, args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, + quantized_scalers, + quantization_factor, args[0].scaler_mean, quantized_data, args[0].nf4, @@ -207,16 +230,13 @@ def nf4_view(aten_op, args, kwargs=None): ] ) def nf4_as_strided(aten_op, args, kwargs=None): - assert len(args) == 4, len(args[1]) == 2 - assert args[0].size(0) == args[1][0] and args[0].size(1) == args[1][1] - quantized_data_size = [int(math.prod(args[1]) / 2)] - quantized_data_stride = (1,) - quantized_data_offset = 0 - strided_data = aten_op(args[0].quantized_data, *(quantized_data_size, quantized_data_stride, quantized_data_offset), **kwargs) + assert len(args[1]) == 2 and math.prod(args[1]) == args[0].numel(), "only support same numel" + assert args[2] == [args[1][1], 1], f"only support stride {[args[1][1], 1]}" + assert args[0].storage_offset() == args[3], f"only support same storage offset" return NF4Tensor( SubclassTensorArgs( - args[0].size(), - args[0].stride(), + torch.Size(args[1]), + tuple(args[2]), args[0].storage_offset(), args[0].dtype, args[0].device, @@ -228,7 +248,7 @@ def nf4_as_strided(aten_op, args, kwargs=None): args[0].quantized_scalers, args[0].quantization_factor, args[0].scaler_mean, - strided_data, + args[0].quantized_data, args[0].nf4, ) diff --git a/transformer_nuggets/quant/qlora.py b/transformer_nuggets/quant/qlora.py index f54339d..4af8503 100644 --- a/transformer_nuggets/quant/qlora.py +++ b/transformer_nuggets/quant/qlora.py @@ -204,8 +204,11 @@ def fsdp_extensions(self) -> Dict[str, Any]: return {"weight": weight_extensions} def _fsdp_pre_all_gather(self, sharded_param: torch.Tensor): - # TODO: shard Tensor-type params - return (sharded_param.quantized_data, ), ( + return ( + sharded_param.quantized_scalers, + sharded_param.quantization_factor, + sharded_param.quantized_data, + ), ( SubclassTensorArgs( sharded_param.size(), sharded_param.stride(), @@ -217,30 +220,10 @@ def _fsdp_pre_all_gather(self, sharded_param: torch.Tensor): sharded_param.block_size, sharded_param.n_blocks, sharded_param.scaler_block_size, - sharded_param.quantized_scalers, - sharded_param.quantization_factor, sharded_param.scaler_mean, sharded_param.nf4, ) - # def fsdp_post_all_gather( - # self, - # all_gather_outputs: Tuple[torch.Tensor, ...], - # metadata: Any, - # param_dtype: torch.dtype, - # *, - # out: Optional[torch.Tensor] = None, - # ) -> Union[Tuple[Tuple[torch.Tensor, ...]], None]: - # (quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4) = all_gather_outputs - # (tensor_meta, block_size, n_blocks, scaler_block_size) = metadata - # if out is not None: - # return - # return (quantized_scalers, quantization_factor, scaler_mean, quantized_data, nf4), () - - # def _fsdp_pre_all_gather(self, sharded_param: torch.Tensor): - # float8_tensor = self.cast_to_float8_e4m3fn(sharded_param, reduce_amax=True) - # return (float8_tensor._data,), (float8_tensor._scale,) - def _fsdp_post_all_gather( self, all_gather_outputs: Tuple[torch.Tensor, ...], @@ -249,13 +232,16 @@ def _fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ) -> Union[Tuple[NF4Tensor, Tuple[torch.Tensor, ...]], None]: - (quantized_data, ) = all_gather_outputs - (tensor_meta, block_size, n_blocks, scaler_block_size, quantized_scalers, quantization_factor, scaler_mean, nf4) = metadata - # TODO: figure out x 2 - tensor_meta.original_shape = (tensor_meta.original_shape[0] * 2, tensor_meta.original_shape[1]) + (quantized_scalers, quantization_factor, quantized_data) = all_gather_outputs + (tensor_meta, block_size, n_blocks, scaler_block_size, scaler_mean, nf4) = metadata + tensor_meta.original_shape = torch.Size([quantized_data.size(0) * 2]) if out is not None: assert isinstance(out, NF4Tensor), f"{type(out)}" assert ( + quantized_scalers.untyped_storage().data_ptr() + == out.quantized_scalers.untyped_storage().data_ptr() and + quantization_factor.untyped_storage().data_ptr() + == out.quantization_factor.untyped_storage().data_ptr() and quantized_data.untyped_storage().data_ptr() == out.quantized_data.untyped_storage().data_ptr() ), f"Expects out's data to be the all-gather output" @@ -271,7 +257,7 @@ def _fsdp_post_all_gather( scaler_mean, quantized_data, nf4, - ), (quantized_data, ) + ), (quantized_scalers, quantization_factor, quantized_data) @dataclass