Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

# Add whole word mask support for lm fine-tune #7925

Merged
merged 24 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,69 @@ slightly slower (over-fitting takes more epochs).

We use the `--mlm` flag so that the script may change its loss function.

If using whole-word masking, use both the`--mlm` and `--wwm` flags.

```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw

python run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm \
--wwm
```

For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level.

**Q :** Why ref file ?

**A :** Suppose we have a Chinese sentence like : `我喜欢你` The original Chinese-BERT will tokenize it as `['我','喜','欢','你']` in char level.
Actually, `喜欢` is a whole word. For whole word mask proxy, We need res like `['我','喜','##欢','你']`.
So we need a ref file to tell model which pos of BERT original token should be added `##`.

**Q :** Why LTP ?

**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
They use LTP, so if we want to fine-tune their model, we need LTP.

```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt

python chinese_ref.py \
--file_name=$TRAIN_FILE \
--ltp=$LTP_RESOURCE
--bert=$BERT_RESOURCE \
--save_path=$SAVE_PATH
```
Now Chinese Ref is only supported by `LineByLineWithRefDataset` Class, so we need add `line_by_line` flag:


```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
export REF_FILE=/path/to/ref.txt

python run_language_modeling.py \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave one version with just mlm and no wwm first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I update my readme, it's same with English version if only mlm.

--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--chinese_ref_file=$REF_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm
--mlm \
--line_by_line \
--wwm
```

### XLNet and permutation language modeling
Expand Down
147 changes: 147 additions & 0 deletions examples/language-modeling/chinese_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import json
from typing import List

from ltp import LTP
from transformers.tokenization_bert import BertTokenizer


def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True

return False


def is_chinese(word: str):
# word like '180' or '身高' or '神'
for char in word:
char = ord(char)
if not _is_chinese_char(char):
return 0
return 1


def get_chinese_word(tokens: List[str]):
word_set = set()

for token in tokens:
chinese_word = len(token) > 1 and is_chinese(token)
if chinese_word:
word_set.add(token)
word_list = list(word_set)
return word_list


def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
if not chinese_word_set:
return bert_tokens
max_word_len = max([len(w) for w in chinese_word_set])

bert_word = bert_tokens
start, end = 0, len(bert_word)
while start < end:
single_word = True
if is_chinese(bert_word[start]):
l = min(end - start, max_word_len)
for i in range(l, 1, -1):
whole_word = "".join(bert_word[start : start + i])
if whole_word in chinese_word_set:
for j in range(start + 1, start + i):
bert_word[j] = "##" + bert_word[j]
start = start + i
single_word = False
break
if single_word:
start += 1
return bert_word


def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
ltp_res = []

for i in range(0, len(lines), 100):
res = ltp_tokenizer.seg(lines[i : i + 100])[0]
res = [get_chinese_word(r) for r in res]
ltp_res.extend(res)
assert len(ltp_res) == len(lines)

bert_res = []
for i in range(0, len(lines), 100):
res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
bert_res.extend(res["input_ids"])
assert len(bert_res) == len(lines)

ref_ids = []
for input_ids, chinese_word in zip(bert_res, ltp_res):

input_tokens = []
for id in input_ids:
token = bert_tokenizer._convert_id_to_token(id)
input_tokens.append(token)
input_tokens = add_sub_symbol(input_tokens, chinese_word)
ref_id = []
# We only save pos of chinese subwords start with ##, which mean is part of a whole word.
for i, token in enumerate(input_tokens):
if token[:2] == "##":
clean_token = token[2:]
# save chinese tokens' pos
if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
ref_id.append(i)
ref_ids.append(ref_id)

assert len(ref_ids) == len(bert_res)

return ref_ids


def main(args):
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
with open(args.file_name, "r", encoding="utf-8") as f:
data = f.readlines()

ltp_tokenizer = LTP(args.ltp) # faster in GPU device
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)

ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer)

with open(args.save_path, "w", encoding="utf-8") as f:
data = [json.dumps(ref) + "\n" for ref in ref_ids]
f.writelines(data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="prepare_chinese_ref")
parser.add_argument(
"--file_name",
type=str,
default="./resources/chinese-demo.txt",
help="file need process, same as training data in lm",
)
parser.add_argument(
"--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path"
)
parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer")
parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res")

args = parser.parse_args()
main(args)
29 changes: 25 additions & 4 deletions examples/language-modeling/run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
AutoTokenizer,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForWholeWordMask,
HfArgumentParser,
LineByLineTextDataset,
LineByLineWithRefDataset,
PreTrainedTokenizer,
TextDataset,
Trainer,
Expand Down Expand Up @@ -101,6 +103,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
chinese_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input ref data file for whole word mask in Chinees."},
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
Expand All @@ -109,6 +115,7 @@ class DataTrainingArguments:
mlm: bool = field(
default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
)
whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."})
mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
Expand Down Expand Up @@ -143,6 +150,16 @@ def get_dataset(
):
def _dataset(file_path):
if args.line_by_line:
if args.chinese_ref_file is not None:
if not args.whole_word_mask or not args.mlm:
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
return LineByLineWithRefDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
ref_path=args.chinese_ref_file,
)

return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(
Expand Down Expand Up @@ -174,7 +191,6 @@ def main():
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
)

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
Expand Down Expand Up @@ -270,9 +286,14 @@ def main():
max_span_length=data_args.max_span_length,
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)
if data_args.mlm and data_args.whole_word_mask:
data_collator = DataCollatorForWholeWordMask(
tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)

# Initialize our Trainer
trainer = Trainer(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,15 @@
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
LineByLineWithRefDataset,
LineByLineWithSOPTextDataset,
SquadDataset,
SquadDataTrainingArguments,
Expand Down
Loading