Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Tensor.equal and torch.equal for QTensor #294

Merged
merged 7 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions optimum/quanto/tensor/qbits/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def qlinear(input, other, bias=None):
return QuantizedLinearFunction.apply(input, other, bias)

return qlinear(*args, **kwargs)
elif func is torch.equal:
input, other = args
return input.equal(other)
# Defer to operations dispatcher
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions optimum/quanto/tensor/qtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ def qtype(self):

def numpy(self):
return self.dequantize().cpu().numpy()

def equal(self, other):
if type(self) is not type(other):
return False
self_tensors, self_meta = self.__tensor_flatten__()
_, other_meta = other.__tensor_flatten__()
for name, value in self_meta.items():
if other_meta[name] != value:
return False
for name in self_tensors:
self_t = getattr(self, name)
other_t = getattr(other, name)
if not torch.equal(self_t, other_t):
return False
return True
3 changes: 3 additions & 0 deletions optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def qlinear(input, other, bias=None):
return WeightQBytesLinearFunction.apply(input, other, bias)

return qlinear(*args, **kwargs)
elif func is torch.equal:
input, other = args
return input.equal(other)
# Defer to operations dispatcher
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions test/models/test_quantized_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def compare_models(a_model, b_model):
if isinstance(b_m, QModuleMixin):
assert isinstance(a_m, QModuleMixin)
if isinstance(a_m, QModuleMixin):
assert torch.equal(a_m.weight._data, b_m.weight._data)
assert torch.equal(a_m.weight._scale, b_m.weight._scale)
assert torch.equal(a_m.weight, b_m.weight)
for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()):
assert a_p_name == b_p_name
assert isinstance(a_p, torch.Tensor)
Expand Down
3 changes: 1 addition & 2 deletions test/models/test_quantized_model_for_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def compare_models(a_model, b_model):
if isinstance(b_m, QModuleMixin):
assert isinstance(a_m, QModuleMixin)
if isinstance(a_m, QModuleMixin):
assert torch.equal(a_m.weight._data, b_m.weight._data)
assert torch.equal(a_m.weight._scale, b_m.weight._scale)
assert torch.equal(a_m.weight, b_m.weight)
for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()):
assert a_p_name == b_p_name
assert isinstance(a_p, torch.Tensor)
Expand Down
6 changes: 2 additions & 4 deletions test/quantize/test_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def save_and_reload_state_dict(state_dict, serialization):
ids=["small", "large"],
)
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
@pytest.mark.parametrize("activations", [None, qint8], ids=["a-none", "a-qint8"])
def test_requantize_serialized_model(
Expand All @@ -64,11 +64,9 @@ def test_requantize_serialized_model(
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
assert module_reloaded.weight.qtype == module.weight.qtype
assert torch.equal(module_reloaded.weight, module.weight)
assert module_reloaded.weight_qtype == module.weight_qtype
assert module_reloaded.activation_qtype == module.activation_qtype
assert torch.equal(module_reloaded.weight._data, module.weight._data)
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)

Expand Down
59 changes: 59 additions & 0 deletions test/tensor/qbits/test_qbits_instantiate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch

from optimum.quanto import QBitsTensor, qint2, qint4


def random_data_scale_shift(input_shape, dtype, qtype, axis, group_size):
out_features, in_features = input_shape
n_groups = in_features * out_features // group_size
data_shape = (n_groups, group_size) if axis == 0 else (group_size, n_groups)
scale_shape = (n_groups, 1) if axis == 0 else (1, n_groups)
min_value = -(2 ** (qtype.bits - 1))
max_value = 2 ** (qtype.bits - 1) - 1
data = torch.randint(max_value - min_value + 1, data_shape, dtype=torch.uint8)
scale = torch.full(scale_shape, 1.0 / -min_value, dtype=dtype)
shift = torch.ones(scale_shape, dtype=dtype)
return data, scale, shift


@pytest.mark.parametrize("input_shape, group_size", [[(32, 32), 16], [(1024, 1024), 128]])
@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"])
def test_qbitstensor_instantiate(input_shape, dtype, qtype, axis, group_size, device):
data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size)
input_stride = torch.ones(input_shape).stride()
qa = QBitsTensor(qtype, axis, group_size, input_shape, input_stride, data, scale=scale, shift=shift).to(device)
assert torch.max(torch.abs(qa.dequantize())) <= 1
assert qa.dtype == dtype
assert qa.qtype == qtype
assert qa.shape == input_shape


@pytest.mark.parametrize("input_shape, group_size", [[(32, 32), 16], [(1024, 1024), 128]])
@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"])
def test_qbitstensor_equal(input_shape, dtype, qtype, axis, group_size, device):
data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size)
qa = QBitsTensor(qtype, axis, group_size, data.size(), data.stride(), data, scale=scale, shift=shift).to(device)
qb = QBitsTensor(
qtype, axis, group_size, data.size(), data.stride(), data.clone(), scale=scale.clone(), shift=shift.clone()
).to(device)
assert qa.equal(qb)
57 changes: 57 additions & 0 deletions test/tensor/qbits/test_qbitstensor_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from helpers import assert_similar, random_qweight, random_tensor

from optimum.quanto import QBitsTensor, qint2, qint4, quantize_weight


