-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Torcharrow based training using RoBERTa model and SST2 classification…
… dataset (#1808)
- Loading branch information
Showing
2 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
## Description | ||
|
||
This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text | ||
pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on | ||
top of TorchArrow DataFrame. | ||
|
||
## Installation and Usage | ||
|
||
The example depends on TorchArrow and TorchData. | ||
|
||
#### TorchArrow | ||
|
||
Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of | ||
the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used | ||
in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence | ||
make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on | ||
nightly releases) | ||
|
||
``` | ||
USE_TORCH=1 python setup.py install | ||
``` | ||
|
||
#### TorchData | ||
|
||
To install TorchData follow instructions at https://github.com/pytorch/data#installation | ||
|
||
#### Usage | ||
|
||
To run example from command line run following command: | ||
|
||
```bash | ||
python roberta_sst2_training_with_torcharrow.py \ | ||
--batch-size 16 \ | ||
--num-epochs 1 \ | ||
--learning-rate 1e-5 | ||
``` |
163 changes: 163 additions & 0 deletions
163
examples/torcharrow/roberta_sst2_training_with_torcharrow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import json | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torcharrow as ta | ||
import torcharrow._torcharrow as _ta | ||
import torcharrow.pytorch as tap | ||
import torchtext.functional as F | ||
import torchtext.transforms as T | ||
from torch.nn import Module | ||
from torch.optim import AdamW | ||
from torch.utils.data import DataLoader | ||
from torcharrow import functional as ta_F | ||
from torchtext.datasets import SST2 | ||
from torchtext.models import RobertaClassificationHead, ROBERTA_BASE_ENCODER | ||
from torchtext.utils import get_asset_local_path | ||
|
||
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
|
||
def init_ta_gpt2bpe_encoder(): | ||
encoder_json_path = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json" | ||
vocab_bpe_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe" | ||
|
||
encoder_json_path = get_asset_local_path(encoder_json_path) | ||
vocab_bpe_path = get_asset_local_path(vocab_bpe_path) | ||
_seperator = "\u0001" | ||
|
||
# load bpe encoder and bpe decoder | ||
with open(encoder_json_path, "r", encoding="utf-8") as f: | ||
bpe_encoder = json.load(f) | ||
# load bpe vocab | ||
with open(vocab_bpe_path, "r", encoding="utf-8") as f: | ||
bpe_vocab = f.read() | ||
bpe_merge_ranks = { | ||
_seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1]) | ||
} | ||
# Caching is enabled in Eager mode | ||
bpe = _ta.GPT2BPEEncoder(bpe_encoder, bpe_merge_ranks, _seperator, T.bytes_to_unicode(), True) | ||
return bpe | ||
|
||
|
||
def init_ta_gpt2bpe_vocab(): | ||
vocab_path = "https://download.pytorch.org/models/text/roberta.vocab.pt" | ||
vocab_path = get_asset_local_path(vocab_path) | ||
vocab = torch.load(vocab_path) | ||
ta_vocab = _ta.Vocab(vocab.get_itos(), vocab.get_default_index()) | ||
return ta_vocab | ||
|
||
|
||
class RobertaTransformDataFrameNativeOps(Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
# Tokenizer to split input text into tokens | ||
self.tokenizer = init_ta_gpt2bpe_encoder() | ||
|
||
# vocabulary converting tokens to IDs | ||
self.vocab = init_ta_gpt2bpe_vocab() | ||
|
||
# Add BOS token to the beginning of sentence | ||
self.add_bos = T.AddToken(token=0, begin=True) | ||
|
||
# Add EOS token to the end of sentence | ||
self.add_eos = T.AddToken(token=2, begin=False) | ||
|
||
def forward(self, input: ta.DataFrame) -> ta.DataFrame: | ||
input["tokens"] = ta_F.bpe_tokenize(self.tokenizer, input["text"]) | ||
input["tokens"] = input["tokens"].list.slice(stop=254) | ||
input["tokens"] = ta_F.lookup_indices(self.vocab, input["tokens"]) | ||
input["tokens"] = ta_F.add_tokens(input["tokens"], [0], begin=True) | ||
input["tokens"] = ta_F.add_tokens(input["tokens"], [2], begin=False) | ||
return input | ||
|
||
|
||
def get_dataloader(split, args): | ||
# Instantiate transform | ||
transform = RobertaTransformDataFrameNativeOps() | ||
|
||
# Create SST2 datapipe and apply pre-processing | ||
train_dp = SST2(split=split) | ||
|
||
# convert to DataFrame of size batches | ||
train_dp = train_dp.dataframe(columns=["text", "labels"], dataframe_size=args.batch_size) | ||
|
||
# Apply transformation on DataFrame | ||
train_dp = train_dp.map(transform) | ||
|
||
# (optional) Remove un-required columns | ||
train_dp = train_dp.map(lambda x: x.drop(["text"])) | ||
|
||
# convert DataFrame to tensor (This will yeild named tuple) | ||
train_dp = train_dp.map(lambda x: x.to_tensor({"tokens": tap.PadSequence(padding_value=1)})) | ||
|
||
# create DataLoader | ||
dl = DataLoader(train_dp, batch_size=None) | ||
|
||
return dl | ||
|
||
|
||
classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) | ||
model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head) | ||
model.to(DEVICE) | ||
|
||
|
||
def train_step(input, target, optim, criteria): | ||
output = model(input) | ||
loss = criteria(output, target) | ||
optim.zero_grad() | ||
loss.backward() | ||
optim.step() | ||
|
||
|
||
def eval_step(input, target, criteria): | ||
output = model(input) | ||
loss = criteria(output, target).item() | ||
return float(loss), (output.argmax(1) == target).type(torch.float).sum().item() | ||
|
||
|
||
def evaluate(dataloader): | ||
model.eval() | ||
total_loss = 0 | ||
correct_predictions = 0 | ||
total_predictions = 0 | ||
counter = 0 | ||
with torch.no_grad(): | ||
for batch in dataloader: | ||
input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE) | ||
target = torch.tensor(batch["target"]).to(DEVICE) | ||
loss, predictions = eval_step(input, target) | ||
total_loss += loss | ||
correct_predictions += predictions | ||
total_predictions += len(target) | ||
counter += 1 | ||
|
||
return total_loss / counter, correct_predictions / total_predictions | ||
|
||
|
||
def main(args): | ||
print(args) | ||
train_dl = get_dataloader(split="train", args=args) | ||
dev_dl = get_dataloader(split="dev", args=args) | ||
|
||
learning_rate = args.learning_rate | ||
optim = AdamW(model.parameters(), lr=learning_rate) | ||
criteria = nn.CrossEntropyLoss() | ||
|
||
for e in range(args.num_epochs): | ||
for batch in train_dl: | ||
input = batch.tokens.to(DEVICE) | ||
target = batch.labels.to(DEVICE) | ||
train_step(input, target, optim, criteria) | ||
|
||
loss, accuracy = evaluate(dev_dl, criteria) | ||
print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--batch-size", default=16, type=int, help="Input batch size used during training") | ||
parser.add_argument("--num-epochs", default=1, type=int, help="Number of epochs to run training") | ||
parser.add_argument("--learning-rate", default=1e-5, type=float, help="Learning rate used for training") | ||
main(parser.parse_args()) |