Skip to content

Commit

Permalink
Add Embedding Quantization to QAT module_swap flow (pytorch#886)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#886

Adding the embedding quantizer in the same fashion as the other module swap setup.

Differential Revision: D62664322
  • Loading branch information
TiRune authored and facebook-github-bot committed Sep 17, 2024
1 parent bd264f9 commit a087e50
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 8 deletions.
35 changes: 35 additions & 0 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,41 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.precision,
)


def _replace_embedding_4w(
module: torch.nn.Module,
groupsize: int,
embedding_class: Type[torch.nn.Module],
padding_allowed: bool,
copy_weights: bool = False,
):
#import the util function here to avoid circular dependency
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Embedding) and (_check_linear_int4_k(child.embedding_dim, groupsize) or padding_allowed)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_embedding = embedding_class(
num_embeddings = child.num_embeddings,
embedding_dim = child.embedding_dim,
padding_idx = child.padding_idx,
max_norm = child.max_norm,
norm_type = child.norm_type,
scale_grad_by_freq = child.scale_grad_by_freq,
sparse = child.sparse,
device=child.weight.device,
groupsize=groupsize,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if copy_weights and child.weight.device != torch.device("meta"):
new_embedding.weight = child.weight
return new_embedding

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)

def _replace_linear_8da4w(
module: torch.nn.Module,
groupsize: int,
Expand Down
82 changes: 74 additions & 8 deletions torchao/quantization/prototype/qat/_module_swap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_check_linear_int4_k,
_replace_linear_int4,
_replace_linear_8da4w,
_replace_embedding_4w,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor,
Int8DynActInt4WeightLinear,
Expand All @@ -28,6 +29,7 @@
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_get_qmin_qmax
)


Expand All @@ -47,6 +49,14 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
instead if possible.
"""

def __init__(self,
quantize_embedding: bool = False,
embedding_groupsize: int = 32,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.quantize_embedding = quantize_embedding
self.embedding_groupsize = embedding_groupsize

def prepare(
self,
model: torch.nn.Module,
Expand All @@ -62,6 +72,14 @@ def prepare(
Int8DynActInt4WeightQATLinear,
copy_weights=True,
)
if self.quantize_embedding:
_replace_embedding_4w(
model,
self.embedding_groupsize,
Int4WeightQATEmbedding,
self.padding_allowed,
copy_weights=True
)
return model

def convert(
Expand Down Expand Up @@ -92,7 +110,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):

# Load weights and qparams into quantized linear
n_bit = 4
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
Expand Down Expand Up @@ -156,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
(act_qmin, act_qmax) = _get_qmin_qmax(8)
x_fq = _fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
Expand All @@ -170,7 +188,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_qmin, weight_qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
Expand All @@ -183,11 +201,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
w_fq = self.weight
return F.linear(x_fq, w_fq)

# TODO: move this to common util
def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

class Int4WeightQATEmbedding(torch.nn.Embedding):
"""
This module implements a embedding layer with int4
args:
embedding_groupsize: the number of elements in each quantized group for weights
scales_precision: precision of per group scales and zero points
"""

def __init__(self,
groupsize: int = 32,
scales_precision: torch.dtype = torch.float32,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.bit_width = 4
self.groupsize = groupsize
self.scales_precision = scales_precision
self.zero_points_precision = torch.int32
self.bit_width = 4
self._fake_quant_enabled = True

def forward(self, x):
weight = self.weight

if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, self.bit_width, self.groupsize, self.scales_precision,
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.groupsize,
)
else:
w_fq = self.weight

return torch.nn.functional.embedding(
x, w_fq, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled

def disable_fake_quant(self):
self.enable_fake_quant(False)


def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/prototype/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,8 @@ def insert_subclass(lin):
return lin

return insert_subclass

def _get_qmin_qmax(n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

0 comments on commit a087e50

Please sign in to comment.