Skip to content

Commit

Permalink
refactor(embedding): improve logging and simplify voyage embedding im… (
Browse files Browse the repository at this point in the history
#3600)

* refactor(embedding): improve logging and simplify voyage embedding implementation

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Dec 20, 2024
1 parent 4a4c595 commit 4184a0e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 132 deletions.
8 changes: 6 additions & 2 deletions crates/http-api-bindings/src/embedding/llama.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;
use tracing::Instrument;

use crate::create_reqwest_client;
use crate::{create_reqwest_client, embedding_info_span};

pub struct LlamaCppEngine {
client: reqwest::Client,
Expand Down Expand Up @@ -44,7 +45,10 @@ impl Embedding for LlamaCppEngine {
request = request.bearer_auth(api_key);
}

let response = request.send().await?;
let response = request
.send()
.instrument(embedding_info_span!("llamacpp"))
.await?;
if response.status().is_server_error() {
let error = response.text().await?;
return Err(anyhow::anyhow!(
Expand Down
22 changes: 14 additions & 8 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
mod llama;
mod openai;
mod voyage;

use core::panic;
use std::sync::Arc;

use llama::LlamaCppEngine;
use openai::OpenAIEmbeddingEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::Embedding;

use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};
use super::rate_limit;

pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
Expand All @@ -30,16 +29,16 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
config.api_key.as_deref(),
),
"ollama/embedding" => ollama_api_bindings::create_embedding(config).await,
"voyage/embedding" => VoyageEmbeddingEngine::create(
config.api_endpoint.as_deref(),
"voyage/embedding" => OpenAIEmbeddingEngine::create(
config
.api_endpoint
.as_deref()
.unwrap_or("https://api.voyageai.com/v1"),
config
.model_name
.as_deref()
.expect("model_name must be set for voyage/embedding"),
config
.api_key
.clone()
.expect("api_key must be set for voyage/embedding"),
config.api_key.as_deref(),
),
unsupported_kind => panic!(
"Unsupported kind for http embedding model: {}",
Expand All @@ -52,3 +51,10 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
config.rate_limit.request_per_minute,
))
}

#[macro_export]
macro_rules! embedding_info_span {
($kind:expr) => {
tracing::info_span!("embedding", kind = $kind)
};
}
91 changes: 67 additions & 24 deletions crates/http-api-bindings/src/embedding/openai.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use anyhow::Context;
use async_openai::{
config::OpenAIConfig,
types::{CreateEmbeddingRequest, EmbeddingInput},
};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;
use tracing::{info_span, Instrument};
use tracing::Instrument;

use crate::embedding_info_span;

pub struct OpenAIEmbeddingEngine {
client: async_openai::Client<OpenAIConfig>,
client: Client,
api_endpoint: String,
api_key: String,
model_name: String,
}

Expand All @@ -18,41 +20,69 @@ impl OpenAIEmbeddingEngine {
model_name: &str,
api_key: Option<&str>,
) -> Box<dyn Embedding> {
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(api_key.unwrap_or_default());

let client = async_openai::Client::with_config(config);

let client = Client::new();
Box::new(Self {
client,
api_endpoint: format!("{}/embeddings", api_endpoint),
api_key: api_key.unwrap_or_default().to_owned(),
model_name: model_name.to_owned(),
})
}
}

#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}

#[async_trait]
impl Embedding for OpenAIEmbeddingEngine {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
let request = CreateEmbeddingRequest {
let request = EmbeddingRequest {
input: vec![prompt.to_owned()],
model: self.model_name.clone(),
input: EmbeddingInput::String(prompt.to_owned()),
encoding_format: None,
user: None,
dimensions: None,
};
let resp = self

let request_builder = self
.client
.embeddings()
.create(request)
.instrument(info_span!("embedding", kind = "openai"))
.post(&self.api_endpoint)
.json(&request)
.header("content-type", "application/json")
.bearer_auth(&self.api_key);

let response = request_builder
.send()
.instrument(embedding_info_span!("openai"))
.await?;
let data = resp

if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error));
}

let response_body = response
.json::<EmbeddingResponse>()
.await
.context("Failed to parse response body")?;

response_body
.data
.into_iter()
.next()
.context("Failed to get embedding")?;
Ok(data.embedding)
.map(|data| data.embedding)
.ok_or_else(|| anyhow::anyhow!("No embedding data found"))
}
}

Expand All @@ -73,4 +103,17 @@ mod tests {
let embedding = engine.embed("Hello, world!").await.unwrap();
assert_eq!(embedding.len(), 768);
}

#[tokio::test]
#[ignore]
async fn test_voyage_embedding() {
let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set");
let engine = OpenAIEmbeddingEngine::create(
"https://api.voyageai.com/v1",
"voyage-code-2",
Some(&api_key),
);
let embedding = engine.embed("Hello, world!").await.unwrap();
assert_eq!(embedding.len(), 1536);
}
}
98 changes: 0 additions & 98 deletions crates/http-api-bindings/src/embedding/voyage.rs

This file was deleted.

0 comments on commit 4184a0e

Please sign in to comment.