diff --git a/quanto/tensor/core.py b/quanto/tensor/core.py index 5d338f36..da99afc6 100644 --- a/quanto/tensor/core.py +++ b/quanto/tensor/core.py @@ -300,6 +300,11 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None): def numpy(self): return self.dequantize().cpu().numpy() + def to(self, *args, **kwargs): + self._data = self._data.to(*args, **kwargs) + self._scale = self._scale.to(*args, **kwargs) + return self + class AffineQuantizer(Function): """A standard affine quantizer.""" @@ -407,3 +412,7 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None): return QBitsTensor(t._qtype, data, scale, zeropoint) args, kwargs = pytree.tree_map_only(QBitsTensor, lambda x: x.qtensor(), (args, kwargs or {})) return op(*args, **kwargs) + + def to(self, *args, **kwargs): + self._zeropoint = self._zeropoint.to(*args, **kwargs) + return super().to(*args, **kwargs) diff --git a/test/nn/test_qlinear.py b/test/nn/test_qlinear.py index 80b0c96a..99ddd0ef 100644 --- a/test/nn/test_qlinear.py +++ b/test/nn/test_qlinear.py @@ -2,7 +2,7 @@ import torch from helpers import assert_similar, random_qtensor -from quanto import Calibration, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8 +from quanto import Calibration, QBitsTensor, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8 from quanto.nn import QLinear @@ -27,6 +27,29 @@ def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, act assert_similar(out, qout, atol=atol) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize( + "activations", + [None, qfloat8_e5m2, qfloat8_e4m3fn], + ids=["None", "a-float8-e5m2", "a-float8-e4m3"], +) +def test_move_qlinear(use_bias, weights, activations, device): + linear = torch.nn.Linear(32, 32, bias=use_bias) + qlinear = QLinear.from_module(linear, weights=weights, activations=activations) + # QAT optional for weight only quantization + qinputs = random_qtensor((1, 32, 32)) + with torch.no_grad(), Calibration(): + qlinear(qinputs) + qlinear.freeze() + qlinear.to(device) + if isinstance(qlinear.weight, QTensor): + assert qlinear.weight._data.device.type == device.type + assert qlinear.weight._scale.device.type == device.type + if isinstance(qlinear.weight, QBitsTensor): + assert qlinear.weight._zeropoint.device.type == device.type + + @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) diff --git a/test/tensor/test_qtensor.py b/test/tensor/test_qtensor.py index 62b791e0..84e5fe2f 100644 --- a/test/tensor/test_qtensor.py +++ b/test/tensor/test_qtensor.py @@ -8,6 +8,14 @@ from quanto import QTensor, absmax_scale, qfloat8_e4m3fn, qfloat8_e5m2, qint8, qint16, qint32 +def test_qtensor_move(device): + input_shape = (2, 4, 8) + qa = random_qtensor(input_shape, dtype=torch.float32) + qa = qa.to(device) + assert qa._data.device.type == device.type + assert qa._scale.device.type == device.type + + @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8], ids=["qint8"])