Skip to content

Commit

Permalink
Add register_apply_tensor_subclass (pytorch#366)
Browse files Browse the repository at this point in the history
Summary:
`register_apply_tensor_subclass` allows users to add a string shortcut for
a new apply_tensor_subclass function, they can use this to test their new dtype tensor subclass

see `test/quantization/test_quant_api.py -k test_register_apply_tensor_subclass` for detail

Test Plan:
python test/quantization/test_quant_api.py -k test_register_apply_tensor_subclass

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jun 14, 2024
1 parent 0252ee0 commit 924ebdc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
20 changes: 18 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao import quantize
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Quantizer,
TwoStepQuantizer,
quantize,
int8da_int4w,
int4wo,
int8wo,
Expand All @@ -51,6 +51,7 @@
from torchao.utils import unwrap_tensor_subclass
import copy
import tempfile
from torch.testing._internal.common_utils import TestCase


def dynamic_quant(model, example_inputs):
Expand Down Expand Up @@ -147,7 +148,7 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)

class TestQuantFlow(unittest.TestCase):
class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
Expand Down Expand Up @@ -601,5 +602,20 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

def test_register_apply_tensor_subclass(self):
from torchao import register_apply_tensor_subclass
def apply_my_dtype(weight):
return weight * 2

m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
with self.assertRaisesRegex(ValueError, "not supported"):
quantize(m, "my_dtype")

register_apply_tensor_subclass("my_dtype", apply_my_dtype)
# make sure it runs
quantize(m, "my_dtype")
m(*example_inputs)

if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@

from torchao.quantization import (
autoquant,
quantize,
register_apply_tensor_subclass,
)
from . import dtypes

__all__ = [
"dtypes",
"autoquant",
"quantize",
"register_apply_tensor_subclass",
]
12 changes: 2 additions & 10 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
from .autoquant import *

__all__ = [
"DynamicallyPerAxisQuantizedLinear",
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
Expand All @@ -31,14 +25,12 @@
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"Int8DynamicallyQuantizedLinearWeight",
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"compute_error",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize_affine",
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"register_apply_tensor_subclass",
]
23 changes: 21 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight


Expand All @@ -50,13 +51,14 @@
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize",
"autoquant",
"_get_subclass_inserter",
"quantize",
"int8da_int4w",
"int8da_int8w",
"int4wo",
"int8wo",
"register_apply_tensor_subclass",
]

from .GPTQ import (
Expand Down Expand Up @@ -292,7 +294,8 @@ def filter_fn(module, fqn):
m = quantize(m, apply_weight_quant, filter_fn)
"""
if isinstance(apply_tensor_subclass, str):
assert apply_tensor_subclass in _APPLY_TS_TABLE, f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}"
if apply_tensor_subclass not in _APPLY_TS_TABLE:
raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}")
apply_tensor_subclass = _APPLY_TS_TABLE[apply_tensor_subclass]

assert not isinstance(apply_tensor_subclass, str)
Expand Down Expand Up @@ -438,3 +441,19 @@ def get_per_token_block_size(x):
"int8_weight_only": int8wo(),
"int8_dynamic": int8da_int8w(),
}

def register_apply_tensor_subclass(name: str, apply_tensor_subclass: Callable):
"""Register a string shortcut for `apply_tensor_subclass` that takes a weight Tensor
as input and ouptuts a tensor with tensor subclass applied
Example:
def apply_my_dtype(weight):
return weight * 2
register_apply_tensor_subclass("my_dtype", apply_my_dtype)
# calls `apply_my_dtype` on weights
quantize(m, "my_dtype")
"""
if name in _APPLY_TS_TABLE:
logging.warning(f"shortcut string {name} already exist, overwriting")
_APPLY_TS_TABLE[name] = apply_tensor_subclass

0 comments on commit 924ebdc

Please sign in to comment.