@pytest.mark.parametrize("group_size", [None, 128], ids=["channel-wise", "group-wise"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
def test_qbitstensor_to_device(dtype, group_size, device):
qa = random_qweight((256, 512), dtype=dtype, qtype=qint4, group_size=group_size, device="cpu")
# Keep a copy of the dequantized Tensor as a reference
dqa = qa.dequantize()
# Move to the target device
moved_qa = qa.to(device)
assert isinstance(moved_qa, QBitsTensor)
assert moved_qa.device.type == device.type
assert moved_qa._data.device.type == device.type
assert moved_qa._scale.device.type == device.type
assert moved_qa._shift.device.type == device.type
moved_dqa = moved_qa.dequantize().to("cpu")
if type(moved_qa) is not QBitsTensor:
# Since we use an optimized packing, the order of operations during
# dequantization might differ, but the moved dequantized Tensor should be nearly identical
assert_similar(moved_dqa, dqa)
else:
assert torch.equal(moved_dqa, dqa)


def test_qbitstensor_detach():
qa = random_qweight((32, 32), qtype=qint4)
dqa = qa.detach()
assert isinstance(dqa, QBitsTensor)


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint2, qint4])
@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"])
def test_qbitstensor_equal(dtype, qtype, axis, device):
a = random_tensor((1024, 1024), dtype=dtype, device=device)
qa1 = quantize_weight(a, qtype=qtype, axis=axis, group_size=128)
qa2 = quantize_weight(a, qtype=qtype, axis=axis, group_size=128)
assert torch.equal(qa1, qa2)
30 changes: 1 addition & 29 deletions test/tensor/test_qbitstensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from helpers import assert_similar, random_qweight, random_tensor
from helpers import random_qweight, random_tensor

from optimum.quanto import QBitsTensor, qint2, qint4, quantize_weight

Expand Down Expand Up @@ -58,31 +58,3 @@ def test_qbitstensor_backward(qtype, axis, group_size, device):
# Backpropagate gradient to the inner float weights
qweight.dequantize().backward(gradient)
assert torch.equal(weight.grad, gradient)


@pytest.mark.parametrize("group_size", [None, 128], ids=["channel-wise", "group-wise"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
def test_to_device(dtype, group_size, device):
qa = random_qweight((256, 512), dtype=dtype, qtype=qint4, group_size=group_size, device="cpu")
# Keep a copy of the dequantized Tensor as a reference
dqa = qa.dequantize()
# Move to the target device
moved_qa = qa.to(device)
assert isinstance(moved_qa, QBitsTensor)
assert moved_qa.device.type == device.type
assert moved_qa._data.device.type == device.type
assert moved_qa._scale.device.type == device.type
assert moved_qa._shift.device.type == device.type
moved_dqa = moved_qa.dequantize().to("cpu")
if type(moved_qa) is not QBitsTensor:
# Since we use an optimized packing, the order of operations during
# dequantization might differ, but the moved dequantized Tensor should be nearly identical
assert_similar(moved_dqa, dqa)
else:
assert torch.equal(moved_dqa, dqa)


def test_detach():
qa = random_qweight((32, 32), qtype=qint4)
dqa = qa.detach()
assert isinstance(dqa, QBitsTensor)
10 changes: 10 additions & 0 deletions test/tensor/weights/test_weights_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ def test_weight_qytes_tensor_to_device(device):
assert qa._scale.device.type == device.type


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint8])
@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"])
def test_weight_qbytes_tensor_equal(dtype, qtype, axis, device):
a = random_tensor((32, 32), dtype=dtype, device=device)
qa1 = quantize_weight(a, qtype=qtype, axis=axis)
qa2 = quantize_weight(a, qtype=qtype, axis=axis)
assert torch.equal(qa1, qa2)


@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"])
@pytest.mark.parametrize("qtype", [qint8])
def test_weight_qbytes_tensor_transpose_contiguous(axis, qtype, device):
Expand Down
32 changes: 23 additions & 9 deletions test/tensor/weights/test_weights_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,37 @@
from optimum.quanto import WeightQBytesTensor, qfloat8, qint8


@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"])
def test_qbytestensor_instantiate(input_shape, dtype, qtype, device):
def random_data_scale(input_shape, dtype, qtype):
if qtype.is_floating_point:
if device.type == "mps":
pytest.skip("float8 types are not supported on MPS device")
min_value = torch.finfo(qtype.dtype).min
max_value = torch.finfo(qtype.dtype).max
data = (torch.rand(input_shape) * max_value + min_value).to(qtype.dtype)
else:
max_value = torch.iinfo(qtype.dtype).max
data = torch.randint(-max_value, max_value, input_shape, dtype=qtype.dtype)
qa = WeightQBytesTensor(
qtype, None, data.size(), data.stride(), data, scale=torch.tensor(1.0 / max_value, dtype=dtype)
).to(device)
scale = torch.tensor(1.0 / max_value, dtype=dtype)
return data, scale


@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"])
def test_qbytestensor_instantiate(input_shape, dtype, qtype, device):
if qtype.is_floating_point and device.type == "mps":
pytest.skip("float8 types are not supported on MPS device")
data, scale = random_data_scale(input_shape, dtype, qtype)
qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale).to(device)
assert torch.max(torch.abs(qa.dequantize())) <= 1
assert qa.dtype == dtype
assert qa.qtype == qtype
assert qa.shape == input_shape


@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"])
@pytest.mark.parametrize("qtype", [qint8], ids=["qint8"])
def test_qbytestensor_equal(input_shape, dtype, qtype, device):
data, scale = random_data_scale(input_shape, dtype, qtype)
qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale).to(device)
qb = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data.clone(), scale=scale).to(device)
assert qa.equal(qb)
Loading