Skip to content

Commit 98eb2fa

Browse files
committed
[StaticQuant] add a linear observer class and test
stack-info: PR: #807, branch: drisspg/stack/8
1 parent de0b7fd commit 98eb2fa

File tree

5 files changed

+361
-11
lines changed

5 files changed

+361
-11
lines changed

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ include = [
88
"torchao/dtypes/nf4tensor.py",
99
"test/dtypes/test_nf4.py",
1010
"torchao/float8/float8_tensor.py",
11+
"torchao/quantization/linear_observer_tensor.py",
12+
"test/quantization/test_observer.py",
1113
]

test/quantization/test_observer.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import torch
3+
import torch.nn as nn
34
from torch.testing._internal.common_utils import TestCase
45
from torchao.quantization.observer import (
56
AffineQuantizedMinMaxObserver,
@@ -9,13 +10,23 @@
910
from torchao.quantization.quant_primitives import (
1011
MappingType,
1112
)
13+
from torchao.quantization.linear_observer_tensor import (
14+
insert_observers_,
15+
)
16+
from torch.testing._internal import common_utils
1217
import unittest
18+
1319
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
1420
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
1521

22+
1623
class TestQuantFlow(TestCase):
1724
def _test_obs_helper(self, obs1, obs2):
18-
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
25+
example_inputs = [
26+
torch.randn(10, 2048),
27+
torch.randn(10, 2048),
28+
torch.randn(10, 2048),
29+
]
1930
for example_input in example_inputs:
2031
obs1(example_input)
2132
obs2(example_input)
@@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2):
2637
self.assertTrue(torch.allclose(zero_point1, zero_point2))
2738

2839
def test_min_max_per_tensor_affine(self):
29-
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
40+
obs = AffineQuantizedMinMaxObserver(
41+
MappingType.ASYMMETRIC,
42+
torch.uint8,
43+
granularity_type=PerTensor(),
44+
eps=torch.finfo(torch.float32).eps,
45+
scale_dtype=torch.float,
46+
zero_point_dtype=torch.int,
47+
)
3048
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
3149
self._test_obs_helper(obs, ref_obs)
3250

3351
def test_min_max_per_channel_affine(self):
34-
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
35-
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
52+
obs = AffineQuantizedMinMaxObserver(
53+
MappingType.ASYMMETRIC,
54+
torch.uint8,
55+
granularity_type=PerAxis(axis=0),
56+
eps=torch.finfo(torch.float32).eps,
57+
scale_dtype=torch.float,
58+
zero_point_dtype=torch.int,
59+
)
60+
ref_obs = PerChannelMinMaxObserver(
61+
dtype=torch.uint8, qscheme=torch.per_channel_affine
62+
)
3663
self._test_obs_helper(obs, ref_obs)
3764

3865
def test_block_size_calc_success(self):
@@ -109,5 +136,82 @@ def test_block_size_row_errors(self):
109136
obs(example_input)
110137

111138

