Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Tensor.to() #90

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions quanto/tensor/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy
from typing import Optional

import torch
Expand Down Expand Up @@ -313,10 +314,23 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
@classmethod
def __torch_dispatch__(cls, op, types, args, kwargs=None):
if op.overloadpacket is torch.ops.aten.detach:
# Detach is required when copying and deserializing
t = args[0]
data = op(t._data)
scale = op(t._scale)
zeropoint = op(t._zeropoint)
return QBitsTensor(t._qtype, data, scale, zeropoint)
elif op.overloadpacket is torch.ops.aten._to_copy:
t = args[0]
# Copy scale
scale = op(t._scale, **kwargs)
# Move data and zeropoint, ignoring dtype (it only applies to scale)
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = torch.uint8
data = op(t._data, **data_kwargs)
zeropoint_kwargs = copy(kwargs)
zeropoint_kwargs["dtype"] = torch.int8
zeropoint = op(t._data, **data_kwargs)
return QBitsTensor(t._qtype, data, scale, zeropoint)
args, kwargs = pytree.tree_map_only(QBitsTensor, lambda x: x.qtensor(), (args, kwargs or {}))
return op(*args, **kwargs)
6 changes: 6 additions & 0 deletions quanto/tensor/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def get_qtensor_func(func):
return _QTENSOR_FUNC_TABLE.get(func, None)


@register_qtensor_func([torch._has_compatible_shallow_copy_type])
def has_compatible_shallow_copy_type(func, input: torch.Tensor, from_: torch.Tensor):
# Prevent torch from trying to shallow copy one QTensor to another
return False


@register_qtensor_func(
[
torch.nn.functional.cross_entropy,
Expand Down
8 changes: 8 additions & 0 deletions quanto/tensor/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
t = args[0]
data = op(t._data)
return PackedTensor(data, t._bits, t.size(), t.stride())
elif op.overloadpacket is torch.ops.aten._to_copy:
t = args[0]
dtype = kwargs.get("dtype", torch.uint8)
if dtype != torch.uint8:
raise ValueError(f"PackedTensor are torch.uint8 only and cannot be moved to {dtype}.")
# Move data
data = op(t._data, **kwargs)
return PackedTensor(data, t._bits, t.size(), t.stride())
args, kwargs = pytree.tree_map_only(PackedTensor, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)

Expand Down
7 changes: 6 additions & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from packaging import version

from quanto import QTensor, absmax_scale, qint8
from quanto import QBitsTensor, QTensor, absmax_scale, qint4, qint8


def torch_min_version(v):
Expand Down Expand Up @@ -39,6 +39,11 @@ def random_qtensor(shape, qtype=qint8, dtype=torch.float32, axis=None):
return QTensor.quantize(t, qtype=qtype, scale=scale)


def random_qbitstensor(shape, qtype=qint4, dtype=torch.float32, axis=None):
t = random_tensor(shape, dtype)
return QBitsTensor.quantize(t, qtype=qtype, axis=axis)


def q_assert_close(x: torch.Tensor, xq: QTensor, atol: float = None, rtol: float = None):
# Please refer to that discussion for default rtol values based on the float type:
# https://scicomp.stackexchange.com/questions/43111/float-equality-tolerance-for-single-and-half-precision
Expand Down
15 changes: 14 additions & 1 deletion test/nn/test_qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from helpers import assert_similar, random_qtensor

from quanto import Calibration, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from quanto import Calibration, QBitsTensor, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from quanto.nn import QLinear


Expand Down Expand Up @@ -113,3 +113,16 @@ def test_qlinear_gradient(tokens, embeddings, activations, weights, device):
assert torch.allclose(qlinear.bias.grad, bias_gradient)
weight_gradient = torch.matmul(gradient.squeeze().t(), qinputs.dequantize().squeeze())
assert torch.allclose(qlinear.weight.grad, weight_gradient)


@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"])
def test_move_qlinear(use_bias, weights, device):
linear = torch.nn.Linear(32, 32, bias=use_bias)
qlinear = QLinear.from_module(linear, weights=weights)
qlinear.freeze()
qlinear.to(device)
assert qlinear.weight._data.device.type == device.type
assert qlinear.weight._scale.device.type == device.type
if isinstance(qlinear.weight, QBitsTensor):
assert qlinear.weight._zeropoint.device.type == device.type
19 changes: 19 additions & 0 deletions test/tensor/ops/test_qbitstensor_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from helpers import random_qbitstensor

from quanto import QBitsTensor, qint4


def test_to_device(device):
qa = random_qbitstensor((32, 32), qtype=qint4)
qa = qa.to(device)
assert isinstance(qa, QBitsTensor)
assert qa.device.type == device.type
assert qa._data.device.type == device.type
assert qa._scale.device.type == device.type
assert qa._zeropoint.device.type == device.type


def test_detach():
qa = random_qbitstensor((32, 32), qtype=qint4)
dqa = qa.detach()
assert isinstance(dqa, QBitsTensor)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def test_to_device(device):
qa = qa.to(device)
assert isinstance(qa, QTensor)
assert qa.device.type == device.type
assert qa._data.device.type == device.type
assert qa._scale.device.type == device.type


@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)])
Expand Down
Loading