Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update v0.7.0 release branch with master #886

Merged
merged 21 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/pytorch-probot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tracking_issue: 876
47 changes: 47 additions & 0 deletions benchmark/experimental_vectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import time

import torch
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.vectors import FastText as FastTextExperimental
from torchtext.vocab import FastText


def benchmark_experimental_vectors():
def _run_benchmark_lookup(tokens, vector):
t0 = time.monotonic()
for token in tokens:
vector[token]
print("Lookup time:", time.monotonic() - t0)

train, = AG_NEWS(data_select='train')
vocab = train.get_vocab()
tokens = []
for (label, text) in train:
for id in text.tolist():
tokens.append(vocab.itos[id])

# existing FastText construction
print("Existing FastText - Not Jit Mode")
t0 = time.monotonic()
fast_text = FastText()
print("Construction time:", time.monotonic() - t0)
_run_benchmark_lookup(tokens, fast_text)

# experimental FastText construction
print("FastText Experimental")
t0 = time.monotonic()
fast_text_experimental = FastTextExperimental(validate_file=False)
print("Construction time:", time.monotonic() - t0)

# not jit lookup
print("FastText Experimental - Not Jit Mode")
_run_benchmark_lookup(tokens, fast_text_experimental)

# jit lookup
print("FastText Experimental - Jit Mode")
jit_fast_text_experimental = torch.jit.script(fast_text_experimental)
_run_benchmark_lookup(tokens, jit_fast_text_experimental)


if __name__ == "__main__":
benchmark_experimental_vectors()
55 changes: 55 additions & 0 deletions benchmark/experimental_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import (Counter, OrderedDict)
import time

import torch
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.vocab import Vocab as VocabExperimental
from torchtext.vocab import Vocab


def benchmark_experimental_vocab():
def _run_benchmark_lookup(tokens, vocab):
t0 = time.monotonic()
for token in tokens:
vocab[token]
print("Lookup time:", time.monotonic() - t0)

train, = AG_NEWS(data_select='train')
vocab = train.get_vocab()
tokens = []
for (label, text) in train:
for id in text.tolist():
tokens.append(vocab.itos[id])

counter = Counter(tokens)
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)

# existing Vocab construction
print("Vocab")
t0 = time.monotonic()
v_existing = Vocab(counter)
print("Construction time:", time.monotonic() - t0)

# experimental Vocab construction
print("Vocab Experimental")
t0 = time.monotonic()
v_experimental = VocabExperimental(ordered_dict)
print("Construction time:", time.monotonic() - t0)
jit_v_experimental = torch.jit.script(v_experimental)

# existing Vocab not jit lookup
print("Vocab - Not Jit Mode")
_run_benchmark_lookup(tokens, v_existing)

# experimental Vocab not jit lookup
print("Vocab Experimental - Not Jit Mode")
_run_benchmark_lookup(tokens, v_experimental)

# experimental Vocab jit lookup
print("Vocab Experimental - Jit Mode")
_run_benchmark_lookup(tokens, jit_v_experimental)


if __name__ == "__main__":
benchmark_experimental_vocab()
2 changes: 1 addition & 1 deletion docs/source/experimental_vectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ torchtext.experimental.vectors
:hidden:`GloVe`
~~~~~~~~~~~~~~~~

.. autofunction:: GloVe
.. autofunction:: GloVe
152 changes: 152 additions & 0 deletions examples/BERT/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
BERT with torchtext
+++++++++

This example shows how to train a BERT model with PyTorch and torchtext only. Then, we fine-tune the pre-trained BERT for the question-answer task.


Generate pre-trained BERT
-------------------------

Train the BERT model with masked language modeling task and next-sentence task. Run the tasks on a local GPU or CPU:

python mlm_task.py
python ns_task.py

or run the tasks on a SLURM powered cluster with Distributed Data Parallel (DDP):

srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python mlm_task.py --parallel DDP --log-interval 600 --dataset BookCorpus

srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python ns_task.py --parallel DDP --bert-model mlm_bert.pt --dataset BookCorpus

The result ppl of mlm_task is 18.97899 for the test set.
The result loss of ns_task is 0.05446 for the test set.

Fine-tune pre-trained BERT for question-answer task
---------------------------------------------------

With SQuAD dataset, the pre-trained BERT is used for question-answer task:

python qa_task.py --bert-model ns_bert.pt --epochs 30

The pre-trained BERT models and vocab are available:

