Skip to content

Commit

Permalink
Enable model.to(device) for int8 weight only quantized model (pytor…
Browse files Browse the repository at this point in the history
…ch#486)

Summary:
Fix some implementation issue for `int8_wo_quantized_model.to(device)`

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_model_to_device

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 8, 2024
1 parent bf64e23 commit 8f0a0d5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
16 changes: 16 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,5 +620,21 @@ def test_quantized_tensor_subclass_save_load(self):
self.assertEqual(res, ref)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_model_to_device(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")

quantize_(m, int8_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to("cuda"),)
m.to(device="cuda")
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)


if __name__ == "__main__":
unittest.main()
9 changes: 6 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,11 @@ def _get_to_kwargs(self, *args, **kwargs):

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
# not supported yet
kwargs.pop("memory_format")
return self.__class__(
self.layout_tensor.to(kwargs["device"]),
self.layout_tensor.to(device),
self.block_size,
self.shape,
self.quant_min,
Expand Down Expand Up @@ -470,8 +473,8 @@ def to(self, *args, **kwargs):
if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"):
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device")
return self.__class__(
self.packed_weight.to(kwargs["device"]),
self.scale_and_zero.to(kwargs["device"]),
self.packed_weight.to(device),
self.scale_and_zero.to(device),
self.transposed
)

Expand Down

0 comments on commit 8f0a0d5

Please sign in to comment.