Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,17 @@ def forward(
logits = logits.float()
return logits, pixel_values, image_idx, outputs.past_key_values

def get_npi_file(self, model_name: str, **compiler_options):
if model_name == "google/gemma-3-4b-it":
compiler_options["node_precision_info"] = constants.DEFAULT_GEMMA3_4B_NODE_PRECISION_INFO
elif model_name == "google/gemma-3-27b-it":
compiler_options["node_precision_info"] = constants.DEFAULT_GEMMA3_27B_NODE_PRECISION_INFO
else:
raise ValueError(
f"For Model {self.pretrained_model_name_or_path} default NPI file is not supported/added. Please use one of the following: google/gemma-3-4b-it, google/gemma-3-27b-it"
)
return compiler_options

def get_specializations(
self,
batch_size: int,
Expand Down
10 changes: 7 additions & 3 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,6 @@ def compile(
compiler_options.pop("continuous_batching", None)
compiler_options.pop("kv_cache_batch_size", None)
compiler_options.pop("full_batch_size", None)

if not skip_vision:
self.vision_model._compile(
compile_dir=compile_dir,
Expand All @@ -1199,6 +1198,10 @@ def compile(
**compiler_options,
)

# Custom NPI file options
if hasattr(self.model, "get_npi_file"):
compiler_options = self.model.get_npi_file(self.model.name_or_path)

if not skip_lang:
custom_io_lang = {}
# Inputs
Expand All @@ -1212,7 +1215,6 @@ def compile(
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype

self.lang_model._compile(
compile_dir=compile_dir,
compile_only=True,
Expand Down Expand Up @@ -1801,6 +1803,9 @@ def compile(
**compiler_options,
)

if hasattr(self.model, "get_npi_file"):
compiler_options = self.model.get_npi_file(self.model.name_or_path)

custom_io = {}
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
# inputs
Expand All @@ -1819,7 +1824,6 @@ def compile(
compiler_options.pop("continuous_batching", None)
compiler_options.pop("kv_cache_batch_size", None)
compiler_options.pop("full_batch_size", None)

self._compile(
onnx_path=onnx_path,
compile_dir=compile_dir,
Expand Down
20 changes: 20 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import os
from dataclasses import dataclass
from pathlib import Path

from QEfficient.transformers.models import gemma3 as gemma3

UTILS_DIR = os.path.dirname(os.path.abspath(__file__))
QEFF_DIR = os.path.dirname(UTILS_DIR)
Expand All @@ -24,6 +27,23 @@
ONNX_EXPORT_IMAGE_DEPTH = 3
ONNX_EXPORT_CTX_LEN = 1024

# Gemma3 Constant
DEFAULT_GEMMA3_4B_NODE_PRECISION_INFO = str(
Path(__file__).resolve().parent.parent
/ "transformers"
/ "models"
/ "gemma3"
/ "configs"
/ "fp32_nodes_gemma3_4b.yaml"
)
DEFAULT_GEMMA3_27B_NODE_PRECISION_INFO = str(
Path(__file__).resolve().parent.parent
/ "transformers"
/ "models"
/ "gemma3"
/ "configs"
/ "fp32_nodes_gemma3_27b.yaml"
)
# Compiler defaults
DEFAULT_AIC_NUM_CORES = 16
DEFAULT_AIC_MXPF6_MATMUL = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
aic_enable_depth_first=True,
skip_vision=True,
mos=1,
node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_4b.yaml", # Change to fp32_nodes_gemma3_27b.yaml for 27B model
)

messages = [
Expand Down Expand Up @@ -80,7 +79,6 @@
mxint8_kv_cache=False,
aic_enable_depth_first=True,
mos=1,
node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_4b.yaml", # Change to fp32_nodes_gemma3_27b.yaml for 27B model
)

### IMAGE + TEXT ###
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ test = ["pytest","pytest-mock"]
docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"]
quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"]

[tool.setuptools.package-data]
"QEfficient.transformers.models.gemma3.configs" = ["*.yaml"]

[build-system]
requires = ["setuptools>=62.0.0"]
build-backend = "setuptools.build_meta"
Expand Down
Loading