From 490f7b822e6854a8f3eb4ffbe7bd9821c83e115e Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Wed, 2 Apr 2025 07:42:11 +0000 Subject: [PATCH 1/3] Added changes to ensure mxint8 compilations of VLMs work. Modified modelling files of InternVL and Llava to have 'vision_embeds' as the name of the image_embeddings. Modified modeling_auto file to incorporate mxint8 modifications for VLMs. LIMITATIONS: It is expected that the Processor of a model always gives vision components in 'float16'. Signed-off-by: quic-dhirajku Signed-off-by: Dhiraj Kumar Sah --- .../models/internvl/modeling_internvl.py | 26 +-- .../models/llava/modeling_llava.py | 30 ++-- .../transformers/models/modeling_auto.py | 163 +++++------------- 3 files changed, 70 insertions(+), 149 deletions(-) diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 8ab178e2e..85ff5e96f 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -20,8 +20,8 @@ def __init__(self, model): self.model = model def forward(self, pixel_values): - vit_embeds = self.model.extract_feature(pixel_values) - return vit_embeds + vision_embeds = self.model.extract_feature(pixel_values) + return vision_embeds class QEffInternDecoderWrapper(nn.Module): @@ -31,7 +31,7 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vit_embeds, position_ids, past_key_values): + def forward(self, input_ids, vision_embeds, position_ids, past_key_values): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -39,13 +39,13 @@ def forward(self, input_ids, vit_embeds, position_ids, past_key_values): selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) - return outputs.logits, vit_embeds, outputs.past_key_values + return outputs.logits, vision_embeds, outputs.past_key_values class QEffInternVLModel(nn.Module): @@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vit_embeds"] = {0: "num_patches"} + lang_dynamic_axes["vision_embeds"] = {0: "num_patches"} vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"} pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} @@ -139,7 +139,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): return dynamic_axes def get_output_names(self, kv_offload: bool = False): - vision_output_names = ["vit_embeds"] + vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: @@ -147,7 +147,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: - lang_output_names.insert(1, "vit_embeds_RetainedState") + lang_output_names.insert(1, "vision_embeds_RetainedState") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: @@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - inputs_shapes["vit_embeds"] = ( + inputs_shapes["vision_embeds"] = ( constants.INTERN_NUM_PATCHES, constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, @@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) - lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False): inputs["vision"] = vision_inputs inputs["lang"] = lang_inputs else: - lang_inputs.pop("vit_embeds") + lang_inputs.pop("vision_embeds") inputs = {**vision_inputs, **lang_inputs} return inputs def forward(self, input_ids, pixel_values, position_ids, past_key_values): input_embeds = self.language_model.get_input_embeddings()(input_ids) - vit_embeds = self.extract_feature(pixel_values) + vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) image_input_ids = input_ids.reshape(B * N) selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) outputs = self.language_model( diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 4ce9f087e..d99d8dfc1 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -38,9 +38,9 @@ def forward(self, pixel_values): selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.model.config.vision_feature_select_strategy}") - image_features = self.model.multi_modal_projector(selected_image_feature) + vision_embeds = self.model.multi_modal_projector(selected_image_feature) - return image_features + return vision_embeds class QEFFLlavaDecoderWrapper(nn.Module): @@ -50,21 +50,21 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, image_features, position_ids, past_key_values): + def forward(self, input_ids, vision_embeds, position_ids, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(mask.shape[0]).view(-1, 1) - image_features_expanded = image_features[indices0, indices1] - inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + vision_embeds_expanded = vision_embeds[indices0, indices1] + inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, ) - return outputs.logits, image_features, outputs.past_key_values + return outputs.logits, vision_embeds, outputs.past_key_values class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration): @@ -86,14 +86,14 @@ def forward(self, input_ids, position_ids, pixel_values, past_key_values): selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") - image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + vision_embeds = self.multi_modal_projector(selected_image_feature) + vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(mask.shape[0]).view(-1, 1) - image_features_expanded = image_features[indices0, indices1] - image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + vision_embeds_expanded = vision_embeds[indices0, indices1] + image_inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds) # *where to skip image encoder for decode* inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) outputs = self.language_model( @@ -118,7 +118,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): } lang_inputs = { "input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64), - "image_features": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32), + "vision_embeds": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32), "attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64), } lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1) @@ -137,7 +137,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): inputs["vision"] = vision_inputs inputs["lang"] = lang_inputs else: - lang_inputs.pop("image_features") + lang_inputs.pop("vision_embeds") inputs = {**vision_inputs, **lang_inputs} return inputs @@ -218,7 +218,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): return dynamic_axes def get_output_names(self, kv_offload: bool = False): - vision_output_names = ["image_features"] + vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: @@ -226,7 +226,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: - lang_output_names.insert(1, "image_features_RetainedState") + lang_output_names.insert(1, "vision_embeds_RetainedState") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0531af7b8..34f34e5a9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -584,7 +584,6 @@ def export( ) self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir) - return self.onnx_path def compile( self, @@ -638,9 +637,12 @@ def compile( custom_io_vision = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" - custom_io_vision["pixel_values"] = kv_cache_dtype + custom_io_vision["pixel_values"] = "float16" for output_name in output_names["vision"]: - custom_io_vision[output_name] = kv_cache_dtype + if output_name.startswith("past_"): + custom_io_vision[output_name] = kv_cache_dtype + else: + custom_io_vision[output_name] = "float16" if vision_onnx_path: self.vision_model.onnx_path = vision_onnx_path @@ -670,12 +672,14 @@ def compile( # Inputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + custom_io_lang[output_name[: -len("_RetainedState")]] = ( + "float16" if "vision_embeds" in output_name else kv_cache_dtype + ) # outputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name] = kv_cache_dtype + custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype self.lang_model._compile( compile_dir, @@ -912,7 +916,7 @@ def export( inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) def compile( self, @@ -964,12 +968,14 @@ def compile( # inputs for input_name in output_names: if input_name.endswith("_RetainedState"): - custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype + custom_io[input_name[: -len("_RetainedState")]] = ( + "float16" if "pixel_values" in input_name else kv_cache_dtype + ) # outputs for output_name in output_names: if output_name.endswith("_RetainedState"): - custom_io[output_name] = kv_cache_dtype + custom_io[output_name] = "float16" if "pixel_values" in output_name else kv_cache_dtype self._compile( onnx_path, @@ -1164,69 +1170,9 @@ def get_model_config(self) -> dict: class QEFFAutoModelForImageTextToText: """ - The QEFFAutoModelForImageTextToText class is used to work with multimodal language models from the HuggingFace hub. - While you can initialize the class directly, it's best to use the ``from_pretrained`` method for this purpose. This class supports both single and dual QPC approaches. + A factory class for creating QEFFAutoModelForImageTextToText instances with for single and Dual QPC approach Attributes: _hf_auto_class (class): The Hugging Face AutoModel class for ImageTextToText models. - - ``Mandatory`` Args: - :pretrained_model_name_or_path (str): Model card name from HuggingFace or local path to model directory. - - ``Optional`` Args: - :kv_offload (bool): Flag to toggle between single and dual QPC approaches. If set to False, the Single QPC approach will be used; otherwise, the dual QPC approach will be applied. Defaults to True. - - .. code-block:: python - import requests - from PIL import Image - from transformers import AutoProcessor, TextStreamer - - from QEfficient import QEFFAutoModelForImageTextToText - - # Add HuggingFace Token to access the model - HF_TOKEN = "" - model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" - query = "Describe this image." - image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" - - ## STEP - 1 Load the Processor and Model, and kv_offload=True/False for dual and single qpc - processor = AutoProcessor.from_pretrained(model_name, token=token) - model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, token=token, attn_implementation="eager", kv_offload=False) - - ## STEP - 2 Export & Compile the Model - model.compile( - prefill_seq_len=32, - ctx_len=512, - img_size=560, - num_cores=16, - num_devices=1, - mxfp6_matmul=False, - ) - - ## STEP - 3 Load and process the inputs for Inference - image = Image.open(requests.get(image_url, stream=True).raw) - messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": query}, - ], - } - ] - input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] - inputs = processor( - text=input_text, - images=image, - return_tensors="pt", - add_special_tokens=False, - padding="max_length", - max_length=prefill_seq_len, - ) - - ## STEP - 4 Run Inference on the compiled model - streamer = TextStreamer(processor.tokenizer) - model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) - """ _hf_auto_class = AutoModelForImageTextToText @@ -1273,6 +1219,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): :model (nn.Module): PyTorch model :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :enable_qnn (bool): Enables QNN Compilation path for the model. .. code-block:: python @@ -1303,6 +1250,7 @@ def __init__( model: nn.Module, continuous_batching: bool = False, is_tlm: bool = False, + enable_qnn: bool = False, **kwargs, ): model_class_name = model.__class__.__name__ @@ -1334,6 +1282,8 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model) self.is_tlm = is_tlm + self.enable_qnn = enable_qnn + @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -1342,12 +1292,18 @@ def model_name(self) -> str: return mname def __repr__(self) -> str: - return self.__class__.__name__ + "\n" + self.model.__repr__() + return self.__class__.__name__ + "\n" + self.model.__repr__ @classmethod @with_replaced_quantizers def from_pretrained( - cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + is_tlm: bool = False, + enable_qnn: bool = False, + *args, + **kwargs, ): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. @@ -1358,6 +1314,7 @@ def from_pretrained( :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :enable_qnn (bool): Enables QNN Compilation path for the model. :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. .. code-block:: python @@ -1391,6 +1348,7 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) # This is support models that should be classified to in a different auto class but transformers load them via this class @@ -1400,7 +1358,7 @@ def from_pretrained( model, kv_offload=kv_offload ) - return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching) + return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching, enable_qnn=enable_qnn) @property def model_hash(self) -> str: @@ -1780,26 +1738,20 @@ def export(self, export_dir: Optional[str] = None) -> str: inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) def compile( self, onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, *, - prefill_seq_len: Optional[int] = 1, - encoder_ctx_len: Optional[int] = None, - ctx_len: int = 150, - full_batch_size: Optional[int] = None, - kv_cache_batch_size: Optional[int] = None, + encoder_ctx_len: int = 1500, + decoder_ctx_len: int = 150, + feature_len: int = 3000, batch_size: int = 1, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, - mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, - enable_qnn: bool = False, - qnn_config: Optional[str] = None, **compiler_options, ) -> str: """ @@ -1810,41 +1762,19 @@ def compile( ``Optional`` Args: :onnx_path (str, optional): Path to pre-exported onnx model. :compile_dir (str, optional): Path for saving the qpc generated. - :encoder_ctx_len (int, optional): The maximum length of context for encoder, based on the AutoProcessor output. ``Defaults to checking config, if None in config then 1500`` - :ctx_len (int, optional): The maximum length of context to keep for decoding. ``Defaults to 150``. + :seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``. :batch_size (int, optional): Batch size. ``Defaults to 1``. :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. - - Other args are not yet implemented for AutoModelForSpeechSeq2Seq + :allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.`` Returns: :str: Path of the compiled ``qpc`` package. """ - specializations, compiler_options = self.model.get_specializations( - batch_size, - encoder_ctx_len, - ctx_len, - **compiler_options, - ) - - if full_batch_size: - logger.warning("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq") + specializations = self.model.get_specializations(batch_size, encoder_ctx_len, decoder_ctx_len, feature_len) - if kv_cache_batch_size: - logger.warning("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq") - - if mxint8_kv_cache: - logger.warning("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq") - - if num_speculative_tokens: - logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq") - - if enable_qnn or qnn_config: - logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq") - - return self._compile( + self._compile( onnx_path, compile_dir, compile_only=True, @@ -1862,6 +1792,7 @@ def generate( inputs: torch.Tensor, generation_len: int, streamer: Optional[TextStreamer] = None, + enable_debug_logs: bool = False, device_ids: List[int] = None, ) -> Union[torch.Tensor, np.ndarray]: """ @@ -1870,8 +1801,9 @@ def generate( ``Mandatory`` Args: :processor: autoprocessor to process inputs and decode logits - :inputs (torch.Tensor): inputs to run the execution. + :inputs (np.ndarray): inputs to run the execution. :generation_len (int): length upto which to generate + :sample_rate (int): sampling rate at which input audio is stored in inputs (needed for processor) :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model Returns: :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. @@ -1882,20 +1814,9 @@ def generate( inputs = self.auto_correct_inputs(inputs) if self.qpc_session is None: - self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids, enable_debug_logs=enable_debug_logs) self.batch_size = self.qpc_session.bindings[0].dims[0] - inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32) - - # add start token id and initial position ids to inputs - seq_len = 1 - inputs["decoder_input_ids"] = ( - torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id - ).numpy() - inputs["decoder_position_ids"] = ( - torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy() - ) - self.qpc_session.skip_buffers( [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] ) From dc3b639b4255d9f063d9d6d6363104494b20f67f Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Thu, 17 Apr 2025 12:08:19 +0000 Subject: [PATCH 2/3] Rebased to update transformers version and addressed comments to edit out older qnn based changes. Signed-off-by: quic-dhirajku Signed-off-by: Dhiraj Kumar Sah --- .../transformers/models/modeling_auto.py | 144 ++++++++++++++---- 1 file changed, 115 insertions(+), 29 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 34f34e5a9..f0c935154 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -584,6 +584,7 @@ def export( ) self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir) + return self.onnx_path def compile( self, @@ -916,7 +917,7 @@ def export( inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) def compile( self, @@ -1170,9 +1171,69 @@ def get_model_config(self) -> dict: class QEFFAutoModelForImageTextToText: """ - A factory class for creating QEFFAutoModelForImageTextToText instances with for single and Dual QPC approach + The QEFFAutoModelForImageTextToText class is used to work with multimodal language models from the HuggingFace hub. + While you can initialize the class directly, it's best to use the ``from_pretrained`` method for this purpose. This class supports both single and dual QPC approaches. Attributes: _hf_auto_class (class): The Hugging Face AutoModel class for ImageTextToText models. + + ``Mandatory`` Args: + :pretrained_model_name_or_path (str): Model card name from HuggingFace or local path to model directory. + + ``Optional`` Args: + :kv_offload (bool): Flag to toggle between single and dual QPC approaches. If set to False, the Single QPC approach will be used; otherwise, the dual QPC approach will be applied. Defaults to True. + + .. code-block:: python + import requests + from PIL import Image + from transformers import AutoProcessor, TextStreamer + + from QEfficient import QEFFAutoModelForImageTextToText + + # Add HuggingFace Token to access the model + HF_TOKEN = "" + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + query = "Describe this image." + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + + ## STEP - 1 Load the Processor and Model, and kv_offload=True/False for dual and single qpc + processor = AutoProcessor.from_pretrained(model_name, token=token) + model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, token=token, attn_implementation="eager", kv_offload=False) + + ## STEP - 2 Export & Compile the Model + model.compile( + prefill_seq_len=32, + ctx_len=512, + img_size=560, + num_cores=16, + num_devices=1, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP - 4 Run Inference on the compiled model + streamer = TextStreamer(processor.tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + """ _hf_auto_class = AutoModelForImageTextToText @@ -1219,7 +1280,6 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): :model (nn.Module): PyTorch model :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. - :enable_qnn (bool): Enables QNN Compilation path for the model. .. code-block:: python @@ -1250,7 +1310,6 @@ def __init__( model: nn.Module, continuous_batching: bool = False, is_tlm: bool = False, - enable_qnn: bool = False, **kwargs, ): model_class_name = model.__class__.__name__ @@ -1282,8 +1341,6 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model) self.is_tlm = is_tlm - self.enable_qnn = enable_qnn - @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -1292,18 +1349,12 @@ def model_name(self) -> str: return mname def __repr__(self) -> str: - return self.__class__.__name__ + "\n" + self.model.__repr__ + return self.__class__.__name__ + "\n" + self.model.__repr__() @classmethod @with_replaced_quantizers def from_pretrained( - cls, - pretrained_model_name_or_path, - continuous_batching: bool = False, - is_tlm: bool = False, - enable_qnn: bool = False, - *args, - **kwargs, + cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs ): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. @@ -1314,7 +1365,6 @@ def from_pretrained( :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. - :enable_qnn (bool): Enables QNN Compilation path for the model. :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. .. code-block:: python @@ -1348,7 +1398,6 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) # This is support models that should be classified to in a different auto class but transformers load them via this class @@ -1358,7 +1407,7 @@ def from_pretrained( model, kv_offload=kv_offload ) - return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching, enable_qnn=enable_qnn) + return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching) @property def model_hash(self) -> str: @@ -1738,20 +1787,26 @@ def export(self, export_dir: Optional[str] = None) -> str: inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) def compile( self, onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, *, - encoder_ctx_len: int = 1500, - decoder_ctx_len: int = 150, - feature_len: int = 3000, + prefill_seq_len: Optional[int] = 1, + encoder_ctx_len: Optional[int] = None, + ctx_len: int = 150, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, batch_size: int = 1, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, + enable_qnn: bool = False, + qnn_config: Optional[str] = None, **compiler_options, ) -> str: """ @@ -1762,19 +1817,41 @@ def compile( ``Optional`` Args: :onnx_path (str, optional): Path to pre-exported onnx model. :compile_dir (str, optional): Path for saving the qpc generated. - :seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``. + :encoder_ctx_len (int, optional): The maximum length of context for encoder, based on the AutoProcessor output. ``Defaults to checking config, if None in config then 1500`` + :ctx_len (int, optional): The maximum length of context to keep for decoding. ``Defaults to 150``. :batch_size (int, optional): Batch size. ``Defaults to 1``. :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. - :allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.`` + + Other args are not yet implemented for AutoModelForSpeechSeq2Seq Returns: :str: Path of the compiled ``qpc`` package. """ - specializations = self.model.get_specializations(batch_size, encoder_ctx_len, decoder_ctx_len, feature_len) + specializations, compiler_options = self.model.get_specializations( + batch_size, + encoder_ctx_len, + ctx_len, + **compiler_options, + ) - self._compile( + if full_batch_size: + logger.warning("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq") + + if kv_cache_batch_size: + logger.warning("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq") + + if mxint8_kv_cache: + logger.warning("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq") + + if num_speculative_tokens: + logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq") + + if enable_qnn or qnn_config: + logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq") + + return self._compile( onnx_path, compile_dir, compile_only=True, @@ -1792,7 +1869,6 @@ def generate( inputs: torch.Tensor, generation_len: int, streamer: Optional[TextStreamer] = None, - enable_debug_logs: bool = False, device_ids: List[int] = None, ) -> Union[torch.Tensor, np.ndarray]: """ @@ -1801,9 +1877,8 @@ def generate( ``Mandatory`` Args: :processor: autoprocessor to process inputs and decode logits - :inputs (np.ndarray): inputs to run the execution. + :inputs (torch.Tensor): inputs to run the execution. :generation_len (int): length upto which to generate - :sample_rate (int): sampling rate at which input audio is stored in inputs (needed for processor) :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model Returns: :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. @@ -1814,9 +1889,20 @@ def generate( inputs = self.auto_correct_inputs(inputs) if self.qpc_session is None: - self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids, enable_debug_logs=enable_debug_logs) + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) self.batch_size = self.qpc_session.bindings[0].dims[0] + inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32) + + # add start token id and initial position ids to inputs + seq_len = 1 + inputs["decoder_input_ids"] = ( + torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id + ).numpy() + inputs["decoder_position_ids"] = ( + torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy() + ) + self.qpc_session.skip_buffers( [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] ) From cbb15c932aa9d6c44204c337153c2ee157d18ce5 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Tue, 22 Apr 2025 08:57:54 +0000 Subject: [PATCH 3/3] Formatting issue resolved Signed-off-by: Dhiraj Kumar Sah --- QEfficient/transformers/models/modeling_auto.py | 4 ++-- tests/transformers/spd/test_pld_inference.py | 6 +++--- tests/transformers/spd/test_spd_inference.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f0c935154..5134d0042 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -674,8 +674,8 @@ def compile( for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): custom_io_lang[output_name[: -len("_RetainedState")]] = ( - "float16" if "vision_embeds" in output_name else kv_cache_dtype - ) + "float16" if "vision_embeds" in output_name else kv_cache_dtype + ) # outputs for output_name in output_names["lang"]: diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 88d86a9be..e5d472734 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -145,9 +145,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): """ num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len - assert ( - input_len_padded <= ctx_len - ), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + assert input_len_padded <= ctx_len, ( + "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + ) return input_len_padded diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 39dbd95cb..b78afdc38 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -75,9 +75,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): """ num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len - assert ( - input_len_padded <= ctx_len - ), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + assert input_len_padded <= ctx_len, ( + "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + ) return input_len_padded @@ -320,9 +320,9 @@ def test_spec_decode_inference( for prompt, generation in zip(prompts, batch_decode): print(f"{prompt=} {generation=}") # validation check - assert mean_num_accepted_tokens == float( - num_speculative_tokens + 1 - ), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}" + assert mean_num_accepted_tokens == float(num_speculative_tokens + 1), ( + f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}" + ) del target_model_session del draft_model_session generated_ids = np.asarray(generated_ids[0]).flatten()