Skip to content

Commit

Permalink
Update torchao api reference and add contributor guide
Browse files Browse the repository at this point in the history
Summary:
1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py
and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py
2. added pytorch#391 to torchao docs

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Nov 8, 2024
1 parent e41ca4e commit cba0672
Show file tree
Hide file tree
Showing 16 changed files with 825 additions and 73 deletions.
4 changes: 3 additions & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ torchao.dtypes

to_nf4
to_affine_quantized_intx
to_affine_quantized_floatx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
AffineQuantizedTensor

..
Expand Down
9 changes: 3 additions & 6 deletions docs/source/api_ref_intro.rst
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
``torchao`` API Reference
=========================

This section introduces the torchao API reference.
Dive into the details of how torchao integrates with PyTorch to
optimize your machine learning models.
This section introduces the torchao API reference. Dive into the details of how torchao integrates with PyTorch to optimize your machine learning models.

.. toctree::
:glob:
:maxdepth: 1
:caption: Python API Reference

api_ref_sparsity
api_ref_quantization
api_ref_dtypes
api_ref_kernel
api_ref_quantization
api_ref_sparsity
39 changes: 32 additions & 7 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,40 @@ torchao.quantization
.. autosummary::
:toctree: generated/
:nosignatures:

SmoothFakeDynQuantMixin
SmoothFakeDynamicallyQuantizedLinear
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference
Int4WeightOnlyGPTQQuantizer
Int4WeightOnlyQuantizer
autoquant

quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
int4_weight_only
int8_weight_only
float8_weight_only
float8_dynamic_activation_float8_weight
float8_static_activation_float8_weight
uintx_weight_only
fpx_weight_only

to_linear_activation_quantized
to_linear_activation_weight_observed

swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
quantize_affine
quantize_affine_floatx
dequantize_affine
dequantize_affine_floatx
choose_qparams_and_quantize_affine_hqq
fake_quantize_affine
fake_quantize_affine_cachemask

safe_int_mm
int_scaled_matmul

MappingType
ZeroPointDomain
TorchAODType

674 changes: 674 additions & 0 deletions docs/source/contributor_guide.rst

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
Welcome to the torchao Documentation
=======================================

**torchao** is an open-source library that provides the functionality
to quantize and prune your models using native PyTorch. Our documentation is under development
with more content coming soon.
`**torchao** <https://github.com/pytorch/ao>`__ is a ibrary for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials.

..
.. grid:: 3
Expand Down Expand Up @@ -81,13 +79,19 @@ with more content coming soon.
:maxdepth: 1
:caption: API Reference

api_ref_sparsity
api_ref_intro
api_ref_quantization
api_ref_dtypes
api_ref_quantization
api_ref_sparsity
..
api_ref_kernel
.. toctree::
:glob:
:maxdepth: 1
:caption: Contributor Guide

contributor_guide

.. toctree::
:glob:
:maxdepth: 1
Expand Down
10 changes: 6 additions & 4 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
)
from torchao.quantization.quant_primitives import (
from torchao.quantization import (
safe_int_mm,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
dequantize_affine,
Expand Down Expand Up @@ -781,7 +783,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
Expand Down Expand Up @@ -973,11 +975,11 @@ def test_weight_only_groupwise_embedding_quant(self):
group_size = 64
m = nn.Embedding(4096, 128)
input = torch.randint(0, 4096, (1, 6))

quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding))
y_q = m(input)
y_ref = m.weight.dequantize()[input]

sqnr = compute_error(y_ref, y_q)

self.assertGreater(sqnr, 45.0)
Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
int_scaled_matmul,
quantize_affine,
quantize_affine_floatx,
)
from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
)
Expand Down
7 changes: 7 additions & 0 deletions torchao/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm

__all__ = [
"safe_int_mm",
"int_scaled_matmul",
]
68 changes: 50 additions & 18 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from .quant_api import * # noqa: F403
from .subclass import * # noqa: F403
from .quant_primitives import * # noqa: F403
from torchao.kernel import (
safe_int_mm,
int_scaled_matmul,
)
from .utils import * # noqa: F403
from .weight_only import * # noqa: F403
from .unified import *
Expand All @@ -20,38 +24,66 @@
from .linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)
from .linear_activation_weight_observed_tensor import (
to_linear_activation_weight_observed,
)