139+
class TestLinearObserver(TestCase):
140+
@common_utils.parametrize("observe_weight", [True, False])
141+
def test_linear_observer_tensor(self, observe_weight: bool):
142+
# Create a simple linear layer
143+
in_features, out_features = 10, 5
144+
linear = nn.Linear(in_features, out_features)
145+
146+
# Create observers
147+
input_observer = AffineQuantizedMinMaxObserver(
148+
MappingType.SYMMETRIC,
149+
torch.float8_e4m3fn,
150+
granularity_type=PerTensor(),
151+
eps=torch.finfo(torch.float32).eps,
152+
scale_dtype=torch.float,
153+
zero_point_dtype=torch.int,
154+
zero_point_domain=None,
155+
)
156+
if observe_weight:
157+
weight_observer = AffineQuantizedMinMaxObserver(
158+
MappingType.SYMMETRIC,
159+
torch.float8_e4m3fn,
160+
granularity_type=PerTensor(),
161+
eps=torch.finfo(torch.float32).eps,
162+
scale_dtype=torch.float,
163+
zero_point_dtype=torch.int,
164+
zero_point_domain=None,
165+
)
166+
else:
167+
weight_observer = None
168+
169+
# Wrap the weight with LinearObserverTensor
170+
insert_observers_(linear, input_observer, weight_observer)
171+
172+
# Create some example inputs
173+
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
174+
max_val = 42.1234
175+
min_val = -39.760
176+
big_tensor = torch.full((6, in_features), max_val)
177+
small_tensor = torch.full((40, in_features), min_val)
178+
example_inputs.extend([big_tensor, small_tensor])
179+
180+
# Run forward passes
181+
for example_input in example_inputs:
182+
_ = linear(example_input)
183+
184+
input_observer = linear.weight.input_observer
185+
186+
# Check that the observers have recorded statistics
187+
assert input_observer.min_val == min_val
188+
assert input_observer.max_val == max_val
189+
190+
# Calculate qparams and ensure they're not None
191+
input_scale, input_zero_point = input_observer.calculate_qparams()
192+
193+
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
194+
self.assertEqual(
195+
input_scale.item(),
196+
max_val / max_fp8,
197+
)
198+
self.assertIsNotNone(input_zero_point)
199+
200+
if observe_weight:
201+
weight_observer = linear.weight.weight_observer
202+
weight_scale, weight_zero_point = weight_observer.calculate_qparams()
203+
torch.testing.assert_close(
204+
weight_scale,
205+
torch.max(linear.weight.original_weight_tensor) / max_fp8,
206+
atol=5e-5,
207+
rtol=0.0,
208+
)
209+
self.assertIsNotNone(weight_zero_point)
210+
else:
211+
self.assertIsNone(linear.weight.weight_observer)
212+
213+
214+
common_utils.instantiate_parametrized_tests(TestLinearObserver)
215+
112216
if __name__ == "__main__":
113217
unittest.main()
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import torch
2+
import torch.nn as nn
3+
from typing import Callable, Optional, Dict
4+
from torch.utils._python_dispatch import return_and_correct_aliasing
5+
from torchao.utils import (
6+
TorchAOBaseTensor,
7+
TORCH_VERSION_AT_LEAST_2_5,
8+
)
9+
10+
from torchao.quantization.quant_api import (
11+
_replace_with_custom_fn_if_matches_filter,
12+
_is_linear,
13+
)
14+
from torchao.quantization.observer import AffineQuantizedObserverBase
15+
16+
__all__ = [
17+
"LinearActivationWeightObservedTensor",
18+
"insert_observers_",
19+
]
20+
21+
aten = torch.ops.aten
22+
Tensor = torch.Tensor
23+
24+
25+
class LinearActivationWeightObservedTensor(TorchAOBaseTensor):
26+
"""
27+
This subclass of Tensor is used in conjuction with a static calibration flow.
28+
The flow is broken up into 3 parts;
29+
1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers
30+
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
31+
3. quantize_ the model to static using the statistics recorded by the observer
32+
33+
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
34+
will first calculat statistics on BOTH the input and weight, and then run the linear op.
35+
"""
36+
37+
original_weight_tensor: torch.Tensor
38+
input_observer: Optional[AffineQuantizedObserverBase]
39+
weight_observer: Optional[AffineQuantizedObserverBase]
40+
41+
def __new__(
42+
cls,
43+
original_weight_tensor: torch.Tensor,
44+
input_observer: Optional[AffineQuantizedObserverBase] = None,
45+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
46+
):
47+
kwargs = {}
48+
dtype = original_weight_tensor.dtype
49+
kwargs["dtype"] = dtype
50+
kwargs["requires_grad"] = False
51+
kwargs["device"] = original_weight_tensor.device
52+
shape = original_weight_tensor.shape
53+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
54+
55+
def __init__(
56+
self,
57+
original_weight_tensor: torch.Tensor,
58+
input_observer: Optional[AffineQuantizedObserverBase] = None,
59+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
60+
):
61+
self.original_weight_tensor = original_weight_tensor
62+
self.input_observer = input_observer
63+
self.weight_observer = weight_observer
64+
65+
def __repr__(self):
66+
return (
67+
f"LinearActivationWeightObservedTensor(\n"
68+
f"original_weight={self.original_weight_tensor}\n"
69+
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
70+
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
71+
)
72+
73+
def __tensor_flatten__(self):
74+
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]
75+
76+
@classmethod
77+
def __tensor_unflatten__(
78+
cls,
79+
tensor_data_dict: Dict[str, Tensor],
80+
tensor_attributes,
81+
outer_size,
82+
outer_stride,
83+
):
84+
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
85+
(input_observer, weight_observer) = tensor_attributes
86+
return cls(original_weight_tensor, input_observer, weight_observer)
87+
88+
@classmethod
89+
def from_float(
90+
cls,
91+
original_weight_tensor: Tensor,
92+
input_observer: Optional[AffineQuantizedObserverBase] = None,
93+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
94+
):
95+
return cls(original_weight_tensor, input_observer, weight_observer)
96+
97+
def _apply_fn_to_data(self, fn: Callable):
98+
"""Applies a fn to the tensor component of the LinearActivationWeightObservedTensor"""
99+
return self.__class__(
100+
fn(self.original_weight_tensor),
101+
self.input_observer,
102+
self.weight_observer,
103+
)
104+
105+
def to(self, *args, **kwargs):
106+
kwargs = self._get_to_kwargs(*args, **kwargs)
107+
return self._apply_fn_to_data(lambda x: x.to(**kwargs))
108+
109+
110+
implements = LinearActivationWeightObservedTensor.implements
111+
112+
113+
@implements(torch.nn.functional.linear)
114+
def _(func, types, args, kwargs):
115+
input_tensor, weight_tensor, bias = (
116+
args[0],
117+
args[1],
118+
args[2] if len(args) > 2 else None,
119+
)
120+
if weight_tensor.input_observer is not None:
121+
input_tensor = weight_tensor.input_observer(input_tensor)
122+
if weight_tensor.weight_observer is not None:
123+
weight_tensor = weight_tensor.weight_observer(
124+
weight_tensor.original_weight_tensor
125+
)
126+
else:
127+
weight_tensor = weight_tensor.original_weight_tensor
128+
129+
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
130+
131+
132+
@implements(aten.detach.default)
133+
def _(func, types, args, kwargs):
134+
return return_and_correct_aliasing(
135+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
136+
)
137+
138+
139+
@implements(aten.clone.default)
140+
def _(func, types, args, kwargs):
141+
return return_and_correct_aliasing(
142+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
143+
)
144+
145+
146+
@implements(aten._to_copy.default)
147+
def _(func, types, args, kwargs):
148+
return return_and_correct_aliasing(
149+
func,
150+
args,
151+
kwargs,
152+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
153+
)
154+
155+
156+
if TORCH_VERSION_AT_LEAST_2_5:
157+
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
158+
torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor])
159+
160+
161+
def insert_observers_(
162+
model: nn.Module,
163+
input_observer: Optional[AffineQuantizedObserverBase],
164+
weight_observer: Optional[AffineQuantizedObserverBase],
165+
*,
166+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
167+
):
168+
"""
169+
Converts the weight of a linear module to a LinearActivationWeightObservedTensor.
170+
171+
This function wraps the weight of the given linear module with a LinearActivationWeightObservedTensor,
172+
which enables observation of both input and weight tensors during forward passes.
173+
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility
174+
with PyTorch's module system.
175+
176+
Example::
177+
178+
```
179+
import torch
180+
import torch.nn as nn
181+
from torchao.quantization.linear_observer_tensor import insert_observers_
182+
from torchao.quantization.observer import (
183+
AffineQuantizedMinMaxObserver,
184+
PerTensor,
185+
MappingType
186+
)
187+
188+
# Create observers
189+
input_observer = AffineQuantizedMinMaxObserver(
190+
MappingType.SYMMETRIC,
191+
torch.float8_e4m3fn,
192+
granularity_type=PerTensor(),
193+
eps=torch.finfo(torch.float32).eps,
194+
scale_dtype=torch.float,
195+
zero_point_dtype=torch.int,
196+
zero_point_domain=None,
197+
)
198+
199+
# Create a linear module
200+
linear_module = nn.Linear(10, 20)
201+
202+
# Convert the linear module's weight to an observed tensor
203+
insert_observers_(linear_module, input_observer, weight_observer=None)
204+
205+
# The linear_module can now be used as usual, with observers calculating statistics
206+
output = linear_module(torch.randn(10, 10))
207+
```
208+
209+
Args:
210+
model (nn.Module): The nn.Module to convert.
211+
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor.
212+
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor.
213+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Filter function to select which modules to convert.
214+
If not provided, all linear modules will be converted.
215+
216+
Returns:
217+
nn.Linear: The modified linear module with its weight wrapped in a LinearActivationWeightObservedTensor.
218+
"""
219+
220+
def convert_to_linear_observer(linear_module: nn.Linear):
221+
# Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter
222+
linear_module.weight = nn.Parameter(
223+
LinearActivationWeightObservedTensor.from_float(
224+
linear_module.weight,
225+
input_observer=input_observer,
226+
weight_observer=weight_observer,
227+
),
228+
requires_grad=linear_module.weight.requires_grad,
229+
)
230+
return linear_module
231+
232+
_replace_with_custom_fn_if_matches_filter(
233+
model,
234+
convert_to_linear_observer,
235+
_is_linear if filter_fn is None else filter_fn,
236+
)

0 commit comments

Comments
 (0)