Skip to content

Commit ec6affe

Browse files
authored
Bug Fix: NF4 .to('cuda') (#158)
1 parent 2003325 commit ec6affe

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

test/dtypes/test_nf4.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,20 @@ def test_to_copy(self, dtype: torch.dtype):
192192
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
193193
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)
194194

195+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
196+
def test_to_copy_device(self):
197+
inpt_tensor = torch.rand(128, device='cpu')
198+
t = to_nf4(inpt_tensor, 32, 2)
199+
assert t.device == torch.device('cpu')
200+
z = t.cuda()
201+
assert z.device.type == "cuda" # Because the device could be cuda:0
202+
x = z.cpu()
203+
assert x.device == torch.device('cpu')
204+
205+
inpt_tensor = torch.rand(128, device='cuda')
206+
t = to_nf4(inpt_tensor, 32, 2)
207+
assert t.device.type == "cuda"
208+
195209
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
196210
def test_to_dtype(self, dtype: torch.dtype):
197211
inpt_tensor = torch.rand(128, dtype=dtype)

torchao/dtypes/nf4tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _to_copy(func, *args, **kwargs):
4747
if not args[0][0].is_contiguous():
4848
assert args[0][0].t().is_contiguous()
4949
return func(args[0][0].t()).t()
50-
return args[0][0].get_original_weight().to(args[1]["dtype"])
50+
return args[0][0].get_original_weight().to(args[1]["dtype"]).to(args[1]["device"])
5151

5252

5353
@implements([torch.ops.aten.to.dtype])

0 commit comments

Comments
 (0)