Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Embedding extraction (#72)
Browse files Browse the repository at this point in the history
* Add code to extract embeddings

* Add ad_hoc_test for embeddings

* Use Tensor::read_data

* Adjust safety comment

* Remove ad-hoc test

* Clippy & fmt

* Fix comment that ended halfway

---------

Co-authored-by: Philpax <me@philpax.me>
  • Loading branch information
setzer22 and philpax authored Mar 26, 2023
1 parent b103dcd commit a067431
Showing 1 changed file with 51 additions and 3 deletions.
54 changes: 51 additions & 3 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod ggml;

use core::slice;
use std::{
collections::HashMap,
fmt::Display,
Expand Down Expand Up @@ -422,6 +423,17 @@ pub enum InferenceError {
UserCallback(Box<dyn std::error::Error>),
}

/// Used in a call to `evaluate` to request information from the transformer.
#[derive(Default)]
pub struct EvaluateOutputRequest {
/// Returns all the logits for the provided batch of tokens.
/// Output shape is n_batch * n_vocab
pub all_logits: Option<Vec<f32>>,
/// Returns the embeddings for the provided batch of tokens
/// Output shape is n_batch * n_embd
pub embeddings: Option<Vec<f32>>,
}

/// NOTE: The original code relies in promotion rules and automatic cast between
/// int to float. What we do instead is use this macro to convert every term of
/// the multiplication to f64, which should have enough precision bits to hold
Expand Down Expand Up @@ -1094,11 +1106,17 @@ impl Model {
}

/// Evaluates the transformer.
///
/// The provided `output_request` struct lets you specify which additional
/// data you are interested in fetching from the transformer. Setting a
/// field to a `Some` value will clear and fill the provided vector with
/// data. The provided vector will be resized to the exact output size.
pub fn evaluate(
&self,
session: &mut InferenceSession,
params: &InferenceParameters,
input_tokens: &[TokenId],
output_request: &mut EvaluateOutputRequest,
) {
let n = input_tokens.len();
let n_past = session.n_past as i32;
Expand Down Expand Up @@ -1317,12 +1335,16 @@ impl Model {
input_layer = current;
}

// Used at the end to optionally extract the embeddings.
let embeddings_tensor;

// norm
{
input_layer = ctx0.op_norm(&input_layer);

// inpL = norm*inpL
input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
embeddings_tensor = input_layer.share();
}

// lm_head
Expand All @@ -1347,6 +1369,28 @@ impl Model {
)
};

// Extract logits
if let Some(all_logits) = &mut output_request.all_logits {
all_logits.resize(n_vocab as usize * n, 0.0);
// SAFETY: Tensor data can be read (properly aligned, initialized,
// data will not be mutated or otherwise aliased during the copy),
// and we're not reading past the end of the tensor data.
assert_eq!(input_layer.nelements(), n_vocab * n as i32);
unsafe {
input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits));
}
}

// Extract embeddings
if let Some(embeddings) = &mut output_request.embeddings {
embeddings.resize(n_embd as usize * n, 0.0);
// SAFETY: Same rationale as for the "Extract logits" section applies.
assert_eq!(embeddings_tensor.nelements(), n_embd * n as i32);
unsafe {
embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings));
}
}

// Adjust the required memory per token if we didn't know that already
if session.mem_per_token == 0 {
session.mem_per_token = ctx0.used_mem() / n;
Expand Down Expand Up @@ -1418,7 +1462,7 @@ impl InferenceSession {
}

for batch in prompt_tokens.chunks(8) {
model.evaluate(self, params, batch);
model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default());
for &tk in batch {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
Expand Down Expand Up @@ -1452,7 +1496,12 @@ impl InferenceSession {
self.tokens.push(next_token);

// Then, evaluate the network again to compute the new last_logits
model.evaluate(self, params, &[next_token]);
model.evaluate(
self,
params,
&[next_token],
&mut EvaluateOutputRequest::default(),
);

// Return the next token
Ok(if next_token as TokenId == EOD_TOKEN_ID {
Expand Down Expand Up @@ -1533,7 +1582,6 @@ impl InferenceSession {
/// ggml context. While the provided `InferenceSnapshotRef` object is alive,
/// no other methods for this model object should be called.
pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
use core::slice;
let memory_k = unsafe {
slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
};
Expand Down

0 comments on commit a067431

Please sign in to comment.