Original Paper: https://arxiv.org/pdf/1703.07469.pdf
Model Checkpoint: https://huggingface.co/eddyyeo/robustfill
The RobustFill network by Devlin et al. is trained for the following task -- based on a few example input-output string pairs, generate a program in a domain-specific language that transforms the given inputs into the given outputs. This program can then be used to transform unseen inputs. For example:
Given these pairs:
Input | Output |
---|---|
Jacob Devlin | Devlin, J. |
Eddy Yeo | Yeo, E. |
Andrej Karpathy | Karpathy, A. |
Anatoly Yakovenko | Yakovenko, A. |
The RobustFill network will generate a program that can be used to transform an unbounded number of unseen inputs:
Unseen input | Transformed Output |
---|---|
Elon Musk | Musk, E. |
Joe Rogan | Rogan, J. |
Balaji Srinivasan | Srinivasan, B. |
The program generated by our trained network for the example above is as follows:
Concat(
Compose(
Trim(),
GetFrom(<Type.LOWER: 6>)
),
ConstStr(','),
ConstStr(' '),
GetUpto(<Type.CHAR: 8>),
ConstStr('.')
)
See the demo notebook to reproduce the result with the model checkpoint:
The network was trained on Google Cloud with 4 x NVIDIA Tesla P4 using PyTorch's Distributed Data Parallel
.
Set up environment:
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
Train neural net. The script will automatically use GPU(s) if they are available.
python train.py --mode full
For testing purposes, run smaller network (on CPU) with a smaller problem size just to see that the loss goes to 0.
python train.py --mode easy
Run profiler:
python train.py --mode profile
Run unit tests:
python -m unittest
Lint:
flake8