Skip to content

Commit

Permalink
dont let dynamo inline inside of NF4 constructors or __torch_dispatch…
Browse files Browse the repository at this point in the history
…__ (#544)

* dont let dynamo inline inside of NF4 constructors or __torch_dispatch__

* allow_in_graph the ctr instead of disabling it
  • Loading branch information
bdhirsh authored Jul 28, 2024
1 parent afde175 commit e5b705c
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
class NF4Tensor(torch.Tensor):
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""

@torch._dynamo.disable
def __new__(
cls,
# Args related for base tensor construction
Expand Down Expand Up @@ -450,6 +451,7 @@ def __new__(
)
return nf4tensor

@torch._dynamo.disable
def __init__(
self,
tensor_meta: SubclassTensorArgs,
Expand Down Expand Up @@ -758,6 +760,7 @@ def __str__(self):
return self.to(torch.float32).__str__()

@classmethod
@torch._dynamo.disable
def __torch_dispatch__(cls, func, types, args, kwargs=None):
"""TODO we are not supporting torch dispatch at the moment
instead we have created a Autograd.Function to handle the linear
Expand Down Expand Up @@ -849,7 +852,7 @@ def fsdp_post_all_gather(
), f"Expects out's data to be the all-gather output"
return

return NF4Tensor(
return nf4_constructor(
tensor_meta,
block_size,
n_blocks,
Expand Down Expand Up @@ -934,3 +937,27 @@ def function_cpu(*args, **kwargs):
updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs)
updated_attrs["device"] = "cpu"
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))

@torch._dynamo.allow_in_graph
def nf4_constructor(
tensor_meta: SubclassTensorArgs,
block_size: int,
n_blocks: int,
scaler_block_size: int,
quantized_scalers: torch.Tensor,
quantization_factor: torch.Tensor,
scaler_mean: torch.Tensor,
quantized_data: torch.Tensor,
nf4: torch.Tensor,
):
return NF4Tensor(
tensor_meta,
block_size,
n_blocks,
scaler_block_size,
quantized_scalers,
quantization_factor,
scaler_mean,
quantized_data,
nf4,
)

0 comments on commit e5b705c

Please sign in to comment.