Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BERT example in torchtext #767

Merged
merged 38 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1ae1a5b
pull torchBERT 70aa449
May 1, 2020
6c097a4
add dataloader to mlm_task
May 5, 2020
45c68ad
Merge remote-tracking branch 'upstream/master' into torchBERT
May 6, 2020
a07b218
Clean up ns and qa task
May 6, 2020
9c38629
revise WikiText103 to the latest abstraction
May 7, 2020
8ee4753
[BC breaking] Add Sentencepiece torchscript Extension (#755)
mthrok May 8, 2020
3e2d13b
ns uses new WikiText103
May 12, 2020
bbddfc4
update README
May 12, 2020
39e08ee
merge master branch
May 13, 2020
9cc864c
Merge branch 'master' into torchBERT
May 19, 2020
f09f191
switch to pytorch transformer
May 19, 2020
c8f1498
remove init_weights in model
May 19, 2020
6b01c9a
remove EnWik9 dataset from data.py. Use torchtext one
May 19, 2020
dd9f239
revise WikiText103
May 19, 2020
f8ffbaa
Merge branch 'master' into torchBERT
Jun 2, 2020
b5b0ec1
remove language modeling datasets from torchBERT pipeline
Jun 4, 2020
12bdf5f
Merge branch 'master' into torchBERT
Jun 5, 2020
06c1279
Merge branch 'master' into torchBERT
Jun 10, 2020
5ec0382
integrate with torchtext
Jun 10, 2020
9eab8b4
fix ns task
Jun 10, 2020
b7f1dd4
Merge branch 'master' into torchBERT
Jun 16, 2020
a6aea63
Merge branch 'master' into torchBERT
Jun 18, 2020
5ea8d83
switch to MHA container in torchtext
Jun 18, 2020
fa7cc14
Add BookCorpus
Jun 19, 2020
22eb9c9
update epoch to 15 in mlm_task
Jun 19, 2020
f50bb9d
update slurm time
Jun 21, 2020
54b1d5c
update README.md
Jun 22, 2020
0010f98
update epoch ns task
Jun 22, 2020
3ec1528
Merge branch 'master' into torchBERT
Jun 25, 2020
cc60313
add BookCorpus to ns_task
Jun 25, 2020
441edff
Upload pre-trained models and vocab to aws s3
Jul 2, 2020
9ddcb82
Merge branch 'master' into torchBERT
Jul 13, 2020
49c558b
Update docs
Jul 13, 2020
9dbdf80
Merge branch 'master' into torchBERT
Jul 15, 2020
8982ab8
switch to torchtext.nn
Jul 15, 2020
98b76b8
switch to torch.save state_dict
Jul 15, 2020
015a8f9
use LanguageModelingDataset from torchtext
Jul 15, 2020
c9a9c56
Update README file with train/valid/test printout for question-answer…
Jul 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/BERT/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
BERT with torchtext
+++++++++

This example shows how to train a BERT model with PyTorch and torchtext only. Then, we fine-tune the pre-trained BERT for the question-answer task.


Generate pre-trained BERT
-------------------------

Train the BERT model with masked language modeling task and next-sentence task. Run the tasks on a local GPU or CPU:

python mlm_task.py
python ns_task.py

or run the tasks on a SLURM powered cluster with Distributed Data Parallel (DDP):

srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python mlm_task.py --parallel DDP --log-interval 600 --dataset BookCorpus

srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python ns_task.py --parallel DDP --bert-model mlm_bert.pt --dataset BookCorpus

The result ppl of mlm_task is 18.97899 for the test set.
The result loss of ns_task is 0.05446 for the test set.

Fine-tune pre-trained BERT for question-answer task
---------------------------------------------------

With SQuAD dataset, the pre-trained BERT is used for question-answer task:

python qa_task.py --bert-model ns_bert.pt --epochs 30

The pre-trained BERT models and vocab are available:

* `bert_vocab.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.pt>`_
* `mlm_bert.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/mlm_bert.pt>`_
* `ns_bert.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/ns_bert.pt>`_

Structure of the example
========================

model.py
--------

This file defines the Transformer and MultiheadAttention models used for BERT. The embedding layer include PositionalEncoding and TokenTypeEncoding layers. MLMTask, NextSentenceTask, and QuestionAnswerTask are the models for the three tasks mentioned above.

data.py
-------

This file provides a few datasets required to train the BERT model and question-answer task. Please note that BookCorpus dataset is not available publicly.


mlm_task.py, ns_task.py, qa_task.py
-----------------------------------

Those three files define the train/valid/test process for the tasks.


metrics.py
----------

This file provides two metrics (F1 and exact score) for question-answer task


utils.py
--------

This file provides a few utils used by the three tasks.
79 changes: 79 additions & 0 deletions examples/BERT/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import glob
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could start incrementally porting this over into the experimental folder which will also help with cleanup. These datasets seem useful in general.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this. I would also help make this pull request shorter.

import logging
from torchtext.data.utils import get_tokenizer
import random


class LanguageModelingDataset(torch.utils.data.Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the one in experimental/datasets.

"""Defines a dataset for language modeling.
"""

def __init__(self, data, vocab):
"""Initiate language modeling dataset.
"""

super(LanguageModelingDataset, self).__init__()
self.data = data
self.vocab = vocab

def __getitem__(self, i):
return self.data[i]

def __len__(self):
return len(self.data)

def __iter__(self):
for x in self.data:
yield x

def get_vocab(self):
return self.vocab


###################################################################
# Set up dataset for book corpus
###################################################################
def BookCorpus(vocab, tokenizer=get_tokenizer("basic_english"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's worth moving this into experimental as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this. But the original data for BookCorpus comes from FAIR cluster.

data_select=('train', 'test', 'valid'), removed_tokens=[],
min_sentence_len=None):

if isinstance(data_select, str):
data_select = [data_select]
if not set(data_select).issubset(set(('train', 'test', 'valid'))):
raise TypeError('data_select is not supported!')

extracted_files = glob.glob('/datasets01/bookcorpus/021819/*/*.txt')
random.seed(1000)
random.shuffle(extracted_files)

num_files = len(extracted_files)
_path = {'train': extracted_files[:(num_files // 20 * 17)],
'test': extracted_files[(num_files // 20 * 17):(num_files // 20 * 18)],
'valid': extracted_files[(num_files // 20 * 18):]}

data = {}
for item in _path.keys():
data[item] = []
logging.info('Creating {} data'.format(item))
tokens = []
for txt_file in _path[item]:
with open(txt_file, 'r', encoding="utf8", errors='ignore') as f:
for line in f.readlines():
_tokens = tokenizer(line.strip())
if min_sentence_len:
if len(_tokens) >= min_sentence_len:
tokens.append([vocab.stoi[token] for token in _tokens])
else:
tokens += [vocab.stoi[token] for token in _tokens]
data[item] = tokens

for key in data_select:
if data[key] == []:
raise TypeError('Dataset {} is empty!'.format(key))
if min_sentence_len:
return tuple(LanguageModelingDataset(data[d], vocab)
for d in data_select)
else:
return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab)
for d in data_select)
72 changes: 72 additions & 0 deletions examples/BERT/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import collections
import re
import string


def compute_qa_exact(ans_pred_tokens_samples):

'''
Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens),
([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens),
...
([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)]
ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example']
Output: exact score of the samples
'''

def normalize_txt(text):
# lower case
text = text.lower()

# remove punc
exclude = set(string.punctuation)
text = "".join(ch for ch in text if ch not in exclude)

# remove articles
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = re.sub(regex, " ", text)

# white space fix
return " ".join(text.split())

exact_scores = []
for (ans_tokens, pred_tokens) in ans_pred_tokens_samples:
pred_str = " ".join(pred_tokens)
candidate_score = []
for item in ans_tokens:
ans_str = " ".join(item)
candidate_score.append(int(normalize_txt(ans_str) == normalize_txt(pred_str)))
exact_scores.append(max(candidate_score))
return 100.0 * sum(exact_scores) / len(exact_scores)


def compute_qa_f1(ans_pred_tokens_samples):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be turned into a generic f1 metric that we can throw into torchtext?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_f1 func can be landed in torchtext.


'''
Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens),
([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens),
...
([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)]
ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example']
Output: f1 score of the samples
'''
def sample_f1(ans_tokens, pred_tokens):
common = collections.Counter(ans_tokens) & collections.Counter(pred_tokens)
num_same = sum(common.values())
if len(ans_tokens) == 0 or len(pred_tokens) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(ans_tokens == pred_tokens)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_tokens)
recall = 1.0 * num_same / len(ans_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1

f1_scores = []
for (ans_tokens, pred_tokens) in ans_pred_tokens_samples:
candidate_score = []
for item in ans_tokens:
candidate_score.append(sample_f1(item, pred_tokens))
f1_scores.append(max(candidate_score))
return 100.0 * sum(f1_scores) / len(f1_scores)
Loading