Skip to content

Commit

Permalink
Propagate: Sparsity Config ignores to top level
Browse files Browse the repository at this point in the history
Add: Decorator to parse fused layers
Choose: Cutlass only when module is not ignored in sparsity_config

Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
rahul-tuli committed Jan 29, 2025
1 parent ce69f7f commit ca624cd
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Literal, Optional, cast
from contextlib import suppress
from typing import Any, Dict, List, Literal, Optional, Tuple, cast

import torch
from compressed_tensors.config import (CompressionFormat,
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
ignore: List[str],
quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
sparsity_ignore_list: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):
Expand All @@ -54,6 +56,7 @@ def __init__(
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config

def get_linear_method(self) -> "CompressedTensorsLinearMethod":
Expand Down Expand Up @@ -98,36 +101,40 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)

return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
)

@classmethod
def _sparsity_scheme_map_from_config(
cls, config: Dict[str,
Any]) -> Dict[str, SparsityCompressionConfig]:
def _parse_sparsity_config(
cls, config: Dict[str, Any]
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
:return: A tuple with two elements
1. A dictionary mapping target layer names to their corresponding
sparsity_config
2. A list of layer names to ignore for sparsity
"""
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
return dict()
return dict(), []

sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
return sparse_scheme_map
sparsity_ignore_list = sparsity_config.ignore or list()
return sparse_scheme_map, sparsity_ignore_list

@classmethod
def _quantization_scheme_map_from_config(
Expand Down Expand Up @@ -352,7 +359,6 @@ def get_scheme(self,
"""
compressed-tensors supports non uniform in the following way:
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
Expand All @@ -370,6 +376,8 @@ def get_scheme(self,
# need to make accelerate optional in ct to do this

# Will be empty for models with only sparsity
weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
Expand All @@ -379,19 +387,24 @@ def get_scheme(self,
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
elif self.sparsity_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
weight_quant = None
input_quant = None

# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)
if self.sparsity_scheme_map:
is_ignored = False
with suppress(ValueError):
is_ignored = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_ignore_list)

# if the layer is in the sparsity ignore list,
# we should not apply any sparsity scheme

if not is_ignored:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)

if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
Expand Down Expand Up @@ -419,6 +432,8 @@ def get_scheme(self,
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
layer_name)
return scheme

def get_cache_scale(self, name: str) -> Optional[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value
CompressionFormat.float_quantized.value,
]
return format in _ACTIVATION_QUANTIZATION_FORMATS

Expand Down Expand Up @@ -68,7 +68,7 @@ def should_ignore_layer(layer_name: Optional[str],
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
Expand All @@ -77,17 +77,64 @@ def check_equal_or_regex_match(layer_name: str,
return False


def _handle_fused_layers(func):
"""
Decorator to handle fused layers by mapping vllm fused layer names
to their corresponding unfused layer names for quantization/pruning schemes.
"""
# fused_layer_name -> unfused_layer_name
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}

def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)

return fused_layer_handler


@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Expand All @@ -106,8 +153,9 @@ def find_matched_target(layer_name: Optional[str], module: Module,
True))

if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
raise ValueError(
f"Unable to find matching target for {layer_name} in the "
"compressed-tensors config.")

return matched_target

Expand Down

0 comments on commit ca624cd

Please sign in to comment.