From 542a410c238842c2d2788cce33bc950636f007ce Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Sat, 2 Mar 2024 16:45:33 +0100 Subject: [PATCH 01/10] Add functionality for creating embeddings --- llama-cpp-2/src/context.rs | 32 ++++++++++++++-- llama-cpp-2/src/context/params.rs | 61 ++++++++++++++++++++++++++++++- llama-cpp-2/src/lib.rs | 9 +++++ llama-cpp-2/src/llama_backend.rs | 14 +++++++ llama-cpp-2/src/llama_batch.rs | 52 +++++++++++++++++++++++--- llama-cpp-2/src/model.rs | 28 +++++++------- 6 files changed, 174 insertions(+), 22 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 2cccb13a..71397625 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -2,15 +2,15 @@ use std::fmt::{Debug, Formatter}; use std::num::NonZeroI32; +use std::ptr::NonNull; +use std::slice; +use crate::{DecodeError, EmbeddingsError}; use crate::llama_batch::LlamaBatch; use crate::model::LlamaModel; use crate::timing::LlamaTimings; use crate::token::data::LlamaTokenData; use crate::token::LlamaToken; -use crate::DecodeError; -use std::ptr::NonNull; -use std::slice; pub mod kv_cache; pub mod params; @@ -24,6 +24,7 @@ pub struct LlamaContext<'a> { /// a reference to the contexts model. pub model: &'a LlamaModel, initialized_logits: Vec, + embeddings_enabled: bool, } impl Debug for LlamaContext<'_> { @@ -38,11 +39,13 @@ impl<'model> LlamaContext<'model> { pub(crate) fn new( llama_model: &'model LlamaModel, llama_context: NonNull, + embeddings_enabled: bool, ) -> Self { Self { context: llama_context, model: llama_model, initialized_logits: Vec::new(), + embeddings_enabled, } } @@ -80,6 +83,29 @@ impl<'model> LlamaContext<'model> { } } + /// Get the embeddings for the `i`th sequence in the current context. + /// + /// # Returns + /// + /// A slice containing the embeddings for the last decoded batch. + /// The size corresponds to the `n_embd` parameter of the context's model. + /// + /// # Errors + /// + /// When the current context was constructed without enabling embeddings. + pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { + if !self.embeddings_enabled { + return Err(EmbeddingsError::NotEnabled) + } + + unsafe { + Ok(std::slice::from_raw_parts( + llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i), + self.model.n_embd() as usize, + )) + } + } + /// Get the logits for the ith token in the context. /// /// # Panics diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index fba41c2d..60566cb5 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -1,8 +1,9 @@ //! A safe wrapper around `llama_context_params`. -use llama_cpp_sys_2; use std::fmt::Debug; use std::num::NonZeroU32; +use llama_cpp_sys_2; + /// A rusty wrapper around `rope_scaling_type`. #[repr(i8)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -267,6 +268,19 @@ impl LlamaContextParams { self.context_params.n_threads } + /// Get the number of threads allocated for batches. + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert_eq!(params.n_threads_batch(), 4); + /// ``` + #[must_use] + pub fn n_threads_batch(&self) -> u32 { + self.context_params.n_threads_batch + } + /// Set the number of threads. /// /// # Examples @@ -282,6 +296,51 @@ impl LlamaContextParams { self.context_params.n_threads = n_threads; self } + + /// Set the number of threads allocated for batches. + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_n_threads_batch(8); + /// assert_eq!(params.n_threads_batch(), 8); + /// ``` + #[must_use] + pub fn with_n_threads_batch(mut self, n_threads: u32) -> Self { + self.context_params.n_threads_batch = n_threads; + self + } + + /// Check whether embeddings are enabled + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert!(!params.embedding()); + /// ``` + #[must_use] + pub fn embedding(&self) -> bool { + self.context_params.embedding + } + + /// Enable the use of embeddings + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_embedding(true); + /// assert!(params.embedding()); + /// ``` + #[must_use] + pub fn with_embedding(mut self, embedding: bool) -> Self { + self.context_params.embedding = embedding; + self + } } /// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`) diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index ab9efb64..c8ed8425 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -52,6 +52,8 @@ pub enum LLamaCppError { /// There was an error adding a token to a batch. #[error["{0}"]] BatchAddError(#[from] BatchAddError), + #[error(transparent)] + EmbeddingError(#[from] EmbeddingsError), } /// Failed to Load context @@ -76,6 +78,13 @@ pub enum DecodeError { Unknown(c_int), } +/// When embedding related functions fail +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum EmbeddingsError { + #[error("Embeddings weren't enabled in the context options")] + NotEnabled, +} + /// Decode a error from llama.cpp into a [`DecodeError`]. impl From for DecodeError { fn from(value: NonZeroI32) -> Self { diff --git a/llama-cpp-2/src/llama_backend.rs b/llama-cpp-2/src/llama_backend.rs index 59b4a39a..cfd45e10 100644 --- a/llama-cpp-2/src/llama_backend.rs +++ b/llama-cpp-2/src/llama_backend.rs @@ -3,6 +3,7 @@ use crate::LLamaCppError; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::SeqCst; +use llama_cpp_sys_2::ggml_log_level; /// Representation of an initialized llama backend /// This is required as a parameter for most llama functions as the backend must be initialized @@ -68,6 +69,19 @@ impl LlamaBackend { } Ok(LlamaBackend {}) } + + /// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`. + pub fn void_logs(&mut self) { + unsafe extern "C" fn void_log( + _level: ggml_log_level, + _text: *const ::std::os::raw::c_char, + _user_data: *mut ::std::os::raw::c_void, + ) {} + + unsafe { + llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut()) + } + } } /// A rusty wrapper around `numa_strategy`. diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 0748dd85..cb94c502 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -6,11 +6,11 @@ use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos /// A safe wrapper around `llama_batch`. #[derive(Debug)] pub struct LlamaBatch { - /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initilized + /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized allocated: usize, - /// The logits that are initilized. Used by [`LlamaContext`] to ensure that only initilized logits are accessed. + /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed. pub(crate) initialized_logits: Vec, - /// The llama_cpp batch. always initilize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` + /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` pub(crate) llama_batch: llama_batch, } @@ -31,7 +31,7 @@ impl LlamaBatch { } /// add a token to the batch for sequences [`seq_ids`] at position [pos]. If [logits] is true, the - /// token will be initilized and can be read from after the next decode. + /// token will be initialized and can be read from after the next decode. /// /// # Panics /// @@ -90,7 +90,49 @@ impl LlamaBatch { Ok(()) } - /// Create a new `LlamaBatch` that cab contain up to `n_tokens` tokens. + + /// Add a sequence of tokens to the batch for the given sequence id. If [logits_all] is true, the + /// tokens will be initialized and can be read from after the next decode. + /// + /// Either way the last token in the sequence will have its logits set to `true`. + /// + /// # Errors + /// + /// Returns an error if there is insufficient space in the buffer + pub fn add_sequence(&mut self, tokens: &[LlamaToken], + seq_id: i32, + logits_all: bool) -> Result<(), BatchAddError> { + let n_tokens_0 = self.llama_batch.n_tokens; + let n_tokens = tokens.len(); + + if self.allocated < n_tokens_0 as usize + n_tokens { + return Err(BatchAddError::InsufficientSpace(self.allocated)); + } + if n_tokens == 0 { + return Ok(()) + } + + self.llama_batch.n_tokens += n_tokens as i32; + for (i, token) in tokens.iter().enumerate() { + let j = n_tokens_0 as usize + i; + unsafe { + self.llama_batch.token.add(j).write(token.0); + self.llama_batch.pos.add(j).write(i as i32); + let seq_id_ptr = *self.llama_batch.seq_id.add(j); + seq_id_ptr.write(seq_id); + self.llama_batch.n_seq_id.add(j).write(1); + self.llama_batch.logits.add(j).write(logits_all as i8) + } + } + + unsafe { + self.llama_batch.logits.add(n_tokens - 1).write(true as i8); + } + + Ok(()) + } + + /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens. /// /// # Arguments /// diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 6d2242dc..35b4a833 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -1,15 +1,16 @@ //! A safe wrapper around `llama_model`. -use crate::context::params::LlamaContextParams; +use std::ffi::CString; +use std::os::raw::c_int; +use std::path::Path; +use std::ptr::NonNull; + +use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError}; use crate::context::LlamaContext; +use crate::context::params::LlamaContextParams; use crate::llama_backend::LlamaBackend; use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::LlamaTokenType; -use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError}; -use std::ffi::CString; -use std::os::raw::c_int; -use std::path::Path; -use std::ptr::NonNull; pub mod params; @@ -29,6 +30,7 @@ pub enum AddBos { /// Do not add the beginning of stream token to the start of the string. Never, } + unsafe impl Send for LlamaModel {} unsafe impl Sync for LlamaModel {} @@ -38,12 +40,12 @@ impl LlamaModel { /// /// # Panics /// - /// If the number of tokens the model was trained on does not fit into an `u16`. This should be impossible on most + /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive. #[must_use] - pub fn n_ctx_train(&self) -> u16 { + pub fn n_ctx_train(&self) -> u32 { let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) }; - u16::try_from(n_ctx_train).expect("n_ctx_train fits into an u16") + u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32") } /// Get all tokens in the model. @@ -54,6 +56,7 @@ impl LlamaModel { .map(LlamaToken::new) .map(|llama_token| (llama_token, self.token_to_str(llama_token))) } + /// Get the beginning of stream token. #[must_use] pub fn token_bos(&self) -> LlamaToken { @@ -276,7 +279,7 @@ impl LlamaModel { /// # Errors /// /// See [`LlamaModelLoadError`] for more information. - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, fields(params))] pub fn load_from_file( _: &LlamaBackend, path: impl AsRef, @@ -290,13 +293,12 @@ impl LlamaModel { let cstr = CString::new(path)?; let llama_model = unsafe { - println!("{:?}", params.params); llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) }; let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?; - println!("Loaded {path:?}"); + tracing::debug!(?path, "Loaded model"); Ok(LlamaModel { model }) } @@ -318,7 +320,7 @@ impl LlamaModel { }; let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?; - Ok(LlamaContext::new(self, context)) + Ok(LlamaContext::new(self, context, params.embedding())) } } From c0faaef3753965278c9874939098f31eb411b9b6 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Sat, 2 Mar 2024 17:23:29 +0100 Subject: [PATCH 02/10] Add en embeddings example --- Cargo.lock | 10 ++ Cargo.toml | 2 +- embeddings/Cargo.toml | 15 +++ embeddings/src/main.rs | 216 +++++++++++++++++++++++++++++++++ llama-cpp-2/src/llama_batch.rs | 1 + 5 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 embeddings/Cargo.toml create mode 100644 embeddings/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 9f0c657d..68241a6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,16 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "embeddings" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "hf-hub", + "llama-cpp-2", +] + [[package]] name = "encode_unicode" version = "0.3.6" diff --git a/Cargo.toml b/Cargo.toml index ec70b251..4cc4588f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = [ "llama-cpp-sys-2", "llama-cpp-2", - "simple", + "simple", "embeddings", ] [workspace.dependencies] diff --git a/embeddings/Cargo.toml b/embeddings/Cargo.toml new file mode 100644 index 00000000..86f4e5aa --- /dev/null +++ b/embeddings/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "embeddings" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.34" } +hf-hub = { workspace = true } +clap = { workspace = true , features = ["derive"] } +anyhow = { workspace = true } + +[lints] +workspace = true diff --git a/embeddings/src/main.rs b/embeddings/src/main.rs new file mode 100644 index 00000000..9cf00fa4 --- /dev/null +++ b/embeddings/src/main.rs @@ -0,0 +1,216 @@ +//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2. +#![allow( +clippy::cast_possible_wrap, +clippy::cast_possible_truncation, +clippy::cast_precision_loss, +clippy::cast_sign_loss +)] + +use std::io::Write; +use std::path::PathBuf; +use std::str::FromStr; +use std::time::Duration; + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use hf_hub::api::sync::ApiBuilder; +use llama_cpp_2::context::LlamaContext; + +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::ggml_time_us; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::AddBos; +use llama_cpp_2::model::LlamaModel; +use llama_cpp_2::model::params::LlamaModelParams; + +#[derive(clap::Parser, Debug, Clone)] +struct Args { + /// The path to the model + #[command(subcommand)] + model: Model, + /// The prompt + #[clap(default_value = "Hello my name is")] + prompt: String, + /// Whether to normalise the produced embeddings + #[clap(short)] + normalise: bool, + /// Disable offloading layers to the gpu + #[cfg(feature = "cublas")] + #[clap(long)] + disable_gpu: bool, +} + + +#[derive(clap::Subcommand, Debug, Clone)] +enum Model { + /// Use an already downloaded model + Local { + /// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa` + path: PathBuf, + }, + /// Download a model from huggingface (or use a cached version) + #[clap(name = "hf-model")] + HuggingFace { + /// the repo containing the model. e.g. `TheBloke/Llama-2-7B-Chat-GGUF` + repo: String, + /// the model name. e.g. `llama-2-7b-chat.Q4_K_M.gguf` + model: String, + }, +} + +impl Model { + /// Convert the model to a path - may download from huggingface + fn get_or_load(self) -> Result { + match self { + Model::Local { path } => Ok(path), + Model::HuggingFace { model, repo } => ApiBuilder::new() + .with_progress(true) + .build() + .with_context(|| "unable to create huggingface api")? + .model(repo) + .get(&model) + .with_context(|| "unable to download model"), + } + } +} + +fn main() -> Result<()> { + let Args { + model, + prompt, + normalise, + #[cfg(feature = "cublas")] + disable_gpu, + } = Args::parse(); + + // init LLM + let backend = LlamaBackend::init()?; + + // offload all layers to the gpu + let model_params = { + #[cfg(feature = "cublas")] + if !disable_gpu { + LlamaModelParams::default().with_n_gpu_layers(1000) + } else { + LlamaModelParams::default() + } + #[cfg(not(feature = "cublas"))] + LlamaModelParams::default() + }; + + let model_path = model + .get_or_load() + .with_context(|| "failed to get model from args")?; + + let model = LlamaModel::load_from_file(&backend, model_path, &model_params) + .with_context(|| "unable to load model")?; + + // initialize the context + let ctx_params = LlamaContextParams::default() + .with_n_threads_batch(std::thread::available_parallelism()?.get() as u32) + .with_embedding(true); + + let mut ctx = model + .new_context(&backend, ctx_params) + .with_context(|| "unable to create the llama_context")?; + + // Split the prompt to display the batching functionality + let prompt_lines = prompt.lines(); + + // tokenize the prompt + let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always)) + .collect::, _>>() + .with_context(|| format!("failed to tokenize {prompt}"))?; + + let n_ctx = ctx.n_ctx() as usize; + let n_ctx_train = model.n_ctx_train(); + + eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}"); + + if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) { + bail!("One of the provided prompts exceeds the size of the context window"); + } + + // print the prompt token-by-token + eprintln!(); + + for (i, token_line) in tokens_lines_list.iter().enumerate() { + eprintln!("Prompt {i}"); + for token in token_line { + eprintln!(" {} --> {}", token, model.token_to_str(*token)?); + } + eprintln!() + } + + std::io::stderr().flush()?; + + // create a llama_batch with the size of the context + // we use this object to submit token data for decoding + let mut batch = LlamaBatch::new(n_ctx, tokens_lines_list.len() as i32); + + // Amount of tokens in the current batch + let mut s_batch = 0; + let mut output = Vec::with_capacity(tokens_lines_list.len()); + + let t_main_start = ggml_time_us(); + + for tokens in &tokens_lines_list { + // Flush the batch if the next prompt would exceed our batch size + if (batch.n_tokens() as usize + tokens.len()) > n_ctx { + batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?; + s_batch = 0; + } + + batch.add_sequence(&tokens, s_batch, false)?; + s_batch += 1; + } + // Handle final batch + batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?; + + let t_main_end = ggml_time_us(); + + for (i, embeddings) in output.iter().enumerate() { + eprintln!("Embeddings {i}: {embeddings:?}"); + eprintln!("\n"); + } + + let duration = Duration::from_micros((t_main_end - t_main_start) as u64); + let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum(); + + eprintln!( + "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", + total_tokens, + duration.as_secs_f32(), + total_tokens as f32 / duration.as_secs_f32() + ); + + println!("{}", ctx.timings()); + + Ok(()) +} + +fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec>, normalise: bool) -> Result<()> { + ctx.clear_kv_cache(); + ctx.decode(batch).with_context(|| "llama_decode() failed")?; + batch.clear(); + + for i in 0..s_batch { + let embedding = ctx.embeddings_ith(i).with_context(|| "Failed to get embeddings")?; + let output_embeddings = if normalise { + normalize(embedding) + } else { + embedding.to_vec() + }; + + output.push(output_embeddings); + } + + Ok(()) +} + +fn normalize(input: &[f32]) -> Vec { + let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt(); + + input.iter().map(|&val| val / magnitude).collect() +} diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index cb94c502..4baaa0bb 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -127,6 +127,7 @@ impl LlamaBatch { unsafe { self.llama_batch.logits.add(n_tokens - 1).write(true as i8); + self.initialized_logits.push(self.llama_batch.n_tokens - 1); } Ok(()) From 32b53edf5162149b84548c1002efb3c687d58038 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 18:00:14 +0100 Subject: [PATCH 03/10] Update for the (hopefully stable!) `llama.cpp` changes. --- embeddings/src/main.rs | 7 ++-- llama-cpp-2/src/context.rs | 53 +++++++++++++++++++++++++++---- llama-cpp-2/src/context/params.rs | 14 ++++---- llama-cpp-2/src/lib.rs | 4 +++ llama-cpp-2/src/llama_batch.rs | 9 +++--- llama-cpp-2/src/model.rs | 2 +- 6 files changed, 66 insertions(+), 23 deletions(-) diff --git a/embeddings/src/main.rs b/embeddings/src/main.rs index 9cf00fa4..31371378 100644 --- a/embeddings/src/main.rs +++ b/embeddings/src/main.rs @@ -109,7 +109,7 @@ fn main() -> Result<()> { // initialize the context let ctx_params = LlamaContextParams::default() .with_n_threads_batch(std::thread::available_parallelism()?.get() as u32) - .with_embedding(true); + .with_embeddings(true); let mut ctx = model .new_context(&backend, ctx_params) @@ -193,10 +193,9 @@ fn main() -> Result<()> { fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec>, normalise: bool) -> Result<()> { ctx.clear_kv_cache(); ctx.decode(batch).with_context(|| "llama_decode() failed")?; - batch.clear(); for i in 0..s_batch { - let embedding = ctx.embeddings_ith(i).with_context(|| "Failed to get embeddings")?; + let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?; let output_embeddings = if normalise { normalize(embedding) } else { @@ -206,6 +205,8 @@ fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, ou output.push(output_embeddings); } + batch.clear(); + Ok(()) } diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 71397625..080e8fc6 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -5,12 +5,12 @@ use std::num::NonZeroI32; use std::ptr::NonNull; use std::slice; -use crate::{DecodeError, EmbeddingsError}; use crate::llama_batch::LlamaBatch; use crate::model::LlamaModel; use crate::timing::LlamaTimings; use crate::token::data::LlamaTokenData; use crate::token::LlamaToken; +use crate::{DecodeError, EmbeddingsError}; pub mod kv_cache; pub mod params; @@ -92,17 +92,51 @@ impl<'model> LlamaContext<'model> { /// /// # Errors /// - /// When the current context was constructed without enabling embeddings. + /// - When the current context was constructed without enabling embeddings. + /// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`] + /// - If the given sequence index exceeds the max sequence id. + pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { + if !self.embeddings_enabled { + return Err(EmbeddingsError::NotEnabled); + } + + unsafe { + let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i); + + // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here. + if embedding.is_null() { + Err(EmbeddingsError::NonePoolType) + } else { + Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize)) + } + } + } + + /// Get the embeddings for the `i`th token in the current context. + /// + /// # Returns + /// + /// A slice containing the embeddings for the last decoded batch of the given token. + /// The size corresponds to the `n_embd` parameter of the context's model. + /// + /// # Errors + /// + /// - When the current context was constructed without enabling embeddings. + /// - When the given token didn't have logits enabled when it was passed. + /// - If the given token index exceeds the max token id. pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { if !self.embeddings_enabled { - return Err(EmbeddingsError::NotEnabled) + return Err(EmbeddingsError::NotEnabled); } unsafe { - Ok(std::slice::from_raw_parts( - llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i), - self.model.n_embd() as usize, - )) + let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i); + // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here. + if embedding.is_null() { + Err(EmbeddingsError::LogitsNotEnabled) + } else { + Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize)) + } } } @@ -155,6 +189,11 @@ impl<'model> LlamaContext<'model> { let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) }; LlamaTimings { timings } } + + /// Returns a reference to the raw [llama_cpp_sys_2::llama_context] pointer. + pub fn raw_ctx(&self) -> &NonNull { + &self.context + } } impl Drop for LlamaContext<'_> { diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 60566cb5..edf3b709 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -319,11 +319,11 @@ impl LlamaContextParams { /// /// ```rust /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); - /// assert!(!params.embedding()); + /// assert!(!params.embeddings()); /// ``` #[must_use] - pub fn embedding(&self) -> bool { - self.context_params.embedding + pub fn embeddings(&self) -> bool { + self.context_params.embeddings } /// Enable the use of embeddings @@ -333,12 +333,12 @@ impl LlamaContextParams { /// ```rust /// use llama_cpp_2::context::params::LlamaContextParams; /// let params = LlamaContextParams::default() - /// .with_embedding(true); - /// assert!(params.embedding()); + /// .with_embeddings(true); + /// assert!(params.embeddings()); /// ``` #[must_use] - pub fn with_embedding(mut self, embedding: bool) -> Self { - self.context_params.embedding = embedding; + pub fn with_embeddings(mut self, embedding: bool) -> Self { + self.context_params.embeddings = embedding; self } } diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index c8ed8425..b5c1b7d9 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -83,6 +83,10 @@ pub enum DecodeError { pub enum EmbeddingsError { #[error("Embeddings weren't enabled in the context options")] NotEnabled, + #[error("Logits were not enabled for the given token")] + LogitsNotEnabled, + #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")] + NonePoolType, } /// Decode a error from llama.cpp into a [`DecodeError`]. diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 4baaa0bb..493089be 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -121,14 +121,13 @@ impl LlamaBatch { let seq_id_ptr = *self.llama_batch.seq_id.add(j); seq_id_ptr.write(seq_id); self.llama_batch.n_seq_id.add(j).write(1); - self.llama_batch.logits.add(j).write(logits_all as i8) + + let write_logits = logits_all || i == n_tokens - 1; + self.llama_batch.logits.add(j).write(write_logits as i8) } } - unsafe { - self.llama_batch.logits.add(n_tokens - 1).write(true as i8); - self.initialized_logits.push(self.llama_batch.n_tokens - 1); - } + self.initialized_logits.push(self.llama_batch.n_tokens - 1); Ok(()) } diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 35b4a833..e9709e95 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -320,7 +320,7 @@ impl LlamaModel { }; let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?; - Ok(LlamaContext::new(self, context, params.embedding())) + Ok(LlamaContext::new(self, context, params.embeddings())) } } From 78601ecc3554a49ee7848f2a083b98fc049054ab Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 18:05:51 +0100 Subject: [PATCH 04/10] Remove accidental escape hatch --- llama-cpp-2/src/context.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 080e8fc6..85066718 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -189,11 +189,6 @@ impl<'model> LlamaContext<'model> { let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) }; LlamaTimings { timings } } - - /// Returns a reference to the raw [llama_cpp_sys_2::llama_context] pointer. - pub fn raw_ctx(&self) -> &NonNull { - &self.context - } } impl Drop for LlamaContext<'_> { From 12a3f8bf9c22b2314a0cb9b4e8820450a9d22d43 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 18:09:07 +0100 Subject: [PATCH 05/10] Properly assign the initialised logits tracking (although nothing seems to be done with them atm) --- llama-cpp-2/src/llama_batch.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 493089be..eb837bd7 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -123,12 +123,13 @@ impl LlamaBatch { self.llama_batch.n_seq_id.add(j).write(1); let write_logits = logits_all || i == n_tokens - 1; - self.llama_batch.logits.add(j).write(write_logits as i8) + self.llama_batch.logits.add(j).write(write_logits as i8); + if write_logits { + self.initialized_logits.push(j as i32); + } } } - self.initialized_logits.push(self.llama_batch.n_tokens - 1); - Ok(()) } From 5bce968e943bb23ac8f6f31e17dd329e4efd1eb0 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 18:13:41 +0100 Subject: [PATCH 06/10] Update the example to use better reference models and allocate less memory for sequence ids --- embeddings/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/embeddings/src/main.rs b/embeddings/src/main.rs index 31371378..d52dc9a5 100644 --- a/embeddings/src/main.rs +++ b/embeddings/src/main.rs @@ -52,9 +52,9 @@ enum Model { /// Download a model from huggingface (or use a cached version) #[clap(name = "hf-model")] HuggingFace { - /// the repo containing the model. e.g. `TheBloke/Llama-2-7B-Chat-GGUF` + /// the repo containing the model. e.g. `BAAI/bge-small-en-v1.5` repo: String, - /// the model name. e.g. `llama-2-7b-chat.Q4_K_M.gguf` + /// the model name. e.g. `BAAI-bge-small-v1.5.Q4_K_M.gguf` model: String, }, } @@ -147,7 +147,7 @@ fn main() -> Result<()> { // create a llama_batch with the size of the context // we use this object to submit token data for decoding - let mut batch = LlamaBatch::new(n_ctx, tokens_lines_list.len() as i32); + let mut batch = LlamaBatch::new(n_ctx, 1); // Amount of tokens in the current batch let mut s_batch = 0; From fa7b508132e4c99980e7b6a488494549bce03726 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 18:18:08 +0100 Subject: [PATCH 07/10] Fix doctest --- llama-cpp-2/src/token_type.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-cpp-2/src/token_type.rs b/llama-cpp-2/src/token_type.rs index 44c2dbd3..06600687 100644 --- a/llama-cpp-2/src/token_type.rs +++ b/llama-cpp-2/src/token_type.rs @@ -28,15 +28,15 @@ pub enum LlamaTokenType { /// /// ``` /// # use std::convert::TryFrom; -/// # use std::ffi::c_uint; +/// # use std::ffi::c_int; /// # use std::num::TryFromIntError; /// # use std::result::Result; /// # use llama_cpp_2::token_type::{LlamaTokenTypeFromIntError, LlamaTokenType}; /// # fn main() -> Result<(), LlamaTokenTypeFromIntError> { -/// let llama_token_type = LlamaTokenType::try_from(0 as c_uint)?; +/// let llama_token_type = LlamaTokenType::try_from(0 as c_int)?; /// assert_eq!(llama_token_type, LlamaTokenType::Undefined); /// -/// let bad_llama_token_type = LlamaTokenType::try_from(100 as c_uint); +/// let bad_llama_token_type = LlamaTokenType::try_from(100 as c_int); /// assert_eq!(Err(LlamaTokenTypeFromIntError::UnknownValue(100)), bad_llama_token_type); /// # Ok(()) /// # } From 9126146f8b31f307db9293b30241ba3058c304e9 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 18:22:20 +0100 Subject: [PATCH 08/10] Use a better name for the `s_batch` variable and remove excessive whitespace --- embeddings/src/main.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/embeddings/src/main.rs b/embeddings/src/main.rs index d52dc9a5..de9d33ce 100644 --- a/embeddings/src/main.rs +++ b/embeddings/src/main.rs @@ -14,8 +14,8 @@ use std::time::Duration; use anyhow::{bail, Context, Result}; use clap::Parser; use hf_hub::api::sync::ApiBuilder; -use llama_cpp_2::context::LlamaContext; +use llama_cpp_2::context::LlamaContext; use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::ggml_time_us; use llama_cpp_2::llama_backend::LlamaBackend; @@ -149,8 +149,7 @@ fn main() -> Result<()> { // we use this object to submit token data for decoding let mut batch = LlamaBatch::new(n_ctx, 1); - // Amount of tokens in the current batch - let mut s_batch = 0; + let mut max_seq_id_batch = 0; let mut output = Vec::with_capacity(tokens_lines_list.len()); let t_main_start = ggml_time_us(); @@ -158,26 +157,25 @@ fn main() -> Result<()> { for tokens in &tokens_lines_list { // Flush the batch if the next prompt would exceed our batch size if (batch.n_tokens() as usize + tokens.len()) > n_ctx { - batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?; - s_batch = 0; + batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?; + max_seq_id_batch = 0; } - batch.add_sequence(&tokens, s_batch, false)?; - s_batch += 1; + batch.add_sequence(&tokens, max_seq_id_batch, false)?; + max_seq_id_batch += 1; } // Handle final batch - batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?; + batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?; let t_main_end = ggml_time_us(); for (i, embeddings) in output.iter().enumerate() { eprintln!("Embeddings {i}: {embeddings:?}"); - eprintln!("\n"); + eprintln!(); } let duration = Duration::from_micros((t_main_end - t_main_start) as u64); let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum(); - eprintln!( "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", total_tokens, From 0b0e850a5f88a4e73a45cdc253ace918c4f40046 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 19:45:52 +0100 Subject: [PATCH 09/10] Actually fix the doc-test --- llama-cpp-2/src/token_type.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama-cpp-2/src/token_type.rs b/llama-cpp-2/src/token_type.rs index 06600687..35c441a9 100644 --- a/llama-cpp-2/src/token_type.rs +++ b/llama-cpp-2/src/token_type.rs @@ -33,10 +33,10 @@ pub enum LlamaTokenType { /// # use std::result::Result; /// # use llama_cpp_2::token_type::{LlamaTokenTypeFromIntError, LlamaTokenType}; /// # fn main() -> Result<(), LlamaTokenTypeFromIntError> { -/// let llama_token_type = LlamaTokenType::try_from(0 as c_int)?; +/// let llama_token_type = LlamaTokenType::try_from(0 as llama_cpp_sys_2::llama_token_type)?; /// assert_eq!(llama_token_type, LlamaTokenType::Undefined); /// -/// let bad_llama_token_type = LlamaTokenType::try_from(100 as c_int); +/// let bad_llama_token_type = LlamaTokenType::try_from(100 as llama_cpp_sys_2::llama_token_type); /// assert_eq!(Err(LlamaTokenTypeFromIntError::UnknownValue(100)), bad_llama_token_type); /// # Ok(()) /// # } From ccb434be92b6ba6e3eba65eb20759d857ba98a75 Mon Sep 17 00:00:00 2001 From: Valentijn Hol Date: Tue, 5 Mar 2024 19:52:50 +0100 Subject: [PATCH 10/10] Swap out unsafe for a nested `LLamaBatch::add` call --- llama-cpp-2/src/llama_batch.rs | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index eb837bd7..1f1ecc54 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -108,26 +108,9 @@ impl LlamaBatch { if self.allocated < n_tokens_0 as usize + n_tokens { return Err(BatchAddError::InsufficientSpace(self.allocated)); } - if n_tokens == 0 { - return Ok(()) - } - self.llama_batch.n_tokens += n_tokens as i32; for (i, token) in tokens.iter().enumerate() { - let j = n_tokens_0 as usize + i; - unsafe { - self.llama_batch.token.add(j).write(token.0); - self.llama_batch.pos.add(j).write(i as i32); - let seq_id_ptr = *self.llama_batch.seq_id.add(j); - seq_id_ptr.write(seq_id); - self.llama_batch.n_seq_id.add(j).write(1); - - let write_logits = logits_all || i == n_tokens - 1; - self.llama_batch.logits.add(j).write(write_logits as i8); - if write_logits { - self.initialized_logits.push(j as i32); - } - } + self.add(*token, i as llama_pos, &[seq_id], logits_all || i == n_tokens - 1)?; } Ok(())