diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f181ee5eb..3c0d1a495 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -38,6 +38,7 @@ CustomOpsTransform, KVCacheModuleMethodMapperTransform, KVCacheTransform, + SamplerTransform, SpDTransform, VlmKVOffloadTransform, VlmNoKVOffloadTransform, @@ -75,7 +76,7 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, *args, **kwargs): if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') @@ -85,7 +86,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = Fals kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model, is_tlm=is_tlm) + return cls(model, qaic_config=qaic_config) @property def model_name(self) -> str: @@ -1268,7 +1269,10 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): ``Mandatory`` Args: :model (nn.Module): PyTorch model :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. - :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :qaic_config (dict): Dictionary with the following keys: + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :include_sampler (bool): Enable/Disable sampling of next tokens during decode. + :return_pdfs (bool): Return probability distributions (logits/probs) or sampled next tokens. If `is_tlm`=True, then `return_pdfs`=True always. If `is_tlm`=False, then `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. .. code-block:: python @@ -1327,6 +1331,10 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed + # Sampling + self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) + self.include_sampler = transformed + @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -1355,7 +1363,10 @@ def from_pretrained( Args: :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. - :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :qaic_config (dict): Dictionary with the following keys: + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :include_sampler (bool): Enable/Disable sampling of next tokens during decode. + :return_pdfs (bool): Return probability distributions (logits/probs) or sampled next tokens. If `is_tlm`=True, then `return_pdfs`=True always. If `is_tlm`=False, then `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. .. code-block:: python @@ -1414,6 +1425,7 @@ def model_hash(self) -> str: mhash.update(to_hashable(self.model.config.to_diff_dict())) mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable({"include_sampler": self.include_sampler})) mhash.update(to_hashable(self._transform_names())) mhash = mhash.hexdigest()[:16] return mhash @@ -1457,7 +1469,13 @@ def export(self, export_dir: Optional[str] = None) -> str: 0: "full_batch_size" if self.continuous_batching else "batch_size", 2: "ctx_len", } - output_names = ["logits"] + output_names = [] + if self.include_sampler: + if self.model.return_pdfs: + output_names.append("probs") + output_names.append("next_tokens") + else: + output_names.append("logits") for i in range(self.num_layers): for kv in ["key", "value"]: @@ -1474,6 +1492,48 @@ def export(self, export_dir: Optional[str] = None) -> str: example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1) dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} + if self.include_sampler: + nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep + max_top_k_ids = constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS + + example_inputs["last_accepted_output_tokens"] = torch.randint(low=0, high=self.model.config.vocab_size, size=(bs, nlk)) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "num_logits_to_keep"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = torch.ones((bs, 1), dtype=torch.float) * 0.5 + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = torch.zeros((bs, 1), dtype=torch.float) + 0.5 + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) + dynamic_axes["temperatures"] = {0: "batch_size"} + + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.80 + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * 0.99 + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + return self._export( example_inputs, output_names, @@ -1488,12 +1548,14 @@ def build_prefill_specialization( batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, + max_top_k_ids: Optional[int] = None, ): spec = { "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "num_logits_to_keep": 1 if self.is_tlm else None, + "num_logits_to_keep": 1 if self.is_tlm or self.include_sampler else None, + "max_top_k_ids": max_top_k_ids if self.include_sampler else None, } if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -1511,6 +1573,7 @@ def build_decode_specialization( kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + max_top_k_ids: Optional[int] = None, ): if prefill_seq_len == 1 and not self.continuous_batching: return None # Avoid duplication with prefill @@ -1518,7 +1581,8 @@ def build_decode_specialization( "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, "ctx_len": ctx_len, - "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, + "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm or self.include_sampler else None, + "max_top_k_ids": max_top_k_ids if self.include_sampler else None, } if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -1600,12 +1664,23 @@ def compile( if prefill_only is None or prefill_only or prefill_seq_len == 1: specializations.append( self.build_prefill_specialization( - prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + max_top_k_ids=constants.Constants.MAX_TOP_K_IDS if self.include_sampler else None, ) ) if prefill_only is None or not prefill_only: decode_spec = self.build_decode_specialization( - prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + max_top_k_ids=constants.Constants.MAX_TOP_K_IDS if self.include_sampler else None, ) if decode_spec: specializations.append(decode_spec) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 333c734ba..1ec5dbb56 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -267,11 +267,11 @@ QEffWhisperPositionalEmbedding, ) from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward SPD_TARGET = "target" - class CustomOpsTransform(ModuleMappingTransform): _module_mapping = { GemmaRMSNorm: GemmaCustomRMSNormAIC, @@ -456,6 +456,40 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) - return model, transformed +class SamplerTransform: + """ + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # Llama + QEffLlamaForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + if qaic_config is None or (include_sampler := qaic_config.get("include_sampler")) is None: + return model, transformed + elif not include_sampler: + return model, transformed + elif (model_class := model.__class__) in cls._module_mapping: + model.forward = MethodType(sampler_forward, model) + model.return_pdfs = qaic_config.get("return_pdfs", False) + transformed = True + else: + raise NotImplementedError( + f"model class {model_class} does not yet support returning multiple logits to keep." + ) + return model, transformed + + class VlmKVOffloadTransform(ModuleMappingTransform): # supported architectures _module_mapping = { diff --git a/QEfficient/transformers/sampler/__init__.py b/QEfficient/transformers/sampler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py new file mode 100644 index 000000000..768374a41 --- /dev/null +++ b/QEfficient/transformers/sampler/sampler.py @@ -0,0 +1,325 @@ +from dataclasses import dataclass +import torch +import torch.nn.functional as F + +# from QEfficient.customop import CtxScatterFunc +from QEfficient.utils.constants import Constants +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput +from typing import List, Optional, Tuple, Union + + +@dataclass +class QEffCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + probs: torch.FloatTensor = None + next_tokens: torch.IntTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + past_repetition_penalty_buffer: Optional[torch.Tensor] = None + past_presence_penalty_buffer: Optional[torch.Tensor] = None + + +def filter_hidden_states( + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + num_logits_to_keep: Optional[int] = None, +) -> torch.Tensor: + """ + Filter hidden states based on whether this is a TLM SpD model + + ``Mandatory`` Args: + :hidden_states (torch.Tensor): Hidden states tensor. + :position_ids (torch.Tensor): Position ids tensor. + ``Optional`` Args: + :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model + + Returns: + :torch.Tensor: Filtered hidden states. + """ + batch_size = position_ids.size(0) + batch_indices = torch.arange(batch_size) + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if num_logits_to_keep is None: + # return the last logit + return hidden_states[batch_indices.view(-1, 1), logit_index] + + # gather approach + num_logits_to_keep = num_logits_to_keep.shape[0] + lower_idx = torch.where( + logit_index < num_logits_to_keep, + 0, + logit_index + 1 - num_logits_to_keep, + ).view( + -1, 1 + ) # shape: [bsz, 1] + spec_idx = torch.arange(num_logits_to_keep).view(1, -1) # shape: [1, k] + indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, k, 1] + indices = indices.repeat( + 1, 1, hidden_states.size(-1) + ) # shape: [bsz, ,k, d_model] + hidden_states = torch.gather( + hidden_states, dim=1, index=indices + ) # shape: [bsz, k, d_model] + return hidden_states + + +def sampler_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, + last_accepted_output_tokens: Optional[torch.Tensor] = None, # (batch_size, spec_length or less) + past_repetition_penalty_buffer: Optional[torch.Tensor] = None, + repetition_penalties: Optional[torch.Tensor] = None, + past_presence_penalty_buffer: Optional[torch.Tensor] = None, + presence_penalties: Optional[torch.Tensor] = None, + temperatures: Optional[torch.Tensor] = None, + top_ks: Optional[torch.Tensor] = None, + top_ps: Optional[torch.Tensor] = None, + min_ps: Optional[torch.Tensor] = None, + random_numbers: Optional[torch.Tensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + last_accepted_output_tokens (`torch.Tensor`, *optional*): + Output tokens accepted by the Speculative Decoding Draft Language Model. + + past_repetition_penalty_buffer (`torch.Tensor`, *optional*): + RetainedState buffer used as a mask to apply repetition penalty to the input + prompt and the output generated so far. + + repetition_penalties (`torch.Tensor`, *optional*): + Sampling parameter that penalizes new tokens based on whether they appear in the + prompt and the generated text so far. Values > 1 encourage the model to use + new tokens, while values < 1 encourage the model to repeat tokens. + + past_presence_penalty_buffer (`torch.Tensor`, *optional*): + RetainedState buffer used as a mask to apply presence penalty to the output + generated so far. + + presence_penalties (`torch.Tensor`, *optional*): + Sampling parameter that penalizes new tokens based on whether they appear in the + generated text so far. Values > 0 encourage the model to use new tokens, while + values < 0 encourage the model to repeat tokens. + + temperatures (`torch.Tensor`, *optional*): + Sampling parameter that controls the randomness of the sampling. Lower values + make the model more deterministic, while higher values make the model more + random. Zero means greedy sampling. + + top_ks (`torch.Tensor`, *optional*): + Sampling parameter that controls the number of top tokens to consider. + + top_ps (`torch.Tensor`, *optional*): + Sampling parameter that controls the cumulative probability of the top tokens to + consider. Must be in (0, 1]. Set to 1.0 to consider all tokens. + + min_ps (`torch.Tensor`, *optional*): + Sampling parameter that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. Must be in + [0, 1]. Set to 0.0 to disable this. + + random_numbers (`torch.Tensor`, *optional*): + Sampling parameter that represents the random seeds to use for random sampling. + Must be in [-1, 1]. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = filter_hidden_states( + outputs[0], position_ids, num_logits_to_keep + ) + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.config.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() # (batch_size, num_logits_to_keep aka spec_length, vocab_size) + + # Perform Sampling + batch_size, spec_length, vocab_size = logits.shape + + # Select relevant rows + batch_index_reshaped = batch_index.view(-1) + past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer[batch_index_reshaped] + past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped] + + logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D + + if input_ids.shape[1] > spec_length: # Prefill phase, initialize retained states + # TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx + # past_repetition_penalty_buffer_selected = CtxScatterFunc.apply(past_repetition_penalty_buffer_selected.unsqueeze(1), input_ids, 1).squeeze(1) + if position_ids[0, 0] == 0: + past_repetition_penalty_buffer_selected = torch.zeros(past_repetition_penalty_buffer_selected.shape, dtype=torch.bool) + past_presence_penalty_buffer_selected = torch.zeros(past_presence_penalty_buffer_selected.shape, dtype=torch.bool) + past_repetition_penalty_buffer_selected.scatter_(1, input_ids, 1) + + else: # Decode phase, update retained states + past_repetition_penalty_buffer_selected.scatter_(1, last_accepted_output_tokens, 1) + past_presence_penalty_buffer_selected.scatter_(1, last_accepted_output_tokens, 1) + # TODO: For frequency retain state, first gather and then scatter + + # Update relevant rows in original tensors + past_repetition_penalty_buffer[batch_index_reshaped] = past_repetition_penalty_buffer_selected + past_presence_penalty_buffer[batch_index_reshaped] = past_presence_penalty_buffer_selected + + # Greedy Sampling + greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1) + if (temperatures == 0).all() and self.return_pdfs == False: + return QEffCausalLMOutputWithPast( + loss=None, + probs=None, + next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + past_repetition_penalty_buffer=past_repetition_penalty_buffer, + past_presence_penalty_buffer=past_presence_penalty_buffer, + ) + + # Repetition Penalty + if (repetition_penalties != 1.).any(): + repetition_penalties = repetition_penalties.repeat(spec_length, vocab_size) # (batch_size, 1) -> (batch_size * spec_length, vocab_size) + past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size) + repetition_penalties[past_repetition_penalty_buffer_selected == 0] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties) + + # Presence Penalty + if (presence_penalties != 0.).any(): + presence_penalties = presence_penalties.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1) + past_presence_penalty_buffer_selected = past_presence_penalty_buffer_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size) + logits -= presence_penalties * past_presence_penalty_buffer_selected + + # TODO: Frequency Penalty + + # Temperature Scaling + temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1) + logits /= temperatures + + # Top K + # TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any() is False: skip but will need topk_values_asc and topk_indices_asc + topk_values, topk_indices = torch.topk(logits, k=Constants.MAX_TOP_K_IDS, dim=1) # (batch_size * spec_length, vocab_size) + topk_values_asc = topk_values.flip(dims=[1]) + topk_indices_asc = topk_indices.flip(dims=[1]) + top_ks[top_ks > Constants.MAX_TOP_K_IDS] = Constants.MAX_TOP_K_IDS # Clip k to max value + # True values in this mask indicate the positions of the non-top K values + topk_mask = torch.arange(topk_values_asc.shape[1]).unsqueeze(0) < (topk_values_asc.size(1) - top_ks.to(torch.long)).repeat(spec_length, 1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS) + topk_values_asc[topk_mask] = torch.finfo(torch.float16).min + + # Top P + # TODO (Optimization): if (top_ps != 1.).any() is False: skip but will need top_probs for Min P + top_probs = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS) + topk_probs_sum = torch.cumsum(top_probs, dim=1) + top_p_mask = topk_probs_sum <= 1 - top_ps.repeat(spec_length, 1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS) + top_p_mask[:, Constants.MAX_TOP_K_IDS - 1] = False + topk_values_asc[top_p_mask] = torch.finfo(torch.float16).min + + # Min P + if (min_ps != 0.).any(): + scaled_min_p = torch.mul(min_ps.repeat(spec_length, 1), top_probs[:, Constants.MAX_TOP_K_IDS - 1:]) # (batch_size * spec_length, 1) + min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS) + topk_values_asc[min_p_mask] = torch.finfo(torch.float16).min + + probs = None + if self.return_pdfs: + # Update the logits + logits.fill_(torch.finfo(torch.float16).min) + logits = logits.scatter(1, topk_indices_asc, topk_values_asc) # (batch_size * spec_length, vocab_size) + # Softmax + probs = torch.softmax(logits, dim=1).reshape(-1, spec_length, vocab_size) # (batch_size, spec_length, vocab_size) + + # Random Sampling + topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS) + gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick + y = topk_probs_asc + gumbel_noise + random_samples_indices = torch.argmax(y, dim=1, keepdim=True) + random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) + + # Sample the next tokens + next_tokens = torch.where(temperatures == 0, greedy_samples, random_samples).reshape(-1, spec_length, 1) # (batch_size, spec_length, 1) + + return QEffCausalLMOutputWithPast( + loss=None, + probs=probs, + next_tokens=next_tokens, # Return sampled next tokens instead of logits + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + past_repetition_penalty_buffer=past_repetition_penalty_buffer, + past_presence_penalty_buffer=past_presence_penalty_buffer, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index c8f74907a..055368bb0 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -56,10 +56,11 @@ def get_models_dir(): QEFF_MODELS_DIR = get_models_dir() -ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 +ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 2 ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep +ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_OPSET = 13 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] @@ -97,6 +98,7 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 + MAX_TOP_K_IDS = 512 SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. SDK_PLATFORM_XML = ( "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version.