diff --git a/QEfficient/finetune/experimental/core/model.py b/QEfficient/finetune/experimental/core/model.py index d647b73a6..2f967d85d 100644 --- a/QEfficient/finetune/experimental/core/model.py +++ b/QEfficient/finetune/experimental/core/model.py @@ -4,3 +4,145 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type + +import torch.nn as nn +from transformers import AutoTokenizer +import transformers +from transformers.utils.logging import get_logger + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token + +logger = get_logger(__name__) + + +class BaseModel(nn.Module, ABC): + """Shared skeleton for every finetunable model in the system.""" + + def __init__(self, model_name: str, **model_kwargs: Any) -> None: + super().__init__() + self.model_name = model_name + self.model_kwargs: Dict[str, Any] = model_kwargs + self._model: Optional[nn.Module] = None + self._tokenizer: Any = None # HF tokenizers are not nn.Modules. + + # Factory constructor: load model after __init__ finishes + @classmethod + def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel": + obj = cls(model_name, **model_kwargs) + module = obj.load_model() + if not isinstance(module, nn.Module): + raise TypeError(f"load_model() must return nn.Module, got {type(module)}") + obj._model = module + return obj + + @abstractmethod + def load_model(self) -> nn.Module: + """Create and return the underlying torch.nn.Module.""" + ... + + def load_tokenizer(self) -> Any: + """Override if the model exposes a tokenizer.""" + raise NotImplementedError(f"{type(self).__name__} does not provide a tokenizer.") + + # Lazy accessors + @property + def model(self) -> nn.Module: + if self._model is None: + raise RuntimeError("Model not loaded; use .create(...) to load.") + return self._model + + @property + def tokenizer(self) -> Any: + if self._tokenizer is None: + self._tokenizer = self.load_tokenizer() + return self._tokenizer + + # nn.Module API surface + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def get_input_embeddings(self): + if hasattr(self.model, "get_input_embeddings"): + return self.model.get_input_embeddings() + logger.info(f"Model {self.model_name} does not expose input embeddings", logging.WARNING) + return None + + def resize_token_embeddings(self, new_num_tokens: int) -> None: + if hasattr(self.model, "resize_token_embeddings"): + self.model.resize_token_embeddings(new_num_tokens) + else: + logger.info(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING) + + # optional + def to(self, *args, **kwargs): + self.model.to(*args, **kwargs) + return self + + def train(self, mode: bool = True): + self.model.train(mode) + return super().train(mode) + + def eval(self): + return self.train(False) + + +@registry.model("hf") +class HFModel(BaseModel): + """HuggingFace-backed model with optional quantization.""" + + def __init__( + self, + model_name: str, + auto_class_name: str = "AutoModelForCausalLM", + *, + tokenizer_name: Optional[str] = None, + **model_kwargs: Any, + ) -> None: + super().__init__(model_name, **model_kwargs) + self.tokenizer_name = tokenizer_name or model_name + self.auto_class: Type = self._resolve_auto_class(auto_class_name) + + @staticmethod + def _resolve_auto_class(auto_class_name: str) -> Type: + if not hasattr(transformers, auto_class_name): + candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel")) + raise ValueError( + f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}" + ) + return getattr(transformers, auto_class_name) + + # def _build_quant_config(self) -> Optional[BitsAndBytesConfig]: + # if not self.model_kwargs.get("load_in_4bit"): + # return None + # return BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"), + # bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16), + # bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True), + # ) + + def configure_model_kwargs(self) -> Dict[str, Any]: + """Hook for subclasses to tweak HF `.from_pretrained` kwargs.""" + extra = dict(self.model_kwargs) + # extra["quantization_config"] = self._build_quant_config() + return extra + + def load_model(self) -> nn.Module: + logger.info(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}") + + return self.auto_class.from_pretrained( + self.model_name, + **self.configure_model_kwargs(), + ) + + def load_tokenizer(self) -> AutoTokenizer: + """Load Hugging Face tokenizer.""" + logger.info(f"Loading tokenizer '{self.tokenizer_name}'") + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + insert_pad_token(tokenizer) + return tokenizer diff --git a/QEfficient/finetune/experimental/tests/test_model.py b/QEfficient/finetune/experimental/tests/test_model.py new file mode 100644 index 000000000..5174f971f --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_model.py @@ -0,0 +1,143 @@ +import pytest +import torch +import torch.nn as nn +from unittest import mock + +import transformers +from QEfficient.finetune.experimental.core import model +from QEfficient.finetune.experimental.core.model import BaseModel, HFModel + + +class TestMockModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + +class TestCustomModel(BaseModel): + def __init__(self, model_name): + super().__init__(model_name) + print("init of custom class") + + def load_model(self) -> nn.Module: + return TestMockModel() + + def load_tokenizer(self): + return "dummy-tokenizer" + + +# BaseModel tests +def test_model_property_errors_if_not_created(): + m = TestCustomModel("dummy") + with pytest.raises(RuntimeError): + _ = m.model # must call .create() + + +def test_create_builds_and_registers(): + breakpoint() + m = TestCustomModel.create("dummy") + # inner model exists and registered + assert "_model" in m._modules + assert isinstance(m.model, TestMockModel) + # forward works + out = m(torch.zeros(1, 2)) + assert out.shape == (1, 2) + + +def test_tokenizer_lazy_loading(): + m = TestCustomModel.create("dummy") + assert m._tokenizer is None + tok = m.tokenizer + assert tok == "dummy-tokenizer" + assert m._tokenizer == tok + + +def test_to_moves_inner_and_returns_self(): + m = TestCustomModel.create("dummy") + with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to: + ret = m.to("cuda:0") + mocked_to.assert_called_once_with(m.model, "cuda:0") + assert ret is m + + +def test_train_eval_sync_flags(): + m = TestCustomModel.create("dummy") + m.eval() + assert m.training is False + assert m.model.training is False + m.train() + assert m.training is True + assert m.model.training is True + + +def test_resize_token_embeddings_and_get_input_embeddings_warn(monkeypatch): + m = TestCustomModel.create("dummy") + + # resize_token_embeddings: underlying model lacks the method, should warn and not raise + with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log: + m.resize_token_embeddings(10) + mocked_log.assert_called_once() + + # get_input_embeddings: underlying model lacks method, should warn and return None + with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log: + assert m.get_input_embeddings() is None + mocked_log.assert_called_once() + + +def test_state_dict_contains_inner_params(): + m = TestCustomModel.create("dummy") + sd = m.state_dict() + # should contain params from TestMockModel.linear + assert any("linear.weight" in k for k in sd) + assert any("linear.bias" in k for k in sd) + + +# HFModel tests +def test_hfmodel_invalid_auto_class_raises(): + with pytest.raises(ValueError): + HFModel.create("hf-name", auto_class_name="AutoDoesNotExist") + + +def test_hfmodel_loads_auto_and_tokenizer(monkeypatch): + # fake HF Auto class + class FakeAuto(nn.Module): + @classmethod + def from_pretrained(cls, name, **kwargs): + inst = cls() + inst.loaded = (name, kwargs) + return inst + + def forward(self, x): + return x + + fake_tok = mock.Mock() + + # Monkeypatch transformer classes used in HFModel + monkeypatch.setattr( + "QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM", + FakeAuto, + raising=False, + ) + monkeypatch.setattr( + model, + "AutoTokenizer", + mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)), + ) + monkeypatch.setattr( + "QEfficient.finetune.experimental.core.model.insert_pad_token", + mock.Mock(), + raising=False, + ) + + m = HFModel.create("hf-name") + assert isinstance(m.model, FakeAuto) + + # load tokenizer + tok = m.load_tokenizer() + + # tokenizer was loaded and pad token inserted + model.AutoTokenizer.from_pretrained.assert_called_once_with("hf-name") + model.insert_pad_token.assert_called_once_with(fake_tok)