diff --git a/examples/multimodal_vision/llava_example.py b/examples/multimodal_vision/llava_example.py new file mode 100644 index 000000000..c86cf0dfe --- /dev/null +++ b/examples/multimodal_vision/llava_example.py @@ -0,0 +1,54 @@ +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration +from llmcompressor.transformers.utils.data_collator import llava_data_collator + +# Load model. +model_id = "llava-hf/llava-1.5-7b-hf" +model = TraceableLlavaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], + sequential_targets=["LlamaDecoderLayer"], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=llava_data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/examples/multimodal_vision/mllama_example.py b/examples/multimodal_vision/mllama_example.py new file mode 100644 index 000000000..16c17f18e --- /dev/null +++ b/examples/multimodal_vision/mllama_example.py @@ -0,0 +1,53 @@ +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration +from llmcompressor.transformers.utils.data_collator import mllama_data_collator + +# Load model. +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model = TraceableMllamaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=mllama_data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/examples/multimodal_vision/pixtral_example.py b/examples/multimodal_vision/pixtral_example.py new file mode 100644 index 000000000..e068a6dc9 --- /dev/null +++ b/examples/multimodal_vision/pixtral_example.py @@ -0,0 +1,54 @@ +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration +from llmcompressor.transformers.utils.data_collator import pixtral_data_collator + +# Load model. +model_id = "mgoin/pixtral-12b" +model = TraceableLlavaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], + sequential_targets=["MistralDecoderLayer"], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=pixtral_data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/pyproject.toml b/pyproject.toml index 1baa7d2c0..e9cd799bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ target-version = ['py38'] [tool.isort] profile = "black" +skip = ["src/llmcompressor/transformers/tracing/"] [tool.mypy] files = "src/guidellm" [tool.ruff] -exclude = ["build", "dist", "env", ".venv"] +exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/transformers/tracing/"] lint.select = ["E", "F", "W"] [tool.flake8] diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c5200cf0f..178fbce4c 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,56 +1,51 @@ +import contextlib import warnings -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import ( - QuantizationScheme, - disable_quantization, - enable_quantization, +from compressed_tensors.quantization import QuantizationScheme +from compressed_tensors.utils import ( + align_module_device, + get_execution_device, + getattr_chain, + update_offload_parameter, ) from loguru import logger -from pydantic import Field, field_validator -from torch.nn import Module +from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization -from llmcompressor.modifiers.quantization.gptq.utils import ( - GPTQWrapper, - get_output_error, +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( + accumulate_hessian, + make_empty_hessian, + quantize_weight, ) -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import DisableKVCache -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - qat_active, +from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.pipelines.basic import run_pipeline as run_basic +from llmcompressor.pipelines.layer_sequential import ( + run_pipeline as run_layer_sequential, ) +from llmcompressor.pipelines.sequential import run_pipeline as run_sequential +from llmcompressor.utils.metric_logging import CompressionLogger +from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active __all__ = ["GPTQModifier"] -class GPTQModifier(Modifier): +class GPTQModifier(Modifier, HooksMixin): """ Modifier for applying the one-shot OBCQ algorithm to a model - Lifecycle: - - on_initialize - - initialize_compression() - - compressible_layers() - - LayerCompressor.pre_compress() - - apply_compression() - - run_calibration_forward() - - LayerCompressor.compress() - - LayerCompressor.post_compress() - - LayerCompressor.revert_layer_wrappers() | Sample yaml: | test_stage: | obcq_modifiers: | GPTQModifier: - | dampening_frac: 0.001 | block_size: 128 + | dampening_frac: 0.001 + | offload_hessians: False | config_groups: | group_0: | targets: @@ -65,29 +60,32 @@ class GPTQModifier(Modifier): | group_size: 128 | actorder: False - - :param sequential_update: Whether or not to update weights sequentially by layer. - This option is depreciated and setting to False is no longer supported - :param targets: list of layer names to compress during GPTQ, or '__ALL__' - to compress every layer in the model + Lifecycle: + - on_initialize_structure + - _build_quant_modifier + - on_initialize + - register_hook(module, compress_module, "forward") + - run_sequential / run_layer_sequential / run_basic + - make_empty_hessian + - accumulate_hessian + - quantize_weight + - on_finalize + - remove_hooks() + - model.apply(freeze_module_quantization) + + :param sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass + :param dampening_frac: Amount of dampening to apply to H, as a fraction of the + diagonal norm :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not already exist in the recipe - :param dampening_frac: Amount of dampening to apply to H, as a fraction of the - diagonal norm + :param offload_hessians: Set to True for decreased memory usage but increased + runtime. :param config_groups: [Used, if a quantization modifier is not specified], dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. - :param ignore: [Used, if a quantization modifier is not specified] - optional list of module class names or submodule names to not - quantize even if they match a target in config_groups. Defaults to empty list. - :param disable_quantization_observer_epoch: [Used, if a quantization modifier is - not specified] Epoch to disable updates to the module - quantization observers. At this point, quantized weights and zero points will - not be updated. Leave None to not disable observers during QAT. Default is None - :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used :param scheme: [Used, if a quantization modifier is not specified], the quantization scheme to apply to the model, this is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter @@ -95,24 +93,40 @@ class GPTQModifier(Modifier): `preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit or a string of a preset scheme if targets is provided and activation 8 bit quantization on the Linear layers. + :param targets: list of layer names to quantize if a scheme is provided. Defaults + to Linear layers + :param ignore: [Used, if a quantization modifier is not specified] + optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + :param disable_quantization_observer_epoch: [Used, if a quantization modifier is + not specified] Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None """ + # gptq modifier arguments sequential_update: bool = True # DEPRECIATED - targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 - quantize: Union[bool, Dict] = True dampening_frac: Optional[float] = 0.01 + quantize: Union[bool, Dict] = True + offload_hessians: bool = False + + # arguments used for attached quant modifier config_groups: Optional[Dict[str, QuantizationScheme]] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None + targets: Union[str, List[str], None] = None ignore: List[str] = Field(default_factory=list) - disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None - scheme: Optional[Union[str, Dict[str, Any]]] = None + disable_quantization_observer_epoch: Optional[float] = None - model: Optional[Any] = None - layer_compressors_: Optional[List[Any]] = None - compressible_layers_: Optional[List] = None - quantization_modifier_: Any = None + # private variables + _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr(default=None) + _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) + _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) + _update_size: Optional[int] = PrivateAttr(default=None) @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -175,8 +189,8 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self.quantization_modifier_: - self.quantization_modifier_.on_initialize_structure(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.on_initialize_structure(state, **kwargs) def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -184,27 +198,79 @@ def on_initialize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ + # initialize quantization modifier if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) - if self.quantization_modifier_: - self.quantization_modifier_.initialize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - modifiable_model = state.model - calibration_dataloader = state.data.calib - + # register hooks + for name, module in state.model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + # HACK: previously, embeddings were not quantized because they were not + # accessible by the layer compressor. For now, we manually ignore it, + # but in the FUTURE this should be ignored by the user + if not isinstance(module, torch.nn.Embedding): + post_hook = partial(self.compress_module, name) + self.register_hook(module, post_hook, "forward") + + # infer sequential targets if self.sequential_targets is None: - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - self.sequential_targets = get_no_split_params(modifiable_model) - - self.initialize_compression(modifiable_model, calibration_dataloader) - self.apply_compression(calibration_dataloader) - state.model.apply(freeze_module_quantization) - - return True + self.sequential_targets = get_no_split_params(state.model) + if isinstance(self.sequential_targets, str): + self.sequential_targets = [self.sequential_targets] + + # infer update size + if self._update_size is None: + self._update_size = len(state.data.calib) + + # infer pipeline + model_name = state.model.__class__.__name__ + input_names = state.data.calib.dataset.column_names + unfixable_errors = ( + torch.OutOfMemoryError, + torch._C._LinAlgError, + KeyboardInterrupt, + ) + try: + run_sequential( + state.model, + state.data.calib, + self.sequential_targets, + self.ignore, + ) + return True + + except Exception as exception: + if isinstance(exception, torch.fx.proxy.TraceError): + warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") + if isinstance(exception, unfixable_errors): + raise exception + + warnings.warn("Falling back to layer_sequential pipeline") + try: + run_layer_sequential( + state.model, + state.data.calib, + self.sequential_targets, + ) + return True + + except Exception as exception: + if isinstance(exception, TypeError): + warnings.warn(f"{model_name} fails layer-wise assumptions") + if isinstance(exception, unfixable_errors): + raise exception + + warnings.warn( + "Falling back to basic pipeline, which requires extra memory and " + "may result in decreased accuracy. Consider using " + "`offload_hessians=True`" + ) + run_basic(state.model, state.data.calib) + return True def on_finalize(self, state: "State", **kwargs) -> bool: """ @@ -212,113 +278,98 @@ def on_finalize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ - if self.quantization_modifier_: - self.quantization_modifier_.finalize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.finalize(state, **kwargs) - return True - - def compressible_layers(self) -> Dict: - """ - Retrieves the modules corresponding to a list of - compressible layer names - - :precondition: self.model is set and is a torch.nn.Module - :return: dictionary of modules to compress - """ - if not isinstance(self.model, Module): - raise ValueError( - "`self.model` must be a torch.nn.Module to use " - f"the {self.__class__.__qualname__} modifier but got " - f"{type(self.model)} instead" - ) + self.remove_hooks() + self._hessians = dict() + self._num_samples = dict() + state.model.apply(freeze_module_quantization) - return get_layers(self.sequential_targets, self.model) + return True - def initialize_compression( + def compress_module( self, - model: Module, - dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + _output: torch.Tensor, ): """ - Setup for GPTQ, initializes the model - and other parameters, also initilializes the - compressible layers of model, and sets the device + Quantize a module's weight according to the GPTQ algorithm - :param model: model to initialize for compression - :param dataloader: calibration data, not used by GPTQ in this function - """ - self.model = model - self.compressible_layers_ = self.compressible_layers() - self.layer_compressors_ = [] - - for idx, (name, layer) in enumerate(self.compressible_layers_.items()): - name = fix_fsdp_module_name(name) - logger.info(f"Preparing {name} for compression") - args = self._pruning_arguments() - comp_cls = self._compression_class() - compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) - self.layer_compressors_.append(compressor) - - # for the initial forward data pass, add an early stop exception in order - # to capture inputs right before being compressed by first module - first_layer_compressor = self.layer_compressors_[0] - first_layer_compressor.set_early_stop() - - @torch.no_grad() - def apply_compression( - self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None - ) -> Dict: - """ - Run GPTQ on the loaded model, using dataloader as calibration data + :param name: name of module being quantized + :param module: module being quantized + :param args: input arguments for module forward pass - :param dataloader: calibration data for GPTQ + :return: total loss from applying weight quantization to this module """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " f"{len(dataloader)} samples..." - ) - - # quantization scales and zp are already initialized but we do not - # want to calibrate wrt to these - self.model.apply(disable_quantization) - - with DisableKVCache(self.model): - # run_calibration_forward uses the early stop exception to capture values - # as intermediates right before the forward pass of the first module - intermediates = run_calibration_forward( - self.model, dataloader, mask_padding=True + # Assume that first argument is the input + inp = args[0] + quant_args = getattr_chain(module, "quantization_scheme.weights") + + # Initialize hessian if not present + if module not in self._num_samples: + init_device = ( + "cpu" if self.offload_hessians else get_execution_device(module) + ) + self._hessians[module] = make_empty_hessian(module, device=init_device) + self._num_samples[module] = 0 + + # Accumulate hessian with input with optional offloading + with self._maybe_onload_hessian(module): + self._hessians[module], self._num_samples[module] = accumulate_hessian( + inp, + module, + self._hessians[module], + self._num_samples[module], ) - self.layer_compressors_[0].clear_early_stop() - num_layers = len(self.compressible_layers_) - for idx, layer_compressor in enumerate(self.layer_compressors_): - logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") + # After enough samples are accumulated, perform quantization + if self._num_samples[module] >= self._update_size: + logger.info(f"Quantizing {name} using {self._num_samples[module]} samples") + with ( + torch.no_grad(), + align_module_device(module), + self._maybe_onload_hessian(module), + CompressionLogger(module) as comp_logger, + ): + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module=module, + quant_args=quant_args, + hessians_dict=self._hessians, + blocksize=self.block_size, + percdamp=self.dampening_frac, + ) + comp_logger.set_loss(loss) + + update_offload_parameter(module, "weight", quantized_weight) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) + if g_idx is not None: + update_offload_parameter(module, "weight_g_idx", g_idx) - # run the forward pass for each transformer layer (block) one at a time - logger.info(f"Calibrating {layer_compressor.name}...") - layer_compressor.pre_compress() - unquantized_outputs = layer_compressor.calibrate_layer(intermediates) + # self._hessians[module] already deleted by quantize_weight + del self._num_samples[module] - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() + @contextlib.contextmanager + def _maybe_onload_hessian(self, module: torch.nn.Module): + if self.offload_hessians: + device = get_execution_device(module) + self._hessians[module] = self._hessians[module].to(device=device) - # perform a second forward pass of the module to calculate - # weight-quantized outputs for use as inputs to the next layer - quantized_outputs = layer_compressor.calibrate_layer(intermediates) - error = get_output_error(unquantized_outputs, quantized_outputs) - logger.info(f"Mean output error from quantization: {error:.3f}") - intermediates = quantized_outputs + yield - # re-enable quantization - self.model.apply(enable_quantization) + if self.offload_hessians: + if module in self._hessians: # may have been deleted in context + self._hessians[module] = self._hessians[module].to(device="cpu") def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, ignore list, and num_calibration_steps. - :postcondition: self.quantization_modifier_ is set to the built + :postcondition: self._quantization_modifier is set to the built quantization modifier """ @@ -344,26 +395,9 @@ def _build_quant_modifier(self): def _build_quant_modifier_from_dict(self, quant_config): modifier_type = list(quant_config.keys())[0] modifier_args = quant_config[modifier_type] - self.quantization_modifier_ = ModifierFactory.create( + self._quantization_modifier = ModifierFactory.create( modifier_type, allow_registered=True, allow_experimental=True, **modifier_args, ) - - def _pruning_arguments(self): - """ - Gather the parameters needed for root module compression in a dict - - :return: dict of params for pruning - """ - return { - "blocksize": self.block_size, - "percdamp": self.dampening_frac, - } - - def _compression_class(self): - """ - :return: wrapper class used for root modules of this compression class - """ - return GPTQWrapper diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py index a8673dfc2..ec39da973 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa -from .gptq_wrapper import * -from .helpers import * +from .gptq_quantize import * diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py new file mode 100644 index 000000000..6f0ae60fb --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -0,0 +1,290 @@ +import math +from copy import copy +from typing import Dict, Optional, Tuple, Union + +import torch +import transformers +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, + fake_quantize, +) + +from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.observers.base import Observer +from llmcompressor.pytorch.utils.helpers import tensor_sparsity + +GPTQ_PRECISION = torch.float32 + +__all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"] + + +def make_empty_hessian( + module: torch.nn.Module, device: Optional[torch.device] = None +) -> torch.Tensor: + weight = module.weight + num_columns = weight.shape[1] + device = device if device is not None else weight.device + return torch.zeros((num_columns, num_columns), device=device, dtype=GPTQ_PRECISION) + + +def accumulate_hessian( + inp: torch.Tensor, + module: torch.nn.Module, + H: Optional[torch.Tensor], + num_samples: int, +) -> Tuple[torch.Tensor, int]: + inp = inp.to(device=H.device) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + num_added = inp.shape[0] + + if isinstance(module, (torch.nn.Linear, transformers.Conv1D)): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + if isinstance(module, torch.nn.Conv2d): + unfold = torch.nn.Unfold( + module.kernel_size, + dilation=module.dilation, + padding=module.padding, + stride=module.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + + H *= num_samples / (num_samples + num_added) + num_samples += num_added + + inp = inp.to(dtype=GPTQ_PRECISION) + inp = math.sqrt(2 / num_samples) * inp + H += inp.matmul(inp.t()) + + return H, num_samples + + +def quantize_weight( + module: torch.nn.Module, + quant_args: QuantizationArgs, + hessians_dict: Dict[torch.nn.Module, torch.Tensor], + blocksize: int = 128, + percdamp: float = 0.01, +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + :param module: module with weight being quantized + :param quant_args: quantization arguments used to find quantization parameters + :param hessian_dict: dictionary containing preaccumulated hessian for quantization + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :return: loss, quantized_weight, scale, zero_point, g_idx + """ + strategy = quant_args.strategy + actorder = quant_args.actorder + final_shape = module.weight.shape + final_dtype = module.weight.dtype + W = module.weight.clone() + H = hessians_dict[module] # unfortunately python does not have a `move` keyword + del hessians_dict[module] # so we have to delete the original reference manually + + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + quant_args.observer, + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + + # standardize shape and dtype + if isinstance(module, torch.nn.Conv2d): + W = W.flatten(1) + elif isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.to(dtype=GPTQ_PRECISION) + num_rows = W.shape[0] + num_columns = W.shape[1] + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = observer(W, g_idx=None) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + scale, zero_point = observer(W, g_idx=None) + W, H, perm = _apply_activation_ordering(W, H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + else: + scale, zero_point = observer(W, g_idx=None) + else: + scale, zero_point = observer(W, g_idx=None) + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + + losses = torch.zeros(num_rows, device=module.weight.device) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + try: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + except torch._C._LinAlgError: + raise torch._C._LinAlgError( + "Failed to invert hessian due to numerical instability. Consider " + "increasing GPTQModifier.dampening_frac, increasing the number " + "of calibration samples, or shuffling the calibration dataset" + ) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + + # quantize column + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + quant_args, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + group_index = g_idx[column_idx] + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, group_index], + zero_point[:, group_index], + altered_qargs, + ) + else: + raise ValueError( + f"Quantization strategy is not supported for GPTQ: {strategy}" + ) + + # propagate column error + Q1[:, i] = q + losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err + Err1[:, i] = err1 + + # propagate block error + W[:, i1:i2] = Q1 + losses += torch.sum(losses1, 1) / 2 + + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err + + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + loss = torch.sum(losses).item() + return ( + loss, + W, + scale.to(dtype=final_dtype), + zero_point.to(dtype=quant_args.pytorch_dtype()), + g_idx, + ) + + +def _apply_activation_ordering( + W: torch.Tensor, H: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Permute weight and hessian in order of greatest outupt activations + + :param W: weight to permute + :param H: hessian used to determine activation ordering + :return: permuted weight, permuted hessian, permutation map + """ + perm = torch.argsort(torch.diag(H), descending=True) + return W[:, perm], H[perm][:, perm], perm diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py deleted file mode 100644 index 02eafb669..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ /dev/null @@ -1,354 +0,0 @@ -import time -from typing import Tuple - -from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy -from compressed_tensors.quantization.lifecycle.forward import fake_quantize - -from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD -from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from llmcompressor.observers import Observer -from llmcompressor.pytorch.utils.helpers import tensor_sparsity -from llmcompressor.utils import getattr_chain -from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_mb - -try: - import transformers -except ImportError as err: - transformers = None - transformers_err = err - -import math -from copy import copy - -import torch -import torch.nn as nn -from compressed_tensors.utils import ( - get_offloaded_device, - is_module_offloaded, - update_parameter_data, - update_prefix_dict, -) -from loguru import logger - -__all__ = ["GPTQWrapper"] - - -class GPTQWrapper(ModuleCompressionWrapper): - """ - Runs GPTQ on a single module that contains no sub-modules - - Lifecycle: - - add_batch - - compress - - free - - :param name: name of module to run compression on - :param layer: module to run compression on - """ - - def __init__(self, name, layer): - super().__init__(name=name, layer=layer) - - # for Hessian calculation - self.register_buffer( - "H", - torch.zeros( - (self.columns, self.columns), device=self.dev, dtype=torch.float32 - ), - ) - - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - - :param inp: tensor containing layer input - :param out: tensor containing layer output - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = inp.to(dtype=self.H.dtype) - inp = math.sqrt(2 / self.nsamples) * inp - self.H += inp.matmul(inp.t()) - - def compress( - self, - blocksize: int = 128, - percdamp: float = 0.01, - ): - """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. - - :param blocksize: Number of columns to compress in one pass - :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm - """ - args_loc = "quantization_scheme.weights" - quant_args = getattr_chain(self.layer, args_loc, None) - if quant_args is None: - logger.debug(f"Skipping unquantized layer {self.name}...") - return - - if is_module_offloaded(self.layer): - self.layer._hf_hook.pre_forward(self.layer) - - strategy = quant_args.strategy - actorder = quant_args.actorder - final_shape = self.layer.weight.shape - final_dtype = self.layer.weight.dtype - W = self.layer.weight.data.clone() - - # create observer for calculating quantization parameters - observer = Observer.load_from_registry( - quant_args.observer, - quantization_args=quant_args, - averaging_constant=1.0, # ignore moving average - ) - - # standardize shape and dtype - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - elif isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.float() - - tick = time.time() - - if strategy == QuantizationStrategy.GROUP: - # mapping from column index to group index - g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // quant_args.group_size - ) - - if actorder == ActivationOrdering.GROUP: - # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - scale, zero_point = observer(W, g_idx=None) - - # use identity g_idx (invert permutation later) - - elif actorder == ActivationOrdering.WEIGHT: - # update groups first, then permute by activation order - scale, zero_point = observer(W, g_idx=None) - W, self.H, perm = self._apply_activation_ordering(W, self.H) - - # permute g_idx to maintain identity mapping after unpermutation - g_idx = g_idx[perm] - - else: - scale, zero_point = observer(W, g_idx=None) - else: - scale, zero_point = observer(W, g_idx=None) - - # sparsity mask - sparsity = tensor_sparsity(W) - preserve_zeros = sparsity >= SPARSITY_THRESHOLD - W_nz_mask = ( - (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() - if preserve_zeros - else None - ) - - # mask dead hessian values - dead = torch.diag(self.H) == 0 - self.H[dead, dead] = 1 - W[:, dead] = 0 - - Losses = torch.zeros(self.rows, device=self.dev) - - # compute inverse hessian in place to save memory - try: - damp = percdamp * torch.mean(torch.diag(self.H)) - diag = torch.arange(self.columns, device=self.dev) - self.H[diag, diag] += damp - self.H = torch.linalg.cholesky(self.H) - self.H = torch.cholesky_inverse(self.H) - self.H = torch.linalg.cholesky(self.H, upper=True) - Hinv = self.H - except torch._C._LinAlgError: - raise ValueError( - "Failed to invert hessian due to numerical instability. Consider " - "increasing GPTQModifier.dampening_frac, increasing the number " - "of calibration samples, or shuffling the calibration dataset" - ) - - # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - if preserve_zeros: - W1_nz_mask = W_nz_mask[:, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = w.clone() - - # quantize column - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - self.layer.quantization_scheme.weights, - ) - elif strategy == QuantizationStrategy.CHANNEL: - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - quant_args, - ) - elif strategy == QuantizationStrategy.GROUP: - # get the group index for the current column - column_idx = i1 + i - group_index = g_idx[column_idx] - - # update quantization parameters to reflect changes - # resulting from previous blocks - if ( - actorder != ActivationOrdering.WEIGHT - and column_idx % quant_args.group_size == 0 - ): - _scale, _zero_point = observer.get_qparams_along_dim( - W[:, g_idx == group_index], dim=0 - ) - scale[:, group_index] = _scale[:, 0] - zero_point[:, group_index] = _zero_point[:, 0] - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(quant_args) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, group_index], - zero_point[:, group_index], - altered_qargs, - ) - else: - raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" - ) - - # propagate column error - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if preserve_zeros: - W1[:, i:] -= w1_err * W1_nz_mask[:, i:] - else: - W1[:, i:] -= w1_err - Err1[:, i] = err1 - - # propagate block error - W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 - - w_err = Err1.matmul(Hinv[i1:i2, i2:]) - if preserve_zeros: - W[:, i2:] -= w_err * W_nz_mask[:, i2:] - else: - W[:, i2:] -= w_err - - if "METRIC" in logger._core.levels.keys(): - self._log_metrics(tick, Losses) - - if strategy == QuantizationStrategy.GROUP: - if actorder == ActivationOrdering.WEIGHT: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - - elif actorder == ActivationOrdering.GROUP: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - g_idx = g_idx[invperm] - - # only save g_idx if mapping is not identity - update_parameter_data(self.layer, g_idx, "weight_g_idx") - - if isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.reshape(final_shape).to(final_dtype) - - update_parameter_data(self.layer, scale, "weight_scale") - update_parameter_data(self.layer, zero_point, "weight_zero_point") - - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += W - - if is_module_offloaded(self.layer): - device = get_offloaded_device(self.layer) - update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) - self.layer._hf_hook.post_forward(self.layer, None) - - def free(self): - """ - Free the Hessian memory after the layer is complete - """ - delattr(self, "H") - super().free() - - def _apply_activation_ordering( - self, W: torch.Tensor, H: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute weight and hessian in order of greatest outupt activations - - :param W: weight to permute - """ - perm = torch.argsort(torch.diag(H), descending=True) - return W[:, perm], H[perm][:, perm], perm - - def _log_metrics(self, start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_mb(self.layer)} MB", - ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py deleted file mode 100644 index 58fedc634..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Iterable, List, Tuple, Union - -import torch - -__all__ = ["get_output_error"] - - -def get_output_error( - unquantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], - quantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], -) -> torch.Tensor: - """ - Calculate mean l1 loss between weight-unquantized outputs and weight-quantized - outputs - - :param unquantized: unquantized-weight outputs - :param quantized: quantized-weight outputs - :return: mean l1 loss between outputs - """ - unquantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in unquantized - ], - start=[], - ) - - quantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in quantized - ], - start=[], - ) - - if len(unquantized_outputs) != len(quantized_outputs): - raise ValueError( - "Number of samples of weight-unquantized and weight-quantized " - "outputs differs" - ) - - return sum( - [ - torch.nn.functional.l1_loss(unq, q) - for unq, q in zip(unquantized_outputs, quantized_outputs) - ] - ) / len(unquantized_outputs) diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 3dd3caa7e..3f3aa3d02 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -4,7 +4,6 @@ import torch from compressed_tensors import get_execution_device from loguru import logger -from torch.nn import Module from tqdm import tqdm from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper @@ -14,8 +13,7 @@ fix_fsdp_module_name, summon_full_params_context, ) -from llmcompressor.utils.pytorch import set_layer -from llmcompressor.utils.pytorch.module import get_prunable_layers +from llmcompressor.utils.pytorch.module import get_prunable_layers, set_layer __all__ = ["LayerCompressor"] @@ -45,8 +43,8 @@ class LayerCompressor: def __init__( self, module_compressor_class: ModuleCompressionWrapper, - model: Module, - layer: Module, + model: torch.nn.Module, + layer: torch.nn.Module, layer_index: int, name: str, args: Dict, diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9003ff22d..c9869f267 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -1,5 +1,5 @@ from itertools import cycle -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch.nn import Module @@ -25,7 +25,7 @@ class EarlyStopException(Exception): :param kwargs: keyword inputs passed to the layer where the excetion was raised """ - def __init__(self, args: Tuple, kwargs: Dict): + def __init__(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]): self.args = tensors_to_device(args, "cpu") self.kwargs = kwargs @@ -34,7 +34,9 @@ def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.T """ Apply a mask to the input ids of a batch. This is used to zero out padding tokens so they do not contribute to the hessian calculation in the - SparseGPT algorithm + GPTQ and SparseGPT algorithms + + Assumes that `attention_mask` only contains zeros and ones :param batch: batch to apply padding to if it exists :return: batch with padding zeroed out in the input_ids @@ -58,7 +60,7 @@ def run_calibration_forward( :param model: PyTorch model to run :param calibration_dataloader: data to use for calibration :param num_calibration_steps: number of items in calibration_dataloader to process, - None or a negative number to process all available data + None or a negative number to process all available data :param calibration_function: option to pass a custom forward function for model :param device: option to move the model to a specific device before calibration :param mask_padding: whether to zero out padding tokens during calibration diff --git a/src/llmcompressor/pipelines/__init__.py b/src/llmcompressor/pipelines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llmcompressor/pipelines/basic/__init__.py b/src/llmcompressor/pipelines/basic/__init__.py new file mode 100644 index 000000000..fc60475ca --- /dev/null +++ b/src/llmcompressor/pipelines/basic/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py new file mode 100644 index 000000000..c7552a654 --- /dev/null +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -0,0 +1,31 @@ +import torch +import torch.utils.data.dataloader +import tqdm +from compressed_tensors.utils import get_execution_device + +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pytorch.utils.helpers import tensors_to_device +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["run_pipeline"] + + +def run_pipeline(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + """ + Run a basic data pipeline. + + Batches are fetched from the data loader and are used to perform forward passes + through the model. This pipeline is typically used for basic model calibration + and, unlike the sequential pipelines, does not propagate compression error when + used to calibrate model compression + + :param model: model being calibrated + :param dataloader: loads data for calibration + """ + model_device = get_execution_device(model) + + with calibration_forward_context(model): + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): + batch = apply_pad_mask_to_batch(batch) + batch = tensors_to_device(batch, model_device) + model(**batch) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py new file mode 100644 index 000000000..57c9b1486 --- /dev/null +++ b/src/llmcompressor/pipelines/cache.py @@ -0,0 +1,187 @@ +import warnings +from dataclasses import dataclass, fields, is_dataclass +from typing import Any, Dict, List, Optional, Union + +import torch +import tqdm + + +@dataclass +class IntermediateValue: + """ + Dataclass which recursively defines offloaded values and which device to onload to + + :param value: either an offloaded Tensor, an primative value, or a recursable value + :param device: if the value is a Tensor, then the device to onload the tensor to, + otherwise None + """ + + value: Union[torch.Tensor, "IntermediateValue", Any] + device: Union[torch.device, None] + + +class IntermediatesCache: + """ + Cache which stores intermediate values (activations) produced by batched, sequential + execution of models. Values are offloaded to the `offload_device` when stored in + the cache and onloaded to their original device when fetched from the cache + + Currently supports nested offloading of dataclass instances and tuples + + Construct using `empty` and `from_dataloader` class methods + """ + + batch_intermediates: List[Dict[str, IntermediateValue]] + offload_device: torch.device + + def __init__( + self, + batch_intermediates: List[Dict[str, IntermediateValue]], + offload_device: torch.device, + ): + self.batch_intermediates = batch_intermediates + self.offload_device = offload_device + + @classmethod + def empty(cls, num_batches: int, offload_device: torch.device): + """ + Construct an empty cache + + :param num_batches: the expected number of batches to be stored + :param offload_device: device to offload values to + """ + batch_intermediates = [{} for _ in range(num_batches)] + return cls(batch_intermediates, offload_device) + + @classmethod + def from_dataloader( + cls, + dataloader: torch.utils.data.DataLoader, + model_device: torch.device, + mask_padding: bool = True, + offload_device: torch.device = torch.device("cpu"), + ): + """ + Initialize a cache with data from the provided dataloader + + :param dataloader: dataloader which generates values to be cached + :param model_device: device which values will be onloaded to when fetched + :param mask_padding: zero out padding tokens if True. This affects modifiers + such as GPTQ and SparseGPT + :param offload_device: device to offload values to + """ + # note: list comprehesion was found to not improve performance + batch_intermediates = [] + for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache"): + intermediate = {} + for key, value in batch.items(): + if mask_padding and key == "input_ids": + value = cls._mask_padding(value, batch["attention_mask"]) + intermediate[key] = IntermediateValue(value=value, device=model_device) + + batch_intermediates.append(intermediate) + + return cls(batch_intermediates, offload_device) + + def fetch( + self, batch_index: int, input_names: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + Fetch values belonging to a batch + + :param batch_index: index of batch whose values are being fetched + :param input_names: list of keys whose values are being fetched + :return: dictionary mapping keys to onloaded values + """ + intermediates = self.batch_intermediates[batch_index] + + return { + key: self._onload_value(subgraph_input) + for key, subgraph_input in intermediates.items() + if input_names is None or key in input_names + } + + def update(self, batch_index: int, values: Dict[str, Any]): + """ + Update/put values belonging to a batch + + :param batch_index: index of batch whose values will be updated + :param values: dictionary mapping keys to values used for update + """ + intermediates = {k: self._offload_value(v) for k, v in values.items()} + self.batch_intermediates[batch_index].update(intermediates) + + def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None): + """ + Delete values from the cache + + :param batch_index: index of batch whose values will be deleted + :param consumed_names: list of keys whose values will be deleted, defaults to + removing all keys + """ + intermediates = self.batch_intermediates[batch_index] + + if consumed_names is None: + consumed_names = list(intermediates.keys()) + + for name in consumed_names: + del intermediates[name] + + def _onload_value(self, intermediate: IntermediateValue) -> Any: + value = intermediate.value + device = intermediate.device + + if isinstance(value, torch.Tensor): + return value.to(device=device) + + elif is_dataclass(value): + for field in fields(value): # `asdict` is recursive, not applicable here + v = getattr(value, field.name) + setattr(value, field.name, self._onload_value(v)) + + return value + + elif isinstance(value, tuple): + return tuple(self._onload_value(v) for v in value) + + elif isinstance(value, (int, str, float, bool)) or value is None: + return value + + else: + return value + + def _offload_value(self, value: Any) -> IntermediateValue: + if isinstance(value, torch.Tensor): + return IntermediateValue( + value=value.to(device=self.offload_device), device=value.device + ) + + elif is_dataclass(value): + for field in fields(value): # `asdict` is recursive, not applicable here + v = getattr(value, field.name) + setattr(value, field.name, self._offload_value(v)) + + return IntermediateValue(value=value, device=None) + + if isinstance(value, tuple): + return IntermediateValue( + value=tuple(self._offload_value(v) for v in value), device=None + ) + + if isinstance(value, (int, str, float, bool)) or value is None: + return IntermediateValue(value=value, device=None) + + else: + warnings.warn(f"Offloading not implemented for type {type(value)}.") + return IntermediateValue(value=value, device=None) + + @staticmethod + def _mask_padding( + input_ids: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + if attention_mask.dim() == 4: + # some attention masks, such as those from pixtral, are are 4d + attention_mask = attention_mask[0, 0, 0].unsqueeze(0) + + # Assumes that `attention_mask` only contains zeros and ones + return input_ids * attention_mask diff --git a/src/llmcompressor/pipelines/layer_sequential/__init__.py b/src/llmcompressor/pipelines/layer_sequential/__init__.py new file mode 100644 index 000000000..fc60475ca --- /dev/null +++ b/src/llmcompressor/pipelines/layer_sequential/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py new file mode 100644 index 000000000..91b300cfd --- /dev/null +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -0,0 +1,129 @@ +import contextlib +import inspect +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch +import tqdm +from compressed_tensors.quantization import find_name_or_class_matches +from compressed_tensors.utils import get_execution_device +from torch.nn import Module +from torch.utils.data.dataloader import DataLoader + +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pytorch.utils.helpers import tensors_to_device +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["match_modules", "capture_first_layer_intermediates", "to_next_layer_kwargs"] + + +def match_modules(model: Module, target_names: List[str]) -> List[Module]: + """ + Find all submodules which match the `target_names` and sort them by name + + :param model: model to search for submodules in + :param target_names: patterns of submodule names to match + :return: list of submodules + """ + names_layers = [ + (name, module) + for name, module in model.named_modules() + if find_name_or_class_matches(name, module, target_names) + ] + + names_layers = sorted(names_layers, key=lambda name_layer: name_layer[0]) + return [layer for _name, layer in names_layers] + + +def capture_first_layer_intermediates( + model: Module, + first_layer: Module, + dataloader: DataLoader, + mask_padding: bool = True, +) -> IntermediatesCache: + """ + Captures the intermediate activations directly before the first model layer. + This is meant to capture any model preprocessing before model layers are executed + + Note that if any modules compressed prior to the execution of the first layer, the + compression error induced by compressing those modules will not be propagated to + subsequent activations, as they would be for modules which are compressed within + a layer + + :param model: model containing layers + :param first_layer: the first layer of the model + :param dataloader: dataloader of calibration inputs + :param mask_padding: zero out padding tokens if True. This affects modifiers such as + GPTQ and SparseGPT + """ + model_device = get_execution_device(model) + intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu")) + signature = inspect.signature(first_layer.forward) + + with calibration_forward_context(model), early_stop_hook(first_layer): + desc = "Preparing intermediates cache" + for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)): + batch = apply_pad_mask_to_batch(batch) if mask_padding else batch + batch = tensors_to_device(batch, model_device) + + try: + model(**batch) + except EarlyStopException as exception: + layer_args = args_to_kwargs(exception._args, signature) + assert not set(layer_args.keys()) & set(exception._kwargs.keys()) + layer_args.update(exception._kwargs) + + intermediates.update(batch_index, layer_args) + else: + raise ValueError( + "Attempted to capture first layer intermediates, but " + "EarlyStopException was not raised" + ) + + return intermediates + + +def to_next_layer_kwargs(args: Tuple[Any, ...], next_layer: Module) -> Dict[str, Any]: + """ + Convert a list of arguments to a dictionary of keyword arguments which match the + next layer's function signature + + :param args: list of argument values + :param next_layer: the next layer whose function signature must be matched + :return: dictionary mapping function signature keywords to argument values + """ + signature = inspect.signature(next_layer.forward) + return args_to_kwargs(args, signature) + + +def args_to_kwargs( + args: Tuple[Any, ...], signature: inspect.Signature +) -> Dict[str, Any]: + return {name: arg for name, arg in zip(signature.parameters.keys(), args)} + + +@contextlib.contextmanager +def early_stop_hook(module: Module): + def trigger_early_stop_fn(module, args, kwargs): + raise EarlyStopException(_args=args, _kwargs=kwargs) + + handle = module.register_forward_pre_hook(trigger_early_stop_fn, with_kwargs=True) + + try: + yield + finally: + handle.remove() + + +@dataclass +class EarlyStopException(Exception): + """ + Note: this exception is different from the exception defined in + llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace it + + Attribute names `args` and `kwargs` are reserved for `dataclass` + """ + + _args: Tuple[Any, ...] + _kwargs: Dict[str, Any] diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py new file mode 100644 index 000000000..f93fd6f2d --- /dev/null +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -0,0 +1,75 @@ +from typing import List + +import torch +import torch.utils.data.dataloader +import tqdm + +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pipelines.layer_sequential.helpers import ( + capture_first_layer_intermediates, + match_modules, + to_next_layer_kwargs, +) +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["run_pipeline"] + + +def run_pipeline( + model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, + sequential_targets: List[str], +): + """ + Run a layer-wise sequential data pipeline according to the following steps: + + 1. Layers are identified according to `sequential_targets` + 2. A hook is attached to the first layer. This hook raises an exception which is + then caught and used to capture the input arguments to the first layer + 3. The inputs to the first layer are used to calibrate the first layer, and the + output of the previous layer is used as inputs to calibrate the next layer + + This pipeline requires that the model have distinct layers defined in its + architecture and that the outputs of the previous layer are exactly the inputs + to the next layer. This is violated by encoder-decoder architectures among others. + + If your model architecture violates these assumptions, consider using the sequential + pipeline (see llmcompressor.pipelines.sequential). Architectures which are known to + fail these assumptions include GPT-J and most vision language models + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param sequential_targets: patterns which match to the layer modules of the model + """ + # find layers + layers = match_modules(model, sequential_targets) + + with calibration_forward_context(model): + # prepare intermediates cache + intermediates: IntermediatesCache = capture_first_layer_intermediates( + model, layers[0], dataloader + ) + + num_layers = len(layers) + for layer_index, layer in enumerate(layers): + # prepare tqdm description texts + calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" + prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" + + # do an preliminary pass to trigger modifier hooks + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index) + layer(**inputs) + + # this pass does not trigger modifier hooks + # and is only used for capturing outputs from the newly compressed modules + with HooksMixin.disable_hooks(): + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): + inputs = intermediates.fetch(batch_index) + output = layer(**inputs) + + if layer_index < num_layers - 1: + output = to_next_layer_kwargs(output, layers[layer_index + 1]) + intermediates.delete(batch_index) + intermediates.update(batch_index, output) diff --git a/src/llmcompressor/pipelines/sequential/__init__.py b/src/llmcompressor/pipelines/sequential/__init__.py new file mode 100644 index 000000000..fc60475ca --- /dev/null +++ b/src/llmcompressor/pipelines/sequential/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py new file mode 100644 index 000000000..4945ba01e --- /dev/null +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -0,0 +1,365 @@ +import inspect +from collections import deque +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Set + +from compressed_tensors import has_offloaded_params +from compressed_tensors.quantization import find_name_or_class_matches +from torch.fx import Graph, GraphModule, Node +from torch.fx.proxy import Argument +from torch.nn import Module +from transformers import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig +from transformers.utils.fx import HFTracer + +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["trace_subgraphs", "Subgraph"] + + +@dataclass +class Subgraph: + """ + Dataclass specifying an executable subgraph of a model graph + + :param graph: subgraph of model graph + :param input_names: argument names of the compiled forward function + :param consumed_names: argument names which are not used by any subsequent subgraphs + and can therefore be deleted from the intermediates cache + """ + + graph: Graph + input_names: Set[str] + consumed_names: Set[str] + + def compile_forward(self) -> Callable[[Any], Any]: + """ + Generate and compile code for executing this subgraph + + :return: function which, when called, executes this subgraph + """ + code = self.graph.python_code("self") + exec(code.src, code.globals) + return code.globals.get("forward") + + +def trace_subgraphs( + model: PreTrainedModel, + sample_input: Dict[str, Any], + sequential_targets: List[str], + ignore: List[str], +) -> List[Subgraph]: + """ + Trace a model to produce subgraphs, where each sequential target belongs to exactly + one subgraph and where executing each subgraph in order is equivalent to executing + the original model + + :param model: model being traced + :param sample_input: inputs whose values will change during execution but whose + __len__, __bool__, and __contains__ values are assumed constant across batches + :param sequential_targets: list of patterns matching sequential targets + :param ignore: list of patterns matching modules to ignore during tracing + :return: a list of Subgraphs in order of execution + """ + # find modules + sequential_targets = match_modules(model, sequential_targets) + ignore = match_modules(model, ignore) + + # initialize arguments + tracer = get_tracer(model, sequential_targets, ignore) + concrete_args = populate_concrete_args(model, sample_input) + + # trace + with ( + calibration_forward_context(model), + HooksMixin.disable_hooks(), + ): + graph = GraphModule( + model, + tracer.trace( + model, + dummy_inputs=sample_input, + concrete_args=concrete_args, + complete_concrete_args_with_inputs_not_in_dummy_inputs=False, + # bug in trace throws an error for variadic + # args and kwargs in function signature + ), + ) + + # copy metadata + graph.config = model.config + graph.class_for_deserialization = model.__class__ + graph.device = model.device + + # perform subgraph partition + partitions = topological_partition(graph, sequential_targets) + subgraphs = partition_graph(model, partitions) + trace_consumed_names(subgraphs) + + return subgraphs + + +def get_tracer( + model: Module, sequential_targets: Set[Module], ignore: Set[Module] +) -> HFTracer: + """ + Get a tracer specialized for the given model. The resulting tracer will not trace + inside of sequential targets, ignored targets, or offloaded modules. + + Tracing within sequential targets and ignored targets is unnecessary, and tracing + within offloaded modules may result in meta tensors being added to the model graph + + :param model: model being traced + :param sequential_targets: modules which are sequential targets + :param ignore: modules which are ignored + """ + offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m)) + skip_trace_modules = sequential_targets | offloaded_modules | ignore + + class SequentialTracer(HFTracer): + def create_arg(self, a: Any) -> Argument: + # special extension allows models which depend on config values to be traced + if isinstance(a, PretrainedConfig): + kwargs = {k: self.create_arg(v) for k, v in a.to_dict().items()} + return self.create_node("call_function", a.__class__, (), kwargs) + + else: + return super().create_arg(a) + + def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: + return module in skip_trace_modules or super().is_leaf_module( + module, module_qualified_name + ) + + return SequentialTracer() + + +def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: + """ + Creates concrete args which, unlike the equivalent function provided by + transformers.utils.fx, creates default values for variadic arguments, which are + needed by some models. + + :param model: model being traced + :param sample_input: values used to symbolically trace the model. All arguments + to the model.forward function which are not in the sample_input are considered + concrete args + :return: dictionary mapping concrete argument names to their default values + """ + sig = inspect.signature(model.forward) + + concrete_args = {} + for parameter in sig.parameters.values(): + if parameter.name in sample_input: + continue + if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL: + value = list() + elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD: + value = dict() + elif parameter.name == "use_cache": + value = False + else: + value = parameter.default + + concrete_args[parameter.name] = value + + return concrete_args + + +def find_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]: + """ + Find all nodes whose execution is equivalent to executing the target modules. + Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer + + :param graph: graph containing target nodes + :param targets: modules whose nodes are being searched for + :return: set of all nodes which call the target modules + """ + return set( + node + for node in graph.graph.nodes + if node.op == "call_module" and graph.get_submodule(node.target) in targets + ) + + +def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]: + """ + Partition the graph into partitions such that each `target` belongs to exactly one + partition and executing each partition depends only on intermediate values produced + by executing the partitions before it. + + :param graph: graph being partitioned + :param targets: target modules which will be assigned to disjoint partitions + :return: list of partitions, where each partition is a list of nodes belonging to + that partition + """ + assert graph_is_well_formed(graph.graph) + target_nodes = find_target_nodes(graph, targets) + + partitions: List[List[Node]] = [[]] + remaining_indegrees = { + node: len([node for node in node.all_input_nodes if node.op != "get_attr"]) + for node in graph.graph.nodes + } + partition_index = 0 # global counter + + # start with graph input nodes, + # but delay the `get_attr` nodes as long as possible + queue = deque( + node + for node in graph.graph.nodes + if remaining_indegrees[node] == 0 and node.op != "get_attr" + ) + while len(queue) > 0: + node = queue.popleft() + + # assign to partition + partitions[partition_index].append(node) + + # guarantee targets are assigned to disjoint partitions + if node in target_nodes: + partition_index += 1 + partitions.append([]) + + # recurse on last indegree only in order to guarantee that + # the node is assigned to maximal partition + for user in node.users: + remaining_indegrees[user] -= 1 + if remaining_indegrees[user] == 0: + queue.append(user) + + # an ideal implementation would involve implicitly consolidating partition indices + # so that each node is assigned to the maximum partition possible (in order to delay + # execution as long as possible), but saving these nodes for last covers the most + # common and costly case (get_attr) + for node in graph.graph.find_nodes(op="get_attr"): + user_partitions = [] + for user in node.users: + for index in range(len(partitions)): + if user in partitions[index]: + user_partitions.append(index) + break + partition_index = min(user_partitions) + partitions[partition_index].insert(0, node) + + assert set().union(*partitions) == set(graph.graph.nodes) + return partitions + + +def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgraph]: + """ + Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping + of output node names to their computed values. Note that the `consumed_names` + attribute of each Subgraph remains empty, to be later populated by + `trace_consumed_names` + + :param model: model which owns the produced Subgraphs + :param partitions: list of partitions, where each partition is a list of nodes + belonging to that partition + :return: list of subgraphs in order of execution + """ + subgraphs = [] + + # create subgraphs + for partition_nodes in partitions: + # create a new graph for the partition + graph = Graph(model) + node_map = {} + + # add placeholders for inputs not in this subgraph. use set to deduplicate + new_input_nodes = { + input_node + for node in partition_nodes + for input_node in node.all_input_nodes + if input_node not in partition_nodes and input_node.op + } + for input_node in new_input_nodes: + node_map[input_node] = graph.placeholder(input_node.name) + + # add the nodes to subgraph + for node in partition_nodes: + node_map[node] = graph.node_copy(node, lambda n: node_map[n]) + + # add an output node to collect all subgraph outputs into a dictionary + if len(graph.find_nodes(op="output")) <= 0: + output_dict = { + node.name: node_map[node] + for node in partition_nodes + if any(user not in partition_nodes for user in node.users.keys()) + } + graph.output(output_dict) + + # save the subgraph for this partition + graph.lint() + input_names = set(node.name for node in graph.nodes if node.op == "placeholder") + subgraphs.append( + Subgraph( + graph=graph, + input_names=input_names, + consumed_names=set(), # populated later + ) + ) + + assert graph_is_well_formed(graph) + + return subgraphs + + +def trace_consumed_names(subgraphs: List[Subgraph]): + """ + Populate the `consumed_names` attribute of each Subgraph according to when inputs + are last used in order to vacate the `intermediates` cache and save memory + + :param subgraphs: list of subgraphs with empty `consumed_names` attributes + """ + # populate consumed_names according to when inputs are last used + # in order to vacate the `intermediates` cache and save memory + all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs)) + for input_name in all_input_names: + for subgraph in reversed(subgraphs): + if input_name in subgraph.input_names: + subgraph.consumed_names.add(input_name) + break + else: + raise ValueError(f"Could not find input name {input_name} in subgraphs") + + +def graph_is_well_formed(graph: Graph) -> bool: + """ + A graph is well formed if and only if + `nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes` + + :param graph: graph being checked + :return: True if the graph is well formed, False otherwise + """ + for node in graph.nodes: + for user in node.users: + if node not in user.all_input_nodes: + return False + + for input_node in node.all_input_nodes: + if node not in input_node.users: + return False + + if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len( + set(node.all_input_nodes) + ): + return False + + return True + + +def match_modules(model: Module, target_names: List[str]) -> Set[Module]: + """ + Find modules whose names match the patterns given by `target_names` + + :param model: model containing submodules to find + :param target_names: target patterns to find + :return: all submodules matching `target_names` + """ + return set( + module + for name, module in model.named_modules() + if find_name_or_class_matches(name, module, target_names) + ) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py new file mode 100644 index 000000000..647d5761e --- /dev/null +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -0,0 +1,77 @@ +from typing import List + +import torch +import torch.utils.data.dataloader +import tqdm +from compressed_tensors.utils import get_execution_device + +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pipelines.sequential.helpers import trace_subgraphs +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["run_pipeline"] + + +def run_pipeline( + model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, + sequential_targets: List[str], + ignore: List[str], +): + """ + Run a sequential data pipeline according to the following steps: + + 1. The model is partitioned into subgraphs according to `sequential_targets` + 2. Data passes through each subgraph sequentially. Data is passed through each + subgraph twice, once to trigger calibration hooks, then a second time in order + to capture activations after quantization has occurred through the hooks. + 3. The intermediate activations between each subgraph are cached and offloaded to + the cpu between each batch in order to save memory + + This pipeline requires that the model be traceable with respect to data from the + data loader. This may be an issue for vision language models with vision datasets, + due to specialized input processing in the model. + + In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model + can be made traceable by wrapping the untraceable functions (see + llmcompressor.transformers.tracing) + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param sequential_targets: patterns which match to the layer modules of the model + :param ignore: patterns which match to modules which should be ignored by tracing + """ + # trace subgraphs + sample_input = next(iter(dataloader)) + subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) + + with calibration_forward_context(model): + # prepare intermediates cache + model_device = get_execution_device(model) + intermediates = IntermediatesCache.from_dataloader(dataloader, model_device) + + num_subgraphs = len(subgraphs) + for subgraph_index, subgraph in enumerate(subgraphs): + # prepare tqdm description texts + calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" + prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" + + # compile subgraph forward function + forward_function = subgraph.compile_forward() + + # do an preliminary pass to trigger modifier hooks + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index, subgraph.input_names) + forward_function(model, **inputs) + + # this pass does not trigger modifier hooks + # and is only used for capturing outputs from the newly compressed modules + with HooksMixin.disable_hooks(): + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): + inputs = intermediates.fetch(batch_index, subgraph.input_names) + output = forward_function(model, **inputs) + + if subgraph_index < num_subgraphs - 1: + intermediates.update(batch_index, output) + intermediates.delete(batch_index, subgraph.consumed_names) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 7c839c5a7..35ab51220 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import psutil import torch @@ -172,6 +172,7 @@ def custom_offload_device_map( model_stub: str, max_memory_per_gpu: Union[str, int], num_gpus: int = 1, + model_cls: Type = AutoModelForCausalLM, **model_kwargs, ) -> Dict[Union[int, str], Union[int, str]]: """ @@ -182,6 +183,8 @@ def custom_offload_device_map( :param max_memory_per_gpu: Max memory to allocate on each GPU, as either a string such as "10GB" or an integer number of bytes :param num_gpus: number of gpus to utilize + :param model_cls: model class to use when initializing model structure, + default is AutoModelForCausalLM :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ @@ -191,7 +194,7 @@ def custom_offload_device_map( device_map = {} with init_empty_weights(): - dummy_model = AutoModelForCausalLM.from_pretrained(model_stub, **model_kwargs) + dummy_model = model_cls.from_pretrained(model_stub, **model_kwargs) device_map = infer_auto_device_map( dummy_model, max_memory=memory_limits, @@ -207,6 +210,7 @@ def calculate_offload_device_map( reserve_for_hessians=False, num_gpus: int = 1, torch_dtype: torch.dtype = torch.float16, + model_cls: Type = AutoModelForCausalLM, **model_kwargs, ) -> Dict[Union[int, str], Union[int, str]]: """ @@ -216,6 +220,8 @@ def calculate_offload_device_map( :param model_stub: local path or HF stub to calculate mapping for :param reserve_for_hessians: whether to reserve memory for GPTQ :param num_gpus: number of gpus to utilize + :param model_cls: model class to use when initializing model structure, + default is AutoModelForCausalLM :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ @@ -230,7 +236,7 @@ def calculate_offload_device_map( device_map = {} with init_empty_weights(): - dummy_model = AutoModelForCausalLM.from_pretrained( + dummy_model = model_cls.from_pretrained( model_stub, torch_dtype=torch_dtype, **model_kwargs ) diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py new file mode 100644 index 000000000..4baa5864d --- /dev/null +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -0,0 +1,13 @@ +from .llava import ( + LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, +) +from .mistral import MistralForCausalLM as TraceableMistralForCausalLM +from .mllama import ( + MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration, +) + +__all__ = [ + "TraceableLlavaForConditionalGeneration", + "TraceableMllamaForConditionalGeneration", + "TraceableMistralForCausalLM", +] diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py new file mode 100644 index 000000000..0f993a356 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -0,0 +1,273 @@ +# flake8: noqa +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# vllm-project: no copyright +"""PyTorch Llava model.""" + +from typing import List, Optional, Tuple, Union + +import torch +from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration +from transformers.models.llava.configuration_llava import LlavaConfig +from transformers.models.llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaMultiModalProjector, + LlavaPreTrainedModel, + logger, +) +from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.utils.fx import HFProxy + +# TRACING: Reuse traceable subclass +from .mistral import MistralForCausalLM as TraceableMistralForCausalLM + + +# TRACING: The shape of image_features is known and documented by +# LlavaForConditionalGeneration.get_image_features +def maybe_install_metadata_image_features( + image_features: Union[torch.Tensor, HFProxy], + pixel_values: Union[torch.Tensor, HFProxy], + config: LlavaConfig, +) -> Union[torch.Tensor, HFProxy]: + if isinstance(image_features, HFProxy): + # (num_images, image_length, embed_dim) + num_images = pixel_values._metadata.size(0) + image_length = config.image_seq_length + embed_dim = config.vision_config.intermediate_size + + original_fn = image_features.tracer.patched_torch_methods["empty"][1] + metadata = original_fn( + (num_images, image_length, embed_dim), device=torch.device("meta") + ) + image_features.install_metadata(metadata) + + return image_features + + +# TRACING: The shape of inputs_embeds is known. This function compensates for +# the fact that shape inference through `masked_scatter` is not implemented yet +def maybe_install_metadata_inputs_embeds( + inputs_embeds_masked: Union[torch.Tensor, HFProxy], + inputs_embeds: Union[torch.Tensor, HFProxy], + special_image_mask: Union[torch.Tensor, HFProxy], + image_features: Union[torch.Tensor, HFProxy], +) -> Union[torch.Tensor, HFProxy]: + if isinstance(inputs_embeds_masked, HFProxy): + metadata = inputs_embeds._metadata.masked_scatter( + special_image_mask._metadata.to(bool), image_features._metadata + ) + inputs_embeds_masked.install_metadata(metadata) + + return inputs_embeds + + +# TRACING: override `__init__` and `forward` +class LlavaForConditionalGeneration(LlavaForConditionalGeneration): + def __init__(self, config: LlavaConfig): + super(LlavaPreTrainedModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + # TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function + if isinstance(config.text_config, MistralConfig): + self.language_model = TraceableMistralForCausalLM(config.text_config) + else: + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + legacy_processing = False + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + + # TRACING: Assume that the user will not pass 500+ images for a single prompt + # instead always use legacy_processing = False + # legacy_processing = ( + # (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + # ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + image_features = maybe_install_metadata_image_features( + image_features, pixel_values, self.config + ) + + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + + # TRACING: Assume that processing and tokenization was done correctly + # if n_image_tokens != n_image_features: + if False: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds_masked = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # TRACING: install metadata + inputs_embeds_masked = maybe_install_metadata_inputs_embeds(inputs_embeds_masked, inputs_embeds, special_image_mask, image_features) + inputs_embeds = inputs_embeds_masked + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py new file mode 100644 index 000000000..3c9102b23 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -0,0 +1,251 @@ +# flake8: noqa +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# vllm-project: no copyright +"""PyTorch Mistral model.""" + +import torch +from torch import nn + +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.utils import ( + logging, +) + +# TRACING: imports +from transformers.models.mistral.modeling_mistral import ( + MistralPreTrainedModel, + MistralModel, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralForQuestionAnswering, +) + +logger = logging.get_logger(__name__) + + +# TRACING: This function is untracable +@torch.fx.wrap +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MistralConfig, + past_key_values: Cache, +): + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + return causal_mask + + +# TRACING: must use wrapped _prepare_4d_causal_attention_mask_with_cache_position +class MistralModel(MistralModel): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = ( + attention_mask[:, -1].sum().item() != input_tensor.size()[0] + ) + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + +# TRACING: Must use MistralModel with wrapped function +class MistralForCausalLM(MistralForCausalLM): + def __init__(self, config): + super(MistralPreTrainedModel, self).__init__(config) + # TRACING: Must use MistralModel with wrapped function + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + +# TRACING: Must use MistralModel with wrapped function +class MistralForSequenceClassification(MistralForSequenceClassification): + def __init__(self, config): + super(MistralPreTrainedModel, self).__init__(config) + self.num_labels = config.num_labels + # TRACING: Must use MistralModel with wrapped function + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + +# TRACING: Must use MistralModel with wrapped function +class MistralForTokenClassification(MistralForTokenClassification): + def __init__(self, config): + super(MistralPreTrainedModel, self).__init__(config) + self.num_labels = config.num_labels + # TRACING: Must use MistralModel with wrapped function + self.model = MistralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + +# TRACING: Must use MistralModel with wrapped function +class MistralForQuestionAnswering(MistralForQuestionAnswering): + def __init__(self, config): + super(MistralPreTrainedModel, self).__init__(config) + # TRACING: Must use MistralModel with wrapped function + self.model = MistralModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py new file mode 100644 index 000000000..8b65b179c --- /dev/null +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -0,0 +1,161 @@ +# flake8: noqa +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# vllm-project: no copyright +"""PyTorch Mllama model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import ( + add_start_docstrings, + logging, +) + +# TRACING: imports +from transformers.models.mllama.modeling_mllama import ( + MLLAMA_START_DOCSTRING, + MllamaForConditionalGeneration, +) + +logger = logging.get_logger(__name__) + + +# TRACING: This function is not traceable +@torch.fx.wrap +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +# TRACING: needs to use wrapped _prepare_cross_attention_mask +@add_start_docstrings( + """The Mllama model which consists of a vision encoder and a language model.""", + MLLAMA_START_DOCSTRING, +) +class MllamaForConditionalGeneration(MllamaForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + # TRACING: use wrapped function + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py new file mode 100644 index 000000000..b2dc7c651 --- /dev/null +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -0,0 +1,48 @@ +import torch + +__all__ = [ + "mllama_data_collator", + "pixtral_data_collator", + "llava_data_collator", + "qwen2_vl_data_collator", +] + + +def mllama_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), + "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), + "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), + } + + +def pixtral_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], + } + + +def llava_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + } + + +def qwen2_vl_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), + } diff --git a/tests/llmcompressor/modifiers/calibration/__init__.py b/tests/llmcompressor/modifiers/calibration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py deleted file mode 100644 index 203d1fe03..000000000 --- a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py +++ /dev/null @@ -1,41 +0,0 @@ -from collections import OrderedDict - -import torch -from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config -from compressed_tensors.quantization.quant_config import QuantizationConfig -from compressed_tensors.quantization.quant_scheme import preset_name_to_scheme -from loguru import logger - -from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper - - -def test_ignore(): - model = torch.nn.Sequential( - OrderedDict( - [ - ("first_layer", torch.nn.Linear(2, 3)), - ("second_layer", torch.nn.Linear(3, 5)), - ] - ) - ) - - config = QuantizationConfig( - config_groups={"group_0": preset_name_to_scheme("W8A8", targets=["Linear"])}, - ignore=["first_layer"], - ) - apply_quantization_config(model, config) - - messages = [] - logger.add(lambda m: messages.append(m)) - - with torch.no_grad(): - first_compressor = GPTQWrapper("first_layer", model.first_layer) - first_compressor.add_batch(torch.ones(2), None) - first_compressor.compress() - - second_compressor = GPTQWrapper("second_layer", model.second_layer) - second_compressor.add_batch(torch.ones(3), None) - second_compressor.compress() - - assert sum("Skipping unquantized layer first_layer" in m for m in messages) == 1 - assert sum("Skipping unquantized layer second_layer" in m for m in messages) == 0 diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py new file mode 100644 index 000000000..71a72eb25 --- /dev/null +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass + +import pytest +import torch +from torch.utils.data import DataLoader, StackDataset + +from llmcompressor.pipelines.cache import IntermediatesCache, IntermediateValue + + +@pytest.fixture +def sample_dataloader(): + # Create sample input tensors + input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]], dtype=torch.bool) + + # Create dataset and dataloader + dataset = StackDataset(input_ids=input_ids, attention_mask=attention_mask) + return DataLoader(dataset, batch_size=2) + + +@pytest.fixture +def sample_cache(sample_dataloader): + return IntermediatesCache.from_dataloader( + dataloader=sample_dataloader, + model_device=torch.device("cpu"), + mask_padding=True, + offload_device=torch.device("cpu"), + ) + + +def test_initialization(sample_dataloader): + cache = IntermediatesCache.from_dataloader( + dataloader=sample_dataloader, + model_device=torch.device("cpu"), + mask_padding=True, + ) + + assert isinstance(cache, IntermediatesCache) + assert len(cache.batch_intermediates) > 0 + assert isinstance(cache.batch_intermediates[0], dict) + + +def test_fetch_inputs(sample_cache): + fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"]) + + assert isinstance(fetched, dict) + assert "input_ids" in fetched + assert "attention_mask" in fetched + assert isinstance(fetched["input_ids"], torch.Tensor) + assert isinstance(fetched["attention_mask"], torch.Tensor) + + +def test_update_intermediates(sample_cache): + new_outputs = { + "hidden_states": torch.randn(2, 4, 768), + "logits": torch.randn(2, 4, 1000), + } + + sample_cache.update(0, new_outputs) + + # Verify the updates were stored + assert "hidden_states" in sample_cache.batch_intermediates[0] + assert "logits" in sample_cache.batch_intermediates[0] + + +def test_delete_intermediates(sample_cache): + # First add some intermediates + new_outputs = { + "hidden_states": torch.randn(2, 4, 768), + "logits": torch.randn(2, 4, 1000), + } + sample_cache.update(0, new_outputs) + + # Then delete them + sample_cache.delete(0, ["hidden_states"]) + + assert "hidden_states" not in sample_cache.batch_intermediates[0] + assert "logits" in sample_cache.batch_intermediates[0] + + +def test_mask_padding(): + input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]]) + + masked = IntermediatesCache._mask_padding(input_ids, attention_mask) + + # Check if padding tokens are properly masked + expected = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) + assert torch.equal(masked, expected) + + +def test_offload_and_onload_tensor(): + cache = IntermediatesCache([], torch.device("cpu")) + + # Test tensor offloading + original_tensor = torch.randn(2, 3).to("cpu") + offloaded = cache._offload_value(original_tensor) + + assert isinstance(offloaded, IntermediateValue) + assert isinstance(offloaded.value, torch.Tensor) + assert offloaded.device == original_tensor.device + + # Test tensor onloading + onloaded = cache._onload_value(offloaded) + assert torch.equal(onloaded, original_tensor) + + +@dataclass +class SampleDataclass: + a: torch.Tensor + b: int + + +def test_offload_and_onload_dataclass(): + cache = IntermediatesCache([], torch.device("cpu")) + + # Create a sample dataclass instance + sample_data = SampleDataclass(a=torch.randn(2, 3), b=42) + + # Test dataclass offloading + offloaded = cache._offload_value(sample_data) + assert isinstance(offloaded, IntermediateValue) + assert isinstance(offloaded.value, SampleDataclass) + assert isinstance(offloaded.value.a, IntermediateValue) + assert isinstance(offloaded.value.b, IntermediateValue) + + # Test dataclass onloading + onloaded = cache._onload_value(offloaded) + assert onloaded == sample_data + + +def test_4d_attention_mask(): + input_ids = torch.tensor([[1, 2, 3, 0]]) + attention_mask = torch.ones(1, 1, 1, 4) # 4D attention mask + + masked = IntermediatesCache._mask_padding(input_ids, attention_mask) + + # Check if the function handles 4D attention mask properly + expected = torch.tensor([[1, 2, 3, 0]]) + assert torch.equal(masked, expected) + + +def test_device_handling(sample_dataloader): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + cuda_device = torch.device("cuda") + cpu_device = torch.device("cpu") + + # Create a cache with GPU as model device and CPU as offload device + cache = IntermediatesCache.from_dataloader( + dataloader=sample_dataloader, + model_device=cuda_device, + offload_device=cpu_device, + ) + + # Add some GPU tensors + new_outputs = {"hidden_states": torch.randn(2, 3).to(cuda_device)} + cache.update(0, new_outputs) + + # Verify tensors are offloaded to CPU + assert cache.batch_intermediates[0]["hidden_states"].value.device.type == "cpu" + + # Verify tensors are loaded back to GPU when fetched + fetched = cache.fetch(0, ["hidden_states"]) + assert fetched["hidden_states"].device.type == "cuda" diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 1a229a6aa..069869436 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -71,15 +71,15 @@ def test_create_default_quant_modifier(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - modifier.quantization_modifier_.create_init_config() + assert isinstance(modifier._quantization_modifier, QuantizationModifier) + modifier._quantization_modifier.create_init_config() default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ + should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ default_config_group_name ] assert should_be_default_quant_scheme.input_activations is None @@ -108,7 +108,7 @@ def test_set_quant_if_modifer_already_exists(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert not modifier.quantization_modifier_ + assert not modifier._quantization_modifier modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in @@ -145,14 +145,14 @@ def test_set_quant_in_gptq(self): kwargs = dict(block_size=128, quantize=self.quant_config) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) - dict_scheme = dict(modifier.quantization_modifier_.config_groups) + dict_scheme = dict(modifier._quantization_modifier.config_groups) self._check_config( dict(dict_scheme["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"],