* `bert_vocab.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.pt>`_
* `mlm_bert.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/mlm_bert.pt>`_
* `ns_bert.pt <https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/ns_bert.pt>`_

An example train/valid/test printout with the pretrained BERT model in question-answer task:

| epoch 1 | 200/ 1055 batches | lr 5.00000 | ms/batch 746.33 | loss 3.70 | ppl 40.45
| epoch 1 | 400/ 1055 batches | lr 5.00000 | ms/batch 746.78 | loss 3.06 | ppl 21.25
| epoch 1 | 600/ 1055 batches | lr 5.00000 | ms/batch 746.83 | loss 2.84 | ppl 17.15
| epoch 1 | 800/ 1055 batches | lr 5.00000 | ms/batch 746.55 | loss 2.69 | ppl 14.73
| epoch 1 | 1000/ 1055 batches | lr 5.00000 | ms/batch 745.48 | loss 2.55 | ppl 12.85
-----------------------------------------------------------------------------------------
| end of epoch 1 | time: 821.25s | valid loss 2.33 | exact 40.052% | f1 52.595%
-----------------------------------------------------------------------------------------
| epoch 2 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.17 | loss 2.33 | ppl 10.25
| epoch 2 | 400/ 1055 batches | lr 5.00000 | ms/batch 745.52 | loss 2.28 | ppl 9.75
| epoch 2 | 600/ 1055 batches | lr 5.00000 | ms/batch 745.50 | loss 2.24 | ppl 9.37
| epoch 2 | 800/ 1055 batches | lr 5.00000 | ms/batch 745.10 | loss 2.22 | ppl 9.18
| epoch 2 | 1000/ 1055 batches | lr 5.00000 | ms/batch 744.61 | loss 2.16 | ppl 8.66
-----------------------------------------------------------------------------------------
| end of epoch 2 | time: 820.75s | valid loss 2.12 | exact 44.632% | f1 57.965%
-----------------------------------------------------------------------------------------
| epoch 3 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.88 | loss 2.00 | ppl 7.41
| epoch 3 | 400/ 1055 batches | lr 5.00000 | ms/batch 746.46 | loss 1.99 | ppl 7.29
| epoch 3 | 600/ 1055 batches | lr 5.00000 | ms/batch 745.35 | loss 1.99 | ppl 7.30
| epoch 3 | 800/ 1055 batches | lr 5.00000 | ms/batch 746.03 | loss 1.98 | ppl 7.27
| epoch 3 | 1000/ 1055 batches | lr 5.00000 | ms/batch 746.01 | loss 1.98 | ppl 7.26
-----------------------------------------------------------------------------------------
| end of epoch 3 | time: 821.98s | valid loss 1.96 | exact 48.001% | f1 61.036%
-----------------------------------------------------------------------------------------
| epoch 4 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.72 | loss 1.82 | ppl 6.19
| epoch 4 | 400/ 1055 batches | lr 5.00000 | ms/batch 745.86 | loss 1.84 | ppl 6.28
| epoch 4 | 600/ 1055 batches | lr 5.00000 | ms/batch 745.63 | loss 1.85 | ppl 6.34
| epoch 4 | 800/ 1055 batches | lr 5.00000 | ms/batch 745.59 | loss 1.82 | ppl 6.20
| epoch 4 | 1000/ 1055 batches | lr 5.00000 | ms/batch 745.55 | loss 1.83 | ppl 6.21
-----------------------------------------------------------------------------------------
| end of epoch 4 | time: 821.10s | valid loss 1.95 | exact 49.149% | f1 62.040%
-----------------------------------------------------------------------------------------
| epoch 5 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.40 | loss 1.66 | ppl 5.24
| epoch 5 | 400/ 1055 batches | lr 5.00000 | ms/batch 756.09 | loss 1.69 | ppl 5.44
| epoch 5 | 600/ 1055 batches | lr 5.00000 | ms/batch 769.19 | loss 1.70 | ppl 5.46
| epoch 5 | 800/ 1055 batches | lr 5.00000 | ms/batch 764.96 | loss 1.72 | ppl 5.58
| epoch 5 | 1000/ 1055 batches | lr 5.00000 | ms/batch 773.25 | loss 1.70 | ppl 5.49
-----------------------------------------------------------------------------------------
| end of epoch 5 | time: 844.20s | valid loss 1.99 | exact 49.509% | f1 61.994%
-----------------------------------------------------------------------------------------
| epoch 6 | 200/ 1055 batches | lr 0.50000 | ms/batch 765.25 | loss 1.50 | ppl 4.49
| epoch 6 | 400/ 1055 batches | lr 0.50000 | ms/batch 749.64 | loss 1.45 | ppl 4.25
| epoch 6 | 600/ 1055 batches | lr 0.50000 | ms/batch 768.16 | loss 1.40 | ppl 4.06
| epoch 6 | 800/ 1055 batches | lr 0.50000 | ms/batch 745.69 | loss 1.43 | ppl 4.18
| epoch 6 | 1000/ 1055 batches | lr 0.50000 | ms/batch 744.90 | loss 1.40 | ppl 4.07
-----------------------------------------------------------------------------------------
| end of epoch 6 | time: 829.55s | valid loss 1.97 | exact 51.182% | f1 63.437%
-----------------------------------------------------------------------------------------
| epoch 7 | 200/ 1055 batches | lr 0.50000 | ms/batch 747.73 | loss 1.36 | ppl 3.89
| epoch 7 | 400/ 1055 batches | lr 0.50000 | ms/batch 744.50 | loss 1.37 | ppl 3.92
| epoch 7 | 600/ 1055 batches | lr 0.50000 | ms/batch 744.20 | loss 1.35 | ppl 3.86
| epoch 7 | 800/ 1055 batches | lr 0.50000 | ms/batch 743.85 | loss 1.36 | ppl 3.89
| epoch 7 | 1000/ 1055 batches | lr 0.50000 | ms/batch 744.01 | loss 1.34 | ppl 3.83
-----------------------------------------------------------------------------------------
| end of epoch 7 | time: 820.02s | valid loss 2.01 | exact 51.507% | f1 63.885%
-----------------------------------------------------------------------------------------
| epoch 8 | 200/ 1055 batches | lr 0.50000 | ms/batch 747.40 | loss 1.31 | ppl 3.72
| epoch 8 | 400/ 1055 batches | lr 0.50000 | ms/batch 744.33 | loss 1.30 | ppl 3.68
| epoch 8 | 600/ 1055 batches | lr 0.50000 | ms/batch 745.76 | loss 1.31 | ppl 3.69
| epoch 8 | 800/ 1055 batches | lr 0.50000 | ms/batch 745.04 | loss 1.31 | ppl 3.69
| epoch 8 | 1000/ 1055 batches | lr 0.50000 | ms/batch 745.13 | loss 1.31 | ppl 3.72
-----------------------------------------------------------------------------------------
| end of epoch 8 | time: 820.40s | valid loss 2.02 | exact 51.260% | f1 63.762%
-----------------------------------------------------------------------------------------
| epoch 9 | 200/ 1055 batches | lr 0.05000 | ms/batch 748.36 | loss 1.26 | ppl 3.54
| epoch 9 | 400/ 1055 batches | lr 0.05000 | ms/batch 744.55 | loss 1.26 | ppl 3.52
| epoch 9 | 600/ 1055 batches | lr 0.05000 | ms/batch 745.46 | loss 1.23 | ppl 3.44
| epoch 9 | 800/ 1055 batches | lr 0.05000 | ms/batch 745.23 | loss 1.26 | ppl 3.52
| epoch 9 | 1000/ 1055 batches | lr 0.05000 | ms/batch 744.69 | loss 1.24 | ppl 3.47
-----------------------------------------------------------------------------------------
| end of epoch 9 | time: 820.41s | valid loss 2.02 | exact 51.578% | f1 63.704%
-----------------------------------------------------------------------------------------
| epoch 10 | 200/ 1055 batches | lr 0.00500 | ms/batch 749.25 | loss 1.25 | ppl 3.50
| epoch 10 | 400/ 1055 batches | lr 0.00500 | ms/batch 745.81 | loss 1.24 | ppl 3.47
| epoch 10 | 600/ 1055 batches | lr 0.00500 | ms/batch 744.89 | loss 1.26 | ppl 3.51
| epoch 10 | 800/ 1055 batches | lr 0.00500 | ms/batch 746.02 | loss 1.23 | ppl 3.42
| epoch 10 | 1000/ 1055 batches | lr 0.00500 | ms/batch 746.61 | loss 1.25 | ppl 3.50
-----------------------------------------------------------------------------------------
| end of epoch 10 | time: 821.85s | valid loss 2.05 | exact 51.648% | f1 63.811%
-----------------------------------------------------------------------------------------
=========================================================================================
| End of training | test loss 2.05 | exact 51.337% | f1 63.645%
=========================================================================================

