diff --git a/test/test_utils.py b/test/test_utils.py index f06835c932..9f93e445bc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -144,6 +144,10 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy): self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata)) self.assertEqual(lp_tensor.attr, reconstructed.attr) + # test _get_to_kwargs + _ = lp_tensor._get_to_kwargs(torch.strided, device="cuda") + _ = lp_tensor._get_to_kwargs(layout=torch.strided, device="cuda") + # `to` / `_to_copy` original_device = lp_tensor.device lp_tensor = lp_tensor.to("cuda") diff --git a/torchao/utils.py b/torchao/utils.py index daf7eab83c..dac1e68e47 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -717,9 +717,7 @@ def _get_tensor_impl_constructor( def _get_to_kwargs(self, *args, **kwargs): # `torch._C._nn._parse_to` can't handle `layout` argument - for arg in args: - if isinstance(arg, torch.layout): - args.remove(arg) + args = tuple(arg for arg in args if not isinstance(arg, torch.layout)) if "layout" in kwargs: kwargs.pop("layout") # ignoring `non_blocking` and `memory_format` args since these are not