-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Quantization draft * Finish initial quantization API * Docs
- Loading branch information
Showing
7 changed files
with
528 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
defmodule Axon.Quantization do | ||
@moduledoc """ | ||
Model quantization. | ||
Model quantization is a technique for reducing the memory footprint of | ||
a model by converting portions of a model to use quantized representations. | ||
Typically, these quantized representations are low-precision integers. | ||
This is an **experimental** API which implements weight-only quantization. | ||
The implementation in this module will convert dense layers in a large | ||
model to quantized-variants. The only supported quantization type is | ||
`{:s, 8}`. Axon quantization is inference-only. Training is not currently | ||
supported. | ||
""" | ||
alias Axon.Quantization.Layers | ||
alias Axon.Quantization.QTensor | ||
|
||
@doc """ | ||
Quantizes a model and a model state. | ||
Given a model and model state, this method will rewrite all | ||
of the dense layers in the model to perform weight-only 8-bit | ||
integer versions of the same operation. It will also replace values | ||
for all dense kernels in the given model state with quantized | ||
tensors. | ||
""" | ||
def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do | ||
quantized_model = quantize_model(model) | ||
quantized_model_state = quantize_model_state(model, model_state) | ||
{quantized_model, quantized_model_state} | ||
end | ||
|
||
@doc """ | ||
Replaces standard operations with quantized variants. | ||
The only supported conversion is to convert regular dense layers | ||
to a weight-only 8-bit integer variant. Note that this only replaces | ||
the properties of the model. If you have a pre-trained model state | ||
that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`. | ||
All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`. | ||
""" | ||
def quantize_model(%Axon{} = model) do | ||
quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias -> | ||
weight_only_quantized_dense(x, units, use_bias: use_bias) | ||
end | ||
|
||
Axon.rewrite_nodes(model, fn | ||
%Axon.Node{op: :dense, meta: meta} -> | ||
&quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias]) | ||
|
||
_ -> | ||
:skip | ||
end) | ||
end | ||
|
||
@doc """ | ||
Returns a quantized model state. | ||
Given a model and a model state, this function will replace | ||
all dense layer kernels with a quantized version of the weight. | ||
Training is not currently supported, so all quantized layers are | ||
automatically frozen. | ||
""" | ||
def quantize_model_state(model, model_state) do | ||
dense_layer_names = | ||
model | ||
|> Axon.properties() | ||
|> Enum.filter(fn {_, v} -> v == :dense end) | ||
|> Enum.map(fn {k, _} -> k end) | ||
|> MapSet.new() | ||
|
||
state = | ||
Enum.reduce(dense_layer_names, model_state, fn layer_name, state -> | ||
update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1) | ||
end) | ||
|
||
Axon.ModelState.freeze(state, fn [name | _] -> | ||
MapSet.member?(dense_layer_names, name) | ||
end) | ||
end | ||
|
||
## Layers | ||
|
||
@doc """ | ||
Adds a weight-only quantized dense layer to the network. | ||
This is equivalent to a dense layer, but works on quantized | ||
weights for reducing model memory footprint. | ||
Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`. | ||
## Options | ||
* `:name` - layer name. | ||
* `:kernel_initializer` - initializer for `kernel` weights. | ||
Defaults to `:glorot_uniform`. | ||
* `:bias_initializer` - initializer for `bias` weights. Defaults | ||
to `:zeros`. | ||
* `:use_bias` - whether the layer should add bias to the output. | ||
Defaults to `true`. | ||
""" | ||
def weight_only_quantized_dense(x, units, opts \\ []) do | ||
opts = | ||
Keyword.validate!(opts, [ | ||
:name, | ||
:meta, | ||
use_bias: true, | ||
kernel_initializer: :glorot_uniform, | ||
bias_initializer: :zeros | ||
]) | ||
|
||
meta = | ||
opts[:meta] || | ||
%{} | ||
|> Map.put(:units, units) | ||
|> Map.put(:use_bias, opts[:use_bias]) | ||
|
||
kernel_shape = &Axon.Shape.dense_kernel(&1, units) | ||
bias_shape = &Axon.Shape.dense_bias(&1, units) | ||
|
||
kernel = | ||
Axon.param("kernel", kernel_shape, | ||
initializer: fn shape, type, key -> | ||
fun = | ||
case opts[:kernel_initializer] do | ||
init when is_atom(init) -> | ||
apply(Axon.Initializers, []) | ||
|
||
fun when is_function(fun) -> | ||
fun | ||
end | ||
|
||
tensor = | ||
case fun do | ||
fun when is_function(fun, 2) -> | ||
fun.(shape, type) | ||
|
||
fun when is_function(fun, 3) -> | ||
fun.(shape, type, key) | ||
end | ||
|
||
QTensor.from_tensor(tensor) | ||
end | ||
) | ||
|
||
{inputs, op} = | ||
if opts[:use_bias] do | ||
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) | ||
{[x, kernel, bias], &Layers.weight_only_quantized_dense/4} | ||
else | ||
{[x, kernel], &Layers.weight_only_quantized_dense/3} | ||
end | ||
|
||
Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
defmodule Axon.Quantization.Layers do | ||
@moduledoc """ | ||
Quantized Layer Implementations. | ||
""" | ||
alias Axon.Quantization.QTensor | ||
import Nx.Defn | ||
|
||
@doc """ | ||
Weight-only quantized version of a dense layer. | ||
It expects the input kernel to be an `Axon.Quantization.QTensor`. | ||
""" | ||
deftransform weight_only_quantized_dense(input, kernel, bias \\ 0, opts \\ []) do | ||
{bias, opts} = | ||
case bias do | ||
%Nx.Tensor{} = bias -> | ||
{bias, opts} | ||
|
||
bias when is_number(bias) -> | ||
{bias, opts} | ||
|
||
opts when is_list(opts) -> | ||
{Nx.tensor(0), opts} | ||
|
||
other -> | ||
raise ArgumentError, "invalid bias, expected a tensor, got #{inspect(other)}" | ||
end | ||
|
||
weight_only_quantized_dense_impl(input, kernel, bias, opts) | ||
end | ||
|
||
defnp weight_only_quantized_dense_impl( | ||
input, | ||
%QTensor{value: kernel, scale: scale}, | ||
bias, | ||
_opts | ||
) do | ||
input | ||
|> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0]) | ||
|> Nx.multiply(scale) | ||
|> Nx.add(bias) | ||
end | ||
end |
Oops, something went wrong.