Skip to content

Clean up eager quant in llm_export #10684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} model.use_sdpa_with_kv_cache=true"
fi
if [[ "${QE}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,1024\""
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,0\""
fi
if [[ "${MPS}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} backend.mps.enabled=true model.enable_dynamic_shape=false debug.verbose=true"
Expand Down
19 changes: 6 additions & 13 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
replace_linear_with_split_linear,
)
from executorch.examples.models.llama.source_transformation.quantize import (
EmbeddingQuantHandler,
get_quant_embedding_transform,
)

from executorch.exir.backend.utils import format_delegated_graph
Expand Down Expand Up @@ -116,18 +116,10 @@ def main() -> None:
] # dtype for model/inputs

if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
else:
group_size = int(group_size)
bitwidth = int(bitwidth)
model = EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth in [2, 4]),
).quantized_model()
quantize_embedding = get_quant_embedding_transform(
export_args.embedding_quantize
)
quantize_embedding(model)

if export_args.target_split_size is not None:
replace_linear_with_split_linear(
Expand Down Expand Up @@ -230,6 +222,7 @@ def main() -> None:
],
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
do_quant_fusion_and_const_prop=True,
)
)

Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ python -m extension.llm.export.export_llm \
```

A few notes:
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `model.use_shared_embedding=True` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `quantization.embedding_quantize="torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `quantization.embedding_quantize="torchao:4,32"`), whereas `quantization.embedding_quantize="torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `model.use_shared_embedding=True` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `model.use_shared_embedding=True` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `quantization.embedding_quantize="torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric (this is the default behavior if you simply use `quantization.embedding_quantize="torchao:4,32"`), whereas `quantization.embedding_quantize="torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric. If `model.use_shared_embedding=True` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.

- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.

Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,8 @@ def _get_source_transforms( # noqa
"""
transforms.append(
get_quant_embedding_transform(
embedding_quantize, use_shared_embedding, checkpoint_dtype
embedding_quantize,
use_shared_embedding,
)
)

Expand Down
113 changes: 62 additions & 51 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@

from executorch.extension.llm.export.builder import DType

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
MappingType,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass


try:
from fairseq2.nn.embedding import (
Expand Down Expand Up @@ -117,13 +126,6 @@ def quantize( # noqa C901
bitwidth = int(matches[0][0])

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
MappingType,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass

with torch.no_grad():
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
Expand All @@ -148,14 +150,18 @@ def quantize( # noqa C901
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
group_size = 128

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.utils import unwrap_tensor_subclass

quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=(
PerAxis(0) if group_size == 0 else PerGroup(group_size)
),
weight_mapping_type=MappingType.SYMMETRIC,
),
)
model = unwrap_tensor_subclass(model)

# TODO: deal with checkpoint / computation dtype decoupling.

if verbose:
print("quantized model:", model)
return model
Expand Down Expand Up @@ -733,34 +739,35 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
def get_quant_embedding_transform(
embedding_quantize: str,
use_shared_embedding: bool = False,
dtype_override: Optional[DType] = None,
):
if embedding_quantize.startswith("torchao:"):
use_torchao = embedding_quantize.startswith("torchao:")
if use_torchao:
quant_args = embedding_quantize.split(":")[1].split(",")
else:
quant_args = embedding_quantize.split(",")
assert len(quant_args) in [
2,
3,
], f"Expected 2 or 3 embedding quant_args, but got: {quant_args}"

bitwidth = int(quant_args[0])
group_size = quant_args[1]
if group_size in ["none", "None", "0"]:
group_size = 0
group_size = int(group_size)
is_symmetric = (
bool(quant_args[3].lower() == "true") if len(quant_args) > 2 else True
)

weight_dtype = getattr(torch, f"int{bitwidth}")
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
mapping_type = MappingType.SYMMETRIC if is_symmetric else MappingType.ASYMMETRIC

if use_torchao:
from torchao.experimental.quant_api import (
EmbeddingQuantizer,
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import MappingType

quant_args = embedding_quantize.split(":")[1].split(",")
if len(quant_args) == 2:
bitwidth, group_size = quant_args
is_asymmetric = True
else:
bitwidth, group_size, is_asymmetric = quant_args

if group_size in ["none", "None", "0"]:
group_size = 0

group_size = int(group_size)
bitwidth = int(bitwidth)
is_asymmetric = bool(is_asymmetric)
weight_dtype = getattr(torch, f"int{bitwidth}")
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
mapping_type = (
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
)

def _torchao_embedding_quantizer(model):
with torch.no_grad():
Expand All @@ -769,7 +776,6 @@ def _torchao_embedding_quantizer(model):
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
use_fallback=False,
).quantize(model)
else:
SharedEmbeddingQuantizer(
Expand All @@ -781,20 +787,25 @@ def _torchao_embedding_quantizer(model):

return _torchao_embedding_quantizer

bitwidth, group_size = embedding_quantize.split(",")
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
else:
group_size = int(group_size)
bitwidth = int(bitwidth)
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
return lambda model: EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth in [2, 4]),
precision=torch_dtype,
).quantized_model()
def _embedding_quantizer(model):
assert weight_dtype in [
torch.int2,
torch.int4,
torch.int8,
], "Only 2, 4, or 8-bit embeddings are supported unless using torchao"
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
),
lambda m, fqn: isinstance(m, nn.Embedding),
)
model = unwrap_tensor_subclass(model)
return model

return _embedding_quantizer


def get_quant_weight_transform(
Expand Down
19 changes: 5 additions & 14 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
replace_kv_cache_with_custom_kv_cache,
)
from executorch.examples.models.llama.source_transformation.quantize import (
EmbeddingQuantHandler,
get_quant_embedding_transform,
get_quant_weight_transform,
)
from executorch.examples.models.llama.source_transformation.sdpa import (
Expand All @@ -38,7 +38,6 @@
)

from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import (
ConstraintBasedSymShapeEvalPass,
HintBasedSymShapeEvalPass,
Expand Down Expand Up @@ -178,15 +177,9 @@ def forward(self, images):


def export_token_embedding(llava, prompt):
def quant_embedding(model):
return EmbeddingQuantHandler(
model,
bitwidth=8,
group_size=32,
packed=False,
).quantized_model()

quantized_token_embed = quant_embedding(llava.model_.language_model.model)
quantized_token_embed = get_quant_embedding_transform("8,32")(
llava.model_.language_model.model,
)
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
dynamic_shapes = [{1: token_dim_1}]
with torch.no_grad():
Expand Down Expand Up @@ -248,15 +241,13 @@ def export_all(llava_model: LlavaModel):
executorch_program = lowered_and_edge.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
QuantFusionPass(),
],
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass={
"image_encoder": ConstraintBasedSymShapeEvalPass(),
"text_model": ConstraintBasedSymShapeEvalPass(),
"token_embedding": HintBasedSymShapeEvalPass(),
},
do_quant_fusion_and_const_prop=True,
)
)
for execution_plan in executorch_program._emitter_output.program.execution_plan:
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def lowering_modules(
alloc_graph_output=False,
),
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
)
with torch.no_grad():
# backend option
Expand Down
Loading