diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 8d17d0c86..16b3e005f 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -156,3 +156,6 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: shutil.copy(self.configuration_file, config.path) if self.generation_utils_file: shutil.copy(self.generation_utils_file, config.path) + gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" + if gen_config.exists(): + shutil.copy(gen_config, config.path) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 7e4f4e117..d2c01af0d 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -118,6 +118,7 @@ def import_weight( class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: GPTModel _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + architecture: typing.ClassVar[str] """ Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) """ @@ -125,6 +126,7 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) @@ -320,8 +322,8 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig.dynamic_type_name, @@ -447,8 +449,8 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["LlamaForCausalLM"]), # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), @@ -499,8 +501,8 @@ class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]), ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm", @@ -545,8 +547,8 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MistralForCausalLM"]), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -569,8 +571,8 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]), ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk ), @@ -613,8 +615,8 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MTPLlamaForCausalLM"]), ConstantExportParamConverter( export_names=(("auto_map",),), export_value={ @@ -685,7 +687,12 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: return converters -class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler): +class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen2HuggingfaceCheckpointHandler): + """ + Handler for DiffusionDream Huggingface checkpoints. + Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin), + but overrides _create_config_converters to update architectures and auto_map. + """ from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream @@ -697,33 +704,8 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "DreamModel" return super()._create_config_converters() + [ - # From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" - ), - RopeScalingParamConverter( - fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), - ), - export_names=(("rope_scaling",),), - ), - IgnoreImportQwen2SlidingWindowParamsConverter(), - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]), ConstantExportParamConverter( export_names=(("auto_map",),), export_value={ @@ -733,26 +715,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - # From Qwen2HuggingfaceCheckpointHandler - transformer_config: TransformerConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, - MLPLayer2Converter, - ), - ] - -class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler): +class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlamaHuggingfaceCheckpointHandler): from fast_llm.models.gpt.external.diffusion_llama import ( configuration_diffusion_llama, @@ -768,12 +732,8 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ - # From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama - # TODO: Llama supports biases - ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]), ConstantExportParamConverter( export_names=(("auto_map",),), export_value={ @@ -789,24 +749,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: # ), ] - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - # From LlamaHuggingfaceCheckpointHandler - transformer_config: TransformerConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, - MLPLayer2Converter, - ), - ] - class AutoGPTHuggingfaceCheckpointHandler( AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC diff --git a/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py b/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py index 58bbd4883..daaf653c9 100644 --- a/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py +++ b/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +17,6 @@ from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging - logger = logging.get_logger(__name__) @@ -47,7 +45,7 @@ def __init__( max_window_layers=28, attention_dropout=0.0, mask_token_id=151666, - pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None, + pad_token_id=None, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=151643, **kwargs, ): self.vocab_size = vocab_size @@ -77,7 +75,7 @@ def __init__( if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - + super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, diff --git a/fast_llm/models/gpt/external/diffusion_dream/generation_config.json b/fast_llm/models/gpt/external/diffusion_dream/generation_config.json new file mode 100644 index 000000000..7f11d4b56 --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_dream/generation_config.json @@ -0,0 +1,7 @@ +{ + "_from_model_config": true, + "bos_token_id": 151643, + "eos_token_id": 151643, + "pad_token_id": 151643, + "transformers_version": "4.46.2" +} diff --git a/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py b/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py index 28bf0efe3..ba2f50e76 100644 --- a/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py +++ b/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,7 +19,7 @@ """LLaMA model configuration""" import math -from typing import Optional, Tuple +from typing import Optional from transformers.configuration_utils import PretrainedConfig from transformers.utils import is_torch_available, logging @@ -30,13 +29,14 @@ if is_torch_available(): import torch + # Update yarn implementation for RoPE (Taken from Llama but updated to use original_max_position_embeddings) def _compute_default_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, **rope_kwargs, -) -> Tuple["torch.Tensor", float]: +) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: @@ -72,9 +72,10 @@ def _compute_default_rope_parameters( inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) return inv_freq, attention_factor + def _compute_yarn_parameters( config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> Tuple["torch.Tensor", float]: +) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Please refer to the [original paper](https://arxiv.org/abs/2309.00071) @@ -101,7 +102,7 @@ def _compute_yarn_parameters( partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) - + # Apriel: Use original max_position_embeddings instead of max_position_embeddings max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings") factor = config.rope_scaling["factor"] @@ -152,6 +153,7 @@ def linear_ramp_factor(min, max, dim): return inv_freq, attention_factor + def _check_received_keys( rope_type: str, received_keys: set, @@ -159,7 +161,6 @@ def _check_received_keys( optional_keys: Optional[set] = None, ignore_keys: Optional[set] = None, ): - """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present if "type" in received_keys: @@ -189,6 +190,7 @@ def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Opt received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" @@ -218,6 +220,8 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) + + # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE # parameterizations, as long as the callable has the same signature. @@ -232,6 +236,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se "yarn": _validate_yarn_parameters, } + def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): """ Validate the RoPE config arguments, given a `PretrainedConfig` object @@ -250,6 +255,7 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" ) + class DiffusionLlamaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA @@ -397,7 +403,7 @@ def __init__( max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, - use_cache=False, # cache not implemented in diffusion + use_cache=False, # cache not implemented in diffusion pad_token_id=None, bos_token_id=1, eos_token_id=2, @@ -409,7 +415,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, head_dim=None, - # mask_token_id= TODO: add the mask_token_id we will be using, + mask_token_id=131072, # TODO: add the mask_token_id we will be using, **kwargs, ): self.vocab_size = vocab_size @@ -435,6 +441,8 @@ def __init__( self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.mask_token_id = mask_token_id + self.pad_token_id = pad_token_id # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: @@ -450,4 +458,5 @@ def __init__( ) # TODO: self.mask_token_id = mask_token_id -__all__ = ["LlamaConfig"] \ No newline at end of file + +__all__ = ["DiffusionLlamaConfig"] diff --git a/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py b/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py index b70dcf49a..6afdb6641 100644 --- a/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py +++ b/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# Copyright 2024 ServiceNow. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -130,12 +130,12 @@ def batch_sample_tokens( @dataclass -class DreamModelOutput(ModelOutput): +class SLAMModelOutput(ModelOutput): sequences: torch.LongTensor = None history: Optional[tuple[torch.FloatTensor]] = None -class DreamGenerationConfig(GenerationConfig): +class SLAMGenerationConfig(GenerationConfig): def __init__(self, **kwargs): self.temperature: float = kwargs.pop("temperature", 0.0) self.top_p: Optional[float] = kwargs.pop("top_p", None) @@ -186,7 +186,7 @@ def validate(self, is_init=False): pass -class DreamGenerationMixin(GenerationMixin): +class SLAMGenerationMixin(GenerationMixin): @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, @@ -247,7 +247,7 @@ def _prepare_generated_length( generation_config.max_length = generation_config.max_new_tokens + input_ids_length elif has_default_max_length: - if generation_config.max_length == DreamGenerationConfig().max_length: + if generation_config.max_length == SLAMGenerationConfig().max_length: generation_config.max_length = generation_config.max_length + input_ids_length max_position_embeddings = getattr(self.config, "max_position_embeddings", None) if max_position_embeddings is not None: @@ -256,8 +256,8 @@ def _prepare_generated_length( return generation_config def _prepare_generation_config( - self, generation_config: Optional[DreamGenerationConfig], **kwargs: dict - ) -> DreamGenerationConfig: + self, generation_config: Optional[SLAMGenerationConfig], **kwargs: dict + ) -> SLAMGenerationConfig: """ Prepares the base generation config, then applies any generation configuration options from kwargs. This function handles retrocompatibility with respect to configuration files. @@ -265,7 +265,7 @@ def _prepare_generation_config( # priority: `generation_config` argument > `model.generation_config` (the default generation config) using_model_generation_config = False if generation_config is None: - generation_config = DreamGenerationConfig.from_model_config(self.config) + generation_config = SLAMGenerationConfig.from_model_config(self.config) using_model_generation_config = True # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` @@ -289,7 +289,7 @@ def _prepare_generation_config( def _prepare_special_tokens( self, - generation_config: DreamGenerationConfig, + generation_config: SLAMGenerationConfig, device: Optional[Union[torch.device, str]] = None, ): """ @@ -338,9 +338,9 @@ def _tensor_or_none(token, device=None): def diffusion_generate( self, inputs: Optional[torch.Tensor] = None, - generation_config: Optional[DreamGenerationConfig] = None, + generation_config: Optional[SLAMGenerationConfig] = None, **kwargs, - ) -> Union[DreamModelOutput, torch.LongTensor]: + ) -> Union[SLAMModelOutput, torch.LongTensor]: # fix seed for reproducability torch.random.manual_seed - lm-eval is setting it torch.random.manual_seed(0) @@ -443,10 +443,10 @@ def _sample( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor], - generation_config: DreamGenerationConfig, + generation_config: SLAMGenerationConfig, generation_tokens_hook_func, generation_logits_hook_func, - ) -> Union[DreamModelOutput, torch.LongTensor]: + ) -> Union[SLAMModelOutput, torch.LongTensor]: # init values output_history = generation_config.output_history return_dict_in_generate = generation_config.return_dict_in_generate @@ -570,7 +570,7 @@ def _sample( histories.append(x.clone()) if return_dict_in_generate: - return DreamModelOutput( + return SLAMModelOutput( sequences=x, history=histories, ) @@ -582,12 +582,12 @@ def _sample_with_block( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor], - generation_config: DreamGenerationConfig, + generation_config: SLAMGenerationConfig, block_size: int, use_cache: bool, generation_tokens_hook_func, generation_logits_hook_func, - ) -> Union[DreamModelOutput, torch.LongTensor]: + ) -> Union[SLAMModelOutput, torch.LongTensor]: # init values output_history = generation_config.output_history return_dict_in_generate = generation_config.return_dict_in_generate @@ -790,7 +790,7 @@ def _sample_with_block( settled_length += block_size if return_dict_in_generate: - return DreamModelOutput( + return SLAMModelOutput( sequences=x, history=histories, ) @@ -802,11 +802,11 @@ def _sample_with_block_with_causal_kv( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor], - generation_config: DreamGenerationConfig, + generation_config: SLAMGenerationConfig, block_size: int, generation_tokens_hook_func, generation_logits_hook_func, - ) -> Union[DreamModelOutput, torch.LongTensor]: + ) -> Union[SLAMModelOutput, torch.LongTensor]: # init values output_history = generation_config.output_history return_dict_in_generate = generation_config.return_dict_in_generate @@ -1032,7 +1032,7 @@ def _sample_with_block_with_causal_kv( # print(f"settled_length: {settled_length} past_length: {past_length} x_input: {x_input.shape} past_key_values: {past_key_values.get_seq_length()}") if return_dict_in_generate: - return DreamModelOutput( + return SLAMModelOutput( sequences=x, history=histories, ) diff --git a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py b/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py index 7e0bd7974..5e613093e 100644 --- a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py +++ b/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2022 ServiceNow. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -46,7 +46,7 @@ ) from .configuration_diffusion_llama import ROPE_INIT_FUNCTIONS, DiffusionLlamaConfig -from .generation_utils import DreamGenerationConfig, DreamGenerationMixin +from .generation_utils import SLAMGenerationConfig, SLAMGenerationMixin if is_torch_flex_attn_available(): from flash_attn import flash_attn_with_kvcache @@ -210,236 +210,59 @@ def sdpa_attention_forward( attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, - # is_causal: Optional[bool] = None, + is_causal: Optional[bool] = None, **kwargs, ) -> tuple[torch.Tensor, None]: if hasattr(module, "num_key_value_groups"): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) - # Note: Updates from Dream - # causal_mask = attention_mask - # if attention_mask is not None: - # causal_mask = causal_mask[:, :, :, : key.shape[-2]] - # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions # Reference: https://github.com/pytorch/pytorch/issues/112577. + print("is_causal", is_causal) query = query.contiguous() key = key.contiguous() value = value.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # Note: Updates from Dream - # if is_causal is None: - # is_causal = causal_mask is None and query.shape[2] > 1 - - # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. - # We convert it to a bool for the SDPA kernel that only accepts bools. - # note: Updates from Dream - # if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): - # is_causal = is_causal.item() - attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, value, - # Note: Updates from Dream - # attn_mask=causal_mask, attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, dropout_p=dropout, scale=scaling, - # is_causal=is_causal, - is_causal=False, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, None - - -def sdpa_attention_from_dream_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - is_causal: Optional[bool] = False, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # print(f"query_states {query_states.shape} {query_states}") - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - # is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, - dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None -def flash_attention_from_dreamforward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 +def flash_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, is_causal: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DreamModel is using DreamFlashAttention, it does not support `output_attentions=True`. Falling back to the manual attention implementation, " - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # print(f"hidden_states: {hidden_states.shape} query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") - # print(f"position_ids {position_ids} {position_ids.shape}") - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # if query_states.device.type == "cuda" and attention_mask is not None: - # query_states = query_states.contiguous() - # key_states = key_states.contiguous() - # value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - # is_causal = True if causal_mask is None and q_len > 1 else False - - # attn_output_sdpa = torch.nn.functional.scaled_dot_product_attention( - # query_states, - # key_states, - # value_states, - # attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, - # dropout_p=self.attention_dropout if self.training else 0.0, - # is_causal=False, # hard coded - # ) - - # print(f"query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - # replacing with flash attention attn_output = flash_attn_with_kvcache( # q dim (batch_size, seqlen, nheads, headdim) - q=query_states.transpose(1, 2).contiguous(), + q=query.transpose(1, 2).contiguous(), k_cache=key_states.transpose(1, 2).contiguous(), v_cache=value_states.transpose(1, 2).contiguous(), - causal=is_causal, # hard coded - softmax_scale=1.0 / math.sqrt(self.head_dim), + causal=is_causal, + softmax_scale=1.0 / math.sqrt(module.head_dim), ) - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, None class LlamaAttention(nn.Module): @@ -453,7 +276,7 @@ def __init__(self, config: DiffusionLlamaConfig, layer_idx: int): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.is_causal = True + # self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -475,6 +298,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, + is_causal: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -484,6 +308,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + # print(f"hidden_states: {hidden_states.shape} query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") + # print(f"position_ids {position_ids} {position_ids.shape}") + # print(f"past_key_value {past_key_value.get_seq_length() if past_key_value is not None else None}") + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -494,6 +322,7 @@ def forward( attention_interface: Callable = eager_attention_forward + # print(f"self.config._attn_implementation {self.config._attn_implementation}") if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -501,9 +330,9 @@ def forward( 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) elif self.config._attn_implementation == "sdpa": - attention_interface = sdpa_attention_from_dream_forward - elif self.config._attn_implementation == "flash_attention": - attention_interface = flash_attention_from_dreamforward + attention_interface = sdpa_attention_forward + elif self.config._attn_implementation == "flash_attention_2": + attention_interface = flash_attention_forward else: raise ValueError(f"Unsupported attention implementation: {self.config._attn_implementation}") # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -516,12 +345,14 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + is_causal=is_causal, **kwargs, ) - + # print(f"attn_output {attn_output.shape}") attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # print(f"attn_output {attn_output.shape}") attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output, attn_weights, past_key_value # TODO: Update after transformer update: class LlamaDecoderLayer(GradientCheckpointingLayer): @@ -546,6 +377,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + is_causal: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -561,6 +393,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + is_causal=is_causal, **kwargs, ) hidden_states = residual + hidden_states @@ -581,7 +414,6 @@ def forward( # When use_cache is True, outputs will have length: # - 2 if output_attentions is False (hidden_states, present_key_value) # - 3 if output_attentions is True (hidden_states, self_attn_weights, present_key_value) - # print(f"DreamDecoderLayer: outputs {len(outputs)}") return outputs @@ -593,8 +425,8 @@ class DiffusionLlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False - _supports_sdpa = False # TODO: Enable sdpa + _supports_flash_attn_2 = True + _supports_sdpa = True # TODO: Enable sdpa _supports_flex_attn = False _supports_cache_class = True _supports_quantized_cache = False @@ -647,13 +479,13 @@ def from_pretrained( ) # NOTE(Lin): we need to override the generation config # because the generation config loaded in `from_pretrained` - # does not include all the attributes of DreamGenerationConfig + # does not include all the attributes of SLAMGenerationConfig resume_download = kwargs.get("resume_download", None) proxies = kwargs.get("proxies", None) subfolder = kwargs.get("subfolder", "") from_auto_class = kwargs.get("_from_auto", False) from_pipeline = kwargs.get("_from_pipeline", None) - _model.generation_config = DreamGenerationConfig.from_pretrained( + _model.generation_config = SLAMGenerationConfig.from_pretrained( pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, @@ -706,6 +538,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + is_causal: Optional[bool] = False, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -770,6 +603,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + is_causal=is_causal, **flash_attn_kwargs, ) @@ -831,136 +665,9 @@ class MaskedLMOutputWithPast(ModelOutput): attentions: Optional[tuple[torch.FloatTensor, ...]] = None past_key_values: Optional[tuple[Cache]] = None - # TODO: Update for diffusion with bi-directional attention (later block casual masking) - # def _update_causal_mask( - # self, - # attention_mask: Union[torch.Tensor, "BlockMask"], - # input_tensor: torch.Tensor, - # cache_position: torch.Tensor, - # past_key_values: Cache, - # output_attentions: bool = False, - # ): - # if self.config._attn_implementation == "flash_attention_2": - # if attention_mask is not None and (attention_mask == 0.0).any(): - # return attention_mask - # return None - # if self.config._attn_implementation == "flex_attention": - # if isinstance(attention_mask, torch.Tensor): - # attention_mask = make_flex_block_causal_mask(attention_mask) - # return attention_mask - - # # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # # to infer the attention mask. - # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - # using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - # if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - # if AttentionMaskConverter._ignore_causal_mask_sdpa( - # attention_mask, - # inputs_embeds=input_tensor, - # past_key_values_length=past_seen_tokens, - # is_training=self.training, - # ): - # return None - - # dtype = input_tensor.dtype - # sequence_length = input_tensor.shape[1] - # if using_compilable_cache: - # target_length = past_key_values.get_max_cache_shape() - # else: - # target_length = ( - # attention_mask.shape[-1] - # if isinstance(attention_mask, torch.Tensor) - # else past_seen_tokens + sequence_length + 1 - # ) - - # # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - # causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - # attention_mask, - # sequence_length=sequence_length, - # target_length=target_length, - # dtype=dtype, - # cache_position=cache_position, - # batch_size=input_tensor.shape[0], - # ) - - # if ( - # self.config._attn_implementation == "sdpa" - # and attention_mask is not None - # and attention_mask.device.type in ["cuda", "xpu", "npu"] - # and not output_attentions - # ): - # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # # Details: https://github.com/pytorch/pytorch/issues/110213 - # min_dtype = torch.finfo(dtype).min - # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - # return causal_mask - - # @staticmethod - # def _prepare_4d_causal_attention_mask_with_cache_position( - # attention_mask: torch.Tensor, - # sequence_length: int, - # target_length: int, - # dtype: torch.dtype, - # cache_position: torch.Tensor, - # batch_size: int, - # **kwargs, - # ): - # """ - # Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - # `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - # Args: - # attention_mask (`torch.Tensor`): - # A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - # `(batch_size, 1, query_length, key_value_length)`. - # sequence_length (`int`): - # The sequence length being processed. - # target_length (`int`): - # The target length: when generating with static cache, the mask should be as long as the static cache, - # to account for the 0 padding, the part of the cache that is not filled yet. - # dtype (`torch.dtype`): - # The dtype to use for the 4D attention mask. - # cache_position (`torch.Tensor`): - # Indices depicting the position of the input sequence tokens in the sequence. - # batch_size (`torch.Tensor`): - # Batch size. - # """ - # if attention_mask is not None and attention_mask.dim() == 4: - # # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - # causal_mask = attention_mask - # else: - # min_dtype = torch.finfo(dtype).min - # causal_mask = torch.full( - # (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - # ) - # if sequence_length != 1: - # causal_mask = torch.triu(causal_mask, diagonal=1) - # causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - # causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - # if attention_mask is not None: - # causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - # mask_length = attention_mask.shape[-1] - # padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - # causal_mask.device - # ) - # padding_mask = padding_mask == 0 - # causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - # padding_mask, min_dtype - # ) - - # return causal_mask - - -# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - # @auto_docstring -class DiffusionLlamaModel(DiffusionLlamaPreTrainedModel, DreamGenerationMixin): +class DiffusionLlamaModel(DiffusionLlamaPreTrainedModel, SLAMGenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} @@ -1049,7 +756,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - # is_casual=kwargs.get("is_casual", False), + is_casual=kwargs.get("is_casual", False), **kwargs, ) @@ -1077,350 +784,6 @@ def forward( ) -# @auto_docstring -# class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): -# _tied_weights_keys = ["lm_head.weight"] -# _tp_plan = {"lm_head": "colwise_rep"} -# _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - -# def __init__(self, config): -# super().__init__(config) -# self.model = LlamaModel(config) -# self.vocab_size = config.vocab_size -# self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - -# # Initialize weights and apply final processing -# self.post_init() - -# def get_input_embeddings(self): -# return self.model.embed_tokens - -# def set_input_embeddings(self, value): -# self.model.embed_tokens = value - -# def get_output_embeddings(self): -# return self.lm_head - -# def set_output_embeddings(self, new_embeddings): -# self.lm_head = new_embeddings - -# def set_decoder(self, decoder): -# self.model = decoder - -# def get_decoder(self): -# return self.model - -# @can_return_tuple -# @auto_docstring -# def forward( -# self, -# input_ids: Optional[torch.LongTensor] = None, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_values: Optional[Cache] = None, -# inputs_embeds: Optional[torch.FloatTensor] = None, -# labels: Optional[torch.LongTensor] = None, -# use_cache: Optional[bool] = None, -# output_attentions: Optional[bool] = None, -# output_hidden_states: Optional[bool] = None, -# cache_position: Optional[torch.LongTensor] = None, -# logits_to_keep: Union[int, torch.Tensor] = 0, -# **kwargs: Unpack[KwargsForCausalLM], -# ) -> CausalLMOutputWithPast: -# r""" -# labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): -# Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., -# config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored -# (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - -# Example: - -# ```python -# >>> from transformers import AutoTokenizer, LlamaForCausalLM - -# >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") -# >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - -# >>> prompt = "Hey, are you conscious? Can you talk to me?" -# >>> inputs = tokenizer(prompt, return_tensors="pt") - -# >>> # Generate -# >>> generate_ids = model.generate(inputs.input_ids, max_length=30) -# >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] -# "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." -# ```""" -# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions -# output_hidden_states = ( -# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states -# ) - -# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -# outputs: BaseModelOutputWithPast = self.model( -# input_ids=input_ids, -# attention_mask=attention_mask, -# position_ids=position_ids, -# past_key_values=past_key_values, -# inputs_embeds=inputs_embeds, -# use_cache=use_cache, -# output_attentions=output_attentions, -# output_hidden_states=output_hidden_states, -# cache_position=cache_position, -# **kwargs, -# ) - -# hidden_states = outputs.last_hidden_state -# # Only compute necessary logits, and do not upcast them to float if we are not computing the loss -# slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep -# logits = self.lm_head(hidden_states[:, slice_indices, :]) - -# loss = None -# if labels is not None: -# loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - -# return CausalLMOutputWithPast( -# loss=loss, -# logits=logits, -# past_key_values=outputs.past_key_values, -# hidden_states=outputs.hidden_states, -# attentions=outputs.attentions, -# ) - - -# @auto_docstring( -# custom_intro=""" -# The LLaMa Model transformer with a sequence classification head on top (linear layer). - -# [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models -# (e.g. GPT-2) do. - -# Since it does classification on the last token, it requires to know the position of the last token. If a -# `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If -# no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the -# padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in -# each row of the batch). -# """ -# ) -# class LlamaForSequenceClassification(LlamaPreTrainedModel): -# def __init__(self, config): -# super().__init__(config) -# self.num_labels = config.num_labels -# self.model = LlamaModel(config) -# self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - -# # Initialize weights and apply final processing -# self.post_init() - -# def get_input_embeddings(self): -# return self.model.embed_tokens - -# def set_input_embeddings(self, value): -# self.model.embed_tokens = value - -# @can_return_tuple -# @auto_docstring -# def forward( -# self, -# input_ids: Optional[torch.LongTensor] = None, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_values: Optional[Cache] = None, -# inputs_embeds: Optional[torch.FloatTensor] = None, -# labels: Optional[torch.LongTensor] = None, -# use_cache: Optional[bool] = None, -# output_attentions: Optional[bool] = None, -# output_hidden_states: Optional[bool] = None, -# ) -> SequenceClassifierOutputWithPast: -# r""" -# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): -# Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., -# config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If -# `config.num_labels > 1` a classification loss is computed (Cross-Entropy). -# """ - -# transformer_outputs: BaseModelOutputWithPast = self.model( -# input_ids, -# attention_mask=attention_mask, -# position_ids=position_ids, -# past_key_values=past_key_values, -# inputs_embeds=inputs_embeds, -# use_cache=use_cache, -# output_attentions=output_attentions, -# output_hidden_states=output_hidden_states, -# ) -# hidden_states = transformer_outputs.last_hidden_state -# logits = self.score(hidden_states) - -# if input_ids is not None: -# batch_size = input_ids.shape[0] -# else: -# batch_size = inputs_embeds.shape[0] - -# if self.config.pad_token_id is None and batch_size != 1: -# raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") -# if self.config.pad_token_id is None: -# last_non_pad_token = -1 -# elif input_ids is not None: -# # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id -# non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) -# token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) -# last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) -# else: -# last_non_pad_token = -1 -# logger.warning_once( -# f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " -# "unexpected if using padding tokens in conjunction with `inputs_embeds.`" -# ) - -# pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - -# loss = None -# if labels is not None: -# loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - -# return SequenceClassifierOutputWithPast( -# loss=loss, -# logits=pooled_logits, -# past_key_values=transformer_outputs.past_key_values, -# hidden_states=transformer_outputs.hidden_states, -# attentions=transformer_outputs.attentions, -# ) - - -# @auto_docstring -# class LlamaForQuestionAnswering(LlamaPreTrainedModel): -# base_model_prefix = "transformer" - -# # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama -# def __init__(self, config): -# super().__init__(config) -# self.transformer = LlamaModel(config) -# self.qa_outputs = nn.Linear(config.hidden_size, 2) - -# # Initialize weights and apply final processing -# self.post_init() - -# def get_input_embeddings(self): -# return self.transformer.embed_tokens - -# def set_input_embeddings(self, value): -# self.transformer.embed_tokens = value - -# @can_return_tuple -# @auto_docstring -# def forward( -# self, -# input_ids: Optional[torch.LongTensor] = None, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_values: Optional[Cache] = None, -# inputs_embeds: Optional[torch.FloatTensor] = None, -# start_positions: Optional[torch.LongTensor] = None, -# end_positions: Optional[torch.LongTensor] = None, -# output_attentions: Optional[bool] = None, -# output_hidden_states: Optional[bool] = None, -# **kwargs, -# ) -> QuestionAnsweringModelOutput: -# outputs: BaseModelOutputWithPast = self.transformer( -# input_ids, -# attention_mask=attention_mask, -# position_ids=position_ids, -# past_key_values=past_key_values, -# inputs_embeds=inputs_embeds, -# output_attentions=output_attentions, -# output_hidden_states=output_hidden_states, -# ) - -# sequence_output = outputs.last_hidden_state - -# logits = self.qa_outputs(sequence_output) -# start_logits, end_logits = logits.split(1, dim=-1) -# start_logits = start_logits.squeeze(-1).contiguous() -# end_logits = end_logits.squeeze(-1).contiguous() - -# loss = None -# if start_positions is not None and end_positions is not None: -# loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - -# return QuestionAnsweringModelOutput( -# loss=loss, -# start_logits=start_logits, -# end_logits=end_logits, -# hidden_states=outputs.hidden_states, -# attentions=outputs.attentions, -# ) - - -# @auto_docstring -# class LlamaForTokenClassification(LlamaPreTrainedModel): -# def __init__(self, config): -# super().__init__(config) -# self.num_labels = config.num_labels -# self.model = LlamaModel(config) -# if getattr(config, "classifier_dropout", None) is not None: -# classifier_dropout = config.classifier_dropout -# elif getattr(config, "hidden_dropout", None) is not None: -# classifier_dropout = config.hidden_dropout -# else: -# classifier_dropout = 0.1 -# self.dropout = nn.Dropout(classifier_dropout) -# self.score = nn.Linear(config.hidden_size, config.num_labels) - -# # Initialize weights and apply final processing -# self.post_init() - -# def get_input_embeddings(self): -# return self.model.embed_tokens - -# def set_input_embeddings(self, value): -# self.model.embed_tokens = value - -# @can_return_tuple -# @auto_docstring -# def forward( -# self, -# input_ids: Optional[torch.LongTensor] = None, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_values: Optional[Cache] = None, -# inputs_embeds: Optional[torch.FloatTensor] = None, -# labels: Optional[torch.LongTensor] = None, -# use_cache: Optional[bool] = None, -# output_attentions: Optional[bool] = None, -# output_hidden_states: Optional[bool] = None, -# ) -> TokenClassifierOutput: -# r""" -# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): -# Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., -# config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If -# `config.num_labels > 1` a classification loss is computed (Cross-Entropy). -# """ - -# outputs: BaseModelOutputWithPast = self.model( -# input_ids, -# attention_mask=attention_mask, -# position_ids=position_ids, -# past_key_values=past_key_values, -# inputs_embeds=inputs_embeds, -# use_cache=use_cache, -# output_attentions=output_attentions, -# output_hidden_states=output_hidden_states, -# ) -# sequence_output = outputs.last_hidden_state -# sequence_output = self.dropout(sequence_output) -# logits = self.score(sequence_output) - -# loss = None -# if labels is not None: -# loss = self.loss_function(logits, labels, self.config) - -# return TokenClassifierOutput( -# loss=loss, -# logits=logits, -# hidden_states=outputs.hidden_states, -# attentions=outputs.attentions, -# ) - - __all__ = [ "DiffusionLlamaModel", ]