Skip to content

Commit 2891c8a

Browse files
iamlemeccebtenzzreggerganov
authored
Add support for BERT embedding models (#5423)
* BERT model graph construction (build_bert) * WordPiece tokenizer (llm_tokenize_wpm) * Add flag for non-causal attention models * Allow for models that only output embeddings * Support conversion of BERT models to GGUF * Based on prior work by @xyzhang626 and @skeskinen --------- Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 97a3365 commit 2891c8a

File tree

8 files changed

+616
-52
lines changed

8 files changed

+616
-52
lines changed

.flake8

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[flake8]
22
max-line-length = 125
3+
ignore = W503

convert-hf-to-gguf.py

+94
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def from_model_architecture(model_architecture):
209209
return InternLM2Model
210210
if model_architecture == "MiniCPMForCausalLM":
211211
return MiniCPMModel
212+
if model_architecture == "BertModel":
213+
return BertModel
212214
return Model
213215

214216
def _is_model_safetensors(self) -> bool:
@@ -264,6 +266,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
264266
return gguf.MODEL_ARCH.INTERNLM2
265267
if arch == "MiniCPMForCausalLM":
266268
return gguf.MODEL_ARCH.MINICPM
269+
if arch == "BertModel":
270+
return gguf.MODEL_ARCH.BERT
267271

268272
raise NotImplementedError(f'Architecture "{arch}" not supported!')
269273

@@ -1629,6 +1633,96 @@ def write_tensors(self):
16291633
self.post_write_tensors(tensor_map, name, data_torch)
16301634

16311635

1636+
class BertModel(Model):
1637+
def __init__(self, *args, **kwargs):
1638+
super().__init__(*args, **kwargs)
1639+
self.block_count = self.hparams["num_hidden_layers"]
1640+
1641+
def set_gguf_parameters(self):
1642+
# TODO(cebtenzzre): merge with parent class
1643+
self.gguf_writer.add_name(self.dir_model.name)
1644+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
1645+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
1646+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
1647+
self.gguf_writer.add_block_count(self.block_count)
1648+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
1649+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
1650+
self.gguf_writer.add_causal_attention(False)
1651+
self.gguf_writer.add_file_type(self.ftype)
1652+
1653+
def set_vocab(self):
1654+
path = self.dir_model
1655+
added_tokens_path = self.dir_model if self.dir_model.exists() else None
1656+
1657+
# use huggingface vocab to get all tokens
1658+
vocab = HfVocab(path, added_tokens_path)
1659+
tokens, scores, toktypes = zip(*vocab.all_tokens())
1660+
assert len(tokens) == vocab.vocab_size
1661+
1662+
# we need this to validate the size of the token_type embeddings
1663+
# though currently we are passing all zeros to the token_type embeddings
1664+
n_token_types = len(set(toktypes))
1665+
self.gguf_writer.add_token_type_count(n_token_types)
1666+
1667+
# convert to phantom space vocab
1668+
def phantom(tok, typ):
1669+
if tok.startswith(b"[") and tok.endswith(b"]"):
1670+
return tok
1671+
if tok.startswith(b"##"):
1672+
return tok[2:]
1673+
return b"\xe2\x96\x81" + tok
1674+
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
1675+
1676+
# set up bos and eos tokens (cls and sep)
1677+
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
1678+
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
1679+
1680+
# add vocab to gguf
1681+
self.gguf_writer.add_tokenizer_model("bert")
1682+
self.gguf_writer.add_token_list(tokens)
1683+
self.gguf_writer.add_token_scores(scores)
1684+
self.gguf_writer.add_token_types(toktypes)
1685+
1686+
# handle special tokens
1687+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
1688+
special_vocab.add_to_gguf(self.gguf_writer)
1689+
1690+
def write_tensors(self):
1691+
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
1692+
tensors = dict(self.get_tensors())
1693+
for name, data_torch in tensors.items():
1694+
# we are only using BERT for embeddings so we don't need the pooling layer
1695+
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
1696+
continue # we don't need these
1697+
1698+
# map tensor names
1699+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1700+
if new_name is None:
1701+
print(f"Can not map tensor {name!r}")
1702+
sys.exit()
1703+
1704+
data = data_torch.squeeze().numpy()
1705+
n_dims = len(data.shape)
1706+
new_dtype: type[np.floating[Any]]
1707+
1708+
if (
1709+
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
1710+
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
1711+
):
1712+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1713+
new_dtype = np.float16
1714+
else:
1715+
# if f32 desired, convert any float16 to float32
1716+
new_dtype = np.float32
1717+
1718+
print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
1719+
1720+
if data.dtype != new_dtype:
1721+
data = data.astype(new_dtype)
1722+
1723+
self.gguf_writer.add_tensor(new_name, data)
1724+
1725+
16321726
###### CONVERSION LOGIC ######
16331727

16341728

examples/embedding/embedding.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
8787
}
8888

