Skip to content

Clean up eager quant in llm_export #10684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ python -m examples.models.llama.export_llama \
```

A few notes:
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized symmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.

Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.
Expand Down
323 changes: 53 additions & 270 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@

from sentencepiece import SentencePieceProcessor

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
MappingType,
quantize_,
)


try:
from fairseq2.nn.embedding import (
Expand Down Expand Up @@ -118,15 +127,6 @@ def quantize( # noqa C901
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
bitwidth = int(matches[0][0])

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
MappingType,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass

with torch.no_grad():
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
# no way to decouple checkpoint and computation dtype.
Expand All @@ -141,7 +141,6 @@ def quantize( # noqa C901
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
)
model = unwrap_tensor_subclass(model)
if verbose:
print("quantized model:", model)
return model
Expand All @@ -150,14 +149,17 @@ def quantize( # noqa C901
if group_size is None:
raise Exception("For 8da4w quantization, group size must be specified.")

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.utils import unwrap_tensor_subclass

quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
model = unwrap_tensor_subclass(model)

quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=(
PerAxis(0) if group_size == 0 else PerGroup(group_size)
),
weight_mapping_type=MappingType.SYMMETRIC,
),
)
# TODO: deal with checkpoint / computation dtype decoupling.

if verbose:
print("quantized model:", model)
return model
Expand Down Expand Up @@ -563,254 +565,32 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)


#########################################################################
##### embedding table quantization ######


def replace_embedding_weight_only_grouped_int8_per_channel(
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
):
for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Embedding):
# print(f"{name, child}")
# print(f"weights size: {child.weight.size()}")
setattr(
module,
name,
QuantizedGroupEmbedding(
device=device,
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
group_size=group_size,
dtype=child.weight.dtype,
packed=packed,
bitwidth=bitwidth,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, device, bitwidth, group_size, packed
)


class EmbeddingQuantHandler(QuantHandler):
def __init__(
self,
mod,
device="cpu",
*,
bitwidth: int = 8,
group_size: Optional[int] = None,
packed=False,
precision: Optional[torch.dtype] = None,
):
if isinstance(packed, str):
packed = packed == "True"
self.mod = mod
self.device = device
self.group_size = group_size
self.bitwidth = bitwidth
self.packed = packed
# Dtype of the weights right before quantization.
self.precision = precision
if (bitwidth not in [2, 4]) and packed:
raise RuntimeError("pack only works with bitsize 2, 4")

@torch.no_grad()
def create_quantized_state_dict(self, packed=False) -> Dict:
cur_state_dict = self.mod.state_dict()

if self.bitwidth == 2:
range_min = -2
range_max = 1
elif self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

for fqn, mod in self.mod.named_modules():
if isinstance(mod, nn.Embedding):
# print("****")
# print(f"Embedding identified: {fqn, mod}")
# print(f"weights size: {mod.weight.size()}")
# print(f"quantize {fqn}...")

print(
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
)
weight, scales, _ = dynamically_quantize_per_channel(
(
mod.weight.to(dtype=self.precision)
if self.precision
else mod.weight
),
range_min,
range_max,
torch.int8,
self.group_size,
scales_dtype=mod.weight.dtype,
)

if packed:
if self.bitwidth == 2:
if weight.shape[-1] % 4 != 0:
raise RuntimeError("automatic padding not implemented yet")
weight_range_shifted = weight.add(2).view(torch.uint8)
weight_view = weight_range_shifted.view(
weight.shape[0], weight.shape[1] // 4, 4
)
weight_0 = weight_view[:, :, 0]
weight_1 = weight_view[:, :, 1] << 2
weight_2 = weight_view[:, :, 2] << 4
weight_3 = weight_view[:, :, 3] << 6
weight_packed = weight_0 + weight_1 + weight_2 + weight_3
weight = weight_packed
elif self.bitwidth == 4:
if weight.shape[-1] % 2 != 0:
raise RuntimeError("automatic padding not implemented yet")
weight_range_shifted = weight.add(8).view(torch.uint8)
weight_view = weight_range_shifted.view(
weight.shape[0], weight.shape[1] // 2, 2
)
weight_even = weight_view[:, :, 0] * 16 # left shift 4
weight_odd = weight_view[:, :, 1]
weight_packed = weight_even + weight_odd
weight = weight_packed

weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
# Update state dict
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes group_size=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)

return cur_state_dict

def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.mod, self.device, self.bitwidth, self.group_size, self.packed
)
return self.mod

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict, assign=True)
return self.mod


class QuantizedGroupEmbedding(torch.nn.Module):
def __init__(
self,
device,
vocab_size: int,
embedding_dim: int,
group_size: Optional[int] = None,
dtype=torch.half,
packed=False,
bitwidth: int = 8,
) -> None:
super().__init__()
if group_size is None or group_size == 0:
group_size = embedding_dim
self.group_size = group_size
self.dtype = dtype
self.packed = packed
self.bitwidth = bitwidth
if not packed:
self.register_buffer(
"weight",
torch.zeros(
(vocab_size, embedding_dim), dtype=torch.int8, device=device
),
)
else: # packed
if bitwidth == 2:
self.register_buffer(
"weight",
torch.zeros(
(vocab_size, embedding_dim // 4),
dtype=torch.uint8,
device=device,
),
)
elif bitwidth == 4:
self.register_buffer(
"weight",
torch.zeros(
(vocab_size, embedding_dim // 2),
dtype=torch.uint8,
device=device,
),
)

groups_per_row = (embedding_dim + group_size - 1) // group_size
if groups_per_row > 1:
self.register_buffer(
"scales",
torch.ones(
(vocab_size, groups_per_row), dtype=torch.float16, device=device
),
)
else:
self.register_buffer(
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
)

@torch.no_grad()
def forward(self, indices: torch.Tensor) -> torch.Tensor:
if not self.packed: # 8bit
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, -128, 127, indices, dtype=self.dtype
)
else: # packed
if self.bitwidth == 2:
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
self.weight, self.scales, None, -2, 1, indices, dtype=self.dtype
)
############################ Source Transform Start #######################

# Remaining case (always return to make pyre happy)
assert self.bitwidth == 4
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
)

def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
use_torchao = args.embedding_quantize.startswith("torchao:")
if use_torchao:
quant_args = args.embedding_quantize.split(":")[1].split(",")
else:
quant_args = args.embedding_quantize.split(",")

############################ Source Transform Start #######################
bitwidth = int(quant_args[0])
group_size = quant_args[0]
if group_size in ["none", "None", "0"]:
group_size = 0
group_size = int(group_size)
is_symmetric = bool(quant_args[3]) if len(quant_args) > 2 else True

weight_dtype = getattr(torch, f"int{bitwidth}")
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
mapping_type = MappingType.SYMMETRIC if is_symmetric else MappingType.ASYMMETRIC

def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
if args.embedding_quantize.startswith("torchao:"):
if use_torchao:
from torchao.experimental.quant_api import (
EmbeddingQuantizer,
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import MappingType

quant_args = args.embedding_quantize.split(":")[1].split(",")
if len(quant_args) == 2:
bitwidth, group_size = quant_args
is_asymmetric = True
else:
bitwidth, group_size, is_asymmetric = quant_args

if group_size in ["none", "None", "0"]:
group_size = 0

group_size = int(group_size)
bitwidth = int(bitwidth)
is_asymmetric = bool(is_asymmetric)
weight_dtype = getattr(torch, f"int{bitwidth}")
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
mapping_type = (
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
)

def _torchao_embedding_quantizer(model):
with torch.no_grad():
Expand All @@ -831,20 +611,23 @@ def _torchao_embedding_quantizer(model):

return _torchao_embedding_quantizer

bitwidth, group_size = args.embedding_quantize.split(",")
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
else:
group_size = int(group_size)
bitwidth = int(bitwidth)
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
return lambda model: EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth in [2, 4]),
precision=torch_dtype,
).quantized_model()
def _quantize_embedding(model):
assert weight_dtype in [
torch.int2,
torch.int4,
torch.int8,
], "Only 2, 4, or 8-bit embeddings are supported unless using torchao"
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
),
)
return model

return _quantize_embedding


def get_quant_weight_transform(
Expand Down
Loading