Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] torchao Contributor Guide #391

Open
jerryzh168 opened this issue Jun 18, 2024 · 16 comments
Open

[RFC] torchao Contributor Guide #391

jerryzh168 opened this issue Jun 18, 2024 · 16 comments
Labels

Comments

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jun 18, 2024

Status: Draft
Updated: 09/18/2024

Objective

In this doc we’ll talk about how different optimization techniques are structured in torchao and how to contribute to torchao.

torchao Stack Overview

First we want to lay out the torchao stack:

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
        Quantized Tensors (derived dtypes): AffineQuantizedTensor, LUTQuantizedTensor
---------------------------------------------------------------------------------------------
  Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
            Basic dtypes: uint1-uint7, int2-int8, float3-float8

Any quantization algorithm will be using some components from the above stack, for example int4_weight_only quantization uses:
(1) weight only quantization flow
(2) tinygemm bf16 activation + int4 weight kernel and quant primitive ops
(3) AffineQuantizedTensor tensor subclass with TensorCoreTiledLayout
(4) torch.uint4 dtype (simulated with quant_min/quant_max right now)

Note: we'll also talk about how to compose sparsity with quantization in the Quantized Tensors section

Basic DTypes

dtype is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call torch.empty(.., dtype)), for more details please check out: https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833

No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data, the dtypes we aim to support in torchao are:

  • torch.uint1 - torch.uint8
  • torch.int1 to torch.int8 (to be added in pytorch)
  • torch.float3_e2_m0, torch.float4_e2_m1, torch.float4_e3_m0, torch.float5_e2_m2, torch.float5_e3_m1, torch.float6_e2_m3, torch.float6_e3_m2, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz (float8 is added to torch, we also plan to add float4 and float6 to torch if they become popular)

Note some of the above are prototype only for now. We'll consider adding then to pytorch core when they become popular and have hardware support.

Current Support

In terms of actual implementation, there are two parts:
1). In PyTorch, we need to add the dtype to torch.dtype, e.g. torch.uint2, example: pytorch/pytorch#117208, but these are just placeholders so that we can use torch.uint2.
2). Outside of PyTorch (e.g. in torchao), we implement the tensor operations for these dtypes with tensor subclasses, also a standard packing format is needed.

Adding placeholder dtype in PyTorch

As mentioned in https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833, the criteria for adding dtype in PyTorch is that it shows wide adoption. For the above mentioned fundamental dtypes, the ones that are supported in PyTorch are:

  • torch.uint1 - torch.uint8, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz

We may add torch.int2 to torch.int7 to PyTorch soon due to request from edge team, but for the other types we plan to wait until there is more evidence of wide adoption and hardware support.

Implementing tensor operations for these dtypes with Tensor subclasses

For this, the requirement is we decide on a "standard" packing format, and hopefully one that is amenable to efficient implementation, but for both uintx and floatx we haven't integrate enough kernels to decide on this. So current packing implementations (e.g.

class UintxTensor(TorchAOBaseTensor):
) are not final. We can revisit after there are more intx and floatx kernels being integrated into torchao.

Quantization Primitive Ops / Efficient Kernels

Quantization Primitive Ops

Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators:
choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization
quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters
dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters

There could be variations of the above to accommodate specific use cases, for example for static quantization we may have choose_qparams_affine_with_min_max that will choose quantization parameters based on min/max values derived from the observation process.

Efficient kernels

We'll also have efficient kernels that works with the low precision tensors, for example

_weight_int4pack_mm the tinygemm int4 kernel (bf16 activation + int4 weight)
int_matmul that takes two int8 tensors and outputs an int32 tensor
int_scaled_matmul that does matmul and also applies a scale to the result.

Note: We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization kernel just relies on torch.compile to get speedup. In this case there is no specific "efficient kernel" that's corresponding to the type of quantization.

Quantized Tensors (derived dtypes)

On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale.

Existing example in torchao is AffineQuantizedTensor, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: low_precision_val = high_precision_val / scale + zero_point, where scale/zero_point are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Affine quantization is a very common type of quantization, since it's straightforward that when we try to map from higher precision values to lower precision values, we do an affine transformation (high_preicsion_val / scale + zero_point). Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is look up table based quantization.

Layout and Packing

Native tensors have a hardcoded list of selections of layout: https://github.com/pytorch/pytorch/blob/647815049ec28a72dc1bb6a977791927bba058d5/c10/core/Layout.h#L11, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout.

The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. And the extension of layout can be achieved at python level tensor subclasses without modifying C++ pytorch core code.

We use this to support different ways that the same quantized Tensor can be packed for efficient execution, for example, for _weight_int4pack_mm we need to pack the weight to an format that is friendly for Tensor Core, we call it TensorCoreTiledLayoutType. We add a layout_tensor for the quantized tensor to store the packed (or unpacked) weight, and we use a layout_type to store different parameters that's relevant for packing.

class AffineQuantizedTensor(...):
    # layout_tensor is also implemented with tensor subclass    
    layout_tensor: torch.Tensor

    @property
    def layout_type(self) -> LayoutType:
        return self.layout_tensor.layout_type

Note that layout is an abstraction not only for custom data representation, it is also used for how the
layout Tensor interacts with different operators, e.g. the same data representation can have different
implementations when running the same operator, e.g. transpose, quantized_linear, even the operator semantics should stay the same.

