Skip to content

Improve BERT tokenization for accented characters and non-latin scripts #5740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 42 additions & 93 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8897,37 +8897,46 @@ struct llm_tokenizer_wpm {
}

std::vector<std::string> preprocess(const std::string & text) {
std::string ori_str = normalize(text);
uint64_t ori_size = ori_str.size();
// normalalization form D
std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
std::vector<uint32_t> nfd_codepoints;
for (uint32_t code : codepoints) {
auto it = nfd_map.find(code);
if (it != nfd_map.end()) {
for (uint32_t c : it->second) {
nfd_codepoints.push_back(c);
}
} else {
nfd_codepoints.push_back(code);
}
}

// single punct / single symbol / single digit
// baseline: add whitespace on the left and right of punct and chinese characters
std::vector<std::string> words;
// strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation
std::string new_str = "";
uint64_t i = 0;
while (i < ori_size) {
int utf_char_len = utf8_len(ori_str[i]);
if ((utf_char_len == 1) && ispunct(ori_str[i])) {
new_str += " ";
new_str += ori_str[i];
new_str += " ";
i += 1;
for (uint32_t code : nfd_codepoints) {
int type = codepoint_type(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
continue;
}
else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
code = to_lower(code);
if (type == CODEPOINT_TYPE_WHITESPACE) {
code = ' ';
}
std::string s = codepoint_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " ";
new_str += ori_str.substr(i, 3);
new_str += s;
new_str += " ";
i += 3;
}
else {
new_str += ori_str[i];
i += 1;
} else {
new_str += s;
}
}

// split by whitespace
uint64_t l = 0;
uint64_t r = 0;
std::vector<std::string> words;
while (r < new_str.size()) {
// if is whitespace
if (isspace(new_str[r])) {
Expand All @@ -8945,47 +8954,22 @@ struct llm_tokenizer_wpm {
return words;
}

std::string normalize(const std::string & text) {
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
std::string text2 = strip_accents(text);
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
char c = text2[i];
if (c >= 'A' && c <= 'Z') {
text2[i] = c - 'A' + 'a';
}
uint32_t to_lower(uint32_t code) {
if (
(code >= 0x041 && code <= 0x05A) || // latin
(code >= 0x391 && code <= 0x3A9) || // greek
(code >= 0x410 && code <= 0x42F) // cyrillic
) {
return code + 32;
}
return text2;
return code;
}

bool is_chinese_char(const std::string & str) {
int len = str.length();
unsigned int codepoint = 0;
int num_bytes = 0;
int i = 0;
unsigned char ch = static_cast<unsigned char>(str[i]);
if (ch <= 0x7f) {
codepoint = ch;
num_bytes = 1;
} else if ((ch >> 5) == 0x06) {
codepoint = ch & 0x1f;
num_bytes = 2;
} else if ((ch >> 4) == 0x0e) {
codepoint = ch & 0x0f;
num_bytes = 3;
} else if ((ch >> 3) == 0x1e) {
codepoint = ch & 0x07;
num_bytes = 4;
}
for (int j = 1; j < num_bytes; ++j) {
if (i + j >= len) {
return false; // incomplete UTF-8 character
}
unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
if ((next_ch >> 6) != 0x02) {
return false; // invalid trailing byte
}
codepoint = (codepoint << 6) | (next_ch & 0x3f);
}
bool is_ascii_punct(uint32_t code) {
return code < 256 && ispunct(code);
}

bool is_chinese_char(uint32_t codepoint) {
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
Expand All @@ -9001,41 +8985,6 @@ struct llm_tokenizer_wpm {
return false;
}

std::string strip_accents(const std::string & input_string) {
std::string resultString;
std::map<std::string, char> accent_map = {
{"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
{"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
{"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
{"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
{"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
{"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
{"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
{"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
{"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
};

for (size_t i = 0; i < input_string.length();) {
int len = utf8_len(input_string[i]);
std::string curChar = input_string.substr(i, len);
auto iter = accent_map.find(curChar);
if (iter != accent_map.end()) {
resultString += iter->second;
} else {
resultString += curChar;
}
i += len;
}

return resultString;
}

static size_t utf8_len(char src) {
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}

const llama_vocab & vocab;
};

Expand Down
Loading