Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Bert #137

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion Sources/Tokenizers/Decoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ extension Decoder {

enum DecoderType: String {
case Sequence
// case WordPiece
case WordPiece
case ByteLevel
case Replace
case ByteFallback
Expand All @@ -47,11 +47,47 @@ struct DecoderFactory {
case .Fuse : return FuseDecoder(config: config)
case .Strip : return StripDecoder(config: config)
case .Metaspace : return MetaspaceDecoder(config: config)
case .WordPiece : return WordPieceDecoder(config: config)
default : fatalError("Unsupported Decoder type: \(typeName)")
}
}
}

class WordPieceDecoder: Decoder {
let prefix: String
let cleanup: Bool

required public init(config: Config) {
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
self.prefix = prefix
self.cleanup = config.cleanup?.boolValue ?? false
}

func decode(tokens: [String]) -> [String] {
var newTokens = [String]()
newTokens.reserveCapacity(tokens.count)
for (index, token) in tokens.enumerated() {
var decodedToken = token
if index != 0 {
if decodedToken.hasPrefix(prefix) {
decodedToken = String(decodedToken.dropFirst(prefix.count))
} else {
decodedToken = " \(decodedToken)"
}
}
if cleanup {
decodedToken = cleanUpTokenization(decodedToken)
}
newTokens.append(decodedToken)
}
return newTokens
}

private func cleanUpTokenization(_ token: String) -> String {
return token.trimmingCharacters(in: .whitespacesAndNewlines)
}
}

class DecoderSequence: Decoder {
let decoders: [Decoder]

Expand Down
3 changes: 2 additions & 1 deletion Sources/Tokenizers/Normalizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum NormalizerType: String {
case NFKD
case NFKC
case Bert
case BertNormalizer
case Precompiled
case StripAccents
case Strip
Expand All @@ -51,7 +52,7 @@ struct NormalizerFactory {
case .NFC: return NFCNormalizer(config: config)
case .NFKD: return NFKDNormalizer(config: config)
case .NFKC: return NFKCNormalizer(config: config)
case .Bert: return BertNormalizer(config: config)
case .Bert, .BertNormalizer: return BertNormalizer(config: config)
case .Precompiled: return PrecompiledNormalizer(config: config)
case .StripAccents: return StripAccentsNormalizer(config: config)
case .Strip: return StripNormalizer(config: config)
Expand Down
18 changes: 16 additions & 2 deletions Sources/Tokenizers/PreTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum PreTokenizerType: String {
case Whitespace
case WhitespaceSplit
case Metaspace
case BertPreTokenizer
// Several more to be supported
case Unknown = ""
}
Expand All @@ -63,11 +64,25 @@ struct PreTokenizerFactory {
case .Split: return SplitPreTokenizer(config: config)
case .Whitespace, .WhitespaceSplit: return WhitespacePreTokenizer(config: config)
case .Metaspace: return MetaspacePreTokenizer(config: config)
case .BertPreTokenizer: return BertPreTokenizer(config: config)
default: fatalError("Unsupported PreTokenizer type: \(typeName)")
}
}
}

class BertPreTokenizer: PreTokenizer {
let re: String

required init(config: Config) {
// Ref: https://github.com/huggingface/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1002
re = "[^\\s\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]"
}

func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
return text.ranges(of: re).map { String(text[$0]) }
}
}

class PreTokenizerSequence: PreTokenizer {
let preTokenizers: [PreTokenizer]

Expand Down Expand Up @@ -184,11 +199,10 @@ class ByteLevelPreTokenizer: PreTokenizer {
}

class PunctuationPreTokenizer: PreTokenizer {
let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
let re: String

required init(config: Config) {
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
re = "[^\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]+"
}

func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
Expand Down
4 changes: 4 additions & 0 deletions Sources/Tokenizers/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ struct Utils {
}
}

enum Constants {
static let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
}

20 changes: 20 additions & 0 deletions Tests/PreTokenizerTests/PreTokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,24 @@ class PreTokenizerTests: XCTestCase {
["▁Hey", "▁my", "▁friend", "▁", "▁<s>", "▁how", "▁are", "▁you"]
)
}

func testBertPreTokenizer() {
let preTokenizer1 = BertPreTokenizer(config: Config([:]))
XCTAssertEqual(
preTokenizer1.preTokenize(text: "Hey friend!"),
["Hey", "friend", "!"]
)
XCTAssertEqual(
preTokenizer1.preTokenize(text: "Hey friend! How are you?!?"),
["Hey", "friend", "!", "How", "are", "you", "?", "!", "?"]
)
XCTAssertEqual(
preTokenizer1.preTokenize(text: " Hey, friend , what's up? "),
["Hey", ",", "friend", ",", "what", "\'", "s", "up", "?"]
)
XCTAssertEqual(
preTokenizer1.preTokenize(text: " Hey, friend , 0 99 what's up? "),
["Hey", ",", "friend", ",", "0", "99", "what", "\'", "s", "up", "?"]
)
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original Rust implementation also had this test case scenario:

        XCTAssertEqual(
            preTokenizer.preTokenize(text: "野口里佳 Noguchi Rika"),
            ["野", "口", "里", "佳", "Noguchi", "Rika"]
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that as well, but figured that the architecture here is a bit different and it should be handled by BertNormalizer, is this assumption correct @pcuenca?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is great as it is, we can always iterate if needed.

}