From 59a4b99f6ec5c20fef9bcc46e9a47f50a4b8ec44 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Wed, 22 Jun 2022 17:29:42 -0400
Subject: [PATCH 01/10] Add initial code for TA based training

---
 .../roberta_sst2_training_with_torcharrow.py  | 163 ++++++++++++++++++
 1 file changed, 163 insertions(+)
 create mode 100644 examples/torcharrow/roberta_sst2_training_with_torcharrow.py

diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
new file mode 100644
index 0000000000..9de4de2cdb
--- /dev/null
+++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
@@ -0,0 +1,163 @@
+import json
+from argparse import ArgumentParser
+
+import torch
+import torch.nn as nn
+import torcharrow as ta
+import torcharrow._torcharrow as _ta
+import torcharrow.pytorch as tap
+import torchtext.functional as F
+import torchtext.transforms as T
+from torch.nn import Module
+from torch.optim import AdamW
+from torch.utils.data import DataLoader
+from torcharrow import functional as ta_F
+from torchtext.datasets import SST2
+from torchtext.models import RobertaClassificationHead, ROBERTA_BASE_ENCODER
+from torchtext.utils import get_asset_local_path
+
+DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+
+def init_ta_gpt2bpe_encoder():
+    encoder_json_path = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json"
+    vocab_bpe_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe"
+
+    encoder_json_path = get_asset_local_path(encoder_json_path)
+    vocab_bpe_path = get_asset_local_path(vocab_bpe_path)
+    _seperator = "\u0001"
+
+    # load bpe encoder and bpe decoder
+    with open(encoder_json_path, "r", encoding="utf-8") as f:
+        bpe_encoder = json.load(f)
+    # load bpe vocab
+    with open(vocab_bpe_path, "r", encoding="utf-8") as f:
+        bpe_vocab = f.read()
+    bpe_merge_ranks = {
+        _seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1])
+    }
+    # Caching is enabled in Eager mode
+    bpe = _ta.GPT2BPEEncoder(bpe_encoder, bpe_merge_ranks, _seperator, T.bytes_to_unicode(), True)
+    return bpe
+
+
+def init_ta_gpt2bpe_vocab():
+    vocab_path = "https://download.pytorch.org/models/text/roberta.vocab.pt"
+    vocab_path = get_asset_local_path(vocab_path)
+    vocab = torch.load(vocab_path)
+    ta_vocab = _ta.Vocab(vocab.get_itos(), vocab.get_default_index())
+    return ta_vocab
+
+
+class RobertaTransformDataFrameNativeOps(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        # Tokenizer to split input text into tokens
+        self.tokenizer = init_ta_gpt2bpe_encoder()
+
+        # vocabulary converting tokens to IDs
+        self.vocab = init_ta_gpt2bpe_vocab()
+
+        # Add BOS token to the beginning of sentence
+        self.add_bos = T.AddToken(token=0, begin=True)
+
+        # Add EOS token to the end of sentence
+        self.add_eos = T.AddToken(token=2, begin=False)
+
+    def forward(self, input: ta.DataFrame) -> ta.DataFrame:
+        input["tokens"] = ta_F.bpe_tokenize(self.tokenizer, input["text"])
+        input["tokens"] = input["tokens"].list.slice(stop=254)
+        input["tokens"] = ta_F.lookup_indices(self.vocab, input["tokens"])
+        input["tokens"] = input["tokens"].transform(self.add_bos, format="python")
+        input["tokens"] = input["tokens"].transform(self.add_eos, format="python")
+        return input
+
+
+def get_dataloader(split, args):
+    # Instantiate transform
+    transform = RobertaTransformDataFrameNativeOps()
+
+    # Create SST2 datapipe and apply pre-processing
+    train_dp = SST2(split=split)
+
+    # convert to DataFrame of size batches
+    train_dp = train_dp.dataframe(columns=["text", "labels"], dataframe_size=args.batch_size)
+
+    # Apply transformation on DataFrame
+    train_dp = train_dp.map(transform)
+
+    # (optional) Remove un-required columns
+    train_dp = train_dp.map(lambda x: x.drop(["text"]))
+
+    # convert DataFrame to tensor (This will yeild named tuple)
+    train_dp = train_dp.map(lambda x: x.to_tensor({"tokens": tap.PadSequence(padding_value=1)}))
+
+    # create DataLoader
+    dl = DataLoader(train_dp, batch_size=None)
+
+    return dl
+
+
+classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
+model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head)
+model.to(DEVICE)
+
+learning_rate = 1e-5
+optim = AdamW(model.parameters(), lr=learning_rate)
+criteria = nn.CrossEntropyLoss()
+
+
+def train_step(input, target):
+    output = model(input)
+    loss = criteria(output, target)
+    print(float(loss))
+    optim.zero_grad()
+    loss.backward()
+    optim.step()
+
+
+def eval_step(input, target):
+    output = model(input)
+    loss = criteria(output, target).item()
+    return float(loss), (output.argmax(1) == target).type(torch.float).sum().item()
+
+
+def evaluate(dataloader):
+    model.eval()
+    total_loss = 0
+    correct_predictions = 0
+    total_predictions = 0
+    counter = 0
+    with torch.no_grad():
+        for batch in dataloader:
+            input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE)
+            target = torch.tensor(batch["target"]).to(DEVICE)
+            loss, predictions = eval_step(input, target)
+            total_loss += loss
+            correct_predictions += predictions
+            total_predictions += len(target)
+            counter += 1
+
+    return total_loss / counter, correct_predictions / total_predictions
+
+
+def main(args):
+
+    train_dl = get_dataloader(split="train", args=args)
+    dev_dl = get_dataloader(split="dev", args=args)
+
+    for e in range(args.num_epochs):
+        for batch in train_dl:
+            input = batch.tokens.to(DEVICE)
+            target = batch.labels.to(DEVICE)
+            train_step(input, target)
+
+    loss, accuracy = evaluate(dev_dl)
+    print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy))
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("--batch-size", default=16, type=int)
+    parser.add_argument("--num-epochs", default=1, type=int)
+    main(parser.parse_args())

