Skip to content

hparams : move vocab params to llama_vocab #11159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & vocab = lctx.model.vocab;

const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);

const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_vocab = vocab.n_vocab();
const auto n_embd = hparams.n_embd;

// TODO: use a per-batch flag for logits presence instead
Expand Down Expand Up @@ -540,7 +541,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
void llama_output_reorder(struct llama_context & ctx) {
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = ctx.model.hparams.n_vocab;
const uint32_t n_vocab = ctx.model.vocab.n_vocab();
const uint32_t n_embd = ctx.model.hparams.n_embd;

const int32_t n_outputs = ctx.n_outputs;
Expand Down Expand Up @@ -724,7 +725,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}

return ctx->logits + j*ctx->model.hparams.n_vocab;
return ctx->logits + j*ctx->model.vocab.n_vocab();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
Expand Down Expand Up @@ -884,7 +885,7 @@ struct llama_data_write {
}

void write_logits(const struct llama_context * ctx) {
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_vocab());

write(&logits_size, sizeof(logits_size));

Expand Down
2 changes: 0 additions & 2 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ struct llama_hparams {
bool use_par_res;
bool swin_norm;

uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
uint32_t n_embd_features = 0;
Expand All @@ -41,7 +40,6 @@ struct llama_hparams {
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0;

// for WavTokenizer
Expand Down
21 changes: 9 additions & 12 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// get general kv
ml.get_key(LLM_KV_GENERAL_NAME, name, false);

// get hparams kv
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);

// everything past this point is not vocab-related
if (hparams.vocab_only) {
return;
Expand Down Expand Up @@ -500,6 +497,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.n_embd_head_v = 0;
}

// for differentiating model types
uint32_t n_vocab = 0;
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);

// arch-specific KVs
switch (arch) {
case LLM_ARCH_LLAMA:
Expand All @@ -519,7 +520,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case 26: type = LLM_TYPE_3B; break;
case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B
// granite uses a vocab with len 49152
case 32: type = hparams.n_vocab == 49152 ? LLM_TYPE_3B : (hparams.n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break;
case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break;
case 36: type = LLM_TYPE_8B; break; // granite
case 40: type = LLM_TYPE_13B; break;
case 48: type = LLM_TYPE_34B; break;
Expand Down Expand Up @@ -621,7 +622,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);

switch (hparams.n_layer) {
Expand All @@ -644,7 +644,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
hparams.f_max_alibi_bias = 8.0f;

Expand All @@ -658,7 +657,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);

if (hparams.n_layer == 12 && hparams.n_embd == 768) {
Expand Down Expand Up @@ -1365,8 +1363,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_ff = hparams.n_ff();
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_vocab = vocab.n_vocab();
const int64_t n_token_types = vocab.n_token_types();
const int64_t n_rot = hparams.n_rot;
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
Expand Down Expand Up @@ -1811,7 +1809,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_NOMIC_BERT:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);

if (arch == LLM_ARCH_BERT) {
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
Expand Down Expand Up @@ -1865,7 +1863,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_JINA_BERT_V2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings

tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias
Expand Down Expand Up @@ -3494,7 +3492,6 @@ void llama_model::print_info() const {

// hparams
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str());
LLAMA_LOG_INFO("%s: n_vocab (hp) = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only);

if (!hparams.vocab_only) {
Expand Down
Loading
Loading