Skip to content

Adding int4 quantized tensor subclass #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,63 +32,87 @@ Relevant APIs can be found in torchao.quantization.quant_api
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.

The following apis use quantized [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor). By taking a linear op/module and replacing the original weight with a q-tensor subclass, we're able to convert it into a quantized version of the op. Upon replacement, these q-tensor subclasses quantize the original weight and override the dispatch for linear ops to instead use the subclass' _quantized_op method.

This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.

### A8W8 Dynamic Quantization

Similar to the weight only api above, the `apply_dynamic_quant` function swaps all
linear modules to dynamically quantized quantized linear modules.
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.

Example

```
import torch
from torchao.quantization import quant_api

# some user model and example input
...
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# convert linear modules to quantized linear modules
quant_api.apply_dynamic_quant(model)
quant_api.change_linear_weights_to_int8_dqtensors(model)

# compile the model to improve performance
...
torch.compile(model, mode='max-autotune')
model(input)
```

This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.


### A16W8 WeightOnly Quantization

The `apply_weight_only_int8_quant` function swaps all
linear modules to weight-only quantized linear modules.
The `change_linear_weights_to_int8_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8WeightOnlyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul

Example

```
import torch
from torchao.quantization import quant_api

# some user model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
...

# convert linear modules to quantized linear modules
quant_api.apply_weight_only_int8_quant(model)
quant_api.change_linear_weights_to_int8_woqtensors(model)

# compile the model to improve performance
torch.compile(model, mode='max-autotune')
model(input)
...
```

This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.


### A16W4 WeightOnly Quantization

The `change_linear_weights_to_int4_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int4WeightOnlyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul

Example

```
# some user model and example input
...

# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int4_woqtensors(model)

# compile the model to improve performance
...
```

The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.

## Other APIs

### A8W8 Dynamic Quantization by subclasses
### Module Swap APIs

You can use [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) to do dynamic quantization with the `change_linear_weights_to_dqtensors` function using the exact same formula as above. This avoids modifying the graph and can be more composable with
other techniques.
The `apply_dynamic_quant` and `apply_weight_only_int8_quant` apis can be used in the same formula as above to achieve dynamic and weight-only quantization using module swaps instead of quantized tensor subclasses.

### A8W8 Dynamic Quantization with Smoothquant

We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
Due to requiring calibration, the API is slightly more complicated
Due to requiring calibration, the API is slightly more complicated and currently only exists with a module swap api.

Example

Expand Down Expand Up @@ -116,6 +140,8 @@ torch.compile(model, mode='max-autotune')
model(input)
```

like the other dynamic quantization apis, the torch._inductor.config.force_fuse_int_mm_with_mul option may significantly improve performance if enabled.

## License

`torchao` is released under the [BSD 3](https://github.com/pytorch-labs/ao/blob/main/LICENSE) license.
Loading