Skip to content

Commit

Permalink
refactor(QModuleMixin): accept optimizer parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Mar 25, 2024
1 parent ce3ff35 commit ef95921
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 17 deletions.
7 changes: 5 additions & 2 deletions quanto/nn/qconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..tensor import QTensor, qtype
from ..tensor import Optimizer, QTensor, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand All @@ -12,7 +12,9 @@
@register_qmodule(torch.nn.Conv2d)
class QConv2d(QModuleMixin, torch.nn.Conv2d):
@classmethod
def qcreate(cls, module, weights: qtype, activations: Optional[qtype] = None):
def qcreate(
cls, module, weights: qtype, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None
):
return cls(
in_channels=module.in_channels,
out_channels=module.out_channels,
Expand All @@ -27,6 +29,7 @@ def qcreate(cls, module, weights: qtype, activations: Optional[qtype] = None):
device=module.weight.device,
weights=weights,
activations=activations,
optimizer=optimizer,
)

def qforward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down
11 changes: 9 additions & 2 deletions quanto/nn/qlayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..tensor import qtype
from ..tensor import Optimizer, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand All @@ -12,7 +12,13 @@
@register_qmodule(torch.nn.LayerNorm)
class QLayerNorm(QModuleMixin, torch.nn.LayerNorm):
@classmethod
def qcreate(cls, module, weights: Optional[qtype] = None, activations: Optional[qtype] = None):
def qcreate(
cls,
module,
weights: Optional[qtype] = None,
activations: Optional[qtype] = None,
optimizer: Optional[Optimizer] = None,
):
if activations is None:
return None
return cls(
Expand All @@ -24,6 +30,7 @@ def qcreate(cls, module, weights: Optional[qtype] = None, activations: Optional[
device=module.weight.device,
weights=None, # We never quantize QLayerNorm weights
activations=activations,
optimizer=None, # We never quantize QLayerNorm weights
)

def qforward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down
7 changes: 5 additions & 2 deletions quanto/nn/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..tensor import QTensor, qtype
from ..tensor import Optimizer, QTensor, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand All @@ -12,7 +12,9 @@
@register_qmodule(torch.nn.Linear)
class QLinear(QModuleMixin, torch.nn.Linear):
@classmethod
def qcreate(cls, module, weights: qtype, activations: Optional[qtype] = None):
def qcreate(
cls, module, weights: qtype, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None
):
return cls(
module.in_features,
module.out_features,
Expand All @@ -21,6 +23,7 @@ def qcreate(cls, module, weights: qtype, activations: Optional[qtype] = None):
device=module.weight.device,
weights=weights,
activations=activations,
optimizer=optimizer,
)

def qforward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down
23 changes: 17 additions & 6 deletions quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from ..tensor import QBitsTensor, QTensor, absmax_scale, qint2, qint4, qtype, qtypes
from ..tensor import Optimizer, QBitsTensor, QTensor, qint2, qint4, qtype, qtypes


__all__ = ["QModuleMixin", "register_qmodule", "quantize_module"]
Expand Down Expand Up @@ -77,6 +77,7 @@ def __init__(
*args,
weights: Optional[Union[qtype, str]] = None,
activations: Optional[Union[qtype, str]] = None,
optimizer: Optional[Optimizer] = None,
**kwargs,
):
# The tests below are meant to help people writing their own quantized Module class
Expand All @@ -103,6 +104,7 @@ def __init__(
group_size = group_size // 2
self.weight_group_size = group_size
self.activation_qtype = activations
self.optimizer = optimizer
self.register_buffer("input_scale", torch.ones(()))
self.register_buffer("output_scale", torch.ones(()))

Expand Down Expand Up @@ -183,9 +185,13 @@ def deserialize_tensor_subclass(t, state_dict, prefix):

@classmethod
def from_module(
cls, module: torch.nn.Module, weights: Optional[qtype] = None, activations: Optional[qtype] = None
cls,
module: torch.nn.Module,
weights: Optional[qtype] = None,
activations: Optional[qtype] = None,
optimizer: Optional[Optimizer] = None,
):
qmodule = cls.qcreate(module, weights, activations)
qmodule = cls.qcreate(module, weights, activations, optimizer)
if qmodule is None:
return None
with torch.no_grad():
Expand Down Expand Up @@ -217,11 +223,16 @@ def qweight(self):
# Quantize dynamically the weights per-axis
if self.weight_qtype in (qint2, qint4):
return QBitsTensor.quantize(
self.weight, qtype=self.weight_qtype, axis=0, group_size=self.weight_group_size
self.weight,
qtype=self.weight_qtype,
axis=0,
group_size=self.weight_group_size,
optimizer=self.optimizer,
)
elif isinstance(self.weight_qtype, qtype):
wscale = absmax_scale(self.weight, axis=0)
return QTensor.quantize(self.weight, qtype=self.weight_qtype, axis=0, group_size=None, scale=wscale)
return QTensor.quantize(
self.weight, qtype=self.weight_qtype, axis=0, group_size=None, optimizer=self.optimizer
)
raise ValueError(f"Invalid quantized weights type {self.weight_qtype}")

def qforward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down
29 changes: 24 additions & 5 deletions test/model/test_quantize_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from helpers import assert_similar, get_device_memory, random_qtensor, random_tensor

from quanto import (
AbsmaxOptimizer,
Calibration,
MaxOptimizer,
QLinear,
QTensor,
freeze,
qfloat8_e4m3fn,
qfloat8_e5m2,
qint4,
qint8,
quantize,
safe_load,
Expand Down Expand Up @@ -48,10 +51,10 @@ def get_outputs(model, batch_size, input_features, device):
return model(qinputs)


def _test_quantize_mlp(weights, activations, frozen, device):
def _test_quantize_mlp(weights, activations, optimizer, frozen, device):
model = MLP(32, 10, 128).to(device)
output = get_outputs(model, 1, 32, device)
quantize(model, weights=weights, activations=activations)
quantize(model, weights=weights, activations=activations, optimizer=optimizer)
if frozen:
freeze(model)
check_mlp(model, frozen)
Expand All @@ -66,14 +69,14 @@ def _test_quantize_mlp(weights, activations, frozen, device):
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
def test_quantize_mlp_weights_only(weights, frozen, device):
_test_quantize_mlp(weights, None, frozen, device)
_test_quantize_mlp(weights, None, None, frozen, device)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
def test_quantize_mlp_int8_activations(weights, frozen, device):
_test_quantize_mlp(weights, qint8, frozen, device)
_test_quantize_mlp(weights, qint8, None, frozen, device)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
Expand All @@ -85,7 +88,7 @@ def test_quantize_mlp_int8_activations(weights, frozen, device):
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
def test_quantize_mlp_float8_activations(weights, activations, frozen, device):
_test_quantize_mlp(weights, activations, frozen, device)
_test_quantize_mlp(weights, activations, None, frozen, device)


def save_and_reload_state_dict(state_dict, serialization):
Expand Down Expand Up @@ -171,3 +174,19 @@ def test_quantized_mlp_device_memory(weights, dtype, weights_only, device):
reloaded_memory = get_device_memory(device)
# Device memory can be lower when reloading (less fragmentation ?)
assert reloaded_memory <= quantized_memory


@pytest.mark.parametrize(
"weights, optimizer", [[qint8, AbsmaxOptimizer()], [qint4, MaxOptimizer()]], ids=["w-qint8", "w-qint4"]
)
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device):
_test_quantize_mlp(weights, None, optimizer, frozen, device)


@pytest.mark.parametrize(
"weights, optimizer", [[qint8, MaxOptimizer()], [qint4, AbsmaxOptimizer()]], ids=["w-qint8", "w-qint4"]
)
def test_quantize_mlp_wrong_optimizer(weights, optimizer, device):
with pytest.raises(ValueError):
_test_quantize_mlp(weights, None, optimizer, False, device)

0 comments on commit ef95921

Please sign in to comment.