Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move batch, sort_embeddings into backends/candle #321

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ thiserror = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
memmap2 = "^0.9"
tokenizers = { workspace = true }

[dev-dependencies]
insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] }
is_close = "0.1.3"
hf-hub = "0.3.2"
anyhow = { workspace = true }
tokenizers = { workspace = true }
serial_test = "2.0.0"

[build-dependencies]
Expand Down
48 changes: 47 additions & 1 deletion backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ use candle::{DType, Device};
use candle_nn::VarBuilder;
use nohash_hasher::BuildNoHashHasher;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::PathBuf;
use std::{cmp::max, collections::HashMap};
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};
use tokenizers::Encoding;

/// This enum is needed to be able to differentiate between jina models that also use
/// the `bert` model type and valid Bert models.
Expand Down Expand Up @@ -465,3 +466,48 @@ impl<O> WrapErr<O> for Result<O, candle::Error> {
self.map_err(|e| BackendError::Inference(e.to_string()))
}
}

pub fn batch(encodings: Vec<Encoding>, pooled_indices: Vec<u32>, raw_indices: Vec<u32>) -> Batch {
let mut input_ids = Vec::new();
let mut token_type_ids = Vec::new();
let mut position_ids = Vec::new();
let mut cumulative_seq_lengths = Vec::with_capacity(encodings.len() + 1);
cumulative_seq_lengths.push(0);

let mut max_length = 0;
let mut cumulative_length = 0;

for encoding in encodings.iter() {
let encoding_length = encoding.len() as u32;
input_ids.extend(encoding.get_ids().to_vec());
token_type_ids.extend(encoding.get_type_ids().to_vec());
position_ids.extend(0..encoding_length);
cumulative_length += encoding_length;
cumulative_seq_lengths.push(cumulative_length);
max_length = max(max_length, encoding_length);
}

Batch {
input_ids,
token_type_ids,
position_ids,
cumulative_seq_lengths,
max_length,
pooled_indices,
raw_indices,
}
}

pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
let mut pooled_embeddings = Vec::new();
let mut raw_embeddings = Vec::new();

for (_, embedding) in embeddings {
match embedding {
Embedding::Pooled(e) => pooled_embeddings.push(e),
Embedding::All(e) => raw_embeddings.extend(e),
}
}

(pooled_embeddings, raw_embeddings)
}
45 changes: 0 additions & 45 deletions backends/candle/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,6 @@ impl From<Vec<Vec<f32>>> for SnapshotEmbeddings {
}
}

pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
let mut pooled_embeddings = Vec::new();
let mut raw_embeddings = Vec::new();

for (_, embedding) in embeddings {
match embedding {
Embedding::Pooled(e) => pooled_embeddings.push(e),
Embedding::All(e) => raw_embeddings.extend(e),
}
}

(pooled_embeddings, raw_embeddings)
}

pub fn download_artifacts(
model_id: &'static str,
revision: Option<&'static str>,
Expand Down Expand Up @@ -232,34 +218,3 @@ pub fn load_tokenizer(model_root: &Path) -> Result<Tokenizer> {
tokenizer.with_padding(None);
Ok(tokenizer)
}

pub fn batch(encodings: Vec<Encoding>, pooled_indices: Vec<u32>, raw_indices: Vec<u32>) -> Batch {
let mut input_ids = Vec::new();
let mut token_type_ids = Vec::new();
let mut position_ids = Vec::new();
let mut cumulative_seq_lengths = Vec::with_capacity(encodings.len() + 1);
cumulative_seq_lengths.push(0);

let mut max_length = 0;
let mut cumulative_length = 0;

for encoding in encodings.iter() {
let encoding_length = encoding.len() as u32;
input_ids.extend(encoding.get_ids().to_vec());
token_type_ids.extend(encoding.get_type_ids().to_vec());
position_ids.extend(0..encoding_length);
cumulative_length += encoding_length;
cumulative_seq_lengths.push(cumulative_length);
max_length = max(max_length, encoding_length);
}

Batch {
input_ids,
token_type_ids,
position_ids,
cumulative_seq_lengths,
max_length,
pooled_indices,
raw_indices,
}
}
4 changes: 2 additions & 2 deletions backends/candle/tests/test_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::{batch, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
use crate::common::{SnapshotEmbeddings, SnapshotScores};
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_jina.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_jina_code.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_mistral.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_nomic.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};
use tokenizers::processors::sequence::Sequence;
use tokenizers::processors::template::TemplateProcessing;
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_jina.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_jina_code.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/test_nomic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::SnapshotEmbeddings;
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::CandleBackend;
use common::{cosine_matcher, download_artifacts, load_tokenizer};
use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend};
use text_embeddings_backend_core::{Backend, ModelType, Pool};

#[test]
Expand Down
1 change: 1 addition & 0 deletions backends/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl fmt::Display for DType {
DType::Float32 => write!(f, "float32"),
// #[cfg(feature = "candle")]
// DType::Q6K => write!(f, "q6k"),
_ => unreachable!()
}
}
}