Skip to content

Commit

Permalink
[NF4Tensor] Switch to save for backward since are now a tensor input (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Jun 5, 2024
1 parent c7cd729 commit 2b91917
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,13 +863,13 @@ class LinearNF4(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
ctx.save_for_backward(weight)
return F.linear(input, weight.to(input.dtype))

@staticmethod
def backward(ctx, grad_output):
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)"""
weight: NF4Tensor = ctx.nf4_weight
weight: NF4Tensor = ctx.saved_tensors[0]
return grad_output @ weight.to(grad_output.dtype), None


Expand Down

0 comments on commit 2b91917

Please sign in to comment.