Skip to content

Commit

Permalink
Fixed f-string printing of NF4Tensors (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed May 30, 2024
1 parent 38dad9b commit 4c1d568
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ def _to_copy(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return func(args[0][0].t()).t()
return args[0][0].get_original_weight().to(args[1]["dtype"]).to(args[1]["device"])
out = args[0][0].get_original_weight().to(args[1]["dtype"])
if "device" in args[1]:
out = out.to(args[1]["device"])
return out


@implements([torch.ops.aten.to.dtype])
Expand Down

0 comments on commit 4c1d568

Please sign in to comment.