8989
const int n_embd = llama_n_embd(model);
90-
const auto * embeddings = llama_get_embeddings(ctx);
90+
auto * embeddings = llama_get_embeddings(ctx);
91+
92+
// l2-normalize embeddings
93+
float norm = 0;
94+
for (int i = 0; i < n_embd; i++) {
95+
norm += embeddings[i] * embeddings[i];
96+
}
97+
norm = sqrt(norm);
98+
for (int i = 0; i < n_embd; i++) {
99+
embeddings[i] /= norm;
100+
}
91101

92102
for (int i = 0; i < n_embd; i++) {
93103
printf("%f ", embeddings[i]);

gguf-py/gguf/constants.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Attention:
5050
VALUE_LENGTH = "{arch}.attention.value_length"
5151
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
5252
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
53+
CAUSAL = "{arch}.attention.causal"
5354

5455
class Rope:
5556
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -60,22 +61,23 @@ class Rope:
6061
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
6162

6263
class Tokenizer:
63-
MODEL = "tokenizer.ggml.model"
64-
LIST = "tokenizer.ggml.tokens"
65-
TOKEN_TYPE = "tokenizer.ggml.token_type"
66-
SCORES = "tokenizer.ggml.scores"
67-
MERGES = "tokenizer.ggml.merges"
68-
BOS_ID = "tokenizer.ggml.bos_token_id"
69-
EOS_ID = "tokenizer.ggml.eos_token_id"
70-
UNK_ID = "tokenizer.ggml.unknown_token_id"
71-
SEP_ID = "tokenizer.ggml.seperator_token_id"
72-
PAD_ID = "tokenizer.ggml.padding_token_id"
73-
ADD_BOS = "tokenizer.ggml.add_bos_token"
74-
ADD_EOS = "tokenizer.ggml.add_eos_token"
75-
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
76-
HF_JSON = "tokenizer.huggingface.json"
77-
RWKV = "tokenizer.rwkv.world"
78-
CHAT_TEMPLATE = "tokenizer.chat_template"
64+
MODEL = "tokenizer.ggml.model"
65+
LIST = "tokenizer.ggml.tokens"
66+
TOKEN_TYPE = "tokenizer.ggml.token_type"
67+
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
68+
SCORES = "tokenizer.ggml.scores"
69+
MERGES = "tokenizer.ggml.merges"
70+
BOS_ID = "tokenizer.ggml.bos_token_id"
71+
EOS_ID = "tokenizer.ggml.eos_token_id"
72+
UNK_ID = "tokenizer.ggml.unknown_token_id"
73+
SEP_ID = "tokenizer.ggml.seperator_token_id"
74+
PAD_ID = "tokenizer.ggml.padding_token_id"
75+
ADD_BOS = "tokenizer.ggml.add_bos_token"
76+
ADD_EOS = "tokenizer.ggml.add_eos_token"
77+
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
78+
HF_JSON = "tokenizer.huggingface.json"
79+
RWKV = "tokenizer.rwkv.world"
80+
CHAT_TEMPLATE = "tokenizer.chat_template"
7981

8082

8183
#
@@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
122124
ATTN_OUT = auto()
123125
ATTN_NORM = auto()
124126
ATTN_NORM_2 = auto()
127+
ATTN_OUT_NORM = auto()
125128
ATTN_ROT_EMBD = auto()
126129
FFN_GATE_INP = auto()
127130
FFN_NORM = auto()
@@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
134137
FFN_UP_EXP = auto()
135138
ATTN_Q_NORM = auto()
136139
ATTN_K_NORM = auto()
140+
LAYER_OUT_NORM = auto()
137141

138142

139143
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -178,6 +182,7 @@ class MODEL_TENSOR(IntEnum):
178182
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
179183
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
180184
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
185+
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
181186
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
182187
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
183188
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
@@ -187,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
187192
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
188193
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
189194
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
195+
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
190196
}
191197

