From 59a4b99f6ec5c20fef9bcc46e9a47f50a4b8ec44 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Wed, 22 Jun 2022 17:29:42 -0400 Subject: [PATCH 01/10] Add initial code for TA based training --- .../roberta_sst2_training_with_torcharrow.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 examples/torcharrow/roberta_sst2_training_with_torcharrow.py diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py new file mode 100644 index 0000000000..9de4de2cdb --- /dev/null +++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py @@ -0,0 +1,163 @@ +import json +from argparse import ArgumentParser + +import torch +import torch.nn as nn +import torcharrow as ta +import torcharrow._torcharrow as _ta +import torcharrow.pytorch as tap +import torchtext.functional as F +import torchtext.transforms as T +from torch.nn import Module +from torch.optim import AdamW +from torch.utils.data import DataLoader +from torcharrow import functional as ta_F +from torchtext.datasets import SST2 +from torchtext.models import RobertaClassificationHead, ROBERTA_BASE_ENCODER +from torchtext.utils import get_asset_local_path + +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +def init_ta_gpt2bpe_encoder(): + encoder_json_path = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json" + vocab_bpe_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe" + + encoder_json_path = get_asset_local_path(encoder_json_path) + vocab_bpe_path = get_asset_local_path(vocab_bpe_path) + _seperator = "\u0001" + + # load bpe encoder and bpe decoder + with open(encoder_json_path, "r", encoding="utf-8") as f: + bpe_encoder = json.load(f) + # load bpe vocab + with open(vocab_bpe_path, "r", encoding="utf-8") as f: + bpe_vocab = f.read() + bpe_merge_ranks = { + _seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1]) + } + # Caching is enabled in Eager mode + bpe = _ta.GPT2BPEEncoder(bpe_encoder, bpe_merge_ranks, _seperator, T.bytes_to_unicode(), True) + return bpe + + +def init_ta_gpt2bpe_vocab(): + vocab_path = "https://download.pytorch.org/models/text/roberta.vocab.pt" + vocab_path = get_asset_local_path(vocab_path) + vocab = torch.load(vocab_path) + ta_vocab = _ta.Vocab(vocab.get_itos(), vocab.get_default_index()) + return ta_vocab + + +class RobertaTransformDataFrameNativeOps(Module): + def __init__(self) -> None: + super().__init__() + # Tokenizer to split input text into tokens + self.tokenizer = init_ta_gpt2bpe_encoder() + + # vocabulary converting tokens to IDs + self.vocab = init_ta_gpt2bpe_vocab() + + # Add BOS token to the beginning of sentence + self.add_bos = T.AddToken(token=0, begin=True) + + # Add EOS token to the end of sentence + self.add_eos = T.AddToken(token=2, begin=False) + + def forward(self, input: ta.DataFrame) -> ta.DataFrame: + input["tokens"] = ta_F.bpe_tokenize(self.tokenizer, input["text"]) + input["tokens"] = input["tokens"].list.slice(stop=254) + input["tokens"] = ta_F.lookup_indices(self.vocab, input["tokens"]) + input["tokens"] = input["tokens"].transform(self.add_bos, format="python") + input["tokens"] = input["tokens"].transform(self.add_eos, format="python") + return input + + +def get_dataloader(split, args): + # Instantiate transform + transform = RobertaTransformDataFrameNativeOps() + + # Create SST2 datapipe and apply pre-processing + train_dp = SST2(split=split) + + # convert to DataFrame of size batches + train_dp = train_dp.dataframe(columns=["text", "labels"], dataframe_size=args.batch_size) + + # Apply transformation on DataFrame + train_dp = train_dp.map(transform) + + # (optional) Remove un-required columns + train_dp = train_dp.map(lambda x: x.drop(["text"])) + + # convert DataFrame to tensor (This will yeild named tuple) + train_dp = train_dp.map(lambda x: x.to_tensor({"tokens": tap.PadSequence(padding_value=1)})) + + # create DataLoader + dl = DataLoader(train_dp, batch_size=None) + + return dl + + +classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) +model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head) +model.to(DEVICE) + +learning_rate = 1e-5 +optim = AdamW(model.parameters(), lr=learning_rate) +criteria = nn.CrossEntropyLoss() + + +def train_step(input, target): + output = model(input) + loss = criteria(output, target) + print(float(loss)) + optim.zero_grad() + loss.backward() + optim.step() + + +def eval_step(input, target): + output = model(input) + loss = criteria(output, target).item() + return float(loss), (output.argmax(1) == target).type(torch.float).sum().item() + + +def evaluate(dataloader): + model.eval() + total_loss = 0 + correct_predictions = 0 + total_predictions = 0 + counter = 0 + with torch.no_grad(): + for batch in dataloader: + input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE) + target = torch.tensor(batch["target"]).to(DEVICE) + loss, predictions = eval_step(input, target) + total_loss += loss + correct_predictions += predictions + total_predictions += len(target) + counter += 1 + + return total_loss / counter, correct_predictions / total_predictions + + +def main(args): + + train_dl = get_dataloader(split="train", args=args) + dev_dl = get_dataloader(split="dev", args=args) + + for e in range(args.num_epochs): + for batch in train_dl: + input = batch.tokens.to(DEVICE) + target = batch.labels.to(DEVICE) + train_step(input, target) + + loss, accuracy = evaluate(dev_dl) + print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy)) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--batch-size", default=16, type=int) + parser.add_argument("--num-epochs", default=1, type=int) + main(parser.parse_args()) From 5e41b42e26f10935ddec95ac93d9ede4ed644be7 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Wed, 20 Jul 2022 09:34:50 -0400 Subject: [PATCH 02/10] use native ops for ading tokens --- examples/torcharrow/roberta_sst2_training_with_torcharrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py index 9de4de2cdb..d4e767ef44 100644 --- a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py +++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py @@ -68,8 +68,8 @@ def forward(self, input: ta.DataFrame) -> ta.DataFrame: input["tokens"] = ta_F.bpe_tokenize(self.tokenizer, input["text"]) input["tokens"] = input["tokens"].list.slice(stop=254) input["tokens"] = ta_F.lookup_indices(self.vocab, input["tokens"]) - input["tokens"] = input["tokens"].transform(self.add_bos, format="python") - input["tokens"] = input["tokens"].transform(self.add_eos, format="python") + input["tokens"] = ta_F.add_tokens(input["tokens"], [0], begin=True) + input["tokens"] = ta_F.add_tokens(input["tokens"], [2], begin=False) return input From b750378cc2028b0183507d0876bdba5c96a401c5 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Wed, 20 Jul 2022 10:11:20 -0400 Subject: [PATCH 03/10] remove print --- examples/torcharrow/roberta_sst2_training_with_torcharrow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py index d4e767ef44..d48b9853dd 100644 --- a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py +++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py @@ -110,7 +110,6 @@ def get_dataloader(split, args): def train_step(input, target): output = model(input) loss = criteria(output, target) - print(float(loss)) optim.zero_grad() loss.backward() optim.step() From 2abfe00288c09920696df05c4ddc2c5c593df0e1 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 09:55:44 -0400 Subject: [PATCH 04/10] minor changes in code --- .../roberta_sst2_training_with_torcharrow.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py index d48b9853dd..a31b83c32b 100644 --- a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py +++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py @@ -102,12 +102,8 @@ def get_dataloader(split, args): model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head) model.to(DEVICE) -learning_rate = 1e-5 -optim = AdamW(model.parameters(), lr=learning_rate) -criteria = nn.CrossEntropyLoss() - -def train_step(input, target): +def train_step(input, target, optim, criteria): output = model(input) loss = criteria(output, target) optim.zero_grad() @@ -115,7 +111,7 @@ def train_step(input, target): optim.step() -def eval_step(input, target): +def eval_step(input, target, criteria): output = model(input) loss = criteria(output, target).item() return float(loss), (output.argmax(1) == target).type(torch.float).sum().item() @@ -141,22 +137,27 @@ def evaluate(dataloader): def main(args): - + print(args) train_dl = get_dataloader(split="train", args=args) dev_dl = get_dataloader(split="dev", args=args) + learning_rate = args.learning_rate + optim = AdamW(model.parameters(), lr=learning_rate) + criteria = nn.CrossEntropyLoss() + for e in range(args.num_epochs): for batch in train_dl: input = batch.tokens.to(DEVICE) target = batch.labels.to(DEVICE) - train_step(input, target) + train_step(input, target, optim, criteria) - loss, accuracy = evaluate(dev_dl) + loss, accuracy = evaluate(dev_dl, criteria) print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy)) if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--batch-size", default=16, type=int) - parser.add_argument("--num-epochs", default=1, type=int) + parser.add_argument("--batch-size", default=16, type=int, help="Input batch size used during training") + parser.add_argument("--num-epochs", default=1, type=int, help="Number of epochs to run training") + parser.add_argument("--learning-rate", default=1e-5, type=float, help="Learning rate used for training") main(parser.parse_args()) From 2dfa7938bc67fddf3a6aabae72f574db3b6c1988 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 10:27:35 -0400 Subject: [PATCH 05/10] add readme --- examples/torcharrow/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 examples/torcharrow/README.md diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md new file mode 100644 index 0000000000..7496f6db3d --- /dev/null +++ b/examples/torcharrow/README.md @@ -0,0 +1,16 @@ +i## Description + +This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on top of TorchArrow DataFrame. + +## Installation and Usage + +The example depends on TorchArrow. Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on nightly releases) + +To run example from command line run following command: + +```bash +python roberta_sst2_training_with_torcharrow.py \ + --batch-size 16 \ + --num-epochs 1 \ + --learning-rate 1e-5 +``` From d31dfc86be42c41b9f1b60780429774fc651cf4d Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 10:31:35 -0400 Subject: [PATCH 06/10] fix lint --- examples/torcharrow/README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md index 7496f6db3d..dd4f749deb 100644 --- a/examples/torcharrow/README.md +++ b/examples/torcharrow/README.md @@ -1,10 +1,16 @@ i## Description -This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on top of TorchArrow DataFrame. +This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text +pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on +top of TorchArrow DataFrame. ## Installation and Usage -The example depends on TorchArrow. Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on nightly releases) +The example depends on TorchArrow. Install it from source following instructions at +https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators +(`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch +library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1` +during TorchArrow installation (this is also the reason why we cannot depend on nightly releases) To run example from command line run following command: From ebb9f46a38473c5306f385c98a9f6b49d8120c0b Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 10:37:31 -0400 Subject: [PATCH 07/10] edit readme --- examples/torcharrow/README.md | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md index dd4f749deb..417bfdfbdf 100644 --- a/examples/torcharrow/README.md +++ b/examples/torcharrow/README.md @@ -1,4 +1,4 @@ -i## Description +## Description This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on @@ -6,11 +6,25 @@ top of TorchArrow DataFrame. ## Installation and Usage -The example depends on TorchArrow. Install it from source following instructions at -https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators -(`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch -library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1` -during TorchArrow installation (this is also the reason why we cannot depend on nightly releases) +The example depends on TorchArrow and TorchData. + +### TorchArrow Installation + +Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of +the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used +in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence +make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on +nightly releases) + +``` +USE_TORCH=1 python setup.py install +``` + +### TorchData Installation + +To install TorchData follow instructions athttps://github.com/pytorch/data#installation + +### Usage To run example from command line run following command: From a42a76198c27778805507d40c6e4a1d845632ff2 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 10:39:06 -0400 Subject: [PATCH 08/10] minor edits --- examples/torcharrow/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md index 417bfdfbdf..54a619ff59 100644 --- a/examples/torcharrow/README.md +++ b/examples/torcharrow/README.md @@ -8,7 +8,7 @@ top of TorchArrow DataFrame. The example depends on TorchArrow and TorchData. -### TorchArrow Installation +#### TorchArrow Installation Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used @@ -20,11 +20,11 @@ nightly releases) USE_TORCH=1 python setup.py install ``` -### TorchData Installation +#### TorchData Installation To install TorchData follow instructions athttps://github.com/pytorch/data#installation -### Usage +#### Usage To run example from command line run following command: From 3b420a2e3381c1727ae20bf5f49f6b3896c992f1 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 10:40:28 -0400 Subject: [PATCH 09/10] minor edit --- examples/torcharrow/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md index 54a619ff59..97299dc683 100644 --- a/examples/torcharrow/README.md +++ b/examples/torcharrow/README.md @@ -8,7 +8,7 @@ top of TorchArrow DataFrame. The example depends on TorchArrow and TorchData. -#### TorchArrow Installation +#### TorchArrow Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used @@ -20,7 +20,7 @@ nightly releases) USE_TORCH=1 python setup.py install ``` -#### TorchData Installation +#### TorchData To install TorchData follow instructions athttps://github.com/pytorch/data#installation From 998b7b9d51cbd06bbb580aa11f32cc49f774bddd Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia <parmeetbhatia@fb.com> Date: Thu, 21 Jul 2022 14:29:45 -0400 Subject: [PATCH 10/10] minor fix --- examples/torcharrow/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md index 97299dc683..a4e95765d5 100644 --- a/examples/torcharrow/README.md +++ b/examples/torcharrow/README.md @@ -22,7 +22,7 @@ USE_TORCH=1 python setup.py install #### TorchData -To install TorchData follow instructions athttps://github.com/pytorch/data#installation +To install TorchData follow instructions at https://github.com/pytorch/data#installation #### Usage