diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 88be0fe88..c3c14a2d5 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,6 +1,10 @@ -from .llava import TracableLlavaForConditionalGeneration -from .mistral import TracableMistralForCausalLM -from .mllama import TracableMllamaForConditionalGeneration +from .llava import ( + LlavaForConditionalGeneration as TracableLlavaForConditionalGeneration, +) +from .mistral import MistralForCausalLM as TracableMistralForCausalLM +from .mllama import ( + MllamaForConditionalGeneration as TracableMllamaForConditionalGeneration, +) __all__ = [ "TracableLlavaForConditionalGeneration", diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index 2a80d2efb..7a71f2564 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -29,10 +29,10 @@ ) from transformers.models.mistral.configuration_mistral import MistralConfig -from .mistral import TracableMistralForCausalLM +from .mistral import MistralForCausalLM as TracableMistralForCausalLM -class TracableLlavaForConditionalGeneration(LlavaForConditionalGeneration): +class LlavaForConditionalGeneration(LlavaForConditionalGeneration): def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py index bbfa9d319..7a63099c3 100644 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -1090,7 +1090,7 @@ def _update_causal_mask( return causal_mask -class TracableMistralForCausalLM(MistralPreTrainedModel, GenerationMixin): +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py index 955f7c270..512ba4227 100644 --- a/src/llmcompressor/transformers/tracing/mllama.py +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -2279,7 +2279,7 @@ def forward( """The Mllama model which consists of a vision encoder and a language model.""", MLLAMA_START_DOCSTRING, ) -class TracableMllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): +class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): _supports_quantized_cache = ( False # quant cache not supported in encoder-decoder setting )