Skip to content

Add Embedding Related Functionality #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 5, 2024
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = [
"llama-cpp-sys-2",
"llama-cpp-2",
"simple",
"simple", "embeddings",
]

[workspace.dependencies]
Expand Down
15 changes: 15 additions & 0 deletions embeddings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "embeddings"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.34" }
hf-hub = { workspace = true }
clap = { workspace = true , features = ["derive"] }
anyhow = { workspace = true }

[lints]
workspace = true
215 changes: 215 additions & 0 deletions embeddings/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2.
#![allow(
clippy::cast_possible_wrap,
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_sign_loss
)]

use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;

use anyhow::{bail, Context, Result};
use clap::Parser;
use hf_hub::api::sync::ApiBuilder;

use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::ggml_time_us;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::AddBos;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::params::LlamaModelParams;

#[derive(clap::Parser, Debug, Clone)]
struct Args {
/// The path to the model
#[command(subcommand)]
model: Model,
/// The prompt
#[clap(default_value = "Hello my name is")]
prompt: String,
/// Whether to normalise the produced embeddings
#[clap(short)]
normalise: bool,
/// Disable offloading layers to the gpu
#[cfg(feature = "cublas")]
#[clap(long)]
disable_gpu: bool,
}


#[derive(clap::Subcommand, Debug, Clone)]
enum Model {
/// Use an already downloaded model
Local {
/// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa`
path: PathBuf,
},
/// Download a model from huggingface (or use a cached version)
#[clap(name = "hf-model")]
HuggingFace {
/// the repo containing the model. e.g. `BAAI/bge-small-en-v1.5`
repo: String,
/// the model name. e.g. `BAAI-bge-small-v1.5.Q4_K_M.gguf`
model: String,
},
}

impl Model {
/// Convert the model to a path - may download from huggingface
fn get_or_load(self) -> Result<PathBuf> {
match self {
Model::Local { path } => Ok(path),
Model::HuggingFace { model, repo } => ApiBuilder::new()
.with_progress(true)
.build()
.with_context(|| "unable to create huggingface api")?
.model(repo)
.get(&model)
.with_context(|| "unable to download model"),
}
}
}

fn main() -> Result<()> {
let Args {
model,
prompt,
normalise,
#[cfg(feature = "cublas")]
disable_gpu,
} = Args::parse();

// init LLM
let backend = LlamaBackend::init()?;

// offload all layers to the gpu
let model_params = {
#[cfg(feature = "cublas")]
if !disable_gpu {
LlamaModelParams::default().with_n_gpu_layers(1000)
} else {
LlamaModelParams::default()
}
#[cfg(not(feature = "cublas"))]
LlamaModelParams::default()
};

let model_path = model
.get_or_load()
.with_context(|| "failed to get model from args")?;

let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
.with_context(|| "unable to load model")?;

// initialize the context
let ctx_params = LlamaContextParams::default()
.with_n_threads_batch(std::thread::available_parallelism()?.get() as u32)
.with_embeddings(true);

let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;

// Split the prompt to display the batching functionality
let prompt_lines = prompt.lines();

// tokenize the prompt
let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always))
.collect::<Result<Vec<_>, _>>()
.with_context(|| format!("failed to tokenize {prompt}"))?;

let n_ctx = ctx.n_ctx() as usize;
let n_ctx_train = model.n_ctx_train();

eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}");

if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) {
bail!("One of the provided prompts exceeds the size of the context window");
}

// print the prompt token-by-token
eprintln!();

for (i, token_line) in tokens_lines_list.iter().enumerate() {
eprintln!("Prompt {i}");
for token in token_line {
eprintln!(" {} --> {}", token, model.token_to_str(*token)?);
}
eprintln!()
}

std::io::stderr().flush()?;

// create a llama_batch with the size of the context
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(n_ctx, 1);

let mut max_seq_id_batch = 0;
let mut output = Vec::with_capacity(tokens_lines_list.len());

