Skip to content

Commit

Permalink
Torcharrow based training using RoBERTa model and SST2 classification…
Browse files Browse the repository at this point in the history
… dataset (#1808)
  • Loading branch information
parmeet authored Jul 21, 2022
1 parent ed69973 commit 4fb43aa
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
36 changes: 36 additions & 0 deletions examples/torcharrow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
## Description

This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text
pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on
top of TorchArrow DataFrame.

## Installation and Usage

The example depends on TorchArrow and TorchData.

#### TorchArrow

Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of
the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used
in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence
make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on
nightly releases)

```
USE_TORCH=1 python setup.py install
```

#### TorchData

To install TorchData follow instructions at https://github.com/pytorch/data#installation

#### Usage

To run example from command line run following command:

```bash
python roberta_sst2_training_with_torcharrow.py \
--batch-size 16 \
--num-epochs 1 \
--learning-rate 1e-5
```
163 changes: 163 additions & 0 deletions examples/torcharrow/roberta_sst2_training_with_torcharrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import json
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torcharrow as ta
import torcharrow._torcharrow as _ta
import torcharrow.pytorch as tap
import torchtext.functional as F
import torchtext.transforms as T
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torcharrow import functional as ta_F
from torchtext.datasets import SST2
from torchtext.models import RobertaClassificationHead, ROBERTA_BASE_ENCODER
from torchtext.utils import get_asset_local_path

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


def init_ta_gpt2bpe_encoder():
encoder_json_path = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json"
vocab_bpe_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe"

encoder_json_path = get_asset_local_path(encoder_json_path)
vocab_bpe_path = get_asset_local_path(vocab_bpe_path)
_seperator = "\u0001"

# load bpe encoder and bpe decoder
with open(encoder_json_path, "r", encoding="utf-8") as f:
bpe_encoder = json.load(f)
# load bpe vocab
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_vocab = f.read()
bpe_merge_ranks = {
_seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1])
}
# Caching is enabled in Eager mode
bpe = _ta.GPT2BPEEncoder(bpe_encoder, bpe_merge_ranks, _seperator, T.bytes_to_unicode(), True)
return bpe


def init_ta_gpt2bpe_vocab():
vocab_path = "https://download.pytorch.org/models/text/roberta.vocab.pt"
vocab_path = get_asset_local_path(vocab_path)
vocab = torch.load(vocab_path)
ta_vocab = _ta.Vocab(vocab.get_itos(), vocab.get_default_index())
return ta_vocab


class RobertaTransformDataFrameNativeOps(Module):
def __init__(self) -> None:
super().__init__()
# Tokenizer to split input text into tokens
self.tokenizer = init_ta_gpt2bpe_encoder()

# vocabulary converting tokens to IDs
self.vocab = init_ta_gpt2bpe_vocab()

# Add BOS token to the beginning of sentence
self.add_bos = T.AddToken(token=0, begin=True)

# Add EOS token to the end of sentence
self.add_eos = T.AddToken(token=2, begin=False)

def forward(self, input: ta.DataFrame) -> ta.DataFrame:
input["tokens"] = ta_F.bpe_tokenize(self.tokenizer, input["text"])
input["tokens"] = input["tokens"].list.slice(stop=254)
input["tokens"] = ta_F.lookup_indices(self.vocab, input["tokens"])
input["tokens"] = ta_F.add_tokens(input["tokens"], [0], begin=True)
input["tokens"] = ta_F.add_tokens(input["tokens"], [2], begin=False)
return input


def get_dataloader(split, args):
# Instantiate transform
transform = RobertaTransformDataFrameNativeOps()

# Create SST2 datapipe and apply pre-processing
train_dp = SST2(split=split)

# convert to DataFrame of size batches
train_dp = train_dp.dataframe(columns=["text", "labels"], dataframe_size=args.batch_size)

# Apply transformation on DataFrame
train_dp = train_dp.map(transform)

# (optional) Remove un-required columns
train_dp = train_dp.map(lambda x: x.drop(["text"]))

# convert DataFrame to tensor (This will yeild named tuple)
train_dp = train_dp.map(lambda x: x.to_tensor({"tokens": tap.PadSequence(padding_value=1)}))

# create DataLoader
dl = DataLoader(train_dp, batch_size=None)

return dl


classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head)
model.to(DEVICE)


def train_step(input, target, optim, criteria):
output = model(input)
loss = criteria(output, target)
optim.zero_grad()
loss.backward()
optim.step()


def eval_step(input, target, criteria):
output = model(input)
loss = criteria(output, target).item()
return float(loss), (output.argmax(1) == target).type(torch.float).sum().item()


def evaluate(dataloader):
model.eval()
total_loss = 0
correct_predictions = 0
total_predictions = 0
counter = 0
with torch.no_grad():
for batch in dataloader:
input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE)
target = torch.tensor(batch["target"]).to(DEVICE)
loss, predictions = eval_step(input, target)
total_loss += loss
correct_predictions += predictions
total_predictions += len(target)
counter += 1

return total_loss / counter, correct_predictions / total_predictions


def main(args):
print(args)
train_dl = get_dataloader(split="train", args=args)
dev_dl = get_dataloader(split="dev", args=args)

learning_rate = args.learning_rate
optim = AdamW(model.parameters(), lr=learning_rate)
criteria = nn.CrossEntropyLoss()

for e in range(args.num_epochs):
for batch in train_dl:
input = batch.tokens.to(DEVICE)
target = batch.labels.to(DEVICE)
train_step(input, target, optim, criteria)

loss, accuracy = evaluate(dev_dl, criteria)
print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy))


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--batch-size", default=16, type=int, help="Input batch size used during training")
parser.add_argument("--num-epochs", default=1, type=int, help="Number of epochs to run training")
parser.add_argument("--learning-rate", default=1e-5, type=float, help="Learning rate used for training")
main(parser.parse_args())

0 comments on commit 4fb43aa

Please sign in to comment.