11use crate :: {
22 config:: Config ,
33 error:: OpenAIError ,
4- types:: { CreateEmbeddingRequest , CreateEmbeddingResponse } ,
4+ types:: {
5+ CreateBase64EmbeddingResponse , CreateEmbeddingRequest , CreateEmbeddingResponse ,
6+ EncodingFormat ,
7+ } ,
58 Client ,
69} ;
710
@@ -23,13 +26,35 @@ impl<'c, C: Config> Embeddings<'c, C> {
2326 & self ,
2427 request : CreateEmbeddingRequest ,
2528 ) -> Result < CreateEmbeddingResponse , OpenAIError > {
29+ if matches ! ( request. encoding_format, Some ( EncodingFormat :: Base64 ) ) {
30+ return Err ( OpenAIError :: InvalidArgument (
31+ "When encoding_format is base64, use Embeddings::create_base64" . into ( ) ,
32+ ) ) ;
33+ }
34+ self . client . post ( "/embeddings" , request) . await
35+ }
36+
37+ /// Creates an embedding vector representing the input text.
38+ ///
39+ /// The response will contain the embedding in base64 format.
40+ pub async fn create_base64 (
41+ & self ,
42+ request : CreateEmbeddingRequest ,
43+ ) -> Result < CreateBase64EmbeddingResponse , OpenAIError > {
44+ if !matches ! ( request. encoding_format, Some ( EncodingFormat :: Base64 ) ) {
45+ return Err ( OpenAIError :: InvalidArgument (
46+ "When encoding_format is not base64, use Embeddings::create" . into ( ) ,
47+ ) ) ;
48+ }
49+
2650 self . client . post ( "/embeddings" , request) . await
2751 }
2852}
2953
3054#[ cfg( test) ]
3155mod tests {
32- use crate :: types:: { CreateEmbeddingResponse , Embedding } ;
56+ use crate :: error:: OpenAIError ;
57+ use crate :: types:: { CreateEmbeddingResponse , Embedding , EncodingFormat } ;
3358 use crate :: { types:: CreateEmbeddingRequestArgs , Client } ;
3459
3560 #[ tokio:: test]
@@ -127,4 +152,56 @@ mod tests {
127152 let Embedding { embedding, .. } = data. pop ( ) . unwrap ( ) ;
128153 assert_eq ! ( embedding. len( ) , dimensions as usize ) ;
129154 }
155+
156+ #[ tokio:: test]
157+ async fn test_cannot_use_base64_encoding_with_normal_create_request ( ) {
158+ let client = Client :: new ( ) ;
159+
160+ const MODEL : & str = "text-embedding-ada-002" ;
161+ const INPUT : & str = "You shall not pass." ;
162+
163+ let b64_request = CreateEmbeddingRequestArgs :: default ( )
164+ . model ( MODEL )
165+ . input ( INPUT )
166+ . encoding_format ( EncodingFormat :: Base64 )
167+ . build ( )
168+ . unwrap ( ) ;
169+ let b64_response = client. embeddings ( ) . create ( b64_request) . await ;
170+ assert ! ( matches!( b64_response, Err ( OpenAIError :: InvalidArgument ( _) ) ) ) ;
171+ }
172+
173+ #[ tokio:: test]
174+ async fn test_embedding_create_base64 ( ) {
175+ let client = Client :: new ( ) ;
176+
177+ const MODEL : & str = "text-embedding-ada-002" ;
178+ const INPUT : & str = "CoLoop will eat the other qual research tools..." ;
179+
180+ let b64_request = CreateEmbeddingRequestArgs :: default ( )
181+ . model ( MODEL )
182+ . input ( INPUT )
183+ . encoding_format ( EncodingFormat :: Base64 )
184+ . build ( )
185+ . unwrap ( ) ;
186+ let b64_response = client
187+ . embeddings ( )
188+ . create_base64 ( b64_request)
189+ . await
190+ . unwrap ( ) ;
191+ let b64_embedding = b64_response. data . into_iter ( ) . next ( ) . unwrap ( ) . embedding ;
192+ let b64_embedding: Vec < f32 > = b64_embedding. into ( ) ;
193+
194+ let request = CreateEmbeddingRequestArgs :: default ( )
195+ . model ( MODEL )
196+ . input ( INPUT )
197+ . build ( )
198+ . unwrap ( ) ;
199+ let response = client. embeddings ( ) . create ( request) . await . unwrap ( ) ;
200+ let embedding = response. data . into_iter ( ) . next ( ) . unwrap ( ) . embedding ;
201+
202+ assert_eq ! ( b64_embedding. len( ) , embedding. len( ) ) ;
203+ for ( b64, normal) in b64_embedding. iter ( ) . zip ( embedding. iter ( ) ) {
204+ assert ! ( ( b64 - normal) . abs( ) < 1e-6 ) ;
205+ }
206+ }
130207}
0 commit comments