From 5e41b42e26f10935ddec95ac93d9ede4ed644be7 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Wed, 20 Jul 2022 09:34:50 -0400
Subject: [PATCH 02/10] use native ops for ading tokens

---
 examples/torcharrow/roberta_sst2_training_with_torcharrow.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
index 9de4de2cdb..d4e767ef44 100644
--- a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
+++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
@@ -68,8 +68,8 @@ def forward(self, input: ta.DataFrame) -> ta.DataFrame:
         input["tokens"] = ta_F.bpe_tokenize(self.tokenizer, input["text"])
         input["tokens"] = input["tokens"].list.slice(stop=254)
         input["tokens"] = ta_F.lookup_indices(self.vocab, input["tokens"])
-        input["tokens"] = input["tokens"].transform(self.add_bos, format="python")
-        input["tokens"] = input["tokens"].transform(self.add_eos, format="python")
+        input["tokens"] = ta_F.add_tokens(input["tokens"], [0], begin=True)
+        input["tokens"] = ta_F.add_tokens(input["tokens"], [2], begin=False)
         return input
 
 

From b750378cc2028b0183507d0876bdba5c96a401c5 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Wed, 20 Jul 2022 10:11:20 -0400
Subject: [PATCH 03/10] remove print

---
 examples/torcharrow/roberta_sst2_training_with_torcharrow.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
index d4e767ef44..d48b9853dd 100644
--- a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
+++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
@@ -110,7 +110,6 @@ def get_dataloader(split, args):
 def train_step(input, target):
     output = model(input)
     loss = criteria(output, target)
-    print(float(loss))
     optim.zero_grad()
     loss.backward()
     optim.step()

From 2abfe00288c09920696df05c4ddc2c5c593df0e1 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 09:55:44 -0400
Subject: [PATCH 04/10] minor changes in code

---
 .../roberta_sst2_training_with_torcharrow.py  | 23 ++++++++++---------
 1 file changed, 12 insertions(+), 11 deletions(-)

diff --git a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
index d48b9853dd..a31b83c32b 100644
--- a/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
+++ b/examples/torcharrow/roberta_sst2_training_with_torcharrow.py
@@ -102,12 +102,8 @@ def get_dataloader(split, args):
 model = ROBERTA_BASE_ENCODER.get_model(head=classifier_head)
 model.to(DEVICE)
 
-learning_rate = 1e-5
-optim = AdamW(model.parameters(), lr=learning_rate)
-criteria = nn.CrossEntropyLoss()
 
-
-def train_step(input, target):
+def train_step(input, target, optim, criteria):
     output = model(input)
     loss = criteria(output, target)
     optim.zero_grad()
@@ -115,7 +111,7 @@ def train_step(input, target):
     optim.step()
 
 
-def eval_step(input, target):
+def eval_step(input, target, criteria):
     output = model(input)
     loss = criteria(output, target).item()
     return float(loss), (output.argmax(1) == target).type(torch.float).sum().item()
@@ -141,22 +137,27 @@ def evaluate(dataloader):
 
 
 def main(args):
-
+    print(args)
     train_dl = get_dataloader(split="train", args=args)
     dev_dl = get_dataloader(split="dev", args=args)
 
+    learning_rate = args.learning_rate
+    optim = AdamW(model.parameters(), lr=learning_rate)
+    criteria = nn.CrossEntropyLoss()
+
     for e in range(args.num_epochs):
         for batch in train_dl:
             input = batch.tokens.to(DEVICE)
             target = batch.labels.to(DEVICE)
-            train_step(input, target)
+            train_step(input, target, optim, criteria)
 
-    loss, accuracy = evaluate(dev_dl)
+    loss, accuracy = evaluate(dev_dl, criteria)
     print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy))
 
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("--batch-size", default=16, type=int)
-    parser.add_argument("--num-epochs", default=1, type=int)
+    parser.add_argument("--batch-size", default=16, type=int, help="Input batch size used during training")
+    parser.add_argument("--num-epochs", default=1, type=int, help="Number of epochs to run training")
+    parser.add_argument("--learning-rate", default=1e-5, type=float, help="Learning rate used for training")
     main(parser.parse_args())

From 2dfa7938bc67fddf3a6aabae72f574db3b6c1988 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 10:27:35 -0400
Subject: [PATCH 05/10] add readme

---
 examples/torcharrow/README.md | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)
 create mode 100644 examples/torcharrow/README.md

diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
new file mode 100644
index 0000000000..7496f6db3d
--- /dev/null
+++ b/examples/torcharrow/README.md
@@ -0,0 +1,16 @@
+i## Description
+
+This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on top of TorchArrow DataFrame. 
+
+## Installation and Usage
+
+The example depends on TorchArrow. Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source.  Note that some of the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on nightly releases)
+
+To run example from command line run following command:
+
+```bash
+python roberta_sst2_training_with_torcharrow.py \
+        --batch-size 16 \
+        --num-epochs 1 \
+        --learning-rate 1e-5
+```

From d31dfc86be42c41b9f1b60780429774fc651cf4d Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 10:31:35 -0400
Subject: [PATCH 06/10] fix lint

---
 examples/torcharrow/README.md | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
index 7496f6db3d..dd4f749deb 100644
--- a/examples/torcharrow/README.md
+++ b/examples/torcharrow/README.md
@@ -1,10 +1,16 @@
 i## Description
 
-This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on top of TorchArrow DataFrame. 
+This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text
+pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on
+top of TorchArrow DataFrame.
 
 ## Installation and Usage
 
