A pytorch implementation for A simple neural network module for relational reasoning https://arxiv.org/abs/1706.01427, working on the CLEVR dataset.
This code tries to reproduce results obtained by DeepMind team, both for the From Pixels and State Descriptions versions they described. Since the paper does not expose all the network details, there could be variations to respect the original results.
The model can also be trained with a slightly modified version of RN, called IR, that enables relational features extraction in order to perform Relational Content Based Image Retrieval (R-CBIR).
We released pretrained models for both Original and Image Retrieval architectures (below, detailed instructions on how to use them).
Accuracy values measured on the test set:
Model | |
---|---|
From Pixels | 93.6% |
State Descriptions | 97.9% |
- Download and extract CLEVR_v1.0 dataset: http://cs.stanford.edu/people/jcjohns/clevr/
- Clone this repository and move into it:
git clone https://github.com/mesnico/RelationNetworks-CLEVR
cd RelationNetworks-CLEVR
- Setup a virtual environment (optional, but recommended)
mkdir env
virtualenv -p /usr/bin/python3 env
source env/bin/activate
- Install requirements:
pip3 install -r requirements.txt
The training code can be run both using Docker or standard python installation with pytorch. If Docker is used, an image is built with all needed dependencies and it can be easily run inside a Docker container.
Move to the cloned directory and issue the command:
python3 train.py --clevr-dir path/to/CLEVR_v1.0/ --model 'original-sd' | tee logfile.log
We reached an accuracy around 98% over the test set. Using these parameters, training is performed by using an exponential increase policy for the learning rate (slow start method). Without this policy, our training stopped at around 70% accuracy. Our training curve measured on the test set:
Move to the cloned directory and issue the command:
python3 train.py --clevr-dir path/to/CLEVR_v1.0/ --model 'original-fp' | tee logfile.log
We used the same exponential increase policy we employed for the State Descriptions version. We were able to reach around 93% accuracy over the test set:
We prepared a json-coded configuration file from which model hyperparameters can be tuned. The option --config
specifies a json configuration file, while the option --model
loads a specific hyperparameters configuration defined in the file.
By default, the configuration file is config.json
and the default model is original-fp
.
Once training ends, some plots (invalid answers, training loss, test loss, test accuracy) can be generated using the plot.py
script:
python3 plot.py -i -trl -tsl -a logfile.log
These plots are also saved inside img/
folder.
To explore a bunch of other possible arguments useful to customize training, issue the command:
$ python3 train.py --help
It is possible to run a test session even after training, by loading a specific checkpoint from the trained network collected at a certain epoch. This is possible by specifying the option --test
:
python3 train.py --clevr-dir path/to/CLEVR_v1.0/ --model 'original-fp' --resume RN_epoch_xxx.pth --test
IMPORTANT: If you receive an out of memory error from CUDA due to the fact that you have not enough V-RAM for testing, just lower the test batch-size to 64 or 32 by using the option --test-batch-size 32
We released pre-trained models for Original and Image-Retrieval architectures, for the challenging from-pixels version.
Epoch 493
python3 train.py --clevr-dir path/to/CLEVR_v1.0/ --model 'original-fp' --resume pretrained_models/original_fp_epoch_493.pth --test
Epoch 312
python3 train.py --clevr-dir path/to/CLEVR_v1.0/ --model 'ir-fp' --resume pretrained_models/ir_fp_epoch_312.pth --test
Once test has been performed at least once (note that a test session can be explicitly run but it is also always run automatically after every train epoch), some insights are saved into test_results
and a confusion plot can be generated from them:
python3 confusionplot.py test_results/test.pickle
This is useful to discover network weaknesses and possibly solve them.
This plot is also saved inside img/
folder.
- Questions and answers dictionaries are built from data in training set, so the model will not work with words never seen before.
- All the words in the dataset are treated in a case-insensitive manner, since we don't want the model to learn case biases.
- For network settings, see sections 4 and B from the original paper https://arxiv.org/abs/1706.01427.
Special thanks to https://github.com/aelnouby and https://github.com/rosinality for their great support. Following, their Relation Network repositories working on CLEVR: