diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f5d88aa..1ee2746 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,14 +20,40 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Restore Builds + id: cache-build-restore + uses: actions/cache/restore@v4 + with: + key: '${{ runner.os }}-cargox-${{ hashFiles(''**/Cargo.toml'') }}' + path: | + onnxruntime/build/Linux/Release/ + + - name: Compile ONNX Runtime for Linux + if: steps.cache-build-restore.outputs.cache-hit != 'true' + run: | + echo Cloning ONNX Runtime repository... + git clone https://github.com/microsoft/onnxruntime --recursive --branch v1.20.1 --single-branch --depth 1 + cd onnxruntime + ./build.sh --update --build --config Release --parallel --compile_no_warning_as_error --skip_submodule_sync + cd .. + - name: Cargo Test With Release Build - run: cargo test --release + run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --release --no-default-features --features online - name: Cargo Test Offline - run: cargo test --no-default-features --features ort-download-binaries + run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --no-default-features - name: Cargo Clippy run: cargo clippy - name: Cargo FMT run: cargo fmt --all -- --check + + - name: Always Save Cache + id: cache-build-save + if: always() && steps.cache-build-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@v4 + with: + key: '${{ steps.cache-build-restore.outputs.cache-primary-key }}' + path: | + onnxruntime/build/Linux/Release/ diff --git a/Cargo.toml b/Cargo.toml index 12f4263..2224de7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ authors = [ "Luya Wang ", "Tri ", "Denny Wong ", - "Alex Rozgo " + "Alex Rozgo ", ] documentation = "https://docs.rs/fastembed" repository = "https://github.com/Anush008/fastembed-rs" @@ -26,8 +26,8 @@ anyhow = { version = "1" } hf-hub = { version = "0.3", default-features = false } image = "0.25.2" ndarray = { version = "0.16", default-features = false } -ort = { version = "=2.0.0-rc.8", default-features = false, features = [ - "half", "ndarray", +ort = { version = "=2.0.0-rc.9", default-features = false, features = [ + "ndarray", ] } rayon = { version = "1.10", default-features = false } serde_json = { version = "1" } diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 230ce49..a01c3a1 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -4,7 +4,10 @@ use hf_hub::{ Cache, }; use ndarray::{Array3, ArrayView3}; -use ort::{GraphOptimizationLevel, Session, Value}; +use ort::{ + session::{builder::GraphOptimizationLevel, Session}, + value::Value, +}; #[cfg(feature = "online")] use std::path::PathBuf; use std::{path::Path, thread::available_parallelism}; diff --git a/src/image_embedding/init.rs b/src/image_embedding/init.rs index 85e9739..00818cf 100644 --- a/src/image_embedding/init.rs +++ b/src/image_embedding/init.rs @@ -1,6 +1,6 @@ use std::path::{Path, PathBuf}; -use ort::{ExecutionProviderDispatch, Session}; +use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; use crate::{ImageEmbeddingModel, DEFAULT_CACHE_DIR}; diff --git a/src/lib.rs b/src/lib.rs index 3bfd651..e8f316a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,7 +62,7 @@ mod reranking; mod sparse_text_embedding; mod text_embedding; -pub use ort::ExecutionProviderDispatch; +pub use ort::execution_providers::ExecutionProviderDispatch; pub use crate::common::{ read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, DEFAULT_CACHE_DIR, diff --git a/src/output/embedding_output.rs b/src/output/embedding_output.rs index 11d86e0..a3f6bb9 100644 --- a/src/output/embedding_output.rs +++ b/src/output/embedding_output.rs @@ -1,4 +1,5 @@ use ndarray::{Array2, ArrayView, Dim, IxDynImpl}; +use ort::session::SessionOutputs; use crate::pooling; @@ -10,7 +11,7 @@ use super::{OutputKey, OutputPrecedence}; /// pooling etc. This struct should contain all the necessary information for the /// post-processing to be performed. pub struct SingleBatchOutput<'r, 's> { - pub session_outputs: ort::SessionOutputs<'r, 's>, + pub session_outputs: SessionOutputs<'r, 's>, pub attention_mask_array: Array2, } @@ -23,17 +24,12 @@ impl<'r, 's> SingleBatchOutput<'r, 's> { &self, precedence: &impl OutputPrecedence, ) -> anyhow::Result>> { - let ort_output = precedence + let ort_output: &ort::value::Value = precedence .key_precedence() .find_map(|key| match key { - OutputKey::OnlyOne => { - // Only export the value if there is only one output available. - if self.session_outputs.len() == 1 { - self.session_outputs.values().next() - } else { - None - } - } + OutputKey::OnlyOne => self + .session_outputs + .get(self.session_outputs.keys().nth(0)?), OutputKey::ByOrder(idx) => { let x = self .session_outputs diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index d503b17..853bfc6 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -1,4 +1,8 @@ use anyhow::Result; +use ort::{ + session::{builder::GraphOptimizationLevel, Session}, + value::Value, +}; use std::thread::available_parallelism; #[cfg(feature = "online")] @@ -10,7 +14,6 @@ use crate::{ #[cfg(feature = "online")] use hf_hub::{api::sync::ApiBuilder, Cache}; use ndarray::{s, Array}; -use ort::{GraphOptimizationLevel, Session, Value}; use rayon::{iter::ParallelIterator, slice::ParallelSlice}; use tokenizers::Tokenizer; diff --git a/src/reranking/init.rs b/src/reranking/init.rs index 1a8a555..3ca3d7c 100644 --- a/src/reranking/init.rs +++ b/src/reranking/init.rs @@ -1,6 +1,6 @@ use std::path::{Path, PathBuf}; -use ort::{ExecutionProviderDispatch, Session}; +use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; use tokenizers::Tokenizer; use crate::{RerankerModel, TokenizerFiles, DEFAULT_CACHE_DIR}; diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index 0440810..444ade6 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -11,9 +11,8 @@ use hf_hub::{ Cache, }; use ndarray::{Array, CowArray}; +use ort::{session::Session, value::Value}; #[cfg_attr(not(feature = "online"), allow(unused_imports))] -use ort::GraphOptimizationLevel; -use ort::{Session, Value}; use rayon::{iter::ParallelIterator, slice::ParallelSlice}; #[cfg(feature = "online")] use std::path::PathBuf; @@ -35,6 +34,7 @@ impl SparseTextEmbedding { #[cfg(feature = "online")] pub fn try_new(options: SparseInitOptions) -> Result { use super::SparseInitOptions; + use ort::{session::builder::GraphOptimizationLevel, session::Session}; let SparseInitOptions { model_name, diff --git a/src/sparse_text_embedding/init.rs b/src/sparse_text_embedding/init.rs index b81dfe8..3ac2348 100644 --- a/src/sparse_text_embedding/init.rs +++ b/src/sparse_text_embedding/init.rs @@ -1,6 +1,6 @@ use std::path::{Path, PathBuf}; -use ort::{ExecutionProviderDispatch, Session}; +use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; use tokenizers::Tokenizer; use crate::{models::sparse::SparseModel, TokenizerFiles, DEFAULT_CACHE_DIR}; diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 0cd0876..7b15804 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -14,7 +14,10 @@ use hf_hub::{ Cache, }; use ndarray::Array; -use ort::{GraphOptimizationLevel, Session, Value}; +use ort::{ + session::{builder::GraphOptimizationLevel, Session}, + value::Value, +}; use rayon::{ iter::{FromParallelIterator, ParallelIterator}, slice::ParallelSlice, diff --git a/src/text_embedding/init.rs b/src/text_embedding/init.rs index f759068..d37495b 100644 --- a/src/text_embedding/init.rs +++ b/src/text_embedding/init.rs @@ -6,7 +6,7 @@ use crate::{ pooling::Pooling, EmbeddingModel, QuantizationMode, }; -use ort::{ExecutionProviderDispatch, Session}; +use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; use std::path::{Path, PathBuf}; use tokenizers::Tokenizer; diff --git a/tests/embeddings.rs b/tests/embeddings.rs index 5d64382..1d432f1 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -15,7 +15,7 @@ use fastembed::{ }; /// A small epsilon value for floating point comparisons. -const EPS: f32 = 1e-4; +const EPS: f32 = 1e-2; /// Precalculated embeddings for the supported models using #99 /// (4f09b6842ce1fcfaf6362678afcad9a176e05304).