Skip to content

Commit

Permalink
feat: Support Gemma 2 (#223)
Browse files Browse the repository at this point in the history
* feat: add gemma 2 model

---------

Co-authored-by: Rongjie Yi <41737961+yirongjie@users.noreply.github.com>
Co-authored-by: yirongjie <yirj0809@gmail.com>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 352e6da commit 036bbbb
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Why on-device multimodal LLM? - It's a key building block for [intelligent perso
| [TinyLLaMA 1.1B](https://github.com/jzhang38/TinyLlama) | [✔️](https://huggingface.co/mllmTeam/tinyllama-1.1b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/tinyllama-1.1b-mllm/tree/main) | |
| [LLaVA 7B](https://github.com/haotian-liu/LLaVA) | [✔️](https://huggingface.co/mllmTeam/llava-1.5-7b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/llava-1.5-7b-mllm/tree/main) | |
| [Gemma 2B](https://github.com/google/gemma_pytorch) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) | |
| [Gemma 2 2B](https://github.com/google/gemma_pytorch) | [✔️](https://huggingface.co/mllmTeam/gemma-2-2b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/gemma-2-2b-mllm/tree/main) | |
| [Qwen 1.5 0.5B](https://github.com/QwenLM/Qwen) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | |
| [Qwen 1.5 1.8B](https://github.com/QwenLM/Qwen) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-1.8b-chat-mllm) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-1.8b-chat-mllm) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-1.8b-chat-mllm) |
| [Qwen 2.5 1.5B](https://github.com/QwenLM/Qwen2.5) | [✔️](https://huggingface.co/mllmTeam/qwen-2.5-1.5b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/qwen-2.5-1.5b-mllm/tree/main) | |
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func_llm_add_executable(demo_stablelm)
func_llm_add_executable(demo_sparse_llama)
func_llm_add_executable(demo_elastic_llama)
func_llm_add_executable(demo_gemma)
func_llm_add_executable(demo_gemma2)
func_llm_add_executable(demo_qwen)
func_llm_add_executable(demo_mistral)
func_llm_add_executable(demo_yi)
Expand Down
53 changes: 53 additions & 0 deletions examples/demo_gemma2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "cmdline.h"
#include "models/gemma2/configuration_gemma2.hpp"
#include "models/gemma2/modeling_gemma2.hpp"
#include "models/gemma/tokenization_gemma.hpp"
#include "processor/PostProcess.hpp"

using namespace mllm;

int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/gemma2_vocab.mllm");
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/gemma-2-2b-q4_k.mllm");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
cmdParser.parse_check(argc, argv);

string vocab_path = cmdParser.get<string>("vocab");
string model_path = cmdParser.get<string>("model");
int tokens_limit = cmdParser.get<int>("limits");
CPUBackend::cpu_threads = cmdParser.get<int>("thread");

// gemma2 uses the same tokenizer as gemma
auto tokenizer = GemmaTokenizer(vocab_path);

Gemma2Config config(tokens_limit, "2B", RoPEType::HFHUBROPE);
auto model = Gemma2ForCausalLM(config);
model.load(model_path);

vector<string> in_strs = {
"Hello, who are you?",
"What can you do?",
"Please introduce Beijing University of Posts and Telecommunications."};

for (int i = 0; i < in_strs.size(); ++i) {
auto in_str = in_strs[i];
auto input_tensor = tokenizer.tokenize(in_str);

std::cout << "[Q] " << in_str << std::endl;
std::cout << "[A] " << std::flush;
for (int step = 0; step < 200; step++) {
auto result = model({input_tensor});
auto [out_string, out_token] = tokenizer.detokenize(result[0]);
auto [not_end, output_string] = tokenizer.postprocess(out_string);
if (!not_end) { break; }
std::cout << output_string << std::flush;
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
model.clear_kvcache();
}

return 0;
}
79 changes: 79 additions & 0 deletions src/models/gemma2/configuration_gemma2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#ifndef CONFIG_GEMMA2_HPP
#define CONFIG_GEMMA2_HPP
#include "Types.hpp"
#include "models/transformer/configuration_transformer.hpp"

using namespace mllm;

class Gemma2NameConfig : public TransformerNameConfig {
public:
/**
* @brief Gemma following the hugging face naming method
*
* @param type RoPEType
*/
void init(RoPEType type = RoPEType::HFHUBROPE) {
switch (type) {
case RoPEType::HFHUBROPE: {
blk_name = "model.layers.";
_attn_base_name = "self_attn.";
_ffn_base_name = "mlp.";
_q_proj_name = "q_proj";
_k_proj_name = "k_proj";
_v_proj_name = "v_proj";
_o_proj_name = "o_proj";
_gate_proj_name = "gate_proj";
_up_proj_name = "up_proj";
_down_proj_name = "down_proj";
_attn_norm_name = "input_layernorm";
_ffn_norm_name = "post_attention_layernorm";
_pre_feedforward_layernorm = "pre_feedforward_layernorm";
_post_feedforward_layernorm = "post_feedforward_layernorm";
token_embd_name = "model.embed_tokens";
post_norm_name = "model.norm";
lm_head_name = "model.embed_tokens";
break;
}
default: {
throw std::runtime_error("Unsupported gemma RoPE type");
}
}
}

std::string blk_name;
std::string token_embd_name;
std::string post_norm_name;
std::string lm_head_name;
std::string _gate_proj_name;
std::string _pre_feedforward_layernorm;
std::string _post_feedforward_layernorm;
};

struct Gemma2Config : public TransformerConfig {
explicit Gemma2Config(int token_limit, const string billions = "2B", RoPEType type = RoPEType::HFHUBROPE) :
cache_limit(token_limit) {
names_config.init(type);
if (!(billions == "2B" || billions == "2b")) {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;
};

int vocab_size = 256000;
int max_position_embeddings = 8192;
int num_hidden_layers = 26;
int num_attention_heads = 8;
int num_key_value_heads = 4;
int hidden_size = 2304;
int sliding_window = 4096;
int intermediate_size = 9216;
int head_dim = 256;
float rms_norm_eps = 1e-6;
float rope_theta = 10000;

int cache_limit;
RoPEType RoPE_type = RoPEType::HFHUBROPE;
Gemma2NameConfig names_config;
};

#endif //! CONFIG_GEMMA2_HPP
228 changes: 228 additions & 0 deletions src/models/gemma2/modeling_gemma2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
#ifndef MODELING_GEMMA2_HPP
#define MODELING_GEMMA2_HPP

#include "Backend.hpp"
#include "Layer.hpp"
#include "Module.hpp"
#include "Tensor.hpp"
#include "configuration_gemma2.hpp"
#include <cmath>
using namespace mllm;

class Gemma2Attention final : public Module {
public:
Gemma2Attention() {}
Gemma2Attention(const Gemma2Config &config, const Gemma2NameConfig &names, const string &base_name) {
hidden_size = config.hidden_size;
num_heads = config.num_attention_heads;
// in gemma2, the head_dim is fixed to 2048 / num_heads rather than hidden_size(2304) / num_heads
head_dim = 2048 / num_heads;
num_key_value_heads = config.num_key_value_heads;
num_key_value_groups = num_heads / num_key_value_heads;

// init layers
q_proj = Linear(hidden_size, head_dim * num_heads, false, base_name + names._q_proj_name);
k_proj = Linear(hidden_size, head_dim * num_key_value_heads, false,
base_name + names._k_proj_name);
v_proj = Linear(hidden_size, head_dim * num_key_value_heads, false,
base_name + names._v_proj_name);
o_proj = Linear(head_dim * num_heads, hidden_size, false, base_name + names._o_proj_name);
q_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings,
base_name + "q_rope");
k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings,
base_name + "k_rope");
k_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "k_cache");
v_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "v_cache");

softmax = Softmax(DIMENSION, true, base_name + "softmax");
}

std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
auto query_states = q_proj(inputs[0]);
auto key_states = k_proj(inputs[1]);
auto value_states = v_proj(inputs[2]);

// [batch, heads, sequence, dims]
query_states = query_states.view(-1, num_heads, -1, head_dim);
key_states = key_states.view(-1, num_key_value_heads, -1, head_dim);
value_states = value_states.view(-1, num_key_value_heads, -1, head_dim);

// embedding
query_states = q_rope(query_states);
key_states = k_rope(key_states);

// kv cache
key_states = k_cache(key_states);
value_states = v_cache(value_states);

// attention weight
auto atten_weight =
Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION))
/ std::sqrt(head_dim);

atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen());

// attention output
auto atten_output = Tensor::mm(atten_weight, value_states);
atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads);
atten_output = o_proj(atten_output);
return {atten_output};
}

vector<KVCache *> get_cache() {
return {&k_cache, &v_cache};
}
vector<RoPE *> get_rope() {
return {&q_rope, &k_rope};
}

private:
int hidden_size;
int num_heads;
int head_dim;
int num_key_value_heads;
int num_key_value_groups;
int layer_num = 0;
Layer q_proj;
Layer k_proj;
Layer v_proj;
Layer o_proj;
RoPE q_rope;
RoPE k_rope;
KVCache k_cache;
KVCache v_cache;
Softmax softmax;
};

