1+ // Licensed to the .NET Foundation under one or more agreements.
2+ // The .NET Foundation licenses this file to you under the MIT license.
3+ // See the LICENSE file in the project root for more information.
4+
5+ using System ;
6+ using System . Buffers ;
7+ using System . Diagnostics ;
8+ using System . Globalization ;
9+ using System . Runtime . CompilerServices ;
10+ using System . Runtime . InteropServices ;
11+ using System . Text ;
12+
13+ namespace Microsoft . ML . Tokenizers
14+ {
15+ /// <summary>
16+ /// Normalizer that performs the Bert model normalization.
17+ /// </summary>
18+ internal sealed class BertNormalizer : Normalizer
19+ {
20+ private readonly bool _doLowerCase ;
21+ private readonly bool _tokenizeChineseChars ;
22+ private readonly bool _stripAccents ;
23+
24+ /// <summary>
25+ /// Normalize the input string.
26+ /// </summary>
27+ /// <param name="original">The input string to normalize.</param>
28+ /// <returns>The normalized string.</returns>
29+ public override string Normalize ( string original )
30+ {
31+ if ( string . IsNullOrEmpty ( original ) )
32+ {
33+ return string . Empty ;
34+ }
35+
36+ if ( _stripAccents )
37+ {
38+ original = original . Normalize ( NormalizationForm . FormD ) ;
39+ }
40+
41+ Span < char > casingBuffer = stackalloc char [ 10 ] ;
42+ char [ ] buffer = ArrayPool < char > . Shared . Rent ( original . Length ) ;
43+ int index = 0 ;
44+
45+ for ( int i = 0 ; i < original . Length ; i ++ )
46+ {
47+ char c = original [ i ] ;
48+
49+ if ( c == '\u0000 ' || c == '\uFFFD ' )
50+ {
51+ continue ;
52+ }
53+
54+ int inc = 0 ;
55+ int codePoint = ( int ) c ;
56+ if ( char . IsHighSurrogate ( c ) && i + 1 < original . Length && char . IsLowSurrogate ( original [ i + 1 ] ) )
57+ {
58+ codePoint = char . ConvertToUtf32 ( c , original [ i + 1 ] ) ;
59+ inc = 1 ;
60+ }
61+
62+ UnicodeCategory category = CharUnicodeInfo . GetUnicodeCategory ( original , i ) ;
63+
64+ if ( category == UnicodeCategory . Control )
65+ {
66+ i += inc ;
67+ continue ;
68+ }
69+
70+ if ( category == UnicodeCategory . SpaceSeparator )
71+ {
72+ InsertChar ( ref buffer , ref index , ' ' ) ;
73+ i += inc ;
74+ continue ;
75+ }
76+
77+ if ( _stripAccents && category is UnicodeCategory . NonSpacingMark or UnicodeCategory . SpacingCombiningMark )
78+ {
79+ i += inc ;
80+ continue ;
81+ }
82+
83+ if ( _doLowerCase && category == UnicodeCategory . UppercaseLetter )
84+ {
85+ int length = original . AsSpan ( ) . Slice ( i , inc + 1 ) . ToLowerInvariant ( casingBuffer ) ;
86+ Debug . Assert ( length > 0 ) ;
87+
88+ InsertSpan ( ref buffer , ref index , casingBuffer . Slice ( 0 , length ) ) ;
89+
90+ i += inc ;
91+ continue ;
92+ }
93+
94+ if ( _tokenizeChineseChars && IsChineseChar ( codePoint ) )
95+ {
96+ InsertChar ( ref buffer , ref index , ' ' ) ;
97+ InsertChar ( ref buffer , ref index , c ) ;
98+ if ( inc > 0 )
99+ {
100+ InsertChar ( ref buffer , ref index , original [ i + 1 ] ) ;
101+ }
102+ InsertChar ( ref buffer , ref index , ' ' ) ;
103+
104+ i += inc ;
105+ continue ;
106+ }
107+
108+ InsertChar ( ref buffer , ref index , c ) ;
109+ if ( inc > 0 )
110+ {
111+ InsertChar ( ref buffer , ref index , original [ i + 1 ] ) ;
112+ }
113+ i += inc ;
114+ }
115+
116+ string result = index == 0 ? string . Empty : new string ( buffer , 0 , index ) . Normalize ( NormalizationForm . FormC ) ;
117+ ArrayPool < char > . Shared . Return ( buffer ) ;
118+ return result ;
119+ }
120+
121+ /// <summary>
122+ /// Normalize the input character span.
123+ /// </summary>
124+ /// <param name="original">The input character span to normalize.</param>
125+ /// <returns>The normalized string.</returns>
126+ public override string Normalize ( ReadOnlySpan < char > original )
127+ {
128+ if ( original . IsEmpty )
129+ {
130+ return string . Empty ;
131+ }
132+
133+ return Normalize ( original . ToString ( ) ) ;
134+ }
135+
136+ /// <summary>
137+ /// Initializes a new instance of the <see cref="BertNormalizer"/> class.
138+ /// </summary>
139+ /// <param name="doLowerCase">Whether to lowercase the input.</param>
140+ /// <param name="tokenizeChineseChars">Whether to tokenize Chinese characters.</param>
141+ /// <param name="stripAccents">Whether to strip accents from the input.</param>
142+ public BertNormalizer ( bool doLowerCase , bool tokenizeChineseChars , bool stripAccents )
143+ {
144+ _doLowerCase = doLowerCase ;
145+ _tokenizeChineseChars = tokenizeChineseChars ;
146+ _stripAccents = stripAccents ;
147+ }
148+
149+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
150+ private static void InsertChar ( ref char [ ] buffer , ref int index , char c )
151+ {
152+ if ( index >= buffer . Length )
153+ {
154+ Helpers . ArrayPoolGrow ( ref buffer , index + 40 ) ;
155+ }
156+
157+ buffer [ index ++ ] = c ;
158+ }
159+
160+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
161+ private static void InsertSpan ( ref char [ ] buffer , ref int index , Span < char > chars )
162+ {
163+ if ( index + buffer . Length >= buffer . Length )
164+ {
165+ Helpers . ArrayPoolGrow ( ref buffer , index + buffer . Length + 10 ) ;
166+ }
167+
168+ chars . CopyTo ( buffer . AsSpan ( index ) ) ;
169+ index += chars . Length ;
170+ }
171+
172+ /// <summary>
173+ /// Checks whether CP is the codepoint of a CJK character.
174+ /// This defines a "chinese character" as anything in the CJK Unicode block:
175+ /// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
176+ /// </summary>
177+ /// <param name="codePoint">The codepoint to check.</param>
178+ /// <remarks>
179+ /// The CJK Unicode block is NOT all Japanese and Korean characters,
180+ /// despite its name. The modern Korean Hangul alphabet is a different block,
181+ /// as is Japanese Hiragana and Katakana. Those alphabets are used to write
182+ /// space-separated words, so they are not treated specially and handled
183+ /// like the all of the other languages.
184+ /// </remarks>
185+ /// <returns>True if the codepoint is a CJK character, false otherwise.</returns>
186+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
187+ private static bool IsChineseChar ( int codePoint )
188+ {
189+ return ( codePoint > 0x3400 ) && // Quick check to exit early if the codepoint is outside of the CJK range
190+ ( ( ( uint ) ( codePoint - 0x3400 ) <= ( uint ) ( 0x4DBF - 0x3400 ) ) ||
191+ ( ( uint ) ( codePoint - 0xF900 ) <= ( uint ) ( 0xFAFF - 0xF900 ) ) ||
192+ ( ( uint ) ( codePoint - 0x4E00 ) <= ( uint ) ( 0x9FFF - 0x4E00 ) ) ||
193+ ( ( uint ) ( codePoint - 0x20000 ) <= ( uint ) ( 0x2A6DF - 0x20000 ) ) ||
194+ ( ( uint ) ( codePoint - 0x2A700 ) <= ( uint ) ( 0x2B73F - 0x2A700 ) ) ||
195+ ( ( uint ) ( codePoint - 0x2B740 ) <= ( uint ) ( 0x2B81F - 0x2B740 ) ) ||
196+ ( ( uint ) ( codePoint - 0x2B820 ) <= ( uint ) ( 0x2CEAF - 0x2B820 ) ) ||
197+ ( ( uint ) ( codePoint - 0x2F800 ) <= ( uint ) ( 0x2FA1F - 0x2F800 ) ) ) ;
198+ }
199+ }
200+ }
0 commit comments