Skip to content

Commit

Permalink
add support for base64 embeddings (#190)
Browse files Browse the repository at this point in the history
* add support for base64 embeddings

* Base64Embedding is an implementation detail

* feat: separate Embeddings::create_base64 method

* chore: use newtype for hosting base64 decoding instead

* chore: remove unused error variant
  • Loading branch information
adri1wald authored Mar 16, 2024
1 parent 208bc08 commit 032f58f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 2 deletions.
81 changes: 79 additions & 2 deletions async-openai/src/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::{
config::Config,
error::OpenAIError,
types::{CreateEmbeddingRequest, CreateEmbeddingResponse},
types::{
CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
EncodingFormat,
},
Client,
};

Expand All @@ -23,13 +26,35 @@ impl<'c, C: Config> Embeddings<'c, C> {
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, OpenAIError> {
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is base64, use Embeddings::create_base64".into(),
));
}
self.client.post("/embeddings", request).await
}

/// Creates an embedding vector representing the input text.
///
/// The response will contain the embedding in base64 format.
pub async fn create_base64(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is not base64, use Embeddings::create".into(),
));
}

self.client.post("/embeddings", request).await
}
}

#[cfg(test)]
mod tests {
use crate::types::{CreateEmbeddingResponse, Embedding};
use crate::error::OpenAIError;
use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
use crate::{types::CreateEmbeddingRequestArgs, Client};

#[tokio::test]
Expand Down Expand Up @@ -127,4 +152,56 @@ mod tests {
let Embedding { embedding, .. } = data.pop().unwrap();
assert_eq!(embedding.len(), dimensions as usize);
}

#[tokio::test]
async fn test_cannot_use_base64_encoding_with_normal_create_request() {
let client = Client::new();

const MODEL: &str = "text-embedding-ada-002";
const INPUT: &str = "You shall not pass.";

let b64_request = CreateEmbeddingRequestArgs::default()
.model(MODEL)
.input(INPUT)
.encoding_format(EncodingFormat::Base64)
.build()
.unwrap();
let b64_response = client.embeddings().create(b64_request).await;
assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
}

#[tokio::test]
async fn test_embedding_create_base64() {
let client = Client::new();

const MODEL: &str = "text-embedding-ada-002";
const INPUT: &str = "CoLoop will eat the other qual research tools...";

let b64_request = CreateEmbeddingRequestArgs::default()
.model(MODEL)
.input(INPUT)
.encoding_format(EncodingFormat::Base64)
.build()
.unwrap();
let b64_response = client
.embeddings()
.create_base64(b64_request)
.await
.unwrap();
let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
let b64_embedding: Vec<f32> = b64_embedding.into();

let request = CreateEmbeddingRequestArgs::default()
.model(MODEL)
.input(INPUT)
.build()
.unwrap();
let response = client.embeddings().create(request).await.unwrap();
let embedding = response.data.into_iter().next().unwrap().embedding;

assert_eq!(b64_embedding.len(), embedding.len());
for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
assert!((b64 - normal).abs() < 1e-6);
}
}
}
38 changes: 38 additions & 0 deletions async-openai/src/types/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::engine::{general_purpose, Engine};
use derive_builder::Builder;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -64,6 +65,32 @@ pub struct Embedding {
pub embedding: Vec<f32>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Base64EmbeddingVector(pub String);

impl From<Base64EmbeddingVector> for Vec<f32> {
fn from(value: Base64EmbeddingVector) -> Self {
let bytes = general_purpose::STANDARD
.decode(value.0)
.expect("openai base64 encoding to be valid");
let chunks = bytes.chunks_exact(4);
chunks
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
}

/// Represents an base64-encoded embedding vector returned by embedding endpoint.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Base64Embedding {
/// The index of the embedding in the list of embeddings.
pub index: u32,
/// The object type, which is always "embedding".
pub object: String,
/// The embedding vector, encoded in base64.
pub embedding: Base64EmbeddingVector,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct EmbeddingUsage {
/// The number of tokens used by the prompt.
Expand All @@ -82,3 +109,14 @@ pub struct CreateEmbeddingResponse {
/// The usage information for the request.
pub usage: EmbeddingUsage,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct CreateBase64EmbeddingResponse {
pub object: String,
/// The name of the model used to generate the embedding.
pub model: String,
/// The list of embeddings generated by the model.
pub data: Vec<Base64Embedding>,
/// The usage information for the request.
pub usage: EmbeddingUsage,
}

0 comments on commit 032f58f

Please sign in to comment.