Structure of the example
========================

model.py
--------

This file defines the Transformer and MultiheadAttention models used for BERT. The embedding layer include PositionalEncoding and TokenTypeEncoding layers. MLMTask, NextSentenceTask, and QuestionAnswerTask are the models for the three tasks mentioned above.

data.py
-------

This file provides a few datasets required to train the BERT model and question-answer task. Please note that BookCorpus dataset is not available publicly.


mlm_task.py, ns_task.py, qa_task.py
-----------------------------------

Those three files define the train/valid/test process for the tasks.


metrics.py
----------

This file provides two metrics (F1 and exact score) for question-answer task


utils.py
--------

This file provides a few utils used by the three tasks.
54 changes: 54 additions & 0 deletions examples/BERT/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import glob
import torch
import logging
from torchtext.data.utils import get_tokenizer
import random
from torchtext.experimental.datasets import LanguageModelingDataset


###################################################################
# Set up dataset for book corpus
###################################################################
def BookCorpus(vocab, tokenizer=get_tokenizer("basic_english"),
data_select=('train', 'test', 'valid'), removed_tokens=[],
min_sentence_len=None):

if isinstance(data_select, str):
data_select = [data_select]
if not set(data_select).issubset(set(('train', 'test', 'valid'))):
raise TypeError('data_select is not supported!')

