99using System . IO ;
1010using System . Net . Http ;
1111using System . Text . RegularExpressions ;
12+ using System . Threading ;
1213using System . Threading . Tasks ;
1314
1415namespace Microsoft . ML . Tokenizers
@@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
346347 /// <param name="modelName">Model name</param>
347348 /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
348349 /// <param name="normalizer">To normalize the text before tokenization</param>
350+ /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
349351 /// <returns>The tokenizer</returns>
350- public static async Task < Tokenizer > CreateByModelNameAsync (
352+ public static Task < Tokenizer > CreateByModelNameAsync (
351353 string modelName ,
352354 IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
353- Normalizer ? normalizer = null )
355+ Normalizer ? normalizer = null ,
356+ CancellationToken cancellationToken = default )
354357 {
355- ModelEncoding encoder ;
356-
357- if ( ! _modelToEncoding . TryGetValue ( modelName , out encoder ) )
358+ try
358359 {
359- foreach ( ( string Prefix , ModelEncoding Encoding ) in _modelPrefixToEncoding )
360+ ModelEncoding encoder ;
361+
362+ if ( ! _modelToEncoding . TryGetValue ( modelName , out encoder ) )
360363 {
361- if ( modelName . StartsWith ( Prefix , StringComparison . OrdinalIgnoreCase ) )
364+ foreach ( ( string Prefix , ModelEncoding Encoding ) in _modelPrefixToEncoding )
362365 {
363- encoder = Encoding ;
364- break ;
366+ if ( modelName . StartsWith ( Prefix , StringComparison . OrdinalIgnoreCase ) )
367+ {
368+ encoder = Encoding ;
369+ break ;
370+ }
365371 }
366372 }
367- }
368373
369- if ( encoder == ModelEncoding . None )
374+ if ( encoder == ModelEncoding . None )
375+ {
376+ throw new NotImplementedException ( $ "Doesn't support this model [{ modelName } ]") ;
377+ }
378+
379+ return CreateByEncoderNameAsync ( encoder , extraSpecialTokens , normalizer , cancellationToken ) ;
380+ }
381+ catch ( Exception ex )
370382 {
371- throw new NotImplementedException ( $ "Doesn't support this model [ { modelName } ]" ) ;
383+ return Task . FromException < Tokenizer > ( ex ) ;
372384 }
373-
374- return await CreateByEncoderNameAsync ( encoder , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
375385 }
376386
377387 private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" ;
@@ -402,36 +412,38 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
402412 /// <param name="modelEncoding">Encoder label</param>
403413 /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
404414 /// <param name="normalizer">To normalize the text before tokenization</param>
415+ /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
405416 /// <returns>The tokenizer</returns>
406417 /// <exception cref="NotImplementedException">Throws if the encoder is not supported</exception>
407- private static async Task < Tokenizer > CreateByEncoderNameAsync (
418+ private static Task < Tokenizer > CreateByEncoderNameAsync (
408419 ModelEncoding modelEncoding ,
409420 IReadOnlyDictionary < string , int > ? extraSpecialTokens ,
410- Normalizer ? normalizer )
421+ Normalizer ? normalizer ,
422+ CancellationToken cancellationToken )
411423 {
412424 switch ( modelEncoding )
413425 {
414426 case ModelEncoding . Cl100kBase :
415427 var specialTokens = new Dictionary < string , int >
416428 { { EndOfText , 100257 } , { FimPrefix , 100258 } , { FimMiddle , 100259 } , { FimSuffix , 100260 } , { EndOfPrompt , 100276 } } ;
417- return await CreateTikTokenTokenizerAsync ( Cl100kBaseRegex ( ) , Cl100kBaseVocabUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
429+ return CreateTikTokenTokenizerAsync ( Cl100kBaseRegex ( ) , Cl100kBaseVocabUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
418430
419431 case ModelEncoding . P50kBase :
420432 specialTokens = new Dictionary < string , int > { { EndOfText , 50256 } } ;
421- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
433+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
422434
423435 case ModelEncoding . P50kEdit :
424436 specialTokens = new Dictionary < string , int >
425437 { { EndOfText , 50256 } , { FimPrefix , 50281 } , { FimMiddle , 50282 } , { FimSuffix , 50283 } } ;
426- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
438+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
427439
428440 case ModelEncoding . R50kBase :
429441 specialTokens = new Dictionary < string , int > { { EndOfText , 50256 } } ;
430- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , R50RanksUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
442+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , R50RanksUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
431443
432444 case ModelEncoding . GPT2 :
433445 specialTokens = new Dictionary < string , int > { { EndOfText , 50256 } , } ;
434- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , GPT2Url , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
446+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , GPT2Url , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
435447
436448 default :
437449 Debug . Assert ( false , $ "Unexpected encoder [{ modelEncoding } ]") ;
@@ -449,13 +461,15 @@ private static async Task<Tokenizer> CreateByEncoderNameAsync(
449461 /// <param name="specialTokens">Special tokens mapping. This may be mutated by the method.</param>
450462 /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
451463 /// <param name="normalizer">To normalize the text before tokenization</param>
464+ /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
452465 /// <returns>The tokenizer</returns>
453466 private static async Task < Tokenizer > CreateTikTokenTokenizerAsync (
454467 Regex regex ,
455468 string mergeableRanksFileUrl ,
456469 Dictionary < string , int > specialTokens ,
457470 IReadOnlyDictionary < string , int > ? extraSpecialTokens ,
458- Normalizer ? normalizer )
471+ Normalizer ? normalizer ,
472+ CancellationToken cancellationToken )
459473 {
460474 if ( extraSpecialTokens is not null )
461475 {
@@ -467,9 +481,9 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
467481
468482 if ( ! _tiktokenCache . TryGetValue ( mergeableRanksFileUrl , out ( Dictionary < ReadOnlyMemory < byte > , int > encoder , Dictionary < string , int > vocab , IReadOnlyDictionary < int , byte [ ] > decoder ) cache ) )
469483 {
470- using ( Stream stream = await _httpClient . GetStreamAsync ( mergeableRanksFileUrl ) . ConfigureAwait ( false ) )
484+ using ( Stream stream = await Helpers . GetStreamAsync ( _httpClient , mergeableRanksFileUrl , cancellationToken ) . ConfigureAwait ( false ) )
471485 {
472- cache = await Tiktoken . LoadTikTokenBpeAsync ( stream , useAsync : true ) . ConfigureAwait ( false ) ;
486+ cache = await Tiktoken . LoadTikTokenBpeAsync ( stream , useAsync : true , cancellationToken ) . ConfigureAwait ( false ) ;
473487 }
474488
475489 _tiktokenCache . TryAdd ( mergeableRanksFileUrl , cache ) ;
0 commit comments