Skip to content

Commit

Permalink
rename classes so the saved config is the original class
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 23, 2024
1 parent 9a08725 commit c3a663a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
10 changes: 7 additions & 3 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .llava import TracableLlavaForConditionalGeneration
from .mistral import TracableMistralForCausalLM
from .mllama import TracableMllamaForConditionalGeneration
from .llava import (
LlavaForConditionalGeneration as TracableLlavaForConditionalGeneration,
)
from .mistral import MistralForCausalLM as TracableMistralForCausalLM
from .mllama import (
MllamaForConditionalGeneration as TracableMllamaForConditionalGeneration,
)

__all__ = [
"TracableLlavaForConditionalGeneration",
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/tracing/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
)
from transformers.models.mistral.configuration_mistral import MistralConfig

from .mistral import TracableMistralForCausalLM
from .mistral import MistralForCausalLM as TracableMistralForCausalLM


class TracableLlavaForConditionalGeneration(LlavaForConditionalGeneration):
class LlavaForConditionalGeneration(LlavaForConditionalGeneration):
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/tracing/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def _update_causal_mask(
return causal_mask


class TracableMistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/tracing/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,7 @@ def forward(
"""The Mllama model which consists of a vision encoder and a language model.""",
MLLAMA_START_DOCSTRING,
)
class TracableMllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
_supports_quantized_cache = (
False # quant cache not supported in encoder-decoder setting
)
Expand Down

0 comments on commit c3a663a

Please sign in to comment.