-
Notifications
You must be signed in to change notification settings - Fork 812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BERT example in torchtext #767
Changes from 33 commits
1ae1a5b
6c097a4
45c68ad
a07b218
9c38629
8ee4753
3e2d13b
bbddfc4
39e08ee
9cc864c
f09f191
c8f1498
6b01c9a
dd9f239
f8ffbaa
b5b0ec1
12bdf5f
06c1279
5ec0382
9eab8b4
b7f1dd4
a6aea63
5ea8d83
fa7cc14
22eb9c9
f50bb9d
54b1d5c
0010f98
3ec1528
cc60313
441edff
9ddcb82
49c558b
9dbdf80
8982ab8
98b76b8
015a8f9
c9a9c56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
BERT with torchtext | ||
+++++++++ | ||
|
||
This example shows how to train a BERT model with PyTorch and torchtext only. Then, we fine-tune the pre-trained BERT for the question-answer task. | ||
|
||
|
||
Generate pre-trained BERT | ||
------------------------- | ||
|
||
Train the BERT model with masked language modeling task and next-sentence task. Run the tasks on a local GPU or CPU: | ||
|
||
python mlm_task.py | ||
python ns_task.py | ||
|
||
or run the tasks on a SLURM powered cluster with Distributed Data Parallel (DDP): | ||
|
||
srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python mlm_task.py --parallel DDP --log-interval 600 --dataset BookCorpus | ||
|
||
srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python ns_task.py --parallel DDP --bert-model mlm_bert.pt --dataset BookCorpus | ||
|
||
The result ppl of mlm_task is 18.97899 for the test set. | ||
The result loss of ns_task is 0.05446 for the test set. | ||
|
||
Fine-tune pre-trained BERT for question-answer task | ||
--------------------------------------------------- | ||
|
||
With SQuAD dataset, the pre-trained BERT is used for question-answer task: | ||
|
||
python qa_task.py --bert-model ns_bert.pt --epochs 30 | ||
|
||
The pre-trained BERT models and vocab are available: | ||
|
||
* `bert_vocab.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.pt>`_ | ||
* `mlm_bert.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/mlm_bert.pt>`_ | ||
* `ns_bert.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/ns_bert.pt>`_ | ||
|
||
Structure of the example | ||
======================== | ||
|
||
model.py | ||
-------- | ||
|
||
This file defines the Transformer and MultiheadAttention models used for BERT. The embedding layer include PositionalEncoding and TokenTypeEncoding layers. MLMTask, NextSentenceTask, and QuestionAnswerTask are the models for the three tasks mentioned above. | ||
|
||
data.py | ||
------- | ||
|
||
This file provides a few datasets required to train the BERT model and question-answer task. Please note that BookCorpus dataset is not available publicly. | ||
|
||
|
||
mlm_task.py, ns_task.py, qa_task.py | ||
----------------------------------- | ||
|
||
Those three files define the train/valid/test process for the tasks. | ||
|
||
|
||
metrics.py | ||
---------- | ||
|
||
This file provides two metrics (F1 and exact score) for question-answer task | ||
|
||
|
||
utils.py | ||
-------- | ||
|
||
This file provides a few utils used by the three tasks. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import glob | ||
import torch | ||
import logging | ||
from torchtext.data.utils import get_tokenizer | ||
import random | ||
|
||
|
||
class LanguageModelingDataset(torch.utils.data.Dataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switched to the one in |
||
"""Defines a dataset for language modeling. | ||
""" | ||
|
||
def __init__(self, data, vocab): | ||
"""Initiate language modeling dataset. | ||
""" | ||
|
||
super(LanguageModelingDataset, self).__init__() | ||
self.data = data | ||
self.vocab = vocab | ||
|
||
def __getitem__(self, i): | ||
return self.data[i] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __iter__(self): | ||
for x in self.data: | ||
yield x | ||
|
||
def get_vocab(self): | ||
return self.vocab | ||
|
||
|
||
################################################################### | ||
# Set up dataset for book corpus | ||
################################################################### | ||
def BookCorpus(vocab, tokenizer=get_tokenizer("basic_english"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's worth moving this into experimental as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about this. But the original data for BookCorpus comes from FAIR cluster. |
||
data_select=('train', 'test', 'valid'), removed_tokens=[], | ||
min_sentence_len=None): | ||
|
||
if isinstance(data_select, str): | ||
data_select = [data_select] | ||
if not set(data_select).issubset(set(('train', 'test', 'valid'))): | ||
raise TypeError('data_select is not supported!') | ||
|
||
extracted_files = glob.glob('/datasets01/bookcorpus/021819/*/*.txt') | ||
random.seed(1000) | ||
random.shuffle(extracted_files) | ||
|
||
num_files = len(extracted_files) | ||
_path = {'train': extracted_files[:(num_files // 20 * 17)], | ||
'test': extracted_files[(num_files // 20 * 17):(num_files // 20 * 18)], | ||
'valid': extracted_files[(num_files // 20 * 18):]} | ||
|
||
data = {} | ||
for item in _path.keys(): | ||
data[item] = [] | ||
logging.info('Creating {} data'.format(item)) | ||
tokens = [] | ||
for txt_file in _path[item]: | ||
with open(txt_file, 'r', encoding="utf8", errors='ignore') as f: | ||
for line in f.readlines(): | ||
_tokens = tokenizer(line.strip()) | ||
if min_sentence_len: | ||
if len(_tokens) >= min_sentence_len: | ||
tokens.append([vocab.stoi[token] for token in _tokens]) | ||
else: | ||
tokens += [vocab.stoi[token] for token in _tokens] | ||
data[item] = tokens | ||
|
||
for key in data_select: | ||
if data[key] == []: | ||
raise TypeError('Dataset {} is empty!'.format(key)) | ||
if min_sentence_len: | ||
return tuple(LanguageModelingDataset(data[d], vocab) | ||
for d in data_select) | ||
else: | ||
return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab) | ||
for d in data_select) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import collections | ||
import re | ||
import string | ||
|
||
|
||
def compute_qa_exact(ans_pred_tokens_samples): | ||
|
||
''' | ||
Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens), | ||
([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens), | ||
... | ||
([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)] | ||
ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example'] | ||
Output: exact score of the samples | ||
''' | ||
|
||
def normalize_txt(text): | ||
# lower case | ||
text = text.lower() | ||
|
||
# remove punc | ||
exclude = set(string.punctuation) | ||
text = "".join(ch for ch in text if ch not in exclude) | ||
|
||
# remove articles | ||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) | ||
text = re.sub(regex, " ", text) | ||
|
||
# white space fix | ||
return " ".join(text.split()) | ||
|
||
exact_scores = [] | ||
for (ans_tokens, pred_tokens) in ans_pred_tokens_samples: | ||
pred_str = " ".join(pred_tokens) | ||
candidate_score = [] | ||
for item in ans_tokens: | ||
ans_str = " ".join(item) | ||
candidate_score.append(int(normalize_txt(ans_str) == normalize_txt(pred_str))) | ||
exact_scores.append(max(candidate_score)) | ||
return 100.0 * sum(exact_scores) / len(exact_scores) | ||
|
||
|
||
def compute_qa_f1(ans_pred_tokens_samples): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be turned into a generic f1 metric that we can throw into torchtext? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
''' | ||
Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens), | ||
([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens), | ||
... | ||
([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)] | ||
ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example'] | ||
Output: f1 score of the samples | ||
''' | ||
def sample_f1(ans_tokens, pred_tokens): | ||
common = collections.Counter(ans_tokens) & collections.Counter(pred_tokens) | ||
num_same = sum(common.values()) | ||
if len(ans_tokens) == 0 or len(pred_tokens) == 0: | ||
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise | ||
return int(ans_tokens == pred_tokens) | ||
if num_same == 0: | ||
return 0 | ||
precision = 1.0 * num_same / len(pred_tokens) | ||
recall = 1.0 * num_same / len(ans_tokens) | ||
f1 = (2 * precision * recall) / (precision + recall) | ||
return f1 | ||
|
||
f1_scores = [] | ||
for (ans_tokens, pred_tokens) in ans_pred_tokens_samples: | ||
candidate_score = [] | ||
for item in ans_tokens: | ||
candidate_score.append(sample_f1(item, pred_tokens)) | ||
f1_scores.append(max(candidate_score)) | ||
return 100.0 * sum(f1_scores) / len(f1_scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could start incrementally porting this over into the experimental folder which will also help with cleanup. These datasets seem useful in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this. I would also help make this pull request shorter.