Skip to content

Commit

Permalink
Bug Fix: NF4 .to('cuda') (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Apr 23, 2024
1 parent 2003325 commit ec6affe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ def test_to_copy(self, dtype: torch.dtype):
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_to_copy_device(self):
inpt_tensor = torch.rand(128, device='cpu')
t = to_nf4(inpt_tensor, 32, 2)
assert t.device == torch.device('cpu')
z = t.cuda()
assert z.device.type == "cuda" # Because the device could be cuda:0
x = z.cpu()
assert x.device == torch.device('cpu')

inpt_tensor = torch.rand(128, device='cuda')
t = to_nf4(inpt_tensor, 32, 2)
assert t.device.type == "cuda"

@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_dtype(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _to_copy(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return func(args[0][0].t()).t()
return args[0][0].get_original_weight().to(args[1]["dtype"])
return args[0][0].get_original_weight().to(args[1]["dtype"]).to(args[1]["device"])


@implements([torch.ops.aten.to.dtype])
Expand Down

0 comments on commit ec6affe

Please sign in to comment.