class Gemma2MLP final : public Module {
public:
Gemma2MLP() = default;
Gemma2MLP(int hidden_size, int intermediate_size, const Gemma2NameConfig &names, const std::string &base_name) {
gate_proj = Linear(hidden_size, intermediate_size, false, base_name + names._gate_proj_name);
gelu = GELU(base_name + "act");
up_proj = Linear(hidden_size, intermediate_size, false, base_name + names._up_proj_name);
down_proj = Linear(intermediate_size, hidden_size, false, base_name + names._down_proj_name);
}

std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
auto x = gate_proj(inputs[0]);
x = gelu(x);
auto y = up_proj(inputs[0]);
x = x * y;
x = down_proj(x);
return {x};
}

private:
Layer gate_proj;
Layer up_proj;
Layer down_proj;

Layer gelu; ///< F.gelu(gate, approximate="tanh")
};

class Gemma2Decoder final : public Module {
public:
Gemma2Decoder() = default;
Gemma2Decoder(const Gemma2Config &config, const Gemma2NameConfig &names, const string &base_name) {
self_attn = Gemma2Attention(config, names, base_name + names._attn_base_name);
mlp = Gemma2MLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name);
input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._attn_norm_name);
post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._ffn_norm_name);
pre_feedforward_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._pre_feedforward_layernorm);
post_feedforward_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._post_feedforward_layernorm);
}

