Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

overwrite to() for QTensor and QBitsTensor #88

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions quanto/tensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
def numpy(self):
return self.dequantize().cpu().numpy()

def to(self, *args, **kwargs):
self._data = self._data.to(*args, **kwargs)
self._scale = self._scale.to(*args, **kwargs)
return self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to return super().to(*args, **kwargs) but it was causing weird behavior with tests using QBitsTensor and it was calling __torch_function__ after . To reproduce, return super().to(*args, **kwargs) and run python -m pytest -sv test/nn/test_qlinear.py::test_move_qlinear

Copy link
Collaborator

@dacorvo dacorvo Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor subclasses are a very special beasts: you should not override the base Tensor methods that way, and instead do it through the dispatch.



class AffineQuantizer(Function):
"""A standard affine quantizer."""
Expand Down Expand Up @@ -407,3 +412,7 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
return QBitsTensor(t._qtype, data, scale, zeropoint)
args, kwargs = pytree.tree_map_only(QBitsTensor, lambda x: x.qtensor(), (args, kwargs or {}))
return op(*args, **kwargs)

def to(self, *args, **kwargs):
self._zeropoint = self._zeropoint.to(*args, **kwargs)
return super().to(*args, **kwargs)
25 changes: 24 additions & 1 deletion test/nn/test_qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from helpers import assert_similar, random_qtensor

from quanto import Calibration, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from quanto import Calibration, QBitsTensor, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from quanto.nn import QLinear


Expand All @@ -27,6 +27,29 @@ def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, act
assert_similar(out, qout, atol=atol)


@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"])
@pytest.mark.parametrize(
"activations",
[None, qfloat8_e5m2, qfloat8_e4m3fn],
ids=["None", "a-float8-e5m2", "a-float8-e4m3"],
)
def test_move_qlinear(use_bias, weights, activations, device):
linear = torch.nn.Linear(32, 32, bias=use_bias)
qlinear = QLinear.from_module(linear, weights=weights, activations=activations)
# QAT optional for weight only quantization
qinputs = random_qtensor((1, 32, 32))
with torch.no_grad(), Calibration():
qlinear(qinputs)
qlinear.freeze()
qlinear.to(device)
if isinstance(qlinear.weight, QTensor):
assert qlinear.weight._data.device.type == device.type
assert qlinear.weight._scale.device.type == device.type
if isinstance(qlinear.weight, QBitsTensor):
assert qlinear.weight._zeropoint.device.type == device.type


@pytest.mark.parametrize("batch_size", [1, 10])
@pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
Expand Down
8 changes: 8 additions & 0 deletions test/tensor/test_qtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from quanto import QTensor, absmax_scale, qfloat8_e4m3fn, qfloat8_e5m2, qint8, qint16, qint32


def test_qtensor_move(device):
input_shape = (2, 4, 8)
qa = random_qtensor(input_shape, dtype=torch.float32)
qa = qa.to(device)
assert qa._data.device.type == device.type
assert qa._scale.device.type == device.type
Comment on lines +11 to +16
Copy link
Member Author

@SunMarc SunMarc Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following works even before this PR. This is why you never had this specific issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint8], ids=["qint8"])
Expand Down
Loading