diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 48249434b..25c33417a 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -272,7 +272,10 @@ def _to_copy(func, *args, **kwargs): if not args[0][0].is_contiguous(): assert args[0][0].t().is_contiguous() return func(args[0][0].t()).t() - return args[0][0].get_original_weight().to(args[1]["dtype"]).to(args[1]["device"]) + out = args[0][0].get_original_weight().to(args[1]["dtype"]) + if "device" in args[1]: + out = out.to(args[1]["device"]) + return out @implements([torch.ops.aten.to.dtype])