let t_main_start = ggml_time_us();

for tokens in &tokens_lines_list {
// Flush the batch if the next prompt would exceed our batch size
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
max_seq_id_batch = 0;
}

batch.add_sequence(&tokens, max_seq_id_batch, false)?;
max_seq_id_batch += 1;
}
// Handle final batch
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;

let t_main_end = ggml_time_us();

for (i, embeddings) in output.iter().enumerate() {
eprintln!("Embeddings {i}: {embeddings:?}");
eprintln!();
}

let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum();
eprintln!(
"Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n",
total_tokens,
duration.as_secs_f32(),
total_tokens as f32 / duration.as_secs_f32()
);

println!("{}", ctx.timings());

Ok(())
}

fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec<Vec<f32>>, normalise: bool) -> Result<()> {
ctx.clear_kv_cache();
ctx.decode(batch).with_context(|| "llama_decode() failed")?;

for i in 0..s_batch {
let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?;
let output_embeddings = if normalise {
normalize(embedding)
} else {
embedding.to_vec()
};

output.push(output_embeddings);
}

batch.clear();

Ok(())
}

fn normalize(input: &[f32]) -> Vec<f32> {
let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt();

input.iter().map(|&val| val / magnitude).collect()
}
66 changes: 63 additions & 3 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

use std::fmt::{Debug, Formatter};
use std::num::NonZeroI32;
use std::ptr::NonNull;
use std::slice;

use crate::llama_batch::LlamaBatch;
use crate::model::LlamaModel;
use crate::timing::LlamaTimings;
use crate::token::data::LlamaTokenData;
use crate::token::LlamaToken;
use crate::DecodeError;
use std::ptr::NonNull;
use std::slice;
use crate::{DecodeError, EmbeddingsError};

pub mod kv_cache;
pub mod params;
Expand All @@ -24,6 +24,7 @@ pub struct LlamaContext<'a> {
/// a reference to the contexts model.
pub model: &'a LlamaModel,
initialized_logits: Vec<i32>,
embeddings_enabled: bool,
}

impl Debug for LlamaContext<'_> {
Expand All @@ -38,11 +39,13 @@ impl<'model> LlamaContext<'model> {
pub(crate) fn new(
llama_model: &'model LlamaModel,
llama_context: NonNull<llama_cpp_sys_2::llama_context>,
embeddings_enabled: bool,
) -> Self {
Self {
context: llama_context,
model: llama_model,
initialized_logits: Vec::new(),
embeddings_enabled,
}
}

Expand Down Expand Up @@ -80,6 +83,63 @@ impl<'model> LlamaContext<'model> {
}
}

/// Get the embeddings for the `i`th sequence in the current context.
///
/// # Returns
///
/// A slice containing the embeddings for the last decoded batch.
/// The size corresponds to the `n_embd` parameter of the context's model.
///
/// # Errors
///
/// - When the current context was constructed without enabling embeddings.
/// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
/// - If the given sequence index exceeds the max sequence id.
pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
if !self.embeddings_enabled {
return Err(EmbeddingsError::NotEnabled);
}

unsafe {
let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);

// Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
if embedding.is_null() {
Err(EmbeddingsError::NonePoolType)
} else {
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
}
}
}

/// Get the embeddings for the `i`th token in the current context.
///
/// # Returns
///
/// A slice containing the embeddings for the last decoded batch of the given token.
/// The size corresponds to the `n_embd` parameter of the context's model.
///
/// # Errors
///
/// - When the current context was constructed without enabling embeddings.
/// - When the given token didn't have logits enabled when it was passed.
/// - If the given token index exceeds the max token id.
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
if !self.embeddings_enabled {
return Err(EmbeddingsError::NotEnabled);
}

unsafe {
let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
// Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
if embedding.is_null() {
Err(EmbeddingsError::LogitsNotEnabled)
} else {
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
}
}
}

/// Get the logits for the ith token in the context.
///
/// # Panics
Expand Down
Loading