std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
auto x = input_layernorm(inputs[0]);
x = self_attn({x, x, x})[0];
x = post_attention_layernorm(x);
auto tmp = x + inputs[0];
x = pre_feedforward_layernorm(tmp);
x = mlp({x})[0];
x = post_feedforward_layernorm(x);
x = x + tmp;
return {x};
}

Gemma2Attention &get_attention() {
return self_attn;
}

private:
// MultiHeadAttention self_attn;
Gemma2Attention self_attn;
Gemma2MLP mlp;
Layer input_layernorm;
Layer post_attention_layernorm;
Layer pre_feedforward_layernorm;
Layer post_feedforward_layernorm;
};

class Gemma2Model final : public Module {
public:
Gemma2Model() = default;
Gemma2Model(const Gemma2Config &config, const Gemma2NameConfig &names, const string &base_name) {
blocks = List<Gemma2Decoder>(config.num_hidden_layers, config, names, base_name);
norm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, names.post_norm_name);
}

std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
auto x = inputs[0];
for (auto &block : blocks) {
x = block({x})[0];
}
x = norm(x);
return {x};
}

void clear_kvcache() override {
for (auto &block : blocks) {
auto kvcache = block.get_attention().get_cache();
for (auto &cache : kvcache) { cache->clearCache(); }
auto ropes = block.get_attention().get_rope();
for (auto &rope : ropes) { rope->clearCache(); }
}
}

private:
std::vector<Gemma2Decoder> blocks;
Layer norm;
};

class Gemma2ForCausalLM final : public Module {
public:
Gemma2ForCausalLM(Gemma2Config &config) {
auto names = config.names_config;
hidden_size = config.hidden_size;
embedding = Embedding(config.vocab_size, config.hidden_size, names.token_embd_name);
model = Gemma2Model(config, names, names.blk_name);

// gemma's lm_head and tok_embedding is tied together.
// They share same parameters. Use a Transpose to do the lm_head instead.
lm_head = Parameter(1, config.vocab_size, 1, config.hidden_size, names.lm_head_name + ".weight");
}

std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override {
auto x = embedding(inputs[0]);

// do nomalize
x = x * std::sqrt(hidden_size);

// go through model
auto outputs = model({x})[0];
outputs = Tensor::mm(outputs, lm_head().transpose(Chl::SEQUENCE, Chl::DIMENSION));
return {outputs};
}
void clear_kvcache() override {
model.clear_kvcache();
}

private:
int hidden_size;
Layer embedding;
Parameter lm_head;
Gemma2Model model;
};

#endif //! MODELING_GEMMA2_HPP
Binary file added vocab/gemma2_vocab.mllm
Binary file not shown.

0 comments on commit 036bbbb

Please sign in to comment.