Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Copy v_transposed like llama.cpp #68

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ fn main() {
}
}),
play_back_previous_tokens: false,
..Default::default()
};
let inference_session_params = {
let mem_typ = if args.float16 {
Expand Down
63 changes: 44 additions & 19 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ impl Default for InferenceSessionParameters {
}
}

#[derive(Clone, Debug, PartialEq)]
/// The parameters that drive text generation.
pub struct InferenceParameters {
pub n_threads: i32,
Expand All @@ -184,6 +185,7 @@ pub struct InferenceParameters {
pub temp: f32,
pub bias_tokens: TokenBias,
pub play_back_previous_tokens: bool,
pub increased_determinism: bool,
}

impl Default for InferenceParameters {
Expand All @@ -197,6 +199,7 @@ impl Default for InferenceParameters {
temp: 0.80,
bias_tokens: TokenBias::default(),
play_back_previous_tokens: false,
increased_determinism: true,
}
}
}
Expand Down Expand Up @@ -1094,11 +1097,13 @@ impl Model {
pub fn evaluate(
&self,
session: &mut InferenceSession,
n_threads: i32,
params: &InferenceParameters,
input_tokens: &[TokenId],
) {
let n = input_tokens.len();
let n_past = session.n_past as i32;
let n_threads = params.n_threads;
let increased_determinism = params.increased_determinism;

let Hyperparameters {
n_vocab,
Expand Down Expand Up @@ -1127,6 +1132,27 @@ impl Model {

let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);

// Defined here to avoid repetition and creating a binding inside nested loops.
// See the call site below for more context.
let vtrans_fun = |il: usize| -> ggml::Tensor {
ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
&session.memory_v,
(n_past + n as i32) * n_embd,
il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize,
),
n_embd / n_head,
n_head,
n_past + n as i32,
),
1,
2,
0,
3,
)
};

for il in 0..n_layer as usize {
let input_self_attention = input_layer.share();
let mut current: ggml::Tensor;
Expand Down Expand Up @@ -1226,22 +1252,21 @@ impl Model {
let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
let v_transposed = ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
&session.memory_v,
(n_past + n as i32) * n_embd,
il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize,
),
n_embd / n_head,
n_head,
n_past + n as i32,
),
1,
2,
0,
3,
);
let v_transposed = {
if !increased_determinism {
vtrans_fun(il)
} else {
ctx0.op_cpy(
&vtrans_fun(il),
&ctx0.new_tensor_3d(
ggml::TYPE_F32,
n_past + n as i32,
n_embd / n_head,
n_head,
),
)
}
};

// KQV = transpose(V) * KQ_soft_max
let k_q_v = ctx0.op_mul_mat(&v_transposed, &k_q_soft_max);
Expand Down Expand Up @@ -1393,7 +1418,7 @@ impl InferenceSession {
}

for batch in prompt_tokens.chunks(8) {
model.evaluate(self, params.n_threads, batch);
model.evaluate(self, params, batch);
for &tk in batch {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
Expand Down Expand Up @@ -1427,7 +1452,7 @@ impl InferenceSession {
self.tokens.push(next_token);

// Then, evaluate the network again to compute the new last_logits
model.evaluate(self, params.n_threads, &[next_token]);
model.evaluate(self, params, &[next_token]);

// Return the next token
Ok(if next_token as TokenId == EOD_TOKEN_ID {
Expand Down