This repo is the implementation of the baseline methods for unsupervised Object-Centric Learning, including IODINE, MONet, Slot Attention, and Genesis V2. The implementation of IODINE, MONet, and Genesis V2 is from here.
- IODINE (Apache-2.0 license): [paper] [original code]
- MONet: [paper] [code]
- Slot Attention (MIT license): [paper] [code1, code2] [orginal code]
- Genesis V2 (GPLv3 license): [paper] [code]
↑↑↑ Visualization of training results logged by WandB ↑↑↑
The directory structure of this repo looks like this:
├── .github <- Github Actions workflows
│
├── configs <- Hydra configs
│ ├── callbacks <- Callbacks configs
│ ├── data <- Data configs
│ ├── debug <- Debugging configs
│ ├── experiment <- *** Experiment configs ***
│ │ ├── slota
│ │ │ ├── clv6.yaml
│ │ │ └── ...
│ │ └── ...
│ ├── extras <- Extra utilities configs
│ ├── hparams_search <- Hyperparameter search configs
│ ├── hydra <- Hydra configs
│ ├── local <- Local configs
│ ├── logger <- Logger configs (we use wandb)
│ ├── model <- Model configs
│ ├── paths <- Project paths configs
│ ├── trainer <- Trainer configs
│ │
│ ├── eval.yaml <- Main config for evaluation
│ └── train.yaml <- Main config for training
│
├── data <- Directory for Dataset
│ ├── CLEVR6
│ │ ├── images <- raw images
│ │ │ ├── train
│ │ │ │ ├── CLEVR_train_******.png
│ │ │ │ └── ...
│ │ │ └── val
│ │ │ ├── CLEVR_val_******.png
│ │ │ └── ...
│ │ ├── masks <- mask annotations
│ │ │ ├── train
│ │ │ │ ├── CLEVR_train_******.png
│ │ │ │ └── ...
│ │ │ └── val
│ │ │ ├── CLEVR_val_******.png
│ │ │ └── ...
│ │ └── scenes <- metadata
│ │ ├── CLEVR_train_scenes.json
│ │ └── CLEVR_val_scenes.json
│ └── ...
│
├── logs <- Logs generated by hydra and lightning loggers
│
├── scripts <- Shell scripts
│
├── src <- Source code
│ ├── data <- Data scripts
│ ├── models <- Model scripts
│ ├── utils <- Utility scripts
│ │
│ ├── eval.py <- Run evaluation
│ └── train.py <- Run training
│
├── tests <- Tests of any kind
│
├── .env.example <- Example of file for storing private environment variables
├── .gitignore <- List of files ignored by git
├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
├── .project-root <- File for inferring the position of project root directory
├── environment.yaml <- File for installing conda environment
├── Makefile <- Makefile with commands like `make train` or `make test`
├── pyproject.toml <- Configuration options for testing and linting
├── requirements.txt <- File for installing python dependencies
├── setup.py <- File for installing project as a package
└── README.md
Note
Each dataset may have each different way of providing mask annotation and metadata, so you should match theDataset
class for each dataset with its desired configuration.
This repo is developed based on Lightning-Hydra-Template 1.5.3 with Python 3.8.12 and PyTorch 1.11.0.
# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning
# [OPTIONAL] create conda environment
conda create -n slota python=3.8
conda activate slota
# install pytorch according to instructions
# https://pytorch.org/get-started/
# install requirements
pip install -r requirements.txt
# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning
# create conda environment and install dependencies
conda env create -f environment.yaml
# activate conda environment
conda activate slota
Train model with chosen experiment configuration from configs/experiment/
# training Slot Attention over CLEVR6 dataset
python src/train.py \
experiment=slota/clv6.yaml
# training Genesis V2 over CLEVRTEX dataset
python src/train.py \
experiment=genesis2/clvt.yaml
You can create your own expreiment configs for the purpose.
But, for simple modification, you can override any parameter from command line.
# training Slot Attention over CLEVR6 dataset with custom config
python src/train.py \
experiment=slota/clv6.yaml \
data.data_dir=/workspace/dataset/clevr_with_masks/CLEVR6 \
trainer.check_val_every_n_epoch=10 \
model.net.num_slots=10 \
model.net.num_iter=5 \
model.name="slota_k10_t5" # model.name will be used for logging on wandb
You can evaluate a trained model with the corresponding checkpoint.
The evaluation is also conducted during training with the interval of trainer.check_val_every_n_epoch
.
# evaluating Slot Attention over CLEVR6 dataset.
# similar to the training phase, you can also customize the config with command line
python src/eval.py \
experiment=slota/clv6.yaml \
ckpt_path=logs/train/runs/clv6_slota/{timestamp}/checkpoints/last.ckpt