Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

# Add whole word mask support for lm fine-tune #7925

Merged
merged 24 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,57 @@ slightly slower (over-fitting takes more epochs).

We use the `--mlm` flag so that the script may change its loss function.

If use whole-word masking, use both `--mlm` and `--wwm` flag(for English Model).
wlhgtc marked this conversation as resolved.
Show resolved Hide resolved

```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw

python run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm \
--wwm
```

For Chinese Model, we need to generate ref files, case it's char level.
wlhgtc marked this conversation as resolved.
Show resolved Hide resolved

```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt

python chinese_ref.py \
--file_name=$TRAIN_FILE \
--ltp=$LTP_RESOURCE
--bert=$BERT_RESOURCE \
--save_path=$SAVE_PATH
```
Then:


```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
export REF_FILE=/path/to/ref.txt

python run_language_modeling.py \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave one version with just mlm and no wwm first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I update my readme, it's same with English version if only mlm.

--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--chinese_ref_file=$REF_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm
--mlm \
--wwm
```

### XLNet and permutation language modeling
Expand Down
148 changes: 148 additions & 0 deletions examples/language-modeling/chinese_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json
import random
import argparse
from ltp import LTP
from transformers.tokenization_bert import BertTokenizer

from typing import List


def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True

return False


def is_chinese(word: str):
# word like '180' or '身高' or '神'
for char in word:
char = ord(char)
if not _is_chinese_char(char):
return 0
return 1


def get_chinese_word(tokens: List[str]):
word_set = set()

for token in tokens:
chinese_word = len(token) > 1 and is_chinese(token)
if chinese_word:
word_set.add(token)
word_list = list(word_set)
return word_list


def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
if not chinese_word_set:
return bert_tokens
max_word_len = max([len(w) for w in chinese_word_set])

bert_word = bert_tokens
start, end = 0, len(bert_word)
while start < end:
single_word = True
if is_chinese(bert_word[start]):
l = min(end - start, max_word_len)
for i in range(l, 1, -1):
whole_word = "".join(bert_word[start : start + i])
if whole_word in chinese_word_set:
for j in range(start + 1, start + i):
bert_word[j] = "##" + bert_word[j]
start = start + i
single_word = False
break
if single_word:
start += 1
return bert_word


def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
ltp_res = []

for i in range(0, len(lines), 100):
res = ltp_tokenizer.seg(lines[i : i + 100])[0]
res = [get_chinese_word(r) for r in res]
ltp_res.extend(res)
assert len(ltp_res) == len(lines)

bert_res = []
for i in range(0, len(lines), 100):
res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
bert_res.extend(res["input_ids"])
assert len(bert_res) == len(lines)

ref_ids = []
for input_ids, chinese_word in zip(bert_res, ltp_res):

input_tokens = []
for id in input_ids:
token = bert_tokenizer._convert_id_to_token(id)
input_tokens.append(token)
input_tokens = add_sub_symbol(input_tokens, chinese_word)
ref_id = []
# We only save pos of chinese subwords start with ##, which mean is part of a whole word.
for i, token in enumerate(input_tokens):
if token[:2] == "##":
clean_token = token[2:]
# save chinese tokens' pos
if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
ref_id.append(i)
ref_ids.append(ref_id)

assert len(ref_ids) == len(bert_res)

return ref_ids


def main(args):
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
with open(args.file_name, "r", encoding="utf-8") as f:
data = f.readlines()

ltp_tokenizer = LTP(args.ltp) # faster in GPU device
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)

ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer)

with open(args.save_path, "w", encoding="utf-8") as f:
data = [json.dumps(ref) + "\n" for ref in ref_ids]
f.writelines(data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="prepare_chinese_ref")
parser.add_argument(
"--file_name",
type=str,
default="./resources/chinese-demo.txt",
help="file need process, same as training data in lm",
)
parser.add_argument(
"--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path"
)
parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer")
parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res")

args = parser.parse_args()
main(args)
32 changes: 25 additions & 7 deletions examples/language-modeling/run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
AutoModelWithLMHead,
AutoTokenizer,
DataCollatorForLanguageModeling,
DataCollatorForWholeWordMask,
DataCollatorForPermutationLanguageModeling,
HfArgumentParser,
LineByLineTextDataset,
LineByLineWithRefDataset,
PreTrainedTokenizer,
TextDataset,
Trainer,
Expand Down Expand Up @@ -101,6 +103,9 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
chinese_ref_file: Optional[str] = field(
default=None, metadata={"help": "An optional input ref data file for whole word mask(wwm) in Chinees."},
wlhgtc marked this conversation as resolved.
Show resolved Hide resolved
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
Expand All @@ -109,6 +114,7 @@ class DataTrainingArguments:
mlm: bool = field(
default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
)
wwm: bool = field(default=False, metadata={"help": "Use Whole Word Mask."})
wlhgtc marked this conversation as resolved.
Show resolved Hide resolved
mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
Expand Down Expand Up @@ -143,6 +149,16 @@ def get_dataset(
):
def _dataset(file_path):
if args.line_by_line:
if args.chinese_ref_file:
wlhgtc marked this conversation as resolved.
Show resolved Hide resolved
if not args.wwm or args.mlm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is not consistent with the error message that wants both of those to be True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

raise ValueError("Need set wwm and mlm to true for Chinese Whole Word Mask")
wlhgtc marked this conversation as resolved.
Show resolved Hide resolved
return LineByLineWithRefDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
ref_path=args.chinese_ref_file,
)

return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(
Expand Down Expand Up @@ -174,7 +190,6 @@ def main():
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
)

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
Expand Down Expand Up @@ -265,14 +280,17 @@ def main():
)
if config.model_type == "xlnet":
data_collator = DataCollatorForPermutationLanguageModeling(
tokenizer=tokenizer,
plm_probability=data_args.plm_probability,
max_span_length=data_args.max_span_length,
tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length,
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)
if data_args.mlm and data_args.wwm:
data_collator = DataCollatorForWholeWordMask(
tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)

# Initialize our Trainer
trainer = Trainer(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForWholeWordMask,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
Expand All @@ -291,6 +292,7 @@
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
LineByLineWithRefDataset,
LineByLineWithSOPTextDataset,
SquadDataset,
SquadDataTrainingArguments,
Expand Down
Loading