From 032f58f6d141ebf529c9569c7a7829cf0e154f01 Mon Sep 17 00:00:00 2001 From: Adrien Wald Date: Sat, 16 Mar 2024 22:14:23 +0000 Subject: [PATCH] add support for base64 embeddings (#190) * add support for base64 embeddings * Base64Embedding is an implementation detail * feat: separate Embeddings::create_base64 method * chore: use newtype for hosting base64 decoding instead * chore: remove unused error variant --- async-openai/src/embedding.rs | 81 ++++++++++++++++++++++++++++- async-openai/src/types/embedding.rs | 38 ++++++++++++++ 2 files changed, 117 insertions(+), 2 deletions(-) diff --git a/async-openai/src/embedding.rs b/async-openai/src/embedding.rs index 1d886e00..70d027eb 100644 --- a/async-openai/src/embedding.rs +++ b/async-openai/src/embedding.rs @@ -1,7 +1,10 @@ use crate::{ config::Config, error::OpenAIError, - types::{CreateEmbeddingRequest, CreateEmbeddingResponse}, + types::{ + CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse, + EncodingFormat, + }, Client, }; @@ -23,13 +26,35 @@ impl<'c, C: Config> Embeddings<'c, C> { &self, request: CreateEmbeddingRequest, ) -> Result { + if matches!(request.encoding_format, Some(EncodingFormat::Base64)) { + return Err(OpenAIError::InvalidArgument( + "When encoding_format is base64, use Embeddings::create_base64".into(), + )); + } + self.client.post("/embeddings", request).await + } + + /// Creates an embedding vector representing the input text. + /// + /// The response will contain the embedding in base64 format. + pub async fn create_base64( + &self, + request: CreateEmbeddingRequest, + ) -> Result { + if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) { + return Err(OpenAIError::InvalidArgument( + "When encoding_format is not base64, use Embeddings::create".into(), + )); + } + self.client.post("/embeddings", request).await } } #[cfg(test)] mod tests { - use crate::types::{CreateEmbeddingResponse, Embedding}; + use crate::error::OpenAIError; + use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat}; use crate::{types::CreateEmbeddingRequestArgs, Client}; #[tokio::test] @@ -127,4 +152,56 @@ mod tests { let Embedding { embedding, .. } = data.pop().unwrap(); assert_eq!(embedding.len(), dimensions as usize); } + + #[tokio::test] + async fn test_cannot_use_base64_encoding_with_normal_create_request() { + let client = Client::new(); + + const MODEL: &str = "text-embedding-ada-002"; + const INPUT: &str = "You shall not pass."; + + let b64_request = CreateEmbeddingRequestArgs::default() + .model(MODEL) + .input(INPUT) + .encoding_format(EncodingFormat::Base64) + .build() + .unwrap(); + let b64_response = client.embeddings().create(b64_request).await; + assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_)))); + } + + #[tokio::test] + async fn test_embedding_create_base64() { + let client = Client::new(); + + const MODEL: &str = "text-embedding-ada-002"; + const INPUT: &str = "CoLoop will eat the other qual research tools..."; + + let b64_request = CreateEmbeddingRequestArgs::default() + .model(MODEL) + .input(INPUT) + .encoding_format(EncodingFormat::Base64) + .build() + .unwrap(); + let b64_response = client + .embeddings() + .create_base64(b64_request) + .await + .unwrap(); + let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding; + let b64_embedding: Vec = b64_embedding.into(); + + let request = CreateEmbeddingRequestArgs::default() + .model(MODEL) + .input(INPUT) + .build() + .unwrap(); + let response = client.embeddings().create(request).await.unwrap(); + let embedding = response.data.into_iter().next().unwrap().embedding; + + assert_eq!(b64_embedding.len(), embedding.len()); + for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) { + assert!((b64 - normal).abs() < 1e-6); + } + } } diff --git a/async-openai/src/types/embedding.rs b/async-openai/src/types/embedding.rs index 295bc480..5f2ad73d 100644 --- a/async-openai/src/types/embedding.rs +++ b/async-openai/src/types/embedding.rs @@ -1,3 +1,4 @@ +use base64::engine::{general_purpose, Engine}; use derive_builder::Builder; use serde::{Deserialize, Serialize}; @@ -64,6 +65,32 @@ pub struct Embedding { pub embedding: Vec, } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Base64EmbeddingVector(pub String); + +impl From for Vec { + fn from(value: Base64EmbeddingVector) -> Self { + let bytes = general_purpose::STANDARD + .decode(value.0) + .expect("openai base64 encoding to be valid"); + let chunks = bytes.chunks_exact(4); + chunks + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect() + } +} + +/// Represents an base64-encoded embedding vector returned by embedding endpoint. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Base64Embedding { + /// The index of the embedding in the list of embeddings. + pub index: u32, + /// The object type, which is always "embedding". + pub object: String, + /// The embedding vector, encoded in base64. + pub embedding: Base64EmbeddingVector, +} + #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct EmbeddingUsage { /// The number of tokens used by the prompt. @@ -82,3 +109,14 @@ pub struct CreateEmbeddingResponse { /// The usage information for the request. pub usage: EmbeddingUsage, } + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateBase64EmbeddingResponse { + pub object: String, + /// The name of the model used to generate the embedding. + pub model: String, + /// The list of embeddings generated by the model. + pub data: Vec, + /// The usage information for the request. + pub usage: EmbeddingUsage, +}