-The example depends on TorchArrow. Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source.  Note that some of the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on nightly releases)
+The example depends on TorchArrow. Install it from source following instructions at
+https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators
+(`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch
+library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1`
+during TorchArrow installation (this is also the reason why we cannot depend on nightly releases)
 
 To run example from command line run following command:
 

From ebb9f46a38473c5306f385c98a9f6b49d8120c0b Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 10:37:31 -0400
Subject: [PATCH 07/10] edit readme

---
 examples/torcharrow/README.md | 26 ++++++++++++++++++++------
 1 file changed, 20 insertions(+), 6 deletions(-)

diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
index dd4f749deb..417bfdfbdf 100644
--- a/examples/torcharrow/README.md
+++ b/examples/torcharrow/README.md
@@ -1,4 +1,4 @@
-i## Description
+## Description
 
 This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text
 pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on
@@ -6,11 +6,25 @@ top of TorchArrow DataFrame.
 
 ## Installation and Usage
 
-The example depends on TorchArrow. Install it from source following instructions at
-https://github.com/pytorch/torcharrow#from-source. Note that some of the natively integrated text operators
-(`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used in this example depend on the torch
-library. By default, TorchArrow doesn’t take dependency on the torch library. Hence make sure to use flag `USE_TORCH=1`
-during TorchArrow installation (this is also the reason why we cannot depend on nightly releases)
+The example depends on TorchArrow and TorchData.
+
+### TorchArrow Installation
+
+Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of
+the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used
+in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence
+make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on
+nightly releases)
+
+```
+USE_TORCH=1 python setup.py install
+```
+
+### TorchData Installation
+
+To install TorchData follow instructions athttps://github.com/pytorch/data#installation
+
+### Usage
 
 To run example from command line run following command:
 

From a42a76198c27778805507d40c6e4a1d845632ff2 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 10:39:06 -0400
Subject: [PATCH 08/10] minor edits

---
 examples/torcharrow/README.md | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
index 417bfdfbdf..54a619ff59 100644
--- a/examples/torcharrow/README.md
+++ b/examples/torcharrow/README.md
@@ -8,7 +8,7 @@ top of TorchArrow DataFrame.
 
 The example depends on TorchArrow and TorchData.
 
-### TorchArrow Installation
+#### TorchArrow Installation
 
 Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of
 the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used
@@ -20,11 +20,11 @@ nightly releases)
 USE_TORCH=1 python setup.py install
 ```
 
-### TorchData Installation
+#### TorchData Installation
 
 To install TorchData follow instructions athttps://github.com/pytorch/data#installation
 
-### Usage
+#### Usage
 
 To run example from command line run following command:
 

From 3b420a2e3381c1727ae20bf5f49f6b3896c992f1 Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 10:40:28 -0400
Subject: [PATCH 09/10] minor edit

---
 examples/torcharrow/README.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
index 54a619ff59..97299dc683 100644
--- a/examples/torcharrow/README.md
+++ b/examples/torcharrow/README.md
@@ -8,7 +8,7 @@ top of TorchArrow DataFrame.
 
 The example depends on TorchArrow and TorchData.
 
-#### TorchArrow Installation
+#### TorchArrow
 
 Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of
 the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used
@@ -20,7 +20,7 @@ nightly releases)
 USE_TORCH=1 python setup.py install
 ```
 
-#### TorchData Installation
+#### TorchData
 
 To install TorchData follow instructions athttps://github.com/pytorch/data#installation
 

From 998b7b9d51cbd06bbb580aa11f32cc49f774bddd Mon Sep 17 00:00:00 2001
From: Parmeet  Singh Bhatia <parmeetbhatia@fb.com>
Date: Thu, 21 Jul 2022 14:29:45 -0400
Subject: [PATCH 10/10] minor fix

---
 examples/torcharrow/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
index 97299dc683..a4e95765d5 100644
--- a/examples/torcharrow/README.md
+++ b/examples/torcharrow/README.md
@@ -22,7 +22,7 @@ USE_TORCH=1 python setup.py install
 
 #### TorchData
 
-To install TorchData follow instructions athttps://github.com/pytorch/data#installation
+To install TorchData follow instructions at https://github.com/pytorch/data#installation
 
 #### Usage