Skip to content

Commit f88f01d

Browse files
author
Swati Allabadi
committed
Adding base class and Hf class
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
1 parent 5cd3fd1 commit f88f01d

File tree

1 file changed

+143
-0
lines changed
  • QEfficient/finetune/experimental/core

1 file changed

+143
-0
lines changed

QEfficient/finetune/experimental/core/model.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,146 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import logging
9+
from abc import ABC, abstractmethod
10+
from typing import Any, Dict, Optional, Type
11+
12+
import torch
13+
import torch.nn as nn
14+
from transformers import AutoTokenizer, BitsAndBytesConfig
15+
import transformers
16+
17+
from QEfficient.finetune.experimental.core.component_registry import registry
18+
from QEfficient.finetune.experimental.utils.dataset_helper import insert_pad_token
19+
20+
logger = get_logger(__name__)
21+
22+
23+
class BaseModel(nn.Module, ABC):
24+
"""Shared skeleton for every finetunable model in the system."""
25+
26+
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
27+
super().__init__()
28+
self.model_name = model_name
29+
self.model_kwargs: Dict[str, Any] = model_kwargs
30+
self._model: Optional[nn.Module] = None
31+
self._tokenizer: Any = None # HF tokenizers are not nn.Modules.
32+
33+
# Factory constructor: load model after __init__ finishes
34+
@classmethod
35+
def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel":
36+
obj = cls(model_name, **model_kwargs)
37+
module = obj.load_model()
38+
if not isinstance(module, nn.Module):
39+
raise TypeError(f"load_model() must return nn.Module, got {type(module)}")
40+
obj._model = module
41+
obj.add_module("_wrapped_model", module) # register
42+
return obj
43+
44+
@abstractmethod
45+
def load_model(self) -> nn.Module:
46+
"""Create and return the underlying torch.nn.Module."""
47+
...
48+
49+
def load_tokenizer(self) -> Any:
50+
"""Override if the model exposes a tokenizer."""
51+
raise NotImplementedError(f"{type(self).__name__} does not provide a tokenizer.")
52+
53+
# Lazy accessors
54+
@property
55+
def model(self) -> nn.Module:
56+
if self._model is None:
57+
raise RuntimeError("Model not loaded; use .create(...) to load.")
58+
return self._model
59+
60+
@property
61+
def tokenizer(self) -> Any:
62+
if self._tokenizer is None:
63+
self._tokenizer = self.load_tokenizer()
64+
return self._tokenizer
65+
66+
# nn.Module API surface
67+
def forward(self, *args, **kwargs):
68+
return self.model(*args, **kwargs)
69+
70+
def get_input_embeddings(self):
71+
if hasattr(self.model, "get_input_embeddings"):
72+
return self.model.get_input_embeddings()
73+
logger.log_rank_zero(f"Model {self.model_name} does not expose input embeddings", logging.WARNING)
74+
return None
75+
76+
def resize_token_embeddings(self, new_num_tokens: int) -> None:
77+
if hasattr(self.model, "resize_token_embeddings"):
78+
self.model.resize_token_embeddings(new_num_tokens)
79+
else:
80+
logger.log_rank_zero(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING)
81+
82+
# optional
83+
def to(self, *args, **kwargs):
84+
self.model.to(*args, **kwargs)
85+
return self
86+
87+
def train(self, mode: bool = True):
88+
self.model.train(mode)
89+
return super().train(mode)
90+
91+
def eval(self):
92+
return self.train(False)
93+
94+
95+
@registry.model("hf")
96+
class HFModel(BaseModel):
97+
"""HuggingFace-backed model with optional quantization."""
98+
99+
def __init__(
100+
self,
101+
model_name: str,
102+
auto_class_name: str = "AutoModelForCausalLM",
103+
*,
104+
tokenizer_name: Optional[str] = None,
105+
**model_kwargs: Any,
106+
) -> None:
107+
super().__init__(model_name, **model_kwargs)
108+
self.tokenizer_name = tokenizer_name or model_name
109+
self.auto_class: Type = self._resolve_auto_class(auto_class_name)
110+
111+
@staticmethod
112+
def _resolve_auto_class(auto_class_name: str) -> Type:
113+
if not hasattr(transformers, auto_class_name):
114+
candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel"))
115+
raise ValueError(
116+
f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}"
117+
)
118+
return getattr(transformers, auto_class_name)
119+
120+
# def _build_quant_config(self) -> Optional[BitsAndBytesConfig]:
121+
# if not self.model_kwargs.get("load_in_4bit"):
122+
# return None
123+
# return BitsAndBytesConfig(
124+
# load_in_4bit=True,
125+
# bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"),
126+
# bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16),
127+
# bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True),
128+
# )
129+
130+
def configure_model_kwargs(self) -> Dict[str, Any]:
131+
"""Hook for subclasses to tweak HF `.from_pretrained` kwargs."""
132+
extra = dict(self.model_kwargs)
133+
# extra["quantization_config"] = self._build_quant_config()
134+
return extra
135+
136+
def load_model(self) -> nn.Module:
137+
logger.log_rank_zero(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}")
138+
139+
return self.auto_class.from_pretrained(
140+
self.model_name,
141+
**self.configure_model_kwargs(),
142+
)
143+
144+
def load_tokenizer(self) -> AutoTokenizer:
145+
"""Load Hugging Face tokenizer."""
146+
logger.log_rank_zero(f"Loading tokenizer '{self.tokenizer_name}'")
147+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
148+
insert_pad_token(tokenizer)
149+
return tokenizer

0 commit comments

Comments
 (0)