From a067431773ba0194d6bc352faf40a770c1eac81c Mon Sep 17 00:00:00 2001 From: setzer22 Date: Sun, 26 Mar 2023 20:56:03 +0200 Subject: [PATCH] Embedding extraction (#72) * Add code to extract embeddings * Add ad_hoc_test for embeddings * Use Tensor::read_data * Adjust safety comment * Remove ad-hoc test * Clippy & fmt * Fix comment that ended halfway --------- Co-authored-by: Philpax --- llama-rs/src/lib.rs | 54 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index aa0a5baf..c718dda1 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,5 +1,6 @@ mod ggml; +use core::slice; use std::{ collections::HashMap, fmt::Display, @@ -422,6 +423,17 @@ pub enum InferenceError { UserCallback(Box), } +/// Used in a call to `evaluate` to request information from the transformer. +#[derive(Default)] +pub struct EvaluateOutputRequest { + /// Returns all the logits for the provided batch of tokens. + /// Output shape is n_batch * n_vocab + pub all_logits: Option>, + /// Returns the embeddings for the provided batch of tokens + /// Output shape is n_batch * n_embd + pub embeddings: Option>, +} + /// NOTE: The original code relies in promotion rules and automatic cast between /// int to float. What we do instead is use this macro to convert every term of /// the multiplication to f64, which should have enough precision bits to hold @@ -1094,11 +1106,17 @@ impl Model { } /// Evaluates the transformer. + /// + /// The provided `output_request` struct lets you specify which additional + /// data you are interested in fetching from the transformer. Setting a + /// field to a `Some` value will clear and fill the provided vector with + /// data. The provided vector will be resized to the exact output size. pub fn evaluate( &self, session: &mut InferenceSession, params: &InferenceParameters, input_tokens: &[TokenId], + output_request: &mut EvaluateOutputRequest, ) { let n = input_tokens.len(); let n_past = session.n_past as i32; @@ -1317,12 +1335,16 @@ impl Model { input_layer = current; } + // Used at the end to optionally extract the embeddings. + let embeddings_tensor; + // norm { input_layer = ctx0.op_norm(&input_layer); // inpL = norm*inpL input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer); + embeddings_tensor = input_layer.share(); } // lm_head @@ -1347,6 +1369,28 @@ impl Model { ) }; + // Extract logits + if let Some(all_logits) = &mut output_request.all_logits { + all_logits.resize(n_vocab as usize * n, 0.0); + // SAFETY: Tensor data can be read (properly aligned, initialized, + // data will not be mutated or otherwise aliased during the copy), + // and we're not reading past the end of the tensor data. + assert_eq!(input_layer.nelements(), n_vocab * n as i32); + unsafe { + input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); + } + } + + // Extract embeddings + if let Some(embeddings) = &mut output_request.embeddings { + embeddings.resize(n_embd as usize * n, 0.0); + // SAFETY: Same rationale as for the "Extract logits" section applies. + assert_eq!(embeddings_tensor.nelements(), n_embd * n as i32); + unsafe { + embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings)); + } + } + // Adjust the required memory per token if we didn't know that already if session.mem_per_token == 0 { session.mem_per_token = ctx0.used_mem() / n; @@ -1418,7 +1462,7 @@ impl InferenceSession { } for batch in prompt_tokens.chunks(8) { - model.evaluate(self, params, batch); + model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default()); for &tk in batch { // NOTE: No string ever tokenizes to the end of sentence. So we // can just return the id here. @@ -1452,7 +1496,12 @@ impl InferenceSession { self.tokens.push(next_token); // Then, evaluate the network again to compute the new last_logits - model.evaluate(self, params, &[next_token]); + model.evaluate( + self, + params, + &[next_token], + &mut EvaluateOutputRequest::default(), + ); // Return the next token Ok(if next_token as TokenId == EOD_TOKEN_ID { @@ -1533,7 +1582,6 @@ impl InferenceSession { /// ggml context. While the provided `InferenceSnapshotRef` object is alive, /// no other methods for this model object should be called. pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> { - use core::slice; let memory_k = unsafe { slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes()) };