Skip to content

add support for base64 embeddings #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions async-openai/src/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::{
config::Config,
error::OpenAIError,
types::{CreateEmbeddingRequest, CreateEmbeddingResponse},
types::{
CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
EncodingFormat,
},
Client,
};

Expand All @@ -23,13 +26,35 @@ impl<'c, C: Config> Embeddings<'c, C> {
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, OpenAIError> {
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is base64, use Embeddings::create_base64".into(),
));
}
self.client.post("/embeddings", request).await
}

/// Creates an embedding vector representing the input text.
///
/// The response will contain the embedding in base64 format.
pub async fn create_base64(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is not base64, use Embeddings::create".into(),
));
}

self.client.post("/embeddings", request).await
}
}

#[cfg(test)]
mod tests {
use crate::types::{CreateEmbeddingResponse, Embedding};
use crate::error::OpenAIError;
use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
use crate::{types::CreateEmbeddingRequestArgs, Client};

#[tokio::test]
Expand Down Expand Up @@ -127,4 +152,56 @@ mod tests {
let Embedding { embedding, .. } = data.pop().unwrap();
assert_eq!(embedding.len(), dimensions as usize);
}

#[tokio::test]
async fn test_cannot_use_base64_encoding_with_normal_create_request() {
let client = Client::new();

const MODEL: &str = "text-embedding-ada-002";
const INPUT: &str = "You shall not pass.";

let b64_request = CreateEmbeddingRequestArgs::default()
.model(MODEL)
.input(INPUT)
.encoding_format(EncodingFormat::Base64)
.build()
.unwrap();
let b64_response = client.embeddings().create(b64_request).await;
assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
}

#[tokio::test]
async fn test_embedding_create_base64() {
let client = Client::new();

const MODEL: &str = "text-embedding-ada-002";
const INPUT: &str = "CoLoop will eat the other qual research tools...";

let b64_request = CreateEmbeddingRequestArgs::default()
.model(MODEL)
.input(INPUT)
.encoding_format(EncodingFormat::Base64)
.build()
.unwrap();
let b64_response = client
.embeddings()
.create_base64(b64_request)
.await
.unwrap();
let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
let b64_embedding: Vec<f32> = b64_embedding.into();

let request = CreateEmbeddingRequestArgs::default()
.model(MODEL)
.input(INPUT)
.build()
.unwrap();
let response = client.embeddings().create(request).await.unwrap();
let embedding = response.data.into_iter().next().unwrap().embedding;

assert_eq!(b64_embedding.len(), embedding.len());
for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
assert!((b64 - normal).abs() < 1e-6);
}
}
}
38 changes: 38 additions & 0 deletions async-openai/src/types/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::engine::{general_purpose, Engine};
use derive_builder::Builder;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -64,6 +65,32 @@ pub struct Embedding {
pub embedding: Vec<f32>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Base64EmbeddingVector(pub String);

impl From<Base64EmbeddingVector> for Vec<f32> {
fn from(value: Base64EmbeddingVector) -> Self {
let bytes = general_purpose::STANDARD
.decode(value.0)
.expect("openai base64 encoding to be valid");
let chunks = bytes.chunks_exact(4);
chunks
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
}

/// Represents an base64-encoded embedding vector returned by embedding endpoint.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Base64Embedding {
/// The index of the embedding in the list of embeddings.
pub index: u32,
/// The object type, which is always "embedding".
pub object: String,
/// The embedding vector, encoded in base64.
pub embedding: Base64EmbeddingVector,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct EmbeddingUsage {
/// The number of tokens used by the prompt.
Expand All @@ -82,3 +109,14 @@ pub struct CreateEmbeddingResponse {
/// The usage information for the request.
pub usage: EmbeddingUsage,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct CreateBase64EmbeddingResponse {
pub object: String,
/// The name of the model used to generate the embedding.
pub model: String,
/// The list of embeddings generated by the model.
pub data: Vec<Base64Embedding>,
/// The usage information for the request.
pub usage: EmbeddingUsage,
}