11/// Payload tokenization logic
22use crate :: TextEmbeddingsError ;
3+ use std:: collections:: HashMap ;
34use tokenizers:: tokenizer:: Tokenizer ;
45pub use tokenizers:: Encoding as RawEncoding ;
56use tokenizers:: { TruncationDirection , TruncationParams , TruncationStrategy } ;
@@ -19,6 +20,8 @@ impl Tokenization {
1920 tokenizer : Tokenizer ,
2021 max_input_length : usize ,
2122 position_offset : usize ,
23+ default_prompt : Option < String > ,
24+ prompts : Option < HashMap < String , String > > ,
2225 ) -> Self {
2326 tracing:: info!( "Starting {workers} tokenization workers" ) ;
2427
@@ -29,12 +32,16 @@ impl Tokenization {
2932 for _ in 0 ..workers {
3033 let tokenizer_clone = tokenizer. clone ( ) ;
3134 let receiver_clone = receiver. clone ( ) ;
35+ let default_prompt_clone = default_prompt. clone ( ) ;
36+ let prompts_clone = prompts. clone ( ) ;
3237 // Spawn worker
3338 std:: thread:: spawn ( move || {
3439 tokenizer_worker (
3540 tokenizer_clone,
3641 max_input_length,
3742 position_offset,
43+ default_prompt_clone,
44+ prompts_clone,
3845 receiver_clone,
3946 )
4047 } ) ;
@@ -49,6 +56,7 @@ impl Tokenization {
4956 inputs : EncodingInput ,
5057 truncate : bool ,
5158 truncation_direction : TruncationDirection ,
59+ prompt_name : Option < String > ,
5260 ) -> Result < ValidEncoding , TextEmbeddingsError > {
5361 // Check if inputs is empty
5462 if inputs. is_empty ( ) {
@@ -66,6 +74,7 @@ impl Tokenization {
6674 inputs,
6775 truncate,
6876 truncation_direction,
77+ prompt_name,
6978 response_sender,
7079 Span :: current ( ) ,
7180 ) )
@@ -82,7 +91,8 @@ impl Tokenization {
8291 & self ,
8392 inputs : EncodingInput ,
8493 add_special_tokens : bool ,
85- ) -> Result < RawEncoding , TextEmbeddingsError > {
94+ prompt_name : Option < String > ,
95+ ) -> Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > {
8696 // Check if inputs is empty
8797 if inputs. is_empty ( ) {
8898 return Err ( TextEmbeddingsError :: Validation (
@@ -98,6 +108,7 @@ impl Tokenization {
98108 . send ( TokenizerRequest :: Tokenize (
99109 inputs,
100110 add_special_tokens,
111+ prompt_name,
101112 response_sender,
102113 Span :: current ( ) ,
103114 ) )
@@ -147,6 +158,8 @@ fn tokenizer_worker(
147158 mut tokenizer : Tokenizer ,
148159 max_input_length : usize ,
149160 position_offset : usize ,
161+ default_prompt : Option < String > ,
162+ prompts : Option < HashMap < String , String > > ,
150163 receiver : async_channel:: Receiver < TokenizerRequest > ,
151164) {
152165 // Loop over requests
@@ -156,11 +169,17 @@ fn tokenizer_worker(
156169 inputs,
157170 truncate,
158171 truncation_direction,
172+ prompt_name,
159173 response_tx,
160174 parent_span,
161175 ) => {
162176 parent_span. in_scope ( || {
163177 if !response_tx. is_closed ( ) {
178+ let default_prompt_clone = match prompt_name {
179+ None => default_prompt. clone ( ) ,
180+ Some ( _) => None ,
181+ } ;
182+
164183 // It's possible that the user dropped its request resulting in a send error.
165184 // We just discard the error
166185 let _ = response_tx. send ( encode_input (
@@ -169,20 +188,37 @@ fn tokenizer_worker(
169188 truncation_direction,
170189 max_input_length,
171190 position_offset,
191+ default_prompt_clone,
192+ prompt_name,
193+ prompts. as_ref ( ) ,
172194 & mut tokenizer,
173195 ) ) ;
174196 }
175197 } )
176198 }
177- TokenizerRequest :: Tokenize ( inputs, add_special_tokens, response_tx, parent_span) => {
199+ TokenizerRequest :: Tokenize (
200+ inputs,
201+ add_special_tokens,
202+ prompt_name,
203+ response_tx,
204+ parent_span,
205+ ) => {
178206 parent_span. in_scope ( || {
179207 if !response_tx. is_closed ( ) {
208+ let default_prompt_clone = match prompt_name {
209+ None => default_prompt. clone ( ) ,
210+ Some ( _) => None ,
211+ } ;
212+
180213 // It's possible that the user dropped its request resulting in a send error.
181214 // We just discard the error
182215 let _ = response_tx. send ( tokenize_input (
183216 inputs,
184217 add_special_tokens,
185218 None ,
219+ default_prompt_clone,
220+ prompt_name,
221+ prompts. as_ref ( ) ,
186222 & mut tokenizer,
187223 ) ) ;
188224 }
@@ -212,40 +248,104 @@ fn decode_ids(
212248 . decode ( & ids, skip_special_tokens) ?)
213249}
214250
251+ fn prepare_pre_prompt (
252+ default_prompt : Option < String > ,
253+ prompt_name : Option < String > ,
254+ prompts : Option < & HashMap < String , String > > ,
255+ ) -> Result < Option < String > , TextEmbeddingsError > {
256+ let pre_prompt = if let Some ( prompt_name) = prompt_name. as_ref ( ) {
257+ match prompts {
258+ None => {
259+ return Err ( TextEmbeddingsError :: Validation ( format ! ( "`default-prompt-name` is set to `{prompt_name}` but no prompts were found in the Sentence Transformers configuration" ) ) ) ;
260+ }
261+ Some ( prompts) if !prompts. contains_key ( prompt_name) => {
262+ return Err ( TextEmbeddingsError :: Validation ( format ! ( "`default-prompt-name` is set to `{prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}" , prompts. keys( ) ) ) ) ;
263+ }
264+ Some ( prompts) => prompts. get ( prompt_name) . cloned ( ) ,
265+ }
266+ } else {
267+ default_prompt
268+ } ;
269+ Ok ( pre_prompt)
270+ }
271+
215272fn tokenize_input (
216273 inputs : EncodingInput ,
217274 add_special_tokens : bool ,
218275 truncate_params : Option < TruncationParams > ,
276+ default_prompt : Option < String > ,
277+ prompt_name : Option < String > ,
278+ prompts : Option < & HashMap < String , String > > ,
219279 tokenizer : & mut Tokenizer ,
220- ) -> Result < RawEncoding , TextEmbeddingsError > {
280+ ) -> Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > {
281+ let pre_prompt = prepare_pre_prompt ( default_prompt, prompt_name, prompts) ?;
282+
221283 let encoding = match inputs {
222284 // encode input
223- EncodingInput :: Single ( s) => tokenizer
224- . with_truncation ( truncate_params) ?
225- . encode :: < String > ( s, add_special_tokens) ?,
226- EncodingInput :: Dual ( s1, s2) => {
227- tokenizer
285+ EncodingInput :: Single ( s) => {
286+ let s = if let Some ( mut pre_prompt) = pre_prompt {
287+ pre_prompt. push_str ( & s) ;
288+ pre_prompt
289+ } else {
290+ s
291+ } ;
292+
293+ let encoding = tokenizer
228294 . with_truncation ( truncate_params) ?
229- . encode :: < ( String , String ) > ( ( s1, s2) , add_special_tokens) ?
295+ . encode :: < & str > ( & s, add_special_tokens) ?;
296+
297+ ( Some ( s) , encoding)
298+ }
299+ EncodingInput :: Dual ( s1, s2) => {
300+ if pre_prompt. is_some ( ) {
301+ return Err ( TextEmbeddingsError :: Validation (
302+ "`prompt_name` cannot be set with dual inputs" . to_string ( ) ,
303+ ) ) ;
304+ }
305+
306+ (
307+ None ,
308+ tokenizer
309+ . with_truncation ( truncate_params) ?
310+ . encode :: < ( String , String ) > ( ( s1, s2) , add_special_tokens) ?,
311+ )
230312 }
231313 // input is encoded -> convert to tokenizers Encoding
232314 EncodingInput :: Ids ( ids) => {
233- let text = tokenizer. decode ( & ids, false ) ?;
234- tokenizer
235- . with_truncation ( truncate_params) ?
236- . encode :: < String > ( text, false ) ?
315+ if let Some ( mut pre_prompt) = pre_prompt {
316+ let text = tokenizer. decode ( & ids, true ) ?;
317+ pre_prompt. push_str ( & text) ;
318+
319+ let encoding = tokenizer
320+ . with_truncation ( truncate_params) ?
321+ . encode :: < & str > ( & pre_prompt, true ) ?;
322+
323+ ( Some ( pre_prompt) , encoding)
324+ } else {
325+ let text = tokenizer. decode ( & ids, false ) ?;
326+
327+ let encoding = tokenizer
328+ . with_truncation ( truncate_params) ?
329+ . encode :: < & str > ( & text, false ) ?;
330+
331+ ( Some ( text) , encoding)
332+ }
237333 }
238334 } ;
239335 Ok ( encoding)
240336}
241337
242338/// Get input length and optionally truncate it
339+ #[ allow( clippy:: too_many_arguments) ]
243340fn encode_input (
244341 inputs : EncodingInput ,
245342 truncate : bool ,
246343 truncation_direction : TruncationDirection ,
247344 max_input_length : usize ,
248345 position_offset : usize ,
346+ default_prompt : Option < String > ,
347+ prompt_name : Option < String > ,
348+ prompts : Option < & HashMap < String , String > > ,
249349 tokenizer : & mut Tokenizer ,
250350) -> Result < ValidEncoding , TextEmbeddingsError > {
251351 // Default truncation params
@@ -256,7 +356,15 @@ fn encode_input(
256356 stride : 0 ,
257357 } ) ;
258358
259- let encoding = tokenize_input ( inputs, true , truncate_params, tokenizer) ?;
359+ let ( _, encoding) = tokenize_input (
360+ inputs,
361+ true ,
362+ truncate_params,
363+ default_prompt,
364+ prompt_name,
365+ prompts,
366+ tokenizer,
367+ ) ?;
260368 let seq_len = encoding. len ( ) ;
261369
262370 if seq_len > max_input_length {
@@ -315,13 +423,15 @@ enum TokenizerRequest {
315423 EncodingInput ,
316424 bool ,
317425 TruncationDirection ,
426+ Option < String > ,
318427 oneshot:: Sender < Result < ValidEncoding , TextEmbeddingsError > > ,
319428 Span ,
320429 ) ,
321430 Tokenize (
322431 EncodingInput ,
323432 bool ,
324- oneshot:: Sender < Result < RawEncoding , TextEmbeddingsError > > ,
433+ Option < String > ,
434+ oneshot:: Sender < Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > > ,
325435 Span ,
326436 ) ,
327437 Decode (
0 commit comments