extracted_files = glob.glob('/datasets01/bookcorpus/021819/*/*.txt')
random.seed(1000)
random.shuffle(extracted_files)

num_files = len(extracted_files)
_path = {'train': extracted_files[:(num_files // 20 * 17)],
'test': extracted_files[(num_files // 20 * 17):(num_files // 20 * 18)],
'valid': extracted_files[(num_files // 20 * 18):]}

data = {}
for item in _path.keys():
data[item] = []
logging.info('Creating {} data'.format(item))
tokens = []
for txt_file in _path[item]:
with open(txt_file, 'r', encoding="utf8", errors='ignore') as f:
for line in f.readlines():
_tokens = tokenizer(line.strip())
if min_sentence_len:
if len(_tokens) >= min_sentence_len:
tokens.append([vocab.stoi[token] for token in _tokens])
else:
tokens += [vocab.stoi[token] for token in _tokens]
data[item] = tokens

for key in data_select:
if data[key] == []:
raise TypeError('Dataset {} is empty!'.format(key))
if min_sentence_len:
return tuple(LanguageModelingDataset(data[d], vocab, lambda x: x, False)
for d in data_select)
else:
return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab, lambda x: x, False)
for d in data_select)
72 changes: 72 additions & 0 deletions examples/BERT/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import collections
import re
import string


def compute_qa_exact(ans_pred_tokens_samples):

'''
Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens),
([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens),
...
([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)]
ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example']
Output: exact score of the samples
'''

def normalize_txt(text):
# lower case
text = text.lower()

# remove punc
exclude = set(string.punctuation)
text = "".join(ch for ch in text if ch not in exclude)

# remove articles
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = re.sub(regex, " ", text)

# white space fix
return " ".join(text.split())

exact_scores = []
for (ans_tokens, pred_tokens) in ans_pred_tokens_samples:
pred_str = " ".join(pred_tokens)
candidate_score = []
for item in ans_tokens:
ans_str = " ".join(item)
candidate_score.append(int(normalize_txt(ans_str) == normalize_txt(pred_str)))
exact_scores.append(max(candidate_score))
return 100.0 * sum(exact_scores) / len(exact_scores)


def compute_qa_f1(ans_pred_tokens_samples):

'''
Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens),
([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens),
...
([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)]
ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example']
Output: f1 score of the samples
'''
def sample_f1(ans_tokens, pred_tokens):
common = collections.Counter(ans_tokens) & collections.Counter(pred_tokens)
num_same = sum(common.values())
if len(ans_tokens) == 0 or len(pred_tokens) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(ans_tokens == pred_tokens)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_tokens)
recall = 1.0 * num_same / len(ans_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1

f1_scores = []
for (ans_tokens, pred_tokens) in ans_pred_tokens_samples:
candidate_score = []
for item in ans_tokens:
candidate_score.append(sample_f1(item, pred_tokens))
f1_scores.append(max(candidate_score))
return 100.0 * sum(f1_scores) / len(f1_scores)
Loading