diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index df2b1f08d..56feb1857 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -410,6 +410,7 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor class NF4Tensor(torch.Tensor): """NF4Tensor class for converting a weight to the QLoRA NF4 format""" + @torch._dynamo.disable def __new__( cls, # Args related for base tensor construction @@ -450,6 +451,7 @@ def __new__( ) return nf4tensor + @torch._dynamo.disable def __init__( self, tensor_meta: SubclassTensorArgs, @@ -758,6 +760,7 @@ def __str__(self): return self.to(torch.float32).__str__() @classmethod + @torch._dynamo.disable def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment instead we have created a Autograd.Function to handle the linear @@ -849,7 +852,7 @@ def fsdp_post_all_gather( ), f"Expects out's data to be the all-gather output" return - return NF4Tensor( + return nf4_constructor( tensor_meta, block_size, n_blocks, @@ -934,3 +937,27 @@ def function_cpu(*args, **kwargs): updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs) updated_attrs["device"] = "cpu" return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@torch._dynamo.allow_in_graph +def nf4_constructor( + tensor_meta: SubclassTensorArgs, + block_size: int, + n_blocks: int, + scaler_block_size: int, + quantized_scalers: torch.Tensor, + quantization_factor: torch.Tensor, + scaler_mean: torch.Tensor, + quantized_data: torch.Tensor, + nf4: torch.Tensor, +): + return NF4Tensor( + tensor_meta, + block_size, + n_blocks, + scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + )