Skip to content

Commit

Permalink
convert-hf : support for mixtral-instruct (#4428)
Browse files Browse the repository at this point in the history
* convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct

* convert : use sentencepiece tokenizer for Mixtral-instruct

* convert : make flake8 happy
  • Loading branch information
Mrkvak committed Dec 12, 2023
1 parent 90c12e6 commit 82e4f64
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,18 @@ def set_gguf_parameters(self):
self.gguf_writer.add_embedding_length(n_embd)
if (n_ff := self.hparams.get("intermediate_size")) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
if (n_head := self.hparams.get("num_attention_head")) is not None:
if (n_head := self.hparams.get("num_attention_heads")) is not None:
self.gguf_writer.add_head_count(n_head)
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)

if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)

self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))

def write_tensors(self):
Expand Down Expand Up @@ -170,6 +180,8 @@ def from_model_architecture(model_architecture):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -207,6 +219,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -837,6 +851,11 @@ def set_gguf_parameters(self):
self.gguf_writer.add_layer_norm_eps(1e-5)


class MixtralModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()


class QwenModel(Model):
@staticmethod
def token_bytes_to_string(b):
Expand Down

0 comments on commit 82e4f64

Please sign in to comment.