Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DeepseekV2ForCausalLM #7519

Merged
merged 30 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c8c353f
Added initial support for DeepseekV2ForCausalLM.
sszymczy May 16, 2024
b24c9ed
Merge branch 'ggerganov:master' into deepseek-v2
fairydreaming May 17, 2024
0398964
Removed unnecessary tensor operations.
sszymczy May 18, 2024
b50c07c
Added five new DeepSeek-V2-specific parameters:
sszymczy May 18, 2024
79f8417
Added initial support for DeepSeek-V2-Lite model.
sszymczy May 18, 2024
6050941
Corrected mscale calculation.
sszymczy May 18, 2024
7e4786b
Added expert_weights_scale parameter for scaling MoE gate weights.
sszymczy May 19, 2024
71a7422
Temporarily hard-coded mscale value for DeepSeek-V2 (FIXME!).
sszymczy May 19, 2024
f99df46
Replaced hardcoded mscale value with rescaling attn_factor that resul…
sszymczy May 19, 2024
3ae7235
Whitespace formatting fixes.
sszymczy May 19, 2024
68a5103
Referenced the relevant GitHub discussion instead of providing long c…
sszymczy May 20, 2024
7be56da
Added YaRN log multiplier model header parameter corresponding to the…
sszymczy May 20, 2024
842ff3f
Added 16B and 236B model types for DeepSeek-V2.
sszymczy May 21, 2024
c033958
Removed usage of output bias tensor since it's not present in DeepSee…
sszymczy May 21, 2024
a54685b
Merge remote-tracking branch 'upstream/master' into deepseek-v2
sszymczy May 24, 2024
bb9c361
gguf-py : re-add SCALING_YARN_LOG_MUL removed during merge by accident
sszymczy May 24, 2024
f3b5e7d
llama : correct llm_build_moe_ffn() arguments in build_arctic()
sszymczy May 26, 2024
abef8b2
llama : code style corrections
sszymczy May 27, 2024
a654cd9
llama : rename n_expert_ff to n_ff_exp
sszymczy May 27, 2024
5a3e6b6
llama : rename qk_rope_head_dim, qk_nope_head_dim variables to n_embd…
sszymczy May 27, 2024
20769c0
llama : remove trailing whitespaces
sszymczy May 27, 2024
fac1e80
llama : rename moe_intermediate_size variable to n_ff_exp
sszymczy May 27, 2024
56f7011
llama : rename n_leading_dense_layer to n_layer_dense_lead
sszymczy May 27, 2024
82cec8b
llama : use attn_factor in mscale calculation to match the rope_yarn(…
sszymczy May 27, 2024
5cc7ec1
llama : rename query_states, key_states, value_states to q_states, k_…
sszymczy May 27, 2024
d02130d
llama : print DeekSeek-V2-specific parameters in llm_load_print_meta()
sszymczy May 27, 2024
bde971a
convert-hf : fix flake8 Lint errors
sszymczy May 27, 2024
98ff6e1
Merge remote-tracking branch 'upstream/master' into deepseek-v2
sszymczy May 28, 2024
841cd47
llama : replace ggml_new_tensor_3d + ggml_set_inplace + ggml_set_inpl…
sszymczy May 28, 2024
3efb659
gguf-py, llama : whitespace formatting fixes
sszymczy May 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2617,6 +2617,88 @@ def write_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("DeepseekV2ForCausalLM")
class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams

self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["v_head_dim"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1*hparams["rope_scaling"]["mscale_all_dim"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def write_tensors(self):
super().write_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


###### CONVERSION LOGIC ######


Expand Down
52 changes: 52 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ class LLM:
CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length"
BLOCK_COUNT = "{arch}.block_count"
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"

Expand All @@ -55,6 +59,8 @@ class Attention:
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand All @@ -64,6 +70,7 @@ class Rope:
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"

class SSM:
CONV_KERNEL = "{arch}.ssm.conv_kernel"
Expand Down Expand Up @@ -140,6 +147,7 @@ class MODEL_ARCH(IntEnum):
DBRX = auto()
OLMO = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -185,6 +193,12 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto()
SSM_D = auto()
SSM_OUT = auto()
ATTN_Q_A = auto()
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -221,6 +235,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -266,6 +281,12 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -757,6 +778,33 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.DEEPSEEK2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}

Expand Down Expand Up @@ -790,6 +838,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.DEEPSEEK2: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}

#
Expand Down
21 changes: 21 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,15 @@ def add_embedding_length(self, length: int) -> None:
def add_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)

def add_leading_dense_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)

def add_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_expert_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)

Expand Down Expand Up @@ -409,6 +415,12 @@ def add_expert_count(self, count: int) -> None:
def add_expert_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)

def add_expert_shared_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)

def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)

def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

Expand All @@ -418,6 +430,12 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_q_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)

def add_kv_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand All @@ -442,6 +460,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)

def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)

def add_ssm_conv_kernel(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)

Expand Down
29 changes: 28 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_UP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
),

# AWQ-activation gate
Expand Down Expand Up @@ -285,6 +286,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_GATE_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
),

# Feed-forward down
Expand Down Expand Up @@ -320,6 +322,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_DOWN_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_NORM: (
Expand Down Expand Up @@ -383,6 +386,30 @@ class TensorNameMap:
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
),

MODEL_TENSOR.ATTN_Q_A: (
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_B: (
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_KV_A_MQA: (
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
),

MODEL_TENSOR.ATTN_KV_B: (
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),

MODEL_TENSOR.ATTN_KV_A_NORM: (
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
),
}

# architecture-specific block mappings
Expand Down Expand Up @@ -415,7 +442,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 128
n_experts = 160
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
Expand Down
Loading
Loading