192198
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -262,17 +268,18 @@ class MODEL_TENSOR(IntEnum):
262268
],
263269
MODEL_ARCH.BERT: [
264270
MODEL_TENSOR.TOKEN_EMBD,
271+
MODEL_TENSOR.TOKEN_EMBD_NORM,
265272
MODEL_TENSOR.TOKEN_TYPES,
266273
MODEL_TENSOR.POS_EMBD,
267274
MODEL_TENSOR.OUTPUT_NORM,
268-
MODEL_TENSOR.ATTN_NORM,
275+
MODEL_TENSOR.ATTN_OUT_NORM,
269276
MODEL_TENSOR.ATTN_Q,
270277
MODEL_TENSOR.ATTN_K,
271278
MODEL_TENSOR.ATTN_V,
272279
MODEL_TENSOR.ATTN_OUT,
273-
MODEL_TENSOR.FFN_NORM,
274280
MODEL_TENSOR.FFN_DOWN,
275281
MODEL_TENSOR.FFN_UP,
282+
MODEL_TENSOR.LAYER_OUT_NORM,
276283
],
277284
MODEL_ARCH.MPT: [
278285
MODEL_TENSOR.TOKEN_EMBD,

gguf-py/gguf/gguf_writer.py

+6
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,9 @@ def add_layer_norm_eps(self, value: float) -> None:
357357
def add_layer_norm_rms_eps(self, value: float) -> None:
358358
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
359359

360+
def add_causal_attention(self, value: bool) -> None:
361+
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
362+
360363
def add_rope_dimension_count(self, count: int) -> None:
361364
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
362365

@@ -387,6 +390,9 @@ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[by
387390
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
388391
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
389392

393+
def add_token_type_count(self, value: int) -> None:
394+
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
395+
390396
def add_token_scores(self, scores: Sequence[float]) -> None:
391397
self.add_array(Keys.Tokenizer.SCORES, scores)
392398

gguf-py/gguf/tensor_mapping.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class TensorNameMap:
3030
# Normalization of token embeddings
3131
MODEL_TENSOR.TOKEN_EMBD_NORM: (
3232
"word_embeddings_layernorm", # bloom
33+
"embeddings.LayerNorm", # bert
3334
),
3435

3536
# Position embeddings
@@ -54,7 +55,6 @@ class TensorNameMap:
5455
"transformer.ln_f", # gpt2 gpt-j falcon
5556
"model.norm", # llama-hf baichuan internlm2
5657
"norm", # llama-pth
57-
"embeddings.LayerNorm", # bert
5858
"transformer.norm_f", # mpt
5959
"ln_f", # refact bloom qwen gpt2
6060
"language_model.encoder.final_layernorm", # persimmon
@@ -79,7 +79,6 @@ class TensorNameMap:
7979
"transformer.h.{bid}.ln_mlp", # falcon40b
8080
"model.layers.{bid}.input_layernorm", # llama-hf
8181
"layers.{bid}.attention_norm", # llama-pth
82-
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
8382
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
8483
"model.layers.{bid}.ln1", # yi
8584
"h.{bid}.ln_1", # gpt2
@@ -155,6 +154,11 @@ class TensorNameMap:
155154
"model.layers.{bid}.attention.wo", # internlm2
156155
),
157156

157+
# Attention output norm
158+
MODEL_TENSOR.ATTN_OUT_NORM: (
159+
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
160+
),
161+
158162
# Rotary embeddings
159163
MODEL_TENSOR.ATTN_ROT_EMBD: (
160164
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
@@ -171,7 +175,6 @@ class TensorNameMap:
171175
"transformer.blocks.{bid}.norm_2", # mpt
172176
"model.layers.{bid}.post_attention_layernorm", # llama-hf
173177
"layers.{bid}.ffn_norm", # llama-pth
174-
"encoder.layer.{bid}.output.LayerNorm", # bert
175178
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
176179
"model.layers.{bid}.ln2", # yi
177180
"h.{bid}.ln_2", # gpt2
@@ -266,6 +269,10 @@ class TensorNameMap:
266269
MODEL_TENSOR.ROPE_FREQS: (
267270
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
268271
),
272+
273+
MODEL_TENSOR.LAYER_OUT_NORM: (
274+
"encoder.layer.{bid}.output.LayerNorm", # bert
275+
)
269276
}
270277

271278
mapping: dict[str, tuple[MODEL_TENSOR, str]]

0 commit comments

Comments
 (0)