Skip to content

Commit

Permalink
add refact model
Browse files Browse the repository at this point in the history
  • Loading branch information
ds5t5 committed Sep 25, 2023
1 parent c091cdf commit 81c00c2
Show file tree
Hide file tree
Showing 3 changed files with 659 additions and 6 deletions.
269 changes: 269 additions & 0 deletions convert-refact-hf-to-gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
#!/usr/bin/env python3
# HF falcon--> gguf conversion

from __future__ import annotations

import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]

if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf


def bytes_to_unicode():
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
return dict(zip(bs, (chr(n) for n in cs)))


def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1

if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a Refact model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
help="output format - use 0 for float32, 1 for float16",
)
return parser.parse_args()

args = parse_args()

dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():

print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16

# map from ftype to string
ftype_str = ["f32", "f16"]

if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+dir_model.name)

with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "GPTRefactForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit(1)

# get number of model parts
num_parts = count_model_parts(dir_model)

ARCH=gguf.MODEL_ARCH.REFACT
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])

print("gguf: get model metadata")

# Get refact feed forward dimension
hidden_dim = hparams["n_embd"]
inner_dim = 4 * hidden_dim
hidden_dim = int(2 * inner_dim / 3)
multiple_of = 256
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

block_count = hparams["n_layer"]

gguf_writer.add_name("Refact")
# refact uses Alibi. So this is from config.json which might be used by training.
gguf_writer.add_context_length(hparams["n_positions"])
gguf_writer.add_embedding_length(hparams["n_embd"])

gguf_writer.add_feed_forward_length(ff_dim)
gguf_writer.add_block_count(block_count)
gguf_writer.add_head_count(hparams["n_head"])
gguf_writer.add_head_count_kv(1)
gguf_writer.add_layer_norm_rms_eps(hparams["layer_norm_epsilon"])
gguf_writer.add_file_type(ftype)

# TOKENIZATION

print("gguf: get tokenizer metadata")

tokens: list[bytearray] = []
scores: list[float] = []
toktypes: list[int] = []

tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)

# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)

print("gguf: get gpt2 tokenizer vocab")

# The number of tokens in tokenizer.json can differ from the expected vocab size.
# This causes downstream issues with mismatched tensor sizes when running the inference
vocab_size = hparams["vocab_size"] if "vocab_size" in hparams else len(tokenizer_json["model"]["vocab"])

tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)

reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

for i in range(vocab_size):
if i in reverse_vocab:
text = reverse_vocab[i]
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)

tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy

gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

tensor_map = gguf.get_tensor_name_map(ARCH,block_count)

# params for qkv transform
n_head = hparams["n_head"]
n_head_kv = 1

head_dim = hparams["n_embd"] // n_head

# tensor info
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)
for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(dir_model / part_name, map_location="cpu")

for name in model_part.keys():
data = model_part[name]

old_dtype = data.dtype

# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)

data = data.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(new_name, data)


print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print(f"gguf: model successfully exported to '{fname_out}'")
print("")
32 changes: 27 additions & 5 deletions gguf-py/gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class MODEL_ARCH(IntEnum):
GPTNEOX : int = auto()
MPT : int = auto()
STARCODER : int = auto()
REFACT : int = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -116,6 +117,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
}

MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
Expand Down Expand Up @@ -185,6 +187,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.REFACT: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPT2: {
# TODO
},
Expand All @@ -209,7 +225,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 mpt
"transformer.wte", # gpt2 mpt refact
"transformer.word_embeddings", # falcon
"model.embed_tokens", # llama-hf
"tok_embeddings", # llama-pth
Expand All @@ -233,6 +249,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 falcon
"model.norm", # llama-hf baichuan
"norm", # llama-pth
"ln_f", # refact
),

# Rope frequencies
Expand All @@ -245,7 +262,7 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln_1", # gpt2 refact
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"transformer.h.{bid}.ln_mlp", # falcon40b
Expand All @@ -269,25 +286,28 @@ class TensorNameMap:
# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf
"transformer.h.{bid}.attn.q", # refact
"layers.{bid}.attention.wq", # llama-pth
),

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf
"transformer.h.{bid}.attn.k", # refact
"layers.{bid}.attention.wk", # llama-pth
),

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf
"transformer.h.{bid}.attn.v", # refact
"layers.{bid}.attention.wv", # llama-pth
),

# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"model.layers.{bid}.self_attn.o_proj", # llama-hf
Expand All @@ -303,7 +323,7 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2
"transformer.h.{bid}.ln_2", # gpt2 refact
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth
Expand All @@ -317,18 +337,20 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"model.layers.{bid}.mlp.up_proj", # llama-hf
"layers.{bid}.feed_forward.w3", # llama-pth
"transformer.h.{bid}.mlp.linear_3", # refact
),

# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf
"transformer.h.{bid}.mlp.linear_1", # refact
"layers.{bid}.feed_forward.w1", # llama-pth
),

# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"model.layers.{bid}.mlp.down_proj", # llama-hf
Expand Down
Loading

0 comments on commit 81c00c2

Please sign in to comment.