diff --git a/optimum/quanto/tensor/qbits/qbits.py b/optimum/quanto/tensor/qbits/qbits.py index bb67e91f..f3c7326f 100644 --- a/optimum/quanto/tensor/qbits/qbits.py +++ b/optimum/quanto/tensor/qbits/qbits.py @@ -267,6 +267,9 @@ def qlinear(input, other, bias=None): return QuantizedLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) + elif func is torch.equal: + input, other = args + return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) diff --git a/optimum/quanto/tensor/qtensor.py b/optimum/quanto/tensor/qtensor.py index 22947cbf..77a9c546 100644 --- a/optimum/quanto/tensor/qtensor.py +++ b/optimum/quanto/tensor/qtensor.py @@ -63,3 +63,18 @@ def qtype(self): def numpy(self): return self.dequantize().cpu().numpy() + + def equal(self, other): + if type(self) is not type(other): + return False + self_tensors, self_meta = self.__tensor_flatten__() + _, other_meta = other.__tensor_flatten__() + for name, value in self_meta.items(): + if other_meta[name] != value: + return False + for name in self_tensors: + self_t = getattr(self, name) + other_t = getattr(other, name) + if not torch.equal(self_t, other_t): + return False + return True diff --git a/optimum/quanto/tensor/weights/qbytes.py b/optimum/quanto/tensor/weights/qbytes.py index 7eb3aa31..9e216f18 100644 --- a/optimum/quanto/tensor/weights/qbytes.py +++ b/optimum/quanto/tensor/weights/qbytes.py @@ -125,6 +125,9 @@ def qlinear(input, other, bias=None): return WeightQBytesLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) + elif func is torch.equal: + input, other = args + return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) diff --git a/test/models/test_quantized_model_for_causal_lm.py b/test/models/test_quantized_model_for_causal_lm.py index 401af9eb..302eb3cb 100644 --- a/test/models/test_quantized_model_for_causal_lm.py +++ b/test/models/test_quantized_model_for_causal_lm.py @@ -55,8 +55,7 @@ def compare_models(a_model, b_model): if isinstance(b_m, QModuleMixin): assert isinstance(a_m, QModuleMixin) if isinstance(a_m, QModuleMixin): - assert torch.equal(a_m.weight._data, b_m.weight._data) - assert torch.equal(a_m.weight._scale, b_m.weight._scale) + assert torch.equal(a_m.weight, b_m.weight) for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()): assert a_p_name == b_p_name assert isinstance(a_p, torch.Tensor) diff --git a/test/models/test_quantized_model_for_pixart.py b/test/models/test_quantized_model_for_pixart.py index dec221b0..c26b968d 100644 --- a/test/models/test_quantized_model_for_pixart.py +++ b/test/models/test_quantized_model_for_pixart.py @@ -46,8 +46,7 @@ def compare_models(a_model, b_model): if isinstance(b_m, QModuleMixin): assert isinstance(a_m, QModuleMixin) if isinstance(a_m, QModuleMixin): - assert torch.equal(a_m.weight._data, b_m.weight._data) - assert torch.equal(a_m.weight._scale, b_m.weight._scale) + assert torch.equal(a_m.weight, b_m.weight) for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()): assert a_p_name == b_p_name assert isinstance(a_p, torch.Tensor) diff --git a/test/quantize/test_requantize.py b/test/quantize/test_requantize.py index 95dfce6c..3f0f20b3 100644 --- a/test/quantize/test_requantize.py +++ b/test/quantize/test_requantize.py @@ -44,7 +44,7 @@ def save_and_reload_state_dict(state_dict, serialization): ids=["small", "large"], ) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) @pytest.mark.parametrize("activations", [None, qint8], ids=["a-none", "a-qint8"]) def test_requantize_serialized_model( @@ -64,11 +64,9 @@ def test_requantize_serialized_model( for name, module in model.named_modules(): if isinstance(module, QModuleMixin): module_reloaded = getattr(model_reloaded, name) - assert module_reloaded.weight.qtype == module.weight.qtype + assert torch.equal(module_reloaded.weight, module.weight) assert module_reloaded.weight_qtype == module.weight_qtype assert module_reloaded.activation_qtype == module.activation_qtype - assert torch.equal(module_reloaded.weight._data, module.weight._data) - assert torch.equal(module_reloaded.weight._scale, module.weight._scale) assert torch.equal(module_reloaded.input_scale, module.input_scale) assert torch.equal(module_reloaded.output_scale, module.output_scale) diff --git a/test/tensor/qbits/test_qbits_instantiate.py b/test/tensor/qbits/test_qbits_instantiate.py new file mode 100644 index 00000000..39ed7ceb --- /dev/null +++ b/test/tensor/qbits/test_qbits_instantiate.py @@ -0,0 +1,59 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from optimum.quanto import QBitsTensor, qint2, qint4 + + +def random_data_scale_shift(input_shape, dtype, qtype, axis, group_size): + out_features, in_features = input_shape + n_groups = in_features * out_features // group_size + data_shape = (n_groups, group_size) if axis == 0 else (group_size, n_groups) + scale_shape = (n_groups, 1) if axis == 0 else (1, n_groups) + min_value = -(2 ** (qtype.bits - 1)) + max_value = 2 ** (qtype.bits - 1) - 1 + data = torch.randint(max_value - min_value + 1, data_shape, dtype=torch.uint8) + scale = torch.full(scale_shape, 1.0 / -min_value, dtype=dtype) + shift = torch.ones(scale_shape, dtype=dtype) + return data, scale, shift + + +@pytest.mark.parametrize("input_shape, group_size", [[(32, 32), 16], [(1024, 1024), 128]]) +@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) +def test_qbitstensor_instantiate(input_shape, dtype, qtype, axis, group_size, device): + data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size) + input_stride = torch.ones(input_shape).stride() + qa = QBitsTensor(qtype, axis, group_size, input_shape, input_stride, data, scale=scale, shift=shift).to(device) + assert torch.max(torch.abs(qa.dequantize())) <= 1 + assert qa.dtype == dtype + assert qa.qtype == qtype + assert qa.shape == input_shape + + +@pytest.mark.parametrize("input_shape, group_size", [[(32, 32), 16], [(1024, 1024), 128]]) +@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) +def test_qbitstensor_equal(input_shape, dtype, qtype, axis, group_size, device): + data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size) + qa = QBitsTensor(qtype, axis, group_size, data.size(), data.stride(), data, scale=scale, shift=shift).to(device) + qb = QBitsTensor( + qtype, axis, group_size, data.size(), data.stride(), data.clone(), scale=scale.clone(), shift=shift.clone() + ).to(device) + assert qa.equal(qb) diff --git a/test/tensor/qbits/test_qbitstensor_dispatch.py b/test/tensor/qbits/test_qbitstensor_dispatch.py new file mode 100644 index 00000000..c82be0dc --- /dev/null +++ b/test/tensor/qbits/test_qbitstensor_dispatch.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from helpers import assert_similar, random_qweight, random_tensor + +from optimum.quanto import QBitsTensor, qint2, qint4, quantize_weight + + +@pytest.mark.parametrize("group_size", [None, 128], ids=["channel-wise", "group-wise"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) +def test_qbitstensor_to_device(dtype, group_size, device): + qa = random_qweight((256, 512), dtype=dtype, qtype=qint4, group_size=group_size, device="cpu") + # Keep a copy of the dequantized Tensor as a reference + dqa = qa.dequantize() + # Move to the target device + moved_qa = qa.to(device) + assert isinstance(moved_qa, QBitsTensor) + assert moved_qa.device.type == device.type + assert moved_qa._data.device.type == device.type + assert moved_qa._scale.device.type == device.type + assert moved_qa._shift.device.type == device.type + moved_dqa = moved_qa.dequantize().to("cpu") + if type(moved_qa) is not QBitsTensor: + # Since we use an optimized packing, the order of operations during + # dequantization might differ, but the moved dequantized Tensor should be nearly identical + assert_similar(moved_dqa, dqa) + else: + assert torch.equal(moved_dqa, dqa) + + +def test_qbitstensor_detach(): + qa = random_qweight((32, 32), qtype=qint4) + dqa = qa.detach() + assert isinstance(dqa, QBitsTensor) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("qtype", [qint2, qint4]) +@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) +def test_qbitstensor_equal(dtype, qtype, axis, device): + a = random_tensor((1024, 1024), dtype=dtype, device=device) + qa1 = quantize_weight(a, qtype=qtype, axis=axis, group_size=128) + qa2 = quantize_weight(a, qtype=qtype, axis=axis, group_size=128) + assert torch.equal(qa1, qa2) diff --git a/test/tensor/test_qbitstensor.py b/test/tensor/test_qbitstensor.py index 70e17058..137ca87b 100644 --- a/test/tensor/test_qbitstensor.py +++ b/test/tensor/test_qbitstensor.py @@ -16,7 +16,7 @@ import pytest import torch -from helpers import assert_similar, random_qweight, random_tensor +from helpers import random_qweight, random_tensor from optimum.quanto import QBitsTensor, qint2, qint4, quantize_weight @@ -58,31 +58,3 @@ def test_qbitstensor_backward(qtype, axis, group_size, device): # Backpropagate gradient to the inner float weights qweight.dequantize().backward(gradient) assert torch.equal(weight.grad, gradient) - - -@pytest.mark.parametrize("group_size", [None, 128], ids=["channel-wise", "group-wise"]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) -def test_to_device(dtype, group_size, device): - qa = random_qweight((256, 512), dtype=dtype, qtype=qint4, group_size=group_size, device="cpu") - # Keep a copy of the dequantized Tensor as a reference - dqa = qa.dequantize() - # Move to the target device - moved_qa = qa.to(device) - assert isinstance(moved_qa, QBitsTensor) - assert moved_qa.device.type == device.type - assert moved_qa._data.device.type == device.type - assert moved_qa._scale.device.type == device.type - assert moved_qa._shift.device.type == device.type - moved_dqa = moved_qa.dequantize().to("cpu") - if type(moved_qa) is not QBitsTensor: - # Since we use an optimized packing, the order of operations during - # dequantization might differ, but the moved dequantized Tensor should be nearly identical - assert_similar(moved_dqa, dqa) - else: - assert torch.equal(moved_dqa, dqa) - - -def test_detach(): - qa = random_qweight((32, 32), qtype=qint4) - dqa = qa.detach() - assert isinstance(dqa, QBitsTensor) diff --git a/test/tensor/weights/test_weights_dispatch.py b/test/tensor/weights/test_weights_dispatch.py index 3ac05ab5..335a5d80 100644 --- a/test/tensor/weights/test_weights_dispatch.py +++ b/test/tensor/weights/test_weights_dispatch.py @@ -14,6 +14,16 @@ def test_weight_qytes_tensor_to_device(device): assert qa._scale.device.type == device.type +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("qtype", [qint8]) +@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) +def test_weight_qbytes_tensor_equal(dtype, qtype, axis, device): + a = random_tensor((32, 32), dtype=dtype, device=device) + qa1 = quantize_weight(a, qtype=qtype, axis=axis) + qa2 = quantize_weight(a, qtype=qtype, axis=axis) + assert torch.equal(qa1, qa2) + + @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("qtype", [qint8]) def test_weight_qbytes_tensor_transpose_contiguous(axis, qtype, device): diff --git a/test/tensor/weights/test_weights_instantiate.py b/test/tensor/weights/test_weights_instantiate.py index 2e26c139..d795e064 100644 --- a/test/tensor/weights/test_weights_instantiate.py +++ b/test/tensor/weights/test_weights_instantiate.py @@ -19,23 +19,37 @@ from optimum.quanto import WeightQBytesTensor, qfloat8, qint8 -@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"]) -def test_qbytestensor_instantiate(input_shape, dtype, qtype, device): +def random_data_scale(input_shape, dtype, qtype): if qtype.is_floating_point: - if device.type == "mps": - pytest.skip("float8 types are not supported on MPS device") min_value = torch.finfo(qtype.dtype).min max_value = torch.finfo(qtype.dtype).max data = (torch.rand(input_shape) * max_value + min_value).to(qtype.dtype) else: max_value = torch.iinfo(qtype.dtype).max data = torch.randint(-max_value, max_value, input_shape, dtype=qtype.dtype) - qa = WeightQBytesTensor( - qtype, None, data.size(), data.stride(), data, scale=torch.tensor(1.0 / max_value, dtype=dtype) - ).to(device) + scale = torch.tensor(1.0 / max_value, dtype=dtype) + return data, scale + + +@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"]) +def test_qbytestensor_instantiate(input_shape, dtype, qtype, device): + if qtype.is_floating_point and device.type == "mps": + pytest.skip("float8 types are not supported on MPS device") + data, scale = random_data_scale(input_shape, dtype, qtype) + qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale).to(device) assert torch.max(torch.abs(qa.dequantize())) <= 1 assert qa.dtype == dtype assert qa.qtype == qtype assert qa.shape == input_shape + + +@pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) +def test_qbytestensor_equal(input_shape, dtype, qtype, device): + data, scale = random_data_scale(input_shape, dtype, qtype) + qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale).to(device) + qb = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data.clone(), scale=scale).to(device) + assert qa.equal(qb)