@@ -290,9 +290,25 @@ public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds
290290 throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
291291 }
292292
293- // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
294- int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
295- List < int > ids = new List < int > ( capacity : capacity ) { ClsTokenId } ;
293+ List < int > ids ;
294+
295+ if ( tokenIds0 is ICollection < int > c1 )
296+ {
297+ int capacity = c1 . Count + 2 ; // Add 2 for [CLS] and two [SEP] tokens.
298+
299+ if ( tokenIds1 is not null )
300+ {
301+ capacity += tokenIds1 is ICollection < int > c2 ? c2 . Count + 1 : c1 . Count + 1 ;
302+ }
303+
304+ ids = new ( capacity ) { ClsTokenId } ;
305+ }
306+ else
307+ {
308+ // slow path
309+ ids = new List < int > ( 10 ) { ClsTokenId } ;
310+ }
311+
296312 ids . AddRange ( tokenIds0 ) ;
297313 ids . Add ( SepTokenId ) ;
298314
@@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0,
323339 throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
324340 }
325341
326- // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
327- int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
328- if ( buffer . Length < capacity )
342+ written = 0 ;
343+ if ( buffer . Length < 1 )
329344 {
330- written = 0 ;
331345 return OperationStatus . DestinationTooSmall ;
332346 }
333347
334- written = 0 ;
335348 buffer [ written ++ ] = ClsTokenId ;
336349 foreach ( int id in tokenIds0 )
337350 {
351+ if ( buffer . Length <= written )
352+ {
353+ written = 0 ;
354+ return OperationStatus . DestinationTooSmall ;
355+ }
356+
338357 buffer [ written ++ ] = id ;
339358 }
359+
360+ if ( buffer . Length <= written )
361+ {
362+ written = 0 ;
363+ return OperationStatus . DestinationTooSmall ;
364+ }
340365 buffer [ written ++ ] = SepTokenId ;
341366
342367 if ( tokenIds1 is not null )
343368 {
344369 foreach ( int id in tokenIds1 )
345370 {
371+ if ( buffer . Length <= written )
372+ {
373+ written = 0 ;
374+ return OperationStatus . DestinationTooSmall ;
375+ }
346376 buffer [ written ++ ] = id ;
347377 }
348378
379+ if ( buffer . Length <= written )
380+ {
381+ written = 0 ;
382+ return OperationStatus . DestinationTooSmall ;
383+ }
349384 buffer [ written ++ ] = SepTokenId ;
350385 }
351386
@@ -367,11 +402,22 @@ public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnum
367402 throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
368403 }
369404
370- int capacity = alreadyHasSpecialTokens ?
371- tokenIds0 . Count ( ) + ( tokenIds1 ? . Count ( ) ?? 0 ) :
372- tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : 1 ) ; // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
405+ List < int > mask ;
406+ if ( tokenIds0 is ICollection < int > c1 )
407+ {
408+ int capcity = c1 . Count + 2 ;
409+
410+ if ( tokenIds1 is not null )
411+ {
412+ capcity += tokenIds1 is ICollection < int > c2 ? c2 . Count + 1 : c1 . Count + 1 ;
413+ }
373414
374- List < int > mask = new List < int > ( capacity : capacity ) ;
415+ mask = new List < int > ( capcity ) ;
416+ }
417+ else
418+ {
419+ mask = new List < int > ( 10 ) ;
420+ }
375421
376422 if ( ! alreadyHasSpecialTokens )
377423 {
@@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
420466 throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
421467 }
422468
423- int capacity = alreadyHasSpecialTokens ?
424- tokenIds0 . Count ( ) + ( tokenIds1 ? . Count ( ) ?? 0 ) :
425- tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ; // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
426-
427469 written = 0 ;
428- if ( buffer . Length < capacity )
429- {
430- return OperationStatus . DestinationTooSmall ;
431- }
432-
433470 if ( ! alreadyHasSpecialTokens )
434471 {
472+ if ( buffer . Length < 1 )
473+ {
474+ return OperationStatus . DestinationTooSmall ;
475+ }
435476 buffer [ written ++ ] = 1 ; // CLS
477+
436478 foreach ( int id in tokenIds0 )
437479 {
480+ if ( buffer . Length <= written )
481+ {
482+ written = 0 ;
483+ return OperationStatus . DestinationTooSmall ;
484+ }
438485 buffer [ written ++ ] = 0 ;
439486 }
487+
488+ if ( buffer . Length <= written )
489+ {
490+ written = 0 ;
491+ return OperationStatus . DestinationTooSmall ;
492+ }
440493 buffer [ written ++ ] = 1 ; // SEP
441494
442495 if ( tokenIds1 is not null )
443496 {
444497 foreach ( int id in tokenIds1 )
445498 {
499+ if ( buffer . Length <= written )
500+ {
501+ written = 0 ;
502+ return OperationStatus . DestinationTooSmall ;
503+ }
446504 buffer [ written ++ ] = 0 ;
447505 }
506+
507+ if ( buffer . Length <= written )
508+ {
509+ written = 0 ;
510+ return OperationStatus . DestinationTooSmall ;
511+ }
448512 buffer [ written ++ ] = 1 ; // SEP
449513 }
450514
@@ -453,13 +517,23 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
453517
454518 foreach ( int id in tokenIds0 )
455519 {
520+ if ( buffer . Length <= written )
521+ {
522+ written = 0 ;
523+ return OperationStatus . DestinationTooSmall ;
524+ }
456525 buffer [ written ++ ] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0 ;
457526 }
458527
459528 if ( tokenIds1 is not null )
460529 {
461530 foreach ( int id in tokenIds1 )
462531 {
532+ if ( buffer . Length <= written )
533+ {
534+ written = 0 ;
535+ return OperationStatus . DestinationTooSmall ;
536+ }
463537 buffer [ written ++ ] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0 ;
464538 }
465539 }
@@ -484,21 +558,38 @@ public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> token
484558 throw new ArgumentNullException ( nameof ( tokenIds0 ) ) ;
485559 }
486560
487- // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
488- int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
561+ List < int > typeIds ;
562+ if ( tokenIds0 is ICollection < int > c1 )
563+ {
564+ int capacity = c1 . Count + 2 ; // Add 2 for [CLS] and [SEP] tokens.
565+
566+ if ( tokenIds1 is not null )
567+ {
568+ capacity += tokenIds1 is ICollection < int > c2 ? c2 . Count + 1 : c1 . Count + 1 ;
569+ }
489570
490- List < int > typeIds = new List < int > ( capacity ) ;
491- for ( int i = 0 ; i < tokenIds0 . Count ( ) + 2 ; i ++ ) // Add 2 for [CLS] and [SEP] tokens.
571+ typeIds = new List < int > ( capacity ) ;
572+ }
573+ else
574+ {
575+ typeIds = new List < int > ( 10 ) ;
576+ }
577+
578+ foreach ( var id in tokenIds0 )
492579 {
493580 typeIds . Add ( 0 ) ;
494581 }
582+ typeIds . Add ( 0 ) ; // [CLS]
583+ typeIds . Add ( 0 ) ; // [SEP]
495584
496585 if ( tokenIds1 is not null )
497586 {
498- for ( int i = 0 ; i < tokenIds1 . Count ( ) + 1 ; i ++ ) // Add 1 for [SEP] token.
587+ foreach ( int id in tokenIds1 )
499588 {
500589 typeIds . Add ( 1 ) ;
501590 }
591+
592+ typeIds . Add ( 1 ) ; // [SEP]
502593 }
503594
504595 return typeIds ;
@@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds
515606
516607 // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
517608 int capacity = tokenIds0 . Count ( ) + 2 + ( tokenIds1 is null ? 0 : tokenIds1 . Count ( ) + 1 ) ;
518- if ( buffer . Length < capacity )
609+ if ( buffer . Length < 2 )
519610 {
520611 return OperationStatus . DestinationTooSmall ;
521612 }
613+ buffer [ written ++ ] = 0 ; // [CLS]
614+ buffer [ written ++ ] = 0 ; // [SEP]
522615
523- for ( int i = 0 ; i < tokenIds0 . Count ( ) + 2 ; i ++ ) // Add 2 for [CLS] and [SEP] tokens.
616+ foreach ( int id in tokenIds0 )
524617 {
618+ if ( buffer . Length <= written )
619+ {
620+ written = 0 ;
621+ return OperationStatus . DestinationTooSmall ;
622+ }
525623 buffer [ written ++ ] = 0 ;
526624 }
527625
528626 if ( tokenIds1 is not null )
529627 {
530- for ( int i = 0 ; i < tokenIds1 . Count ( ) + 1 ; i ++ ) // Add 1 for [SEP] token.
628+ foreach ( int id in tokenIds1 )
531629 {
630+ if ( buffer . Length <= written )
631+ {
632+ written = 0 ;
633+ return OperationStatus . DestinationTooSmall ;
634+ }
532635 buffer [ written ++ ] = 1 ;
533636 }
637+
638+ if ( buffer . Length < written )
639+ {
640+ return OperationStatus . DestinationTooSmall ;
641+ }
642+ buffer [ written ++ ] = 1 ; // [SEP]
534643 }
535644
536645 return OperationStatus . Done ;
0 commit comments