__all__ = [
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
# top level API - auto
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"compute_error",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize_affine",
"dequantize_affine",
"choose_qprams_affine",

# top level API - manual
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight"
"uintx_weight_only",
"fpx_weight_only",
"LinearActivationQuantizedTensor",

# smooth quant - subject to change
"swap_conv2d_1x1_to_linear"
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"compute_error",

# building blocks
"to_linear_activation_quantized",
"to_weight_tensor_with_linear_activation_scale_metadata",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight"

# quant primitive ops
"choose_qprams_affine",
"choose_qparams_affine_with_min_max",
"choose_qparams_affine_floatx",
"quantize_affine",
"quantize_affine_floatx",
"dequantize_affine",
"dequantize_affine_floatx",
"choose_qparams_and_quantize_affine_hqq",
"fake_quantize_affine",
"fake_quantize_affine_cachemask",

# operators/kernels
"safe_int_mm",
"int_scaled_matmul",

"MappingType",
"ZeroPointDomain",
"TorchAODType",

"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"LinearActivationQuantizedTensor",
]
26 changes: 13 additions & 13 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PerRow,
PerTensor,
)
from .quant_primitives import safe_int_mm
from torchao.kernel import safe_int_mm
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.float8.inference import Float8MMConfig
Expand Down Expand Up @@ -661,31 +661,31 @@ def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, *
# TODO: Mode being a list is weird, should be a string or some object
@torch.no_grad()
def autoquant(
model,
example_input=None,
qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST,
filter_fn=None,
mode=["interpolate", .85],
manual=False,
model,
example_input=None,
qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST,
filter_fn=None,
mode=["interpolate", .85],
manual=False,
set_inductor_config=True,
supress_autoquant_errors=True,
**aq_kwargs
):
"""
Autoquantization is a process which identifies the fastest way to quantize each layer of a model over some set of potential
qtensor subclasses.
Autoquantization happens in three steps:
1-Prepare Model: the model is searched for Linear layers whose weights are exchanged for AutoQuantizableLinearWeight.
2-Shape Calibration: the user runs the model on one or more inputs, the details of the activation shape/dtype seen by
2-Shape Calibration: the user runs the model on one or more inputs, the details of the activation shape/dtype seen by
the AutoQuantizableLinearWeight are recorded so we know what shapes/dtypes to use in order to optimize the quantized op in step 3
3-Finalize Autoquantization: for each AutoQuantizableLinearWeight, benchmarks are run for each shape/dtype on each member of the qtensor_class_list.
the fastest option is picked, resulting in a highly performant model
This autoquant function performs step 1. Steps 2 and 3 can be completed by simply running the model.
If `example_input` is provided, this function also runs the model (which completes steps 2 and 3).
This autoquant api can handle models which have already had torch.compile applied to them, in which case, once the model is run and quantized,
This autoquant function performs step 1. Steps 2 and 3 can be completed by simply running the model.
If `example_input` is provided, this function also runs the model (which completes steps 2 and 3).
This autoquant api can handle models which have already had torch.compile applied to them, in which case, once the model is run and quantized,
the torch.compile process normally proceeds as well.
To optimize over a combination of input shapes/dtypes, the user can set manual=True, run the model with all desired shapes/dtypes, then
Expand All @@ -699,7 +699,7 @@ def autoquant(
filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"),
and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85].
manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for
manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for
the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged.
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True)
Expand Down
14 changes: 7 additions & 7 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
"""
Applies activation quantization for linear operator, this is used to support
dynamic quantization or static quantization, user can pass in a `input_quant_func`
dynamic quantization, user can pass in a `input_quant_func`
that is used to quantize the activation
Args:
Expand All @@ -30,7 +30,7 @@ class LinearActivationQuantizedTensor(TorchAOBaseTensor):
Restriction: Must not contain tensor values.
"""
quant_kwargs: Dict[str, Any]

def __new__(
cls,
original_weight_tensor: torch.Tensor,
Expand All @@ -56,7 +56,7 @@ def __init__(
self.quant_kwargs = quant_kwargs

def __repr__(self):
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_quant_func, self.quant_kwargs]
Expand All @@ -82,9 +82,9 @@ def _quantized_linear_op(input_tensor: torch.Tensor, weight_tensor: torch.Tensor
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)

@classmethod
def from_float(cls,
input_float: torch.Tensor,
input_quant_func: Callable,
def from_float(cls,
input_float: torch.Tensor,
input_quant_func: Callable,
quant_kwargs: Optional[Dict[str, Any]] = None):
if quant_kwargs is None:
quant_kwargs = {}
Expand Down Expand Up @@ -199,4 +199,4 @@ def _(func, types, args, kwargs):

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearActivationQuantizedTensor])
torch.serialization.add_safe_globals([LinearActivationQuantizedTensor])
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

__all__ = [
"LinearActivationWeightObservedTensor",
"to_linear_activation_weight_observed",
]

aten = torch.ops.aten
Expand Down Expand Up @@ -147,6 +148,8 @@ def _(func, types, args, kwargs):
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

to_linear_activation_weight_observed = LinearActivationWeightObservedTensor.from_float


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
Loading

0 comments on commit cba0672

Please sign in to comment.