Skip to content

inverted-ai/conditional-permutation-invariant-flows-datasets

Repository files navigation

Conditional sets

This repository provides the code to generate the datasets and validation metrics introduced in "Conditional Permutation Invariant Flows" [1]. It contains two main tasks, each consisting of several sub variants. It moreover provides functions to calculate the metrics reported in [1].

The datasets are derivatives of two publicly available datasets. Since both have licenses that do not allow republishing, we provide our code in a way that assumes the user has obtained them from their respective sources. The uses of the two datasets and sub tasks will be described briefly.

If you use this code, please cite our paper:

@article{
zwartsenberg2023conditional,
title={Conditional Permutation Invariant Flows},
author={Berend Zwartsenberg and Adam Scibior and Matthew Niedoba and Vasileios Lioutas and Justice Sefas and Yunpeng Liu and Setareh Dabiri and Jonathan Wilder Lavington and Trevor Campbell and Frank Wood},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=DUsgPi3oCC},
note={}
}

Conditional traffic scenes

Here, a dataset is provided on top of the INTERACTION dataset. The task is to fit an Nx5 distribution, conditional on the map. Each datapoint is represented as an Nx5 dimensional tensor, with N the number of agents, and the other five dimensions being (x,y,length,width,orientation). The INTERACTION dataset [2] is required for this task. It's root path is described as interaction_path, and should look like:

├── interaction_path
│   ├── recorded_trackfiles
│   ├── maps 
│   ├── ...

The module create_traffic_dataset.py provides a dataset with a __getitem__ method that returns location, datapoint, location being the conditional information that has to predict datatapoint, an Nx5 dimensional traffic scene. Both train and val splits are provided. There are two tasks, the first has variable N, which is not consistent with all generative methods (prune_outer=0). The second fixes N=7, which is consistent with more methods (prune_outer=1), full details can be found in [1]. The module infractions.py provides the metrics reported in the paper. A convenience script is also provided in write-datasets.sh that processes all the datasets and pre-writes the samples.

Object detection

A dataset based on the CLEVR dataset [3] is provided. This task is to fit an Nx3 distribution conditioned on the input image. N is the number of objects, the other dimensions are [x,y,size]. The CLEVR dataset [3] is required for this task, and should be set through clevr_path. The path passed there should look like:

├── clevr_path 
│   ├── images
│   ├── scenes 
│   ├── ...

The module create_detection_dataset.py provides a dataset with a __getitem__ method, that returns an image path, as well as an Nx3 datapoint. Both train and val splits are provided through the split argument. Two modes are provided, one where N is fixed to 3 (task=CLEVR3), and another where N is 3-6 (task=CLEVR6). There is also a function that calculates IOU (intersection over union) reported in the paper.

Installation

This repository requires torchdrivesim. Please follow installation instructions on https://docs.torchdrivesim.org/en/latest/, and then pip install -r requirements.txt.

[1] "Permutation Invariant Flows", Zwartsenberg et al., TMLR (2023)
[2] Interaction Dataset
[3] CLEVR Dataset

About

Code to generate datasets for conditional set generation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published