-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
# Add whole word mask support for lm fine-tune #7925
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
2f610e7
ADD: add whole word mask proxy for both eng and chinese
wlhgtc eeb0e02
MOD: adjust format
wlhgtc dde73e2
MOD: reformat code
wlhgtc 6470d5b
MOD: update import
wlhgtc 30928b0
MOD: fix bug
wlhgtc dc7794f
MOD: add import
wlhgtc 1fd0e15
MOD: fix bug
wlhgtc 0aec80a
MOD: decouple code and update readme
wlhgtc 156d40a
MOD: reformat code
wlhgtc 30668ee
Update examples/language-modeling/README.md
wlhgtc d1c4d25
Update examples/language-modeling/README.md
wlhgtc 56bf427
Update examples/language-modeling/run_language_modeling.py
wlhgtc 745e49d
Update examples/language-modeling/run_language_modeling.py
wlhgtc cf909c0
Update examples/language-modeling/run_language_modeling.py
wlhgtc 960e4c8
Update examples/language-modeling/run_language_modeling.py
wlhgtc 45265b7
change wwm to whole_word_mask
wlhgtc 68d8832
reformat code
wlhgtc ea21325
reformat
wlhgtc bf976a4
format
wlhgtc 9b50ee3
Code quality
sgugger a285531
ADD: update chinese ref readme
wlhgtc ed15eba
MOD: small changes
wlhgtc 87ab48c
MOD: small changes2
wlhgtc beeb7aa
update readme
wlhgtc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import argparse | ||
import json | ||
from typing import List | ||
|
||
from ltp import LTP | ||
from transformers.tokenization_bert import BertTokenizer | ||
|
||
|
||
def _is_chinese_char(cp): | ||
"""Checks whether CP is the codepoint of a CJK character.""" | ||
# This defines a "chinese character" as anything in the CJK Unicode block: | ||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | ||
# | ||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | ||
# despite its name. The modern Korean Hangul alphabet is a different block, | ||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | ||
# space-separated words, so they are not treated specially and handled | ||
# like the all of the other languages. | ||
if ( | ||
(cp >= 0x4E00 and cp <= 0x9FFF) | ||
or (cp >= 0x3400 and cp <= 0x4DBF) # | ||
or (cp >= 0x20000 and cp <= 0x2A6DF) # | ||
or (cp >= 0x2A700 and cp <= 0x2B73F) # | ||
or (cp >= 0x2B740 and cp <= 0x2B81F) # | ||
or (cp >= 0x2B820 and cp <= 0x2CEAF) # | ||
or (cp >= 0xF900 and cp <= 0xFAFF) | ||
or (cp >= 0x2F800 and cp <= 0x2FA1F) # | ||
): # | ||
return True | ||
|
||
return False | ||
|
||
|
||
def is_chinese(word: str): | ||
# word like '180' or '身高' or '神' | ||
for char in word: | ||
char = ord(char) | ||
if not _is_chinese_char(char): | ||
return 0 | ||
return 1 | ||
|
||
|
||
def get_chinese_word(tokens: List[str]): | ||
word_set = set() | ||
|
||
for token in tokens: | ||
chinese_word = len(token) > 1 and is_chinese(token) | ||
if chinese_word: | ||
word_set.add(token) | ||
word_list = list(word_set) | ||
return word_list | ||
|
||
|
||
def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()): | ||
if not chinese_word_set: | ||
return bert_tokens | ||
max_word_len = max([len(w) for w in chinese_word_set]) | ||
|
||
bert_word = bert_tokens | ||
start, end = 0, len(bert_word) | ||
while start < end: | ||
single_word = True | ||
if is_chinese(bert_word[start]): | ||
l = min(end - start, max_word_len) | ||
for i in range(l, 1, -1): | ||
whole_word = "".join(bert_word[start : start + i]) | ||
if whole_word in chinese_word_set: | ||
for j in range(start + 1, start + i): | ||
bert_word[j] = "##" + bert_word[j] | ||
start = start + i | ||
single_word = False | ||
break | ||
if single_word: | ||
start += 1 | ||
return bert_word | ||
|
||
|
||
def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer): | ||
ltp_res = [] | ||
|
||
for i in range(0, len(lines), 100): | ||
res = ltp_tokenizer.seg(lines[i : i + 100])[0] | ||
res = [get_chinese_word(r) for r in res] | ||
ltp_res.extend(res) | ||
assert len(ltp_res) == len(lines) | ||
|
||
bert_res = [] | ||
for i in range(0, len(lines), 100): | ||
res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512) | ||
bert_res.extend(res["input_ids"]) | ||
assert len(bert_res) == len(lines) | ||
|
||
ref_ids = [] | ||
for input_ids, chinese_word in zip(bert_res, ltp_res): | ||
|
||
input_tokens = [] | ||
for id in input_ids: | ||
token = bert_tokenizer._convert_id_to_token(id) | ||
input_tokens.append(token) | ||
input_tokens = add_sub_symbol(input_tokens, chinese_word) | ||
ref_id = [] | ||
# We only save pos of chinese subwords start with ##, which mean is part of a whole word. | ||
for i, token in enumerate(input_tokens): | ||
if token[:2] == "##": | ||
clean_token = token[2:] | ||
# save chinese tokens' pos | ||
if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)): | ||
ref_id.append(i) | ||
ref_ids.append(ref_id) | ||
|
||
assert len(ref_ids) == len(bert_res) | ||
|
||
return ref_ids | ||
|
||
|
||
def main(args): | ||
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm) | ||
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp) | ||
with open(args.file_name, "r", encoding="utf-8") as f: | ||
data = f.readlines() | ||
|
||
ltp_tokenizer = LTP(args.ltp) # faster in GPU device | ||
bert_tokenizer = BertTokenizer.from_pretrained(args.bert) | ||
|
||
ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer) | ||
|
||
with open(args.save_path, "w", encoding="utf-8") as f: | ||
data = [json.dumps(ref) + "\n" for ref in ref_ids] | ||
f.writelines(data) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="prepare_chinese_ref") | ||
parser.add_argument( | ||
"--file_name", | ||
type=str, | ||
default="./resources/chinese-demo.txt", | ||
help="file need process, same as training data in lm", | ||
) | ||
parser.add_argument( | ||
"--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path" | ||
) | ||
parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer") | ||
parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res") | ||
|
||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave one version with just mlm and no wwm first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I update my readme, it's same with English version if only mlm.