From ec6affec2ab80cb63490c6ea43b0cd3854a8be4e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 22 Apr 2024 17:03:40 -0700 Subject: [PATCH] Bug Fix: NF4 .to('cuda') (#158) --- test/dtypes/test_nf4.py | 14 ++++++++++++++ torchao/dtypes/nf4tensor.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index e3b25e3c3..55bbe0bcb 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -192,6 +192,20 @@ def test_to_copy(self, dtype: torch.dtype): nf4_to_dtype = inpt_tensor_nf4.to(dtype) torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) + @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") + def test_to_copy_device(self): + inpt_tensor = torch.rand(128, device='cpu') + t = to_nf4(inpt_tensor, 32, 2) + assert t.device == torch.device('cpu') + z = t.cuda() + assert z.device.type == "cuda" # Because the device could be cuda:0 + x = z.cpu() + assert x.device == torch.device('cpu') + + inpt_tensor = torch.rand(128, device='cuda') + t = to_nf4(inpt_tensor, 32, 2) + assert t.device.type == "cuda" + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_to_dtype(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, dtype=dtype) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 886eb6c0a..f09d53821 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -47,7 +47,7 @@ def _to_copy(func, *args, **kwargs): if not args[0][0].is_contiguous(): assert args[0][0].t().is_contiguous() return func(args[0][0].t()).t() - return args[0][0].get_original_weight().to(args[1]["dtype"]) + return args[0][0].get_original_weight().to(args[1]["dtype"]).to(args[1]["device"]) @implements([torch.ops.aten.to.dtype])