Skip to content

Commit feae967

Browse files
committed
Adding support for gemma3 TransformerEvalWrapper
Summary: This is to support AWQ for gemma3 Test Plan: Before change: ``` File "/data/users/jerryzh/ao/.github/scripts/torchao_model_releases/quantize_gemma3.py", line 59, in <module> TransformerEvalWrapper( File "/home/jerryzh/.conda/envs/hmbd/lib/python3.10/site-packages/torchao-0.14.0+gitb47f1a36-py3.10.egg/torchao/_models/_eval.py", line 82, in run_eval result = evaluate( File "/home/jerryzh/.conda/envs/hmbd/lib/python3.10/site-packages/lm_eval/utils.py", line 456, in _wrapper return fn(*args, **kwargs) File "/home/jerryzh/.conda/envs/hmbd/lib/python3.10/site-packages/lm_eval/evaluator.py", line 585, in evaluate resps = getattr(lm, reqtype)(cloned_reqs) File "/home/jerryzh/.conda/envs/hmbd/lib/python3.10/site-packages/lm_eval/api/model.py", line 391, in loglikelihood return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) File "/home/jerryzh/.conda/envs/hmbd/lib/python3.10/site-packages/lm_eval/models/huggingface.py", line 1279, in _loglikelihood_tokens multi_logits = F.log_softmax( File "/home/jerryzh/.conda/envs/hmbd/lib/python3.10/site-packages/torch/nn/functional.py", line 2245, in log_softmax ret = input.log_softmax(dim) AttributeError: 'Gemma3CausalLMOutputWithPast' object has no attribute 'log_softmax' Running loglikelihood requests: 0%| ``` After change: quantize_qwen3.py: https://gist.github.com/jerryzh168/85cc75bc1feb4723fddb156582adc6ad Uploaded checkpoint after the change: https://huggingface.co/jerryzh168/gemma-3-12b-it-AWQ-INT4 Reviewers: Subscribers: Tasks: Tags:
1 parent b47f1a3 commit feae967

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

torchao/_models/_eval.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,16 @@ def _model_call(self, inps):
5959
with torch.device(self._device):
6060
if hasattr(self._model, "setup_caches"):
6161
self._model.setup_caches(self.batch_size, max_seq_length)
62-
logits = self._model(*input)
62+
output = self._model(*input)
6363
from transformers.modeling_outputs import CausalLMOutputWithPast
64+
from transformers.models.gemma3.modeling_gemma3 import (
65+
Gemma3CausalLMOutputWithPast,
66+
)
6467

65-
if isinstance(logits, CausalLMOutputWithPast):
66-
logits = logits.logits
68+
if isinstance(output, (CausalLMOutputWithPast, Gemma3CausalLMOutputWithPast)):
69+
logits = output.logits
70+
else:
71+
logits = output
6772
return logits
6873

6974
def run_eval(self, tasks, limit):

0 commit comments

Comments
 (0)