From aa2b3c332bb2a77677b45a81471a12dc6a1f9aa7 Mon Sep 17 00:00:00 2001 From: jeadie Date: Mon, 24 Jun 2024 16:26:40 +1000 Subject: [PATCH 1/5] move batch, sort_embeddings out of backends/candle/tests/ --- backends/candle/src/lib.rs | 48 ++++++++++++++++++- backends/candle/tests/common.rs | 45 ----------------- backends/candle/tests/test_bert.rs | 6 +-- backends/candle/tests/test_flash_bert.rs | 2 +- backends/candle/tests/test_flash_jina.rs | 2 +- backends/candle/tests/test_flash_jina_code.rs | 2 +- backends/candle/tests/test_flash_nomic.rs | 2 +- backends/candle/tests/test_jina.rs | 2 +- backends/candle/tests/test_jina_code.rs | 2 +- backends/candle/tests/test_nomic.rs | 2 +- backends/src/dtype.rs | 1 + 11 files changed, 58 insertions(+), 56 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index a7a5dcf0..b600d8d0 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -24,11 +24,12 @@ use candle::{DType, Device}; use candle_nn::VarBuilder; use nohash_hasher::BuildNoHashHasher; use serde::Deserialize; -use std::collections::HashMap; use std::path::PathBuf; +use std::{cmp::max, collections::HashMap}; use text_embeddings_backend_core::{ Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions, }; +use tokenizers::Encoding; /// This enum is needed to be able to differentiate between jina models that also use /// the `bert` model type and valid Bert models. @@ -373,3 +374,48 @@ impl WrapErr for Result { self.map_err(|e| BackendError::Inference(e.to_string())) } } + +pub fn batch(encodings: Vec, pooled_indices: Vec, raw_indices: Vec) -> Batch { + let mut input_ids = Vec::new(); + let mut token_type_ids = Vec::new(); + let mut position_ids = Vec::new(); + let mut cumulative_seq_lengths = Vec::with_capacity(encodings.len() + 1); + cumulative_seq_lengths.push(0); + + let mut max_length = 0; + let mut cumulative_length = 0; + + for encoding in encodings.iter() { + let encoding_length = encoding.len() as u32; + input_ids.extend(encoding.get_ids().to_vec()); + token_type_ids.extend(encoding.get_type_ids().to_vec()); + position_ids.extend(0..encoding_length); + cumulative_length += encoding_length; + cumulative_seq_lengths.push(cumulative_length); + max_length = max(max_length, encoding_length); + } + + Batch { + input_ids, + token_type_ids, + position_ids, + cumulative_seq_lengths, + max_length, + pooled_indices, + raw_indices, + } +} + +pub fn sort_embeddings(embeddings: Embeddings) -> (Vec>, Vec>) { + let mut pooled_embeddings = Vec::new(); + let mut raw_embeddings = Vec::new(); + + for (_, embedding) in embeddings { + match embedding { + Embedding::Pooled(e) => pooled_embeddings.push(e), + Embedding::All(e) => raw_embeddings.extend(e), + } + } + + (pooled_embeddings, raw_embeddings) +} diff --git a/backends/candle/tests/common.rs b/backends/candle/tests/common.rs index d7ebc67d..236a799d 100644 --- a/backends/candle/tests/common.rs +++ b/backends/candle/tests/common.rs @@ -51,20 +51,6 @@ impl From>> for SnapshotScores { } } -pub fn sort_embeddings(embeddings: Embeddings) -> (Vec>, Vec>) { - let mut pooled_embeddings = Vec::new(); - let mut raw_embeddings = Vec::new(); - - for (_, embedding) in embeddings { - match embedding { - Embedding::Pooled(e) => pooled_embeddings.push(e), - Embedding::All(e) => raw_embeddings.extend(e), - } - } - - (pooled_embeddings, raw_embeddings) -} - pub fn download_artifacts( model_id: &'static str, revision: Option<&'static str>, @@ -147,34 +133,3 @@ pub fn load_tokenizer(model_root: &Path) -> Result { tokenizer.with_padding(None); Ok(tokenizer) } - -pub fn batch(encodings: Vec, pooled_indices: Vec, raw_indices: Vec) -> Batch { - let mut input_ids = Vec::new(); - let mut token_type_ids = Vec::new(); - let mut position_ids = Vec::new(); - let mut cumulative_seq_lengths = Vec::with_capacity(encodings.len() + 1); - cumulative_seq_lengths.push(0); - - let mut max_length = 0; - let mut cumulative_length = 0; - - for encoding in encodings.iter() { - let encoding_length = encoding.len() as u32; - input_ids.extend(encoding.get_ids().to_vec()); - token_type_ids.extend(encoding.get_type_ids().to_vec()); - position_ids.extend(0..encoding_length); - cumulative_length += encoding_length; - cumulative_seq_lengths.push(cumulative_length); - max_length = max(max_length, encoding_length); - } - - Batch { - input_ids, - token_type_ids, - position_ids, - cumulative_seq_lengths, - max_length, - pooled_indices, - raw_indices, - } -} diff --git a/backends/candle/tests/test_bert.rs b/backends/candle/tests/test_bert.rs index 45d02577..30639483 100644 --- a/backends/candle/tests/test_bert.rs +++ b/backends/candle/tests/test_bert.rs @@ -1,9 +1,9 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use common::{download_artifacts, load_tokenizer, relative_matcher}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index 1888a32b..1a3cbe14 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -5,7 +5,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index 4a5f8276..2d5c5b75 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -4,7 +4,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index 508bf722..35fc97fe 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -4,7 +4,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 3e9b6e1d..83ca853b 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -4,7 +4,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 4cd7bba6..7a685fd9 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -3,7 +3,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 70248e1a..78425a4b 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -3,7 +3,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index 914be7ea..b1cb62ca 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -3,7 +3,7 @@ mod common; use crate::common::{sort_embeddings, SnapshotScores}; use anyhow::Result; use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index d2c896ce..7c73650e 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -33,6 +33,7 @@ impl fmt::Display for DType { DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), + _ => unimplemented!(), } } } From 8bfd8e30d069ae38f7e76f155163d66eac87168d Mon Sep 17 00:00:00 2001 From: jeadie Date: Mon, 24 Jun 2024 16:31:06 +1000 Subject: [PATCH 2/5] move tokenizer from dev to main dependency --- backends/candle/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index 9fb300b9..f918968e 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -25,13 +25,13 @@ thiserror = "^1.0" serde = { version = "^1.0", features = ["serde_derive"] } serde_json = "^1.0" memmap2 = "^0.9" +tokenizers = { version = "^0.19.1", default-features = false, features = ["onig", "esaxx_fast"] } [dev-dependencies] insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] } is_close = "0.1.3" hf-hub = "0.3.2" anyhow = "1.0.75" -tokenizers = { version = "^0.19.1", default-features = false, features = ["onig", "esaxx_fast"] } serial_test = "2.0.0" [build-dependencies] From 274a451648772f3b26d0065dbe7dd1c1f9d4d7d3 Mon Sep 17 00:00:00 2001 From: jeadie Date: Mon, 24 Jun 2024 16:37:30 +1000 Subject: [PATCH 3/5] fix test imports --- backends/candle/tests/test_flash_bert.rs | 4 ++-- backends/candle/tests/test_flash_jina.rs | 4 ++-- backends/candle/tests/test_flash_jina_code.rs | 4 ++-- backends/candle/tests/test_flash_nomic.rs | 4 ++-- backends/candle/tests/test_jina.rs | 4 ++-- backends/candle/tests/test_jina_code.rs | 4 ++-- backends/candle/tests/test_nomic.rs | 4 ++-- backends/src/dtype.rs | 1 - 8 files changed, 14 insertions(+), 15 deletions(-) diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index 1a3cbe14..cdff547e 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -2,9 +2,9 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index 2d5c5b75..30c1ce46 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index 35fc97fe..4d4ea790 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 83ca853b..9a18ed91 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 7a685fd9..f761cd8c 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 78425a4b..ff8f6625 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index b1cb62ca..9fdde986 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotScores}; +use crate::common::SnapshotScores; use anyhow::Result; -use common::{batch, download_artifacts, load_tokenizer, relative_matcher}; +use common::{download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index 7c73650e..d2c896ce 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -33,7 +33,6 @@ impl fmt::Display for DType { DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), - _ => unimplemented!(), } } } From dde326b330ecb609201110d4725928a6e977f0dd Mon Sep 17 00:00:00 2001 From: jeadie Date: Mon, 24 Jun 2024 16:39:15 +1000 Subject: [PATCH 4/5] dtype unreachable --- backends/src/dtype.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index d2c896ce..48b239bc 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -33,6 +33,7 @@ impl fmt::Display for DType { DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), + _ => unimplemented!() } } } From 94754a232bfe5537045d682a09218e3682bf176b Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 3 Jul 2024 13:58:11 +1000 Subject: [PATCH 5/5] fix merge --- backends/candle/Cargo.toml | 3 +-- backends/src/dtype.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index 883defad..60931fb6 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -25,14 +25,13 @@ thiserror = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } memmap2 = "^0.9" -tokenizers = { version = "^0.19.1", default-features = false, features = ["onig", "esaxx_fast"] } +tokenizers = { workspace = true } [dev-dependencies] insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] } is_close = "0.1.3" hf-hub = "0.3.2" anyhow = { workspace = true } -tokenizers = { workspace = true } serial_test = "2.0.0" [build-dependencies] diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index 48b239bc..01c0b60f 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -33,7 +33,7 @@ impl fmt::Display for DType { DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), - _ => unimplemented!() + _ => unreachable!() } } }