Quantize + Sparse Tensor can also be supported through the Layout abstraction, for example, int4 weight only quantization + sparse. We also provide some common utils that helps people to add different layouts to a quantized tensor, please check out the developer guide below for code examples.

Quantization Algorithms/Flows

On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up.

For demonstration purposes, let's say after previous step we have AffineQuantizedTensor and to_affine_quantized factory function defined. For simplicity, let's say to_affine_quantized takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an AffineQuantizedTensor with corresponding dtype.

Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in Tensor Subclass Developer Guide section.

Weight Only Quantization

This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:

linear_module.weight = torch.nn.Parameter(to_affine_quantized(linear_module.weight), requires_grad=False))

apply the above to all linear modules in the model and we'll get a weight only quantized model.

Dynamic Activation Quantization + Weight Quantization

This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao, the common pattern we use is by applying to_linear_activation_quantized on top of quantized weight:

quantized_weight = to_affine_quantized(linear_module.weight)
activation_and_weight_quantized = to_linear_activation_quantized(quantized_weight)
linear_module.weight = torch.nn.Parameter(activation_and_weight_quantized, requires_grad=False))

to_linear_activation_quantized is used to apply quantization to activation, it takes a input_quant_func that will quantize the activation and the original weight, and during runtime when it encounters a F.linear op, it will apply the stored input_qunat_func to activation and redispatch to F.linear with quantized activation and weight.

If the above does not work, user can also do module swaps, or use torch.export.unflatten.unflatten() to get a traced module that you can modify

But using tensor subclass is preferred because it is easier for serialization/deserialization, if we use tensor subclasses to support dynamic quantization, then we can load the quantized weights directly without further preparation for the model. Otherwise, we'd need to do module swap or other modifications to the model first before loading the quantized weights.

Static Quantization

Static quantization means activation is statically quantized instead of dynamically quantized at runtime. In terms of flow, static quantization requires calibration with sample data in order that we can figure out the appropriate quantization parameters.

At the high level there are three steps for static quantization: (1) insert observers (2) calibration (3) quantize the model

Insert Observers

In insert observers step, we need to add observer modules to input (and output) activation and weight of the operator to collect statistics of the Tensor. So there are two things we need to address, how to define observer module? how to add observer module to the model.

How to define observer module

Observers are specific to: (1) type of quantization (e.g. affine quantization, look up table based quantization) (2) type of stats we want to track, e.g. min max observer, moving average observer.

Generally an observer module should define forward and calculate_qparams

For affine quantization, we defined AffineQuantizedMinMaxObserver that records min_val/max_val based on the granularity of affine quantization, and also defines how to calculate_qparams based on the recorded stats.

How to add observer module to the model
  1. Use Tensor Subclasses
    If the only operator you are interested in quantizing is linear, you can use linear activation weight observer, we also have a corresponding insert_observer_ API that handles modifying the weight of linear.

  2. Module swap
    Alternatively, you could also define and ObservedLinear module (or other module types) and swap the non observed with the observed module

Calibration

Calibration step is typically straightforward, typically we just need to run the model through the calibration dataset. For more complicated calibration (e.g. where we record all inputs and do optimizations based on all inputs), we'll cover some of them in next section.

Quantize

We can reuse the quantize_ API but provide a different apply_tensor_subclass function that converts the observed linear module to a linear module with quantized weight and statically quantized input activation, this can be done in the same manner as the dynamic quantization (with to_linear_activation_quantized), see example.

Alternatively, user can do module swap as well.

Other Quantization Flows

For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, GPTQ like quantization flow that is adopted by Autoround, it uses MultiTensor and module hooks to optimize the module.

If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details.

Training

The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well.

Quantization Aware Training

Low Bit Optimizers

Today we have some prototype low bit optimizers: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). We can extend our AffineQuantizedTensor for that to be used in optimizers as well following the example.

Quantized Training

Similar to low bit optimizers, we have quantized training prototype in https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training, and we could extend AffineQuantizedTensor to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc.

Case Study: How int4 weight only quantization works in torchao?

To connect everything together, here is a more detailed walk through for how int4 weight only quantization is implemented in torchao.

High Level Summary

Quantization Flow: quantize_(model, int4_weight_only())
      * What happens: linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight), requires_grad=False)
      * quantization primitive ops: choose_qparams and quantize_affine are called to quantize the Tensor
      * quantized Tensor will be `AffineQuantizedTensor`, a quantized tensor with derived dtype (e.g. int4 with scale and zero_point)
      * packing op `_convert_weight_to_int4pack` to pack the quantized weight for efficient execution

During Model Execution: model(input)
      * `torch.ops.aten._weight_int4pack_mm` is called on input and the packed weight

During Quantization

First we start with the API call: quantize_(model, int4_weight_only()) what this does is it converts the weights of nn.Linear modules in the model to int4 quantized tensor (AffineQuantizedTensor that is int4 dtype, asymmetric, per group quantized), using the layout for tinygemm kernel: "tensor_core_tiled" layout type.

  • quantize_: the model level API that quantizes the weight of linear by applying the conversion function from user (second argument)
  • int4_weight_only: the function that returns a function that converts weight of linear to int4 weight only quantized weight
  • TensorCoreTiledLayoutType: the tensor core tiled layout type, storing parameters for the packing format
  • TensorCoreTiledAQTLayout: the tensor core tiled layout tensor, stores the packed weight for efficient int4 weight only kernel (tinygemm kernel)

