-
Notifications
You must be signed in to change notification settings - Fork 812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BERT example in torchtext #767
Conversation
This CC adds `torchscript` extension `_torchtext.so`, which contains simple interface to `SentencePiece`. - SentencePiece `v0.1.86` is used. - `libsentencepiece.a` is built right before `_torchtext.so` is compiled. The logic for triggering this build from `setuptools` can be found under `build_tools/setup_helpers`. - `_torchtext.so` provides interface to train a SentencePiece model and load a model from file. Breaking change: Previously `torchtext.data.functional.load_sp_model` returned `sentencepiece.SentencePieceProcessor` object, which supported the following methods, in addition to `__len__` and `__getitem__` ``` $ grep '$self->' third_party/sentencepiece/python/sentencepiece.i return $self->Load(filename); return $self->LoadFromSerializedProto(filename); return $self->SetEncodeExtraOptions(extra_option); return $self->SetDecodeExtraOptions(extra_option); return $self->SetVocabulary(valid_vocab); return $self->ResetVocabulary(); return $self->LoadVocabulary(filename, threshold); return $self->EncodeAsPieces(input); return $self->EncodeAsIds(input); return $self->NBestEncodeAsPieces(input, nbest_size); return $self->NBestEncodeAsIds(input, nbest_size); return $self->SampleEncodeAsPieces(input, nbest_size, alpha); return $self->SampleEncodeAsIds(input, nbest_size, alpha); return $self->DecodePieces(input); return $self->DecodeIds(input); return $self->EncodeAsSerializedProto(input); return $self->SampleEncodeAsSerializedProto(input, nbest_size, alpha); return $self->NBestEncodeAsSerializedProto(input, nbest_size); return $self->DecodePiecesAsSerializedProto(pieces); return $self->DecodeIdsAsSerializedProto(ids); return $self->GetPieceSize(); return $self->PieceToId(piece); return $self->IdToPiece(id); return $self->GetScore(id); return $self->IsUnused(id); return $self->IsControl(id); return $self->IsUnused(id); return $self->GetPieceSize(); return $self->PieceToId(key); ``` The new C++ Extension provides the following methods ``` Encode(input) EncodeAsIds(input) EncodeAsPieces(input) ```
@@ -0,0 +1,603 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could start incrementally porting this over into the experimental folder which will also help with cleanup. These datasets seem useful in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this. I would also help make this pull request shorter.
return 100.0 * sum(exact_scores) / len(exact_scores) | ||
|
||
|
||
def compute_qa_f1(ans_pred_tokens_samples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be turned into a generic f1 metric that we can throw into torchtext?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample_f1
func can be landed in torchtext.
elif args.dataset == 'BookCorpus': | ||
train_dataset, test_dataset, valid_dataset = BookCorpus(vocab) | ||
|
||
train_data = process_raw_data(train_dataset.data, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move this until the end of the function into a separate def train(...)
function.
start_time = time.time() | ||
|
||
|
||
def run_main(args, rank=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like some of this code used here could be shared with the other run_main functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. However, since those three tasks are quite different, the run_main
func is set up explicitly here, instead of passing a lot of arguments.
Looks pretty good! I think the next steps could include some code deduplication. Looks like you could factor out some more stuff already and also use more stuff from torchtext. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if this were divided into at least 3 PRs: one for each of the two tasks, and one for some of the abstractions that could land straight into torchtext (e.g. datasets).
@@ -0,0 +1,603 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this. I would also help make this pull request shorter.
Thanks for the feedback. Yes, there are currently two PRs to land the datasets into torchtext and a separate PR to merge the model. Once done, this PR will be very task-related work. |
|
||
def setup(rank, world_size, seed): | ||
os.environ['MASTER_ADDR'] = 'localhost' | ||
os.environ['MASTER_PORT'] = '12355' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is specific to the particular distributed implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's set up to SLURM.
Codecov Report
@@ Coverage Diff @@
## master #767 +/- ##
=======================================
Coverage 77.44% 77.44%
=======================================
Files 44 44
Lines 3055 3055
=======================================
Hits 2366 2366
Misses 689 689 Continue to review full report at Codecov.
|
c1b12f7
to
b61e3d5
Compare
e08705b
to
441edff
Compare
examples/BERT/data.py
Outdated
import random | ||
|
||
|
||
class LanguageModelingDataset(torch.utils.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to the one in experimental/datasets
.
################################################################### | ||
# Set up dataset for book corpus | ||
################################################################### | ||
def BookCorpus(vocab, tokenizer=get_tokenizer("basic_english"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's worth moving this into experimental as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about this. But the original data for BookCorpus comes from FAIR cluster.
return processed_data | ||
|
||
|
||
def collate_batch(batch, args, cls_id, sep_id, pad_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have "collate_batch" earlier as well. Is there some combination possible here with a utils file?
Train a BERT model with PyTorch and torchtext, including masked language modeling and next sentence tasks. Fine-tune the BERT model for question-answer task.
There are a few things to do:
torchtex.experimental.datasets
, including SQuAD. Some ongoing PRs Question answer datasets: SQuAD1 and SQuAD2 #773, experimental.dataset WikiText2, WikiText103, PennTreeBank, WMTNewsCrawl #774