Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture #12466

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)

@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
# would require using decorated functions instead of simply defining the property
if "model_arch" not in cls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")

def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in self.hparams), None)
if key is not None:
Expand Down Expand Up @@ -702,6 +695,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e":
# ref: https://huggingface.co/Xenova/gpt-4o
res = "gpt-4o"
if chkhsh == "a81863d07e75497e2194eb1a1574d5e5cd4d5f85a87a0728b922bf2bed6fb327":
res = "bert"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly added tokenizer is nomic-embed-text-v2-moe and not bert, is this expected?

And also this list is auto-generated, please make sure not to modify it manually


if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3141,32 +3136,74 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

@Model.register("NomicBertModel")
class NomicBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.NOMIC_BERT

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):

hparams = Model.load_hparams(dir_model)

if (hparams is None or hparams.get("moe_every_n_layers", 0) < 1):
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT
else:
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE

super().__init__(dir_model, ftype, fname_out, is_big_endian, use_temp_file,
eager, metadata_override, model_name, split_max_tensors,
split_max_size, dry_run, small_first_shard, hparams)

# the HF config claims n_ctx=8192, but it uses RoPE scaling
self.hparams["n_ctx"] = 2048

# SwigLU activation
assert self.hparams["activation_function"] == "swiglu"
if (self.hparams.get("moe_every_n_layers", 0) < 1):
assert self.hparams["activation_function"] == "swiglu"
else:
assert self.hparams["activation_function"] == "gelu"

# this doesn't do anything in the HF version
assert self.hparams["causal"] is False
# no bias tensors
assert self.hparams["qkv_proj_bias"] is False
assert self.hparams["mlp_fc1_bias"] is False
assert self.hparams["mlp_fc2_bias"] is False

if (self.hparams.get("moe_every_n_layers", 0) < 1):
assert self.hparams["qkv_proj_bias"] is False
assert self.hparams["mlp_fc1_bias"] is False
assert self.hparams["mlp_fc2_bias"] is False
else:
assert self.hparams["qkv_proj_bias"] is True
assert self.hparams["mlp_fc1_bias"] is True
assert self.hparams["mlp_fc2_bias"] is True

# norm at end of layer
assert self.hparams["prenorm"] is False
# standard RoPE
assert self.hparams["rotary_emb_fraction"] == 1.0
assert self.hparams["rotary_emb_interleaved"] is False
assert self.hparams["rotary_emb_scale_base"] is None

def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]:
# If the tensor is an experts bias tensor, skip it by returning an empty list.
if "mlp.experts.bias" in name:
return [] # Explicitly return an empty list.

if "mlp.experts.mlp.w1" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
return [(self.map_tensor_name(name) + ".weight", data_torch)]
Copy link
Collaborator

@ngxson ngxson Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work? (no need to return here)

map_tensor_name will append .weight if the given original name also have it

Suggested change
return [(self.map_tensor_name(name) + ".weight", data_torch)]
name += ".weight"


if "mlp.experts.mlp.w2" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.transpose(1, 2)
return [(self.map_tensor_name(name) + ".weight", data_torch)]

return [(self.map_tensor_name(name), data_torch)]

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
if (self.hparams.get("moe_every_n_layers", 0) > 0):
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])


@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
{"name": "nomic-embed-text-v2-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe", },
]


Expand Down
20 changes: 19 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class LLM:
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
Expand Down Expand Up @@ -239,6 +240,7 @@ class MODEL_ARCH(IntEnum):
REFACT = auto()
BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
JINA_BERT_V2 = auto()
BLOOM = auto()
STABLELM = auto()
Expand Down Expand Up @@ -441,6 +443,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
Expand Down Expand Up @@ -768,11 +771,26 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.NOMIC_BERT_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.JINA_BERT_V2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,9 @@ def add_expert_weights_norm(self, value: bool) -> None:
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)

def add_moe_every_n_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)

def add_swin_norm(self, value: bool) -> None:
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)

Expand Down
7 changes: 5 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand Down Expand Up @@ -308,7 +309,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate_up_proj", # phi3
"model.layers.layers.{bid}.mlp.up_proj", # plamo
"model.layers.{bid}.feed_forward.w3", # internlm2
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
"encoder.layers.{bid}.mlp.fc1", # nomic-bert
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
Expand All @@ -322,6 +323,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
),

MODEL_TENSOR.FFN_UP_SHEXP: (
Expand All @@ -342,7 +344,6 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.c_fc2", # jais
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
Expand Down Expand Up @@ -388,6 +389,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand All @@ -397,6 +399,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
),

MODEL_TENSOR.FFN_DOWN_SHEXP: (
Expand Down
20 changes: 20 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
Expand Down Expand Up @@ -99,6 +100,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
Expand Down Expand Up @@ -433,6 +435,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum llm_arch {
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
Expand Down Expand Up @@ -103,6 +104,7 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
Expand Down
1 change: 1 addition & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct llama_hparams {
float expert_weights_scale = 0.0;
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;

float f_norm_eps;
float f_norm_rms_eps;
Expand Down
30 changes: 23 additions & 7 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers);

if (hparams.n_layer == 12 && hparams.n_embd == 768) {
type = LLM_TYPE_137M;
Expand Down Expand Up @@ -1911,6 +1913,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
Expand Down Expand Up @@ -1944,20 +1947,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
}

if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
}

layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);

layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);

if (arch == LLM_ARCH_BERT) {
if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
} else {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);

if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
} else {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
}
}

layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
Expand Down Expand Up @@ -11650,6 +11664,7 @@ llm_graph_result_ptr llama_model::build_graph(
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
llm = std::make_unique<llm_build_bert>(*this, params, gf);
} break;
Expand Down Expand Up @@ -11982,6 +11997,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DBRX:
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_STABLELM:
case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN:
Expand Down
Loading