Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Embedding extraction #72

Merged
merged 8 commits into from
Mar 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod ggml;

use core::slice;
use std::{
collections::HashMap,
fmt::Display,
Expand Down Expand Up @@ -422,6 +423,17 @@ pub enum InferenceError {
UserCallback(Box<dyn std::error::Error>),
}

/// Used in a call to `evaluate` to request information from the transformer.
#[derive(Default)]
pub struct EvaluateOutputRequest {
/// Returns all the logits for the provided batch of tokens.
/// Output shape is n_batch * n_vocab
pub all_logits: Option<Vec<f32>>,
/// Returns the embeddings for the provided batch of tokens
/// Output shape is n_batch * n_embd
pub embeddings: Option<Vec<f32>>,
}

/// NOTE: The original code relies in promotion rules and automatic cast between
/// int to float. What we do instead is use this macro to convert every term of
/// the multiplication to f64, which should have enough precision bits to hold
Expand Down Expand Up @@ -1094,11 +1106,17 @@ impl Model {
}

/// Evaluates the transformer.
///
/// The provided `output_request` struct lets you specify which additional
/// data you are interested in fetching from the transformer. Setting a
/// field to a `Some` value will clear and fill the provided vector with
/// data. The provided vector will be resized to the exact output size.
pub fn evaluate(
&self,
session: &mut InferenceSession,
params: &InferenceParameters,
input_tokens: &[TokenId],
output_request: &mut EvaluateOutputRequest,
) {
let n = input_tokens.len();
let n_past = session.n_past as i32;
Expand Down Expand Up @@ -1317,12 +1335,16 @@ impl Model {
input_layer = current;
}

// Used at the end to optionally extract the embeddings.
let embeddings_tensor;

// norm
{
input_layer = ctx0.op_norm(&input_layer);

// inpL = norm*inpL
input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
embeddings_tensor = input_layer.share();
}

// lm_head
Expand All @@ -1347,6 +1369,28 @@ impl Model {
)
};

// Extract logits
if let Some(all_logits) = &mut output_request.all_logits {
all_logits.resize(n_vocab as usize * n, 0.0);
// SAFETY: Tensor data can be read (properly aligned, initialized,
// data will not be mutated or otherwise aliased during the copy),
// and we're not reading past the end of the tensor data.
assert_eq!(input_layer.nelements(), n_vocab * n as i32);
unsafe {
input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits));
}
}

// Extract embeddings
if let Some(embeddings) = &mut output_request.embeddings {
embeddings.resize(n_embd as usize * n, 0.0);
// SAFETY: Same rationale as for the "Extract logits" section applies.
assert_eq!(embeddings_tensor.nelements(), n_embd * n as i32);
unsafe {
embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings));
}
}

// Adjust the required memory per token if we didn't know that already
if session.mem_per_token == 0 {
session.mem_per_token = ctx0.used_mem() / n;
Expand Down Expand Up @@ -1418,7 +1462,7 @@ impl InferenceSession {
}

for batch in prompt_tokens.chunks(8) {
model.evaluate(self, params, batch);
model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default());
for &tk in batch {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
Expand Down Expand Up @@ -1452,7 +1496,12 @@ impl InferenceSession {
self.tokens.push(next_token);

// Then, evaluate the network again to compute the new last_logits
model.evaluate(self, params, &[next_token]);
model.evaluate(
self,
params,
&[next_token],
&mut EvaluateOutputRequest::default(),
);

// Return the next token
Ok(if next_token as TokenId == EOD_TOKEN_ID {
Expand Down Expand Up @@ -1533,7 +1582,6 @@ impl InferenceSession {
/// ggml context. While the provided `InferenceSnapshotRef` object is alive,
/// no other methods for this model object should be called.
pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
use core::slice;
let memory_k = unsafe {
slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
};
Expand Down