Skip to content

Commit d192e6b

Browse files
authored
add support for base64 embeddings (64bit#190)
* add support for base64 embeddings * Base64Embedding is an implementation detail * feat: separate Embeddings::create_base64 method * chore: use newtype for hosting base64 decoding instead * chore: remove unused error variant
1 parent 49a29a9 commit d192e6b

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

async-openai/src/embedding.rs

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use crate::{
22
config::Config,
33
error::OpenAIError,
4-
types::{CreateEmbeddingRequest, CreateEmbeddingResponse},
4+
types::{
5+
CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6+
EncodingFormat,
7+
},
58
Client,
69
};
710

@@ -23,13 +26,35 @@ impl<'c, C: Config> Embeddings<'c, C> {
2326
&self,
2427
request: CreateEmbeddingRequest,
2528
) -> Result<CreateEmbeddingResponse, OpenAIError> {
29+
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
30+
return Err(OpenAIError::InvalidArgument(
31+
"When encoding_format is base64, use Embeddings::create_base64".into(),
32+
));
33+
}
34+
self.client.post("/embeddings", request).await
35+
}
36+
37+
/// Creates an embedding vector representing the input text.
38+
///
39+
/// The response will contain the embedding in base64 format.
40+
pub async fn create_base64(
41+
&self,
42+
request: CreateEmbeddingRequest,
43+
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
44+
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
45+
return Err(OpenAIError::InvalidArgument(
46+
"When encoding_format is not base64, use Embeddings::create".into(),
47+
));
48+
}
49+
2650
self.client.post("/embeddings", request).await
2751
}
2852
}
2953

3054
#[cfg(test)]
3155
mod tests {
32-
use crate::types::{CreateEmbeddingResponse, Embedding};
56+
use crate::error::OpenAIError;
57+
use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
3358
use crate::{types::CreateEmbeddingRequestArgs, Client};
3459

3560
#[tokio::test]
@@ -127,4 +152,56 @@ mod tests {
127152
let Embedding { embedding, .. } = data.pop().unwrap();
128153
assert_eq!(embedding.len(), dimensions as usize);
129154
}
155+
156+
#[tokio::test]
157+
async fn test_cannot_use_base64_encoding_with_normal_create_request() {
158+
let client = Client::new();
159+
160+
const MODEL: &str = "text-embedding-ada-002";
161+
const INPUT: &str = "You shall not pass.";
162+
163+
let b64_request = CreateEmbeddingRequestArgs::default()
164+
.model(MODEL)
165+
.input(INPUT)
166+
.encoding_format(EncodingFormat::Base64)
167+
.build()
168+
.unwrap();
169+
let b64_response = client.embeddings().create(b64_request).await;
170+
assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
171+
}
172+
173+
#[tokio::test]
174+
async fn test_embedding_create_base64() {
175+
let client = Client::new();
176+
177+
const MODEL: &str = "text-embedding-ada-002";
178+
const INPUT: &str = "CoLoop will eat the other qual research tools...";
179+
180+
let b64_request = CreateEmbeddingRequestArgs::default()
181+
.model(MODEL)
182+
.input(INPUT)
183+
.encoding_format(EncodingFormat::Base64)
184+
.build()
185+
.unwrap();
186+
let b64_response = client
187+
.embeddings()
188+
.create_base64(b64_request)
189+
.await
190+
.unwrap();
191+
let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
192+
let b64_embedding: Vec<f32> = b64_embedding.into();
193+
194+
let request = CreateEmbeddingRequestArgs::default()
195+
.model(MODEL)
196+
.input(INPUT)
197+
.build()
198+
.unwrap();
199+
let response = client.embeddings().create(request).await.unwrap();
200+
let embedding = response.data.into_iter().next().unwrap().embedding;
201+
202+
assert_eq!(b64_embedding.len(), embedding.len());
203+
for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
204+
assert!((b64 - normal).abs() < 1e-6);
205+
}
206+
}
130207
}

async-openai/src/types/embedding.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use base64::engine::{general_purpose, Engine};
12
use derive_builder::Builder;
23
use serde::{Deserialize, Serialize};
34

@@ -64,6 +65,32 @@ pub struct Embedding {
6465
pub embedding: Vec<f32>,
6566
}
6667

68+
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
69+
pub struct Base64EmbeddingVector(pub String);
70+
71+
impl From<Base64EmbeddingVector> for Vec<f32> {
72+
fn from(value: Base64EmbeddingVector) -> Self {
73+
let bytes = general_purpose::STANDARD
74+
.decode(value.0)
75+
.expect("openai base64 encoding to be valid");
76+
let chunks = bytes.chunks_exact(4);
77+
chunks
78+
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
79+
.collect()
80+
}
81+
}
82+
83+
/// Represents an base64-encoded embedding vector returned by embedding endpoint.
84+
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
85+
pub struct Base64Embedding {
86+
/// The index of the embedding in the list of embeddings.
87+
pub index: u32,
88+
/// The object type, which is always "embedding".
89+
pub object: String,
90+
/// The embedding vector, encoded in base64.
91+
pub embedding: Base64EmbeddingVector,
92+
}
93+
6794
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
6895
pub struct EmbeddingUsage {
6996
/// The number of tokens used by the prompt.
@@ -82,3 +109,14 @@ pub struct CreateEmbeddingResponse {
82109
/// The usage information for the request.
83110
pub usage: EmbeddingUsage,
84111
}
112+
113+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
114+
pub struct CreateBase64EmbeddingResponse {
115+
pub object: String,
116+
/// The name of the model used to generate the embedding.
117+
pub model: String,
118+
/// The list of embeddings generated by the model.
119+
pub data: Vec<Base64Embedding>,
120+
/// The usage information for the request.
121+
pub usage: EmbeddingUsage,
122+
}

0 commit comments

Comments
 (0)