During Model Execution

When we run the quantized model model(inputs), we'll run through the functional linear operator in nn.Linear:

return F.linear(input, weight, bias)

where input is a bfloat16 floating point Tensor, weight is an int4 AffineQuantizedTensor, it calls into a __torch_function__ of the AffineQuantizedTensor subclass, which will end up in an implementation for F.linear when one of the input is AffineQuantizedTensor, so it calls:

return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

The _quantized_linear_op goes through the _AQT_QLINEAR_DISPATCH_TABLE and checks each dispatch conditions, if the dispatch condition passes, it will call the implementation with input/weight/bias. Please check out this doc for the explanation of dispatch_condition and impl

In this case the dispatch_condition for the int4 weight only quantization kernel will be this and the implementation we are using will be this, the function takes an bfloat16 input Tensor and an int4 AffineQuantizedTensor, and call torch.ops.aten._weight_int4pack_mm with the input Tensor and the packed weight that's stored in weight_tensor.layout_tensor.

During Save/Load

Since AffineQuantizedTensor weight is still a torch.Tensor, save/load works the same way as the original high precision floating point model.

Tensor Subclass Developer Guide

We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios.

Prerequisites

Some externally available resources for tensor subclasses:

Why Tensor Subclass?

There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things:
(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core
(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization
(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques

Example Code for a new Quantization Technique or DType

Please feel free to start with https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations.

Basic Structure

A tensor subclass needs to define a few basic methods: __new__, __init__, __tensor_flatten__, __tensor_unflatten__
and also dispatch functions for torch functions __torch_function__ and aten ops __torch_dispatch__

Here is an example of basic structure:

# check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437
from torchao.utils import TorchAOBaseTensor

class MyDTypeLayout(TorchAOBaseTensor):
    # see tutorial code for details
    pass

class MyDtypeTensor(TorchAOBaseTensor):
    """We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize
    the instance. There is no requirement on what the argument list should look like here, only requirement is
    that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call
    """
    @staticmethod
    def __new__(
        cls,
        layout_tensor: MyDTypeLayout,
        shape: torch.Size,
        dtype: Optional[torch.dtype] = None,
    ):
        ...
        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)  # type: ignore[attr-defined]

    def __init__(
        self,
        layout_tensor: MyDTypeLayout,
        shape: torch.Size, ...
    ):
        self.layout_tensor = layout_tensor


    """`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and
    reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define
    a Tensor subclass for torch.compile support
    """
    def __tensor_flatten__(self):
        return ["layout_tensor"], [self.shape]

    """see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride
    """
    @classmethod
    def __tensor_unflatten__(
        cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
    ):
       layout_tensor = tensor_data_dict["layout_tensor"]
        shape, = tensor_attributes
        return cls(
            layout_tensor,
            shape if outer_size is None else outer_size,
        )


    """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
    """
   @classmethod
    def from_float(
        cls,
        input_float: torch.Tensor,
    ):
        mapping_type = MappingType.SYMMETRIC
        block_size = input_float.shape
        dtype = torch.int16
        scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
        int_data = (input_float / scale).to(torch.int8)
        layout_tensor = MyDTypeLayout.from_plain(int_data, scale)
        return cls(layout_tensor, input_float.shape)
    

    """[Optional] see docs for `Layout/Packing` under `Quantized Tensors` section to understand what layout_type is
    """
    @property
    def layout_type(self) -> LayoutType:
        return self.layout_tensor.layout_type

    """There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch:
    
    __torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear,
    tensor.detach, tensor.reshape, tensor.t etc.

    __torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example:
    aten.mm, aten.addmm, aten.detach.default, aten.t.default etc.
    you can checkout https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 to understand what `__torch_function__` and `__torch_dispatch__` are doing, but with `TorchAoBaseTensor` user can use
    some helper functions directly (see next section)

Operator Support

There are two types of operator support, torch function and aten ops. For torch functions (e.g. torch.nn.functional.linear), we’ll need to overwrite __torch_function__ callback in the Tensor subclass, for aten ops (e.g. torch.ops.aten.mm), we’ll need to overwrite __torch_dispatch__ callback function.
For a new dtype, we’d like people to define the following decorator:

if your dtype class is inherited from `TorchAoBaseTensor`, you can do:

implements = my_dtype_tensor_cls.implements

And we can implement the operator dispatch with the following:

# Example for torch_function dispatch for torch.nn.functional.linear
def _quantized_linear_op(input_tensor, weight_tensor, bias):
    if isinstance(input_tensor, MyDtypeTensor):
        input_tensor = input_tensor.dequantize()
    if isinstance(weight_tensor, MyDtypeTensor):
        weight_tensor = weight_tensor.dequantize()
    return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(torch.nn.functional.linear)
def _(*args, **kwargs):
    input_tensor, weight_tensor, bias = (
        args[0],
        args[1],
        args[2] if len(args) > 2 else None,
    )
    # using try/except here so that we can have a general fallback when input_tensor/weight_tensor
    # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
    # make the branches easier to understand in `_quantized_linear_op`
    try:
        return _quantized_linear_op(input_tensor, weight_tensor, bias)
    except NotImplementedError:
        if isinstance(input_tensor, MyDtypeTensor):
            input_tensor = input_tensor.dequantize()
        if isinstance(weight_tensor, MyDtypeTensor):
            weight_tensor = weight_tensor.dequantize()
        return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

# Example for aten op dispatch for aten.detach.default
@implements(aten.detach.default)
def _(func, *args, **kwargs):
    # `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to 
    # work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op, 
    # which is needed for correctness in AOTAutograd.

    # `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass
    # of `my_dtype`
    return return_and_correct_aliasing(
        func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
    )

What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are:
__torch_function__: torch.nn.functional.linear
__torch_dispatch__: torch.ops.aten.addmm.default, torch.ops.aten.mm.default, torch.ops.aten.detach.default, torch.ops.aten.t.default

You can also find the ops that can be overwritten in __torch_function__ or __torch_dispatch__ with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details):

class M(torch.nn.Module): 
    def __init__(self) -> None: 
        super().__init__() 
        self.linear = torch.nn.Linear(10, 10)
    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        return self.linear(x) + x

from torch.overrides import TorchFunctionMode
class TorchFunctionLoggingMode(TorchFunctionMode):
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        print(f"TORCH_FUNC={str(func)}")
        return func(*args, **kwargs)

with TorchFunctionLoggingMode():
     m(*example_inputs)

## Example output
# TORCH_FUNC=<built-in function linear>
# TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects>


from torch.utils._python_dispatch import TorchDispatchMode
class TorchDispatchLoggingMode(TorchDispatchMode):
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        print(f"ATEN_FUNC={str(func)}")
        return func(*args, **kwargs)

with TorchDispatchLoggingMode():
     m(*example_inputs)

## Example output
# ATEN_FUNC=aten.t.default
# ATEN_FUNC=aten.addmm.default
# ATEN_FUNC=aten.add.Tensor

# or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py

Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes.

We are still working on a table that talks about for each feature what are the operators that need to be supported.

Adding Efficient Kernels

Custom triton kernels

Custom triton kernels can be implemented and registered in https://github.com/pytorch/ao/tree/main/torchao/kernel
Implementation Example:

def int_scaled_matmul_kernel(a, b, scales1, c, config):
M, K = a.shape
K, N = b.shape
# print("a.sizes(): ", a.size(), "a.strides(): ", a.stride(), "a.dtype: ", a.dtype)
# print("b.sizes(): ", b.size(), "b.strides(): ", b.stride(), "b.dtype: ", b.dtype)
# print("c.sizes(): ", c.size(), "c.strides(): ", c.stride(), "c.dtype: ", c.dtype)
# print("scales1.sizes(): ", scales1.size(), "scales1.strides(): ", scales1.stride(), "scales1.dtype", scales1.dtype)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)
scaled_matmul_kernel_with_block_pointers[grid](
a,
b,
c,
scales1,
M,
N,
K, #
a.stride(0),
a.stride(1), #
b.stride(0),
b.stride(1), #
c.stride(0),
c.stride(1),
scales1.stride(0),
scales1.stride(1),
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
EVEN_K=(K % 2 == 0),
**config.kwargs,
)
return c

Register as a custom op:
@torch.library.impl(lib, "int_scaled_matmul", "Meta")
def int_scaled_matmul_meta(a, b, scales1):
M, K = a.shape
K, N = b.shape
return torch.empty((M, N), device=a.device, dtype=scales1.dtype)
@torch.library.impl(lib, "int_scaled_matmul", "CUDA")
def int_scaled_matmul_cuda(a, b, scales1):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
# assert a.is_contiguous(), "Matrix A must be contiguous"
# assert b.is_contiguous(), "Matrix B must be contiguous"
# Allocates output.
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=scales1.dtype)
# 1D launch kernel where each block gets its own program.
best_config = get_best_config_fn(
int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs
)
return int_scaled_matmul_kernel(a, b, scales1, c, best_config)
@torch.library.impl(lib, "int_scaled_matmul", "CPU")
def int_scaled_matmul_cpu(a, b, scales1):
c = torch._int_mm(a, b)
return c.to(scales1.dtype) * scales1
, you may need to define you own autotuner as well

Custom hand written kernels

Custom kernels (implementations) for cpu/cuda/mps can be implemented through https://github.com/pytorch/ao/tree/main/torchao/csrc e.g. int4 cuda, and accessible through torch.ops.my_custom_op

Dispatches

For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in torch_function or torch_dispatch and dispatch to target operators, for example:

ao/torchao/dtypes/aqt.py

Lines 348 to 355 in cbc74ee

if (
is_cuda and
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT
):
.

Specifically for AffineQuantizedTensor, we also allow people to extend the quantized linear to use a new efficient kernel or implement by defining two functions:
dispatch_condition (defines the condition to dispatch to the kernel) and impl (actual implementation that takes activation, (quantized) weight, bias Tensor and runs the efficient kernel), both taking input_tensor, weight_tensor, bias as argument, and can be registered into dispatch of quantized linear in AffineQuantizedTensor with register_aqt_quantized_linear_dispatch. here is an example showing how it works:

def test_register_new_dispatch(self):
from torchao.dtypes.affine_quantized_tensor import (
register_aqt_quantized_linear_dispatch,
deregister_aqt_quantized_linear_dispatch,
)
from torchao.dtypes import to_affine_quantized_intx
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
def dispatch_condition(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor) and
weight_tensor.quant_min == 0 and
weight_tensor.quant_max == 2**6-1
)
def impl(input_tensor, weight_tensor, bias):
# this is just for testing, normally people will call into uint6 weight only
# quantized linear operator here
assert False, "dispatching to my impl for uint6 weight only quant"
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

Packing/Layout

Sometimes the quantized weights has to be packed in order to yield optimal performance. For this we want to extend the “layout” concept in Tensor and introduce an indirection for tensor data storage, see #278 for more details.

Here is an example (see notebook for full code):

# 1. define a base layout for your dtype
class MyDTypeLayout(torch.Tensor):
    """
    Base class for the layout tensor for `MyDTypeTensor`
    """

    # this should be set for each layout class during registration
    extended_layout: Optional[str] = None

    # get the original unpacked Tensors
    def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.int_data, self.scale

    # how to get the layout tensor from plain tensors
    @classmethod
    def from_plain(
        cls,
        int_data: torch.Tensor,
        scale: torch.Tensor,
    ):
	pass

# 2. define and register new layout

# if `MyDTypeTensor` is inherited from `TorchAoBaseTensor`, we can use classmethod `register_layout_cls` and
# `get_layout_tensor_constructor` directly
register_layout_cls = MyDTypeTensor.register_layout_cls
get_layout_tensor_constructor = MyDTypeTensor.get_layout_tensor_constructor

@register_layout_cls("plain")
class MyDTypePlainLayout(MyDTypeLayout):
    def __new__(cls, ...):
        pass

    def __init__(self, ...):
 	 pass

    @classmethod
    def __tensor_flatten__(self):
        pass

    @classmethod
    def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride):
        pass

    @classmethod
    def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
        packed = pack(int_data, scale, zero_point, inner_k_tiles)
	 return cls(packed, ...)


# 4. use the layout tensor in original tensor subclass

class MyDtypeTensor(torch.Tensor):
    @classmethod
    def from_float(
        cls,
        input_float: torch.Tensor,
        extended_layout: str = "plain",
    ):
        layout_tensor_ctr = get_layout_tensor_constructor(extended_layout)
        layout_tensor = layout_tensor_ctr(int_data, scale)
        return cls(layout_tensor, input_float.shape)

Flow

After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.

# convert from floating point tensor to affine quantized tensor
to_my_dtype = MyDTypeTensor.from_float

For model level API, people can reuse torchao.quantization.quantize_ that allows people to apply a tensor subclass conversion to weight of linear, and allows filtering function: https://github.com/pytorch/ao/blob/aeee551b15eebeaabf98ffab9a00addc675a12a9/torchao/quantization/quant_api.py (TODO: replace this with torchao doc website link when that's ready)

See Quantization Algorithms/Flows section for examples of weight only/dynamic quant/static quant and other types of model level APIs based on the factory function.

Using torch.compile for Performance

Note: for 2.4 and below, we need to use the following:

from torchao.quantization.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)

In order to be compatible with torch.compile. To aim for performance optimization, we should run through torch.compile with fullgraph mode first, and remove any unnecessary graph breaks. You can add TORCH_LOGS=”output_code” when you run the script in order to see the inductor generated code. e.g. TORCH_LOGS=”output_code” python example.py

model = torch.compile(model, mode="max-autotune", fullgraph=True)

Serialization

This test shows how we expect save/load to work for a model quantized with tensor subclass based API:

m = ToyLinearModel().eval().to(torch.bfloat16)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, "int8_weight_only")
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
    torch.save(m.state_dict(), f)
    f.seek(0)
    state_dict = torch.load(f)

# at load time, we can initialize the model in meta device (to avoid memory cost), 
# then use assign=True to load a quantized state_dict
with torch.device("meta"):
    m_load_time = ToyLinearModel()

m_load_time.load_state_dict(state_dict, assign=True)

res = m_load_time(*example_inputs)
assert torch.equal(res, ref)

You can checkout the serialization doc for more details.

Note: we are also integrated with huggingface and supports serialization/deserialization through the huggingface save_pretrained/push_to_hub/from_pretrained APIs, available after huggingface/transformers#33456 is landed.

Other Feature Support

The above just talks about basic feature support, we also provide examples on how to add supports for training, tensor parallel, FSDP by extending the MyDTypeTensor, we'll put more examples in developer_api_guide folder covering the following use cases.

General Guide on Extending torchao

For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder https://github.com/pytorch/ao/tree/main/torchao/prototype, but you could also take a look at AffineQuantizedTensor if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case.

To contribute to existing code base:

Tensor Subclass Functionality/Composability Testing

We are also working on test suites to test out the functionalities of tensor subclass and the composability with different systems like (torch.compile, DTensor etc.):

Kernel Microbenchmarks

Before we test performance on models, we can also do some microbenchmarks on single linear operator (or other compute intensive/memory intensive) operators with different input dimensions to get a sense of speedup. For a specific kernel that you'd like to benchmark, you can create a benchmark file like https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_aq.py and run benchmark with different shapes that's important for target model. A quick way to get the relevant shape for linear op and other ops is by running the example with this:

Change the model with the model you are interested in optimizing, and run the following:

python tutorials/developer_api_guide/print_op_and_shapes.py

Example output:

TORCH_FUNC=<built-in function linear> (M, K, N): 10 10 10
TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects> args[0] shape: torch.Size([10, 10])

all linear shapes (M, K, N): [(10, 10, 10)]

The output of all linear shapes can be copy pasted to microbenchmarking script code under benchmarks/benchmark_your_kernel.py for benchmarking.

For benchmark helper functions, right now we have

def benchmark_model(model, num_runs, args=(), kwargs=None, device_type=None):
and
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
feel free to use either one for now, but we'll probably keep one in the future.

Model Benchmarks and Eval

After you have the quantization flow implemented, you can run benchmark and eval on llama (llama2/llama3) or sam models that are already modified to be friendly to torch.compile, and compare with existing techniques in torchao.

Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models.

Please checkout the --help option for each of the script to understand the supported options, e.g. you can use --profile=profile_path to get the chrome trace of the run to understand detailed chrome trace: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-tracing-functionality

Please let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder.

@jerryzh168 jerryzh168 added the rfc label Jun 18, 2024
@gau-nernst
Copy link
Collaborator

Regarding 1, apart from what I have feedbacked in #384, starting to think of another alternative

quantizer = Int4WeightOnlyQuantizer(groupsize=32)
quantizer.quantize(model)

But then this feels like the old api change_linear_weights_to_int4_woqtensors(mode, groupsize=32), which we have moved away from. The current quantize() does feel somewhat more convenient.

Personally I don't really like a function returning a function, like the current int4wo and int8wo. Feels like having a proper class makes it cleaner (we can also inspect the quant hyperparams after instantiation) - as discussed in #384.

Another option is to expose apply_int4wo_quant() directly and the user should call partial.functools() on it (same effect as current int4wo() implementation)

from functools import partial

quantize(model, partial(apply_int4wo_quant, groupsize=32))

Also, since the quantization is in-place, I think it's good to use quantize_() instead to clearly signal the in-place behavior.

@drisspg
Copy link
Contributor

drisspg commented Jun 18, 2024

For the manual API why have both a string and a int4wo(group_size), I think it would be cleaner to just have one version of this

@jeromeku
Copy link
Collaborator

jeromeku commented Jun 18, 2024

Is there a tutorial or end-to-end example of how to compose these APIs to implement a non-trivial quantization method (e.g., AWQ, GPTQ, etc.) and specialized deployment layout (e.g., Marlin)? Basically a reference impl of how these tools can be used to facilitate the translation of research ideas to deployment-ready libraries.

If not, happy to work on one.

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Jun 18, 2024

Regarding 1, apart from what I have feedbacked in #384, starting to think of another alternative

quantizer = Int4WeightOnlyQuantizer(groupsize=32)
quantizer.quantize(model)

But then this feels like the old api change_linear_weights_to_int4_woqtensors(mode, groupsize=32), which we have moved away from. The current quantize() does feel somewhat more convenient.

Personally I don't really like a function returning a function, like the current int4wo and int8wo. Feels like having a proper class makes it cleaner (we can also inspect the quant hyperparams after instantiation) - as discussed in #384.

Another option is to expose apply_int4wo_quant() directly and the user should call partial.functools() on it (same effect as current int4wo() implementation)

from functools import partial

quantize(model, partial(apply_int4wo_quant, groupsize=32))

Also, since the quantization is in-place, I think it's good to use quantize_() instead to clearly signal the in-place behavior.

the quantizer API is actually what I have been thinking about before as "Unified Quantization API": https://github.com/pytorch/ao/blob/main/torchao/quantization/unified.py and these two APIs will cover most of the current quant flows, it's also used by QAT prototype:

class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
, personally I think we can use this so we have a unified experience for modeling users. But Christian has raised some concerns on this one since he feels introducing classes is a bit overkill I think.

the partial function idea has been raised in our meetings before as well, but that also doesn't seem very straightforward to use.

For now I'm planning to just use quantize(model, int4_weight_only(groupsize=32)) and but open to change in the future if there are more feedback on this API

also in the ideal future I think we'd expect modeling user just use the autoquant and not worry about all these details

@jerryzh168
Copy link
Contributor Author

For the manual API why have both a string and a int4wo(group_size), I think it would be cleaner to just have one version of this

so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Jun 18, 2024

Is there a tutorial or end-to-end example of how to compose these APIs to implement a non-trivial quantization method (e.g., AWQ, GPTQ, etc.) and specialized deployment layout (e.g., Marlin)? Basically a reference impl of how these tools can be used to facilitate the translation of research ideas to deployment-ready libraries.

If not, happy to work on one.

Not yet, so my understanding is that this doc talks about how we build the fundamental "dtype" of quantization, it can serve as a building block for more sophisticated quantization method that can utilize the "dtype" as a data representation.

I'm planning to put up an example of static quant (with module swap) that could potentially help demonstrate how these other techniques (e.g. ones that require calibration etc.) can be implemented in similar ways. please feel free to work on a tutorial to show how a real world end to end quantization example looks like utilizing the "dtype" that we build with tensor subclass in this doc

we also plan to build out hqq with this design #255, cc @HDCharles, this one also doesn't not require calibration though.
also there is GPTQ that could be refactored to use tensor subclass and compose with AffineQuantizedTensor, the main thing for GPTQ is we are not sure if people are interested in using it, but seems like we have some feedback saying this is important: #384 (comment), so maybe we could refactor it as well.

@drisspg
Copy link
Contributor

drisspg commented Jun 18, 2024

so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names

But they are already importing the quantize api right? Idk I tend to be in favor of verbosity, but this was a nit anyways so carry on

@jerryzh168
Copy link
Contributor Author

so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names

But they are already importing the quantize api right? Idk I tend to be in favor of verbosity, but this was a nit anyways so carry on

yeah, we are thinking of just removing these for now, it would be better for people to also see the docstrings for these things, and an extra import doesn't seem to be a big issue

@vadimkantorov
Copy link

About subclasses: I hope there would still be way to (when needed) register custom fused kernels which do e.g. q-someop-dq in a fused way, without having a separate kernel launches for q and dq. I know this type of graph matching is possible with torch.compile, but I hope that the explicit introduction of subclasses (and seemingly mainly used for representational/expressiveness/dispatch purpose) will not make this more complicated.

Also, hoping that it will work nicely with profiling/tracing to know exactly what kernel is getting invoked and exactly where any q/dq is happening (especially for autoquant regimes).

This is kind of similar to what was originally done with quint8 dtype, right? (except now it will allow user-powered extension and dispatch is based on subclass type instead of dtype)

@jerryzh168
Copy link
Contributor Author

About subclasses: I hope there would still be way to (when needed) register custom fused kernels which do e.g. q-someop-dq in a fused way, without having a separate kernel launches for q and dq. I know this type of graph matching is possible with torch.compile, but I hope that the explicit introduction of subclasses (and seemingly mainly used for representational/expressiveness/dispatch purpose) will not make this more complicated.

yeah I think we should still be able to register inductor fusion passes, but one thing here is, q/dq ops are no longer large ops in the torch.compile path, we are planning to keep them as smaller aten ops (sub/mul etc.) so these can participate in normal inductor optimization directly, so the optimization story will be a bit different for inductor/torch.compile I think.

However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434

Also, hoping that it will work nicely with profiling/tracing to know exactly what kernel is getting invoked and exactly where any q/dq is happening (especially for autoquant regimes).

yeah we can definitely provide additional information on what kernel is picked for autoquant, cc @HDCharles

This is kind of similar to what was originally done with quint8 dtype, right? (except now it will allow user-powered extension and dispatch is based on subclass type instead of dtype)

yes, this is similar to quint8, except it's built in python with tensor subclasses extension point, this allows us to stay out of core and have faster iteration speed as well. for dispatch, I feel it could also continue to use dtype as well, after we sort out the dtype story: #442

jerryzh168 added a commit to jerryzh168/ao that referenced this issue Jul 2, 2024
Summary:
Addressing feedback for `quantize` API from pytorch#391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this issue Jul 2, 2024
Summary:
Addressing feedback for `quantize` API from pytorch#391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this issue Jul 2, 2024
Summary:
Addressing feedback for `quantize` API from pytorch#391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
@kimishpatel
Copy link
Contributor

However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434

Based on the example, it seems like it would be the property of DTypeTensor that decides whether to use q-dq or not, right?

@kimishpatel
Copy link
Contributor

So what I understand from this proposal, as far as wrapping LayoutTensor and DTypeTensor is concerned is that,

A. Static quantization (both activation and weights are quantized)
B. Dynamic quantization. Weight is quantized AOT, act quantized dynamically
C. Weight only quantization.

It is not clear how the proposed API addresses 1, but I presume you have ideas so I will assume it will work.

Tensor subclass as I understand does/can do two things: 1) override representation of the tensor, e.g. linear.weight changed from torch.Tensor to DTypeTensor and 2) also change the dispatch behavior to dictate how an op with DTypeTensor should be executed.
DType tensor seem to be well suited for 1, but 2, that dictates execution semantics of an op with DTypeTensor in its args, has conflict with B and C. What I mean by that is that a 4-bit DTypeTensor, with whatever layout, can do both B and C. If so what would be the right design. Should we introduce yet another tensor subclass like WeightOnlyQuantizedTensor(DTypeTensor) And have DynamicQuantWeightTensor that will dynamically quantized activation tensor? OR add more args to DTypeTensor e.g. DTypeTensor.quant_type : Enum('dynamic', 'weight_only', 'static')? Given there arent many varieties in between static and dynamic act quantization, I would be ok if we "suggest" arg based approach.

On the DTypeLayout: I feel that having each backend or kernel that has its own special layout for execution should be its own tensor subclass, however this can also result in proliferation, e.g. DTypeLayoutCUDA, DTypeLayoutCUDAMySecialPacking, DTypeLayoutMetalDefault etc. I actually liked PT2E workflow in this regard where representation was canonical and execution semantics, arising from weight packing etc, were done as a separate transform. If I were to think of the same here, then I would say for 4-bit there is DTypeTensor and DTypeDefaultLayout and subsequent transforms can replace the tensor subclass with their backend specific tensor subclass.

Separate from above: For the comment on using q-dq based dispatch vs. fused op, I think we can allow overriding behavior where users can plugin their own implementation, including custom fused ops, for a specific DTypeTensor subclass that uses a specific DTypeLayout tensor.

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Jul 3, 2024

However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434

Based on the example, it seems like it would be the property of DTypeTensor that decides whether to use q-dq or not, right?

yeah this is correct

static quantization

yeah working on an example for this right now

dynamic quantization

I should probably add more docs for this one, right now it's implemented by applying a LinearActQuantizedTensor (which stores a input_quant_func and the original weight) on top of a Affine quantized tensor:

weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_linear_act_quantized(weight, input_quant_func)
, in LienarActQuantizedTensor, when dispatching to linear op, we'll apply the quantization function to input_quant_func to the input, and then continue the dispatch:
if isinstance(weight_tensor, LinearActQuantizedTensor):
, and in AffineQuantizedTensor dispatch, it's dispatched based on the type of input and weight, this is not distinguishable from the final dispatch of static quant I think:
if isinstance(weight_qtensor, AffineQuantizedTensor):
weight_is_int8 = _aqt_is_int8(weight_qtensor)
weight_is_uint4 = _aqt_is_uint4(weight_qtensor)
if isinstance(input_tensor, AffineQuantizedTensor):

also I want to highlight that dynamic quant, static quant is not considered as purely a dtype problem, since this also involves flows (how to convert my model to use these quantized tensors?), I'm also working on giving more details/examples of how to do that as well.

DTypeLayout

  1. I feel we could still produce a canonical representation for executorch, e.g. we can introduce a "canonical" (name TBD) layout that will use q/dq etc. without any optimizations and rely on backend lowering to do weight packing
  2. for " DTypeDefaultLayout and subsequent transforms can replace the tensor subclass with their backend specific tensor subclass." yes, I think this should be implemented under this API: `default_layout_tensor.to(extended_layout="my_optimized_packing_format") right now, we haven't implemented this part, but that's what we can do following current design

For the comment on using q-dq based dispatch vs. fused op, I think we can allow overriding behavior where users can plugin their own implementation, including custom fused ops,

yeah I think so, user should be able to customize what they would like to say by implementing a new LayoutTensor type I think, although I guess the difference here is user has to reason through different dispatch layers to figure out what is the final representation they will see in the end, like the dynamic quant example.

@kimishpatel
Copy link
Contributor

I feel we could still produce a canonical representation for executorch, e.g. we can introduce a "canonical" (name TBD) layout that will use q/dq etc. without any optimizations and rely on backend lowering to do weight packing

@jerryzh168 please note that my questions/responses are not motivated by whether it works for executorch or not. My comment on canonical representation was to borrow the same concept from PT2E where quantization and execution of quantized ops are separated. In the current APIs proposed, it is not the case and thats what I was highlighting

@kimishpatel
Copy link
Contributor

I feel we could still produce a canonical representation for executorch, e.g. we can introduce a "canonical" (name TBD) layout that will use q/dq etc. without any optimizations and rely on backend lowering to do weight packing

@jerryzh168 please note that my questions/responses are not motivated by whether it works for executorch or not. My comment on canonical representation was to borrow the same concept from PT2E where quantization and execution of quantized ops are separated. In the current APIs proposed, it is not the case and thats what I was highlighting

And this I mean for eager model not for export. Basically in exported graph there is a) quant and b) lowering. What is the equivalent of that in eager mode subclass based API and whether it is useful to have that

msaroufim pushed a commit that referenced this issue Jul 4, 2024
Summary:
Addressing feedback for `quantize` API from #391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Jul 8, 2024

@kimishpatel, I see, yeah I think separation of quant and lowering makes sense for executorch stack, but for eager it is not really applicable, since in eager people would just expect to quantize a model and get acceleration, require eager mode use case to do an extra lowering step seems to change the UX for eager mode? what do you think?

dbyoung18 pushed a commit to dbyoung18/ao that referenced this issue Jul 31, 2024
Summary:
Addressing feedback for `quantize` API from pytorch#391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this issue Aug 1, 2024
Summary:
Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from pytorch#391 to `tutorials` folder
so that code can be executed while we develop new APIs/utils and being kept up to date

Test Plan:
python

Reviewers:
python tutorials/developer_api_guide.py

regression tests:
python test/quantization/test_quant_api.py
python test/integration/test_integraton.py

Subscribers:

Tasks:

Tags:
msaroufim pushed a commit that referenced this issue Aug 2, 2024
Summary:
Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from #391 to `tutorials` folder
so that code can be executed while we develop new APIs/utils and being kept up to date

Test Plan:
python

Reviewers:
python tutorials/developer_api_guide.py

regression tests:
python test/quantization/test_quant_api.py
python test/integration/test_integraton.py

Subscribers:

Tasks:

Tags:
jainapurva pushed a commit that referenced this issue Aug 7, 2024
Summary:
Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from #391 to `tutorials` folder
so that code can be executed while we develop new APIs/utils and being kept up to date

Test Plan:
python

Reviewers:
python tutorials/developer_api_guide.py

regression tests:
python test/quantization/test_quant_api.py
python test/integration/test_integraton.py

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 changed the title [RFC] Tensor Subclass based Quantization API [RFC] torchao Contributor Guide Sep 10, 2024
@supriyar supriyar pinned this issue Sep 13, 2024
jerryzh168 added a commit to jerryzh168/ao that referenced this issue Nov 8, 2024
Summary:
1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py
and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py
2. added pytorch#391 to torchao docs

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this issue Nov 12, 2024
Summary:
1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py
and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py
2. added pytorch#391 to torchao docs

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit that referenced this issue Nov 13, 2024
* Update torchao api reference and add contributor guide

Summary:
1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py
and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py
2. added #391 to torchao docs

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* format

* typo

* renaming

* comma

* format

* comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants