Skip to content

Commit e08db42

Browse files
sfallahdanbevggerganov
authored
model: EmbeddingGemma Adding Support for SentenceTransformers Dense Modules (#16367)
* model: EmbeddingGemma sentence-transformers dense linear projections support * model: add support for EmbeddingGemma SentenceTransformers dense linear projections Adding support for the Dense modules used in EmbeddingGemma models. EmbeddingGemma is a SentenceTransformers model with additional modules beyond the base Transformer backbone. See: https://developers.googleblog.com/en/gemma-explained-embeddinggemma-architecture-and-recipe/ * model: add support for EmbeddingGemma SentenceTransformers dense linear projections - converting model with dense-layers is optional - introduced dense config params * Update convert_hf_to_gguf.py Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> * fixed formatting issues * Update src/llama-graph.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * - removed pooling_type_opt, always allow overriding pooling_type - asserts checking dense features dims * fix python lint * fix ubuntu gcc build warning * - fixed thread-safety test - moved asserts to load_hparams * - tidying up code - simplifying graph-context expecting both dense weights * minor : add TODO --------- Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 12bbc3f commit e08db42

12 files changed

+170
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ class ModelBase:
9393
# Mistral format specifics
9494
is_mistral_format: bool = False
9595
disable_mistral_community_chat_template: bool = False
96+
sentence_transformers_dense_modules: bool = False
9697

9798
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
9899
use_temp_file: bool = False, eager: bool = False,
99100
metadata_override: Path | None = None, model_name: str | None = None,
100101
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
101102
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
102-
disable_mistral_community_chat_template: bool = False):
103+
disable_mistral_community_chat_template: bool = False,
104+
sentence_transformers_dense_modules: bool = False):
103105
if type(self) is ModelBase or \
104106
type(self) is TextModel or \
105107
type(self) is MmprojModel:
@@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
114116
self.lazy = not eager or (remote_hf_model_id is not None)
115117
self.dry_run = dry_run
116118
self.remote_hf_model_id = remote_hf_model_id
119+
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
117120
if remote_hf_model_id is not None:
118121
self.is_safetensors = True
119122

@@ -5269,6 +5272,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
52695272
@ModelBase.register("Gemma3TextModel")
52705273
class EmbeddingGemma(Gemma3Model):
52715274
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
5275+
module_paths = []
5276+
dense_features_dims = {}
5277+
5278+
def __init__(self, *args, **kwargs):
5279+
super().__init__(*args, **kwargs)
5280+
if self.sentence_transformers_dense_modules:
5281+
# read modules.json to determine if model has Dense layers
5282+
modules_file = self.dir_model / "modules.json"
5283+
if modules_file.is_file():
5284+
with open(modules_file, encoding="utf-8") as modules_json_file:
5285+
mods = json.load(modules_json_file)
5286+
for mod in mods:
5287+
if mod["type"] == "sentence_transformers.models.Dense":
5288+
mod_path = mod["path"]
5289+
# check if model.safetensors file for Dense layer exists
5290+
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
5291+
if model_tensors_file.is_file():
5292+
self.module_paths.append(mod_path)
5293+
# read config.json of the Dense layer to get in/out features
5294+
mod_conf_file = self.dir_model / mod_path / "config.json"
5295+
if mod_conf_file.is_file():
5296+
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
5297+
mod_conf = json.load(mod_conf_json_file)
5298+
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
5299+
prefix = self._get_dense_prefix(mod_path)
5300+
if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
5301+
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
5302+
5303+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
5304+
from safetensors.torch import load_file
5305+
module_paths = list(self.module_paths)
5306+
for i, module_path in enumerate(module_paths):
5307+
tensors_file = self.dir_model / module_path / "model.safetensors"
5308+
local_tensors = load_file(tensors_file)
5309+
tensor_name = self._get_dense_prefix(module_path)
5310+
for name, local_tensor in local_tensors.items():
5311+
if not name.endswith(".weight"):
5312+
continue
5313+
orig_name = name.replace("linear", tensor_name)
5314+
name = self.map_tensor_name(orig_name)
5315+
yield name, local_tensor.clone()
5316+
5317+
@staticmethod
5318+
def _get_dense_prefix(module_path) -> str:
5319+
"""Get the tensor name prefix for the Dense layer from module path."""
5320+
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
5321+
return tensor_name
52725322

52735323
def set_gguf_parameters(self):
52745324
super().set_gguf_parameters()
@@ -5285,6 +5335,10 @@ def set_gguf_parameters(self):
52855335
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
52865336
f"instead of {self.hparams['sliding_window']}")
52875337
self.gguf_writer.add_sliding_window(orig_sliding_window)
5338+
if self.sentence_transformers_dense_modules:
5339+
for dense, dims in self.dense_features_dims.items():
5340+
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
5341+
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
52885342

52895343
self._try_set_pooling_type()
52905344

@@ -9335,6 +9389,13 @@ def parse_args() -> argparse.Namespace:
93359389
)
93369390
)
93379391

9392+
parser.add_argument(
9393+
"--sentence-transformers-dense-modules", action="store_true",
9394+
help=("Whether to include sentence-transformers dense modules."
9395+
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
9396+
"Default these modules are not included.")
9397+
)
9398+
93389399
args = parser.parse_args()
93399400
if not args.print_supported_models and args.model is None:
93409401
parser.error("the following arguments are required: model")
@@ -9397,9 +9458,13 @@ def main() -> None:
93979458
if args.remote:
93989459
hf_repo_id = args.model
93999460
from huggingface_hub import snapshot_download
9461+
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
9462+
if args.sentence_transformers_dense_modules:
9463+
# include sentence-transformers dense modules safetensors files
9464+
allowed_patterns.append("*.safetensors")
94009465
local_dir = snapshot_download(
94019466
repo_id=hf_repo_id,
9402-
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
9467+
allow_patterns=allowed_patterns)
94039468
dir_model = Path(local_dir)
94049469
logger.info(f"Downloaded config and tokenizer to {local_dir}")
94059470
else:
@@ -9467,7 +9532,8 @@ def main() -> None:
94679532
split_max_tensors=args.split_max_tensors,
94689533
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
94699534
small_first_shard=args.no_tensor_first_split,
9470-
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
9535+
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
9536+
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
94719537
)
94729538

94739539
if args.vocab_only:

gguf-py/gguf/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class LLM:
128128
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
129129
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
130130
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
131+
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
132+
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
131133

132134
class Attention:
133135
HEAD_COUNT = "{arch}.attention.head_count"
@@ -433,6 +435,8 @@ class MODEL_TENSOR(IntEnum):
433435
TOKEN_TYPES = auto()
434436
POS_EMBD = auto()
435437
OUTPUT = auto()
438+
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
439+
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
436440
OUTPUT_NORM = auto()
437441
ROPE_FREQS = auto()
438442
ROPE_FACTORS_LONG = auto()
@@ -777,6 +781,8 @@ class MODEL_TENSOR(IntEnum):
777781
MODEL_TENSOR.POS_EMBD: "position_embd",
778782
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
779783
MODEL_TENSOR.OUTPUT: "output",
784+
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
785+
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
780786
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
781787
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
782788
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
@@ -1759,6 +1765,8 @@ class MODEL_TENSOR(IntEnum):
17591765
MODEL_ARCH.GEMMA_EMBEDDING: [
17601766
MODEL_TENSOR.TOKEN_EMBD,
17611767
MODEL_TENSOR.OUTPUT,
1768+
MODEL_TENSOR.DENSE_2_OUT,
1769+
MODEL_TENSOR.DENSE_3_OUT,
17621770
MODEL_TENSOR.OUTPUT_NORM,
17631771
MODEL_TENSOR.ATTN_Q,
17641772
MODEL_TENSOR.ATTN_Q_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,10 @@ def add_shared_kv_layers(self, value: int) -> None:
730730
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
731731
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
732732

733+
def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
734+
self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
735+
self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
736+
733737
def add_logit_scale(self, value: float) -> None:
734738
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
735739

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ class TensorNameMap:
7676
"lm_head", # llama4
7777
"model.transformer.ff_out", # llada
7878
),
79-
79+
MODEL_TENSOR.DENSE_2_OUT: (
80+
"dense_2_out", # embeddinggemma
81+
),
82+
MODEL_TENSOR.DENSE_3_OUT: (
83+
"dense_3_out", # embeddinggemma
84+
),
8085
# Output norm
8186
MODEL_TENSOR.OUTPUT_NORM: (
8287
"gpt_neox.final_layer_norm", # gptneox

src/llama-arch.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
219219
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
220220

221221
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
222+
// sentence-transformers dense modules feature dims
223+
{ LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
224+
{ LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" },
225+
{ LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
226+
{ LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },
222227

223228
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
224229
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
@@ -1071,6 +1076,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10711076
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
10721077
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
10731078
{ LLM_TENSOR_OUTPUT, "output" },
1079+
{ LLM_TENSOR_DENSE_2_OUT, "dense_2" },
1080+
{ LLM_TENSOR_DENSE_3_OUT, "dense_3" },
10741081
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
10751082
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
10761083
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
@@ -2281,6 +2288,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
22812288
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
22822289
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
22832290
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2291+
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
2292+
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
22842293
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
22852294
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
22862295
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},

src/llama-arch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,21 @@ enum llm_kv {
271271
LLM_KV_TOKENIZER_PREFIX_ID,
272272
LLM_KV_TOKENIZER_SUFFIX_ID,
273273
LLM_KV_TOKENIZER_MIDDLE_ID,
274+
275+
// sentence-transformers dense layers in and out features
276+
LLM_KV_DENSE_2_FEAT_IN,
277+
LLM_KV_DENSE_2_FEAT_OUT,
278+
LLM_KV_DENSE_3_FEAT_IN,
279+
LLM_KV_DENSE_3_FEAT_OUT,
274280
};
275281

276282
enum llm_tensor {
277283
LLM_TENSOR_TOKEN_EMBD,
278284
LLM_TENSOR_TOKEN_EMBD_NORM,
279285
LLM_TENSOR_TOKEN_TYPES,
280286
LLM_TENSOR_POS_EMBD,
287+
LLM_TENSOR_DENSE_2_OUT,
288+
LLM_TENSOR_DENSE_3_OUT,
281289
LLM_TENSOR_OUTPUT,
282290
LLM_TENSOR_OUTPUT_NORM,
283291
LLM_TENSOR_ROPE_FREQS,

src/llama-context.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model(
23462346
return nullptr;
23472347
}
23482348

2349+
if (params.pooling_type != model->hparams.pooling_type) {
2350+
//user-specified pooling-type is different from the model default
2351+
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
2352+
model->hparams.pooling_type, params.pooling_type);
2353+
}
2354+
23492355
try {
23502356
auto * ctx = new llama_context(*model, params);
23512357
return ctx;

src/llama-graph.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
18531853
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
18541854
}
18551855

1856+
void llm_graph_context::build_dense_out(
1857+
ggml_tensor * dense_2,
1858+
ggml_tensor * dense_3) const {
1859+
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
1860+
return;
1861+
}
1862+
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
1863+
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
1864+
1865+
cur = ggml_mul_mat(ctx0, dense_2, cur);
1866+
cur = ggml_mul_mat(ctx0, dense_3, cur);
1867+
cb(cur, "result_embd_pooled", -1);
1868+
res->t_embd_pooled = cur;
1869+
ggml_build_forward_expand(gf, cur);
1870+
}
1871+
1872+
18561873
void llm_graph_context::build_pooling(
18571874
ggml_tensor * cls,
18581875
ggml_tensor * cls_b,

src/llama-graph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,14 @@ struct llm_graph_context {
814814
ggml_tensor * cls_b,
815815
ggml_tensor * cls_out,
816816
ggml_tensor * cls_out_b) const;
817+
818+
//
819+
// dense (out)
820+
//
821+
822+
void build_dense_out(
823+
ggml_tensor * dense_2,
824+
ggml_tensor * dense_3) const;
817825
};
818826

819827
// TODO: better name

src/llama-hparams.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ struct llama_hparams {
169169
uint32_t laurel_rank = 64;
170170
uint32_t n_embd_altup = 256;
171171

172+
// needed for sentence-transformers dense layers
173+
uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense
174+
uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense
175+
uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense
176+
uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense
177+
172178
// xIELU
173179
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
174180
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;

0 commit comments

Comments
 (0)