From 96ec3ab85d8e953845279776cab60e2daba6a28a Mon Sep 17 00:00:00 2001 From: Theo West Date: Thu, 2 May 2024 13:48:44 +0200 Subject: [PATCH] first-commit --- .github/workflows/apptainer-image.yml | 41 ++ .github/workflows/conda.yml | 37 ++ .github/workflows/docker-image.yml | 31 ++ .github/workflows/mypy.yml | 31 ++ .github/workflows/requirements.txt | 13 + .gitignore | 425 ++++++++++++++++++ README.md | 553 ++++++++++++++++++++++++ arguments.py | 73 ++++ base.py | 110 +++++ configs/rounD/example.json | 29 ++ configs/synthD/ci.json | 29 ++ container/Dockerfile | 37 ++ container/apptainer.def | 59 +++ container/environment.yml | 27 ++ datamodules/__init__.py | 3 + datamodules/dataloader.py | 165 +++++++ datamodules/dataset.py | 53 +++ datamodules/transforms.py | 102 +++++ metrics/__init__.py | 8 + metrics/collision_rate.py | 81 ++++ metrics/log_likelihood.py | 123 ++++++ metrics/min_ade.py | 71 +++ metrics/min_apde.py | 79 ++++ metrics/min_brier.py | 86 ++++ metrics/min_fde.py | 76 ++++ metrics/miss_rate.py | 80 ++++ metrics/unittest/__init__.py | 0 metrics/unittest/test_ade.py | 26 ++ metrics/unittest/test_apde.py | 26 ++ metrics/unittest/test_brier.py | 39 ++ metrics/unittest/test_collision_rate.py | 45 ++ metrics/unittest/test_fde.py | 26 ++ metrics/unittest/test_mr.py | 26 ++ metrics/unittest/test_nll.py | 67 +++ metrics/unittest/test_utils.py | 70 +++ metrics/utils.py | 84 ++++ models/prototype.py | 55 +++ preamble.py | 53 +++ preprocess.sh | 4 + preprocessing/__init__.py | 1 + preprocessing/arguments.py | 47 ++ preprocessing/configs/highD.json | 16 + preprocessing/configs/inD.json | 211 +++++++++ preprocessing/configs/rounD.json | 157 +++++++ preprocessing/configs/uniD.json | 91 ++++ preprocessing/preprocess_highway.py | 415 ++++++++++++++++++ preprocessing/preprocess_urban.py | 398 +++++++++++++++++ preprocessing/utils/__init__.py | 5 + preprocessing/utils/common.py | 296 +++++++++++++ preprocessing/utils/highway_graph.py | 160 +++++++ preprocessing/utils/highway_utils.py | 269 ++++++++++++ preprocessing/utils/lanelet_graph.py | 285 ++++++++++++ train.py | 122 ++++++ train.sh | 1 + 54 files changed, 5417 insertions(+) create mode 100644 .github/workflows/apptainer-image.yml create mode 100644 .github/workflows/conda.yml create mode 100644 .github/workflows/docker-image.yml create mode 100644 .github/workflows/mypy.yml create mode 100644 .github/workflows/requirements.txt create mode 100644 .gitignore create mode 100644 README.md create mode 100644 arguments.py create mode 100644 base.py create mode 100644 configs/rounD/example.json create mode 100644 configs/synthD/ci.json create mode 100644 container/Dockerfile create mode 100644 container/apptainer.def create mode 100644 container/environment.yml create mode 100644 datamodules/__init__.py create mode 100644 datamodules/dataloader.py create mode 100644 datamodules/dataset.py create mode 100644 datamodules/transforms.py create mode 100644 metrics/__init__.py create mode 100644 metrics/collision_rate.py create mode 100644 metrics/log_likelihood.py create mode 100644 metrics/min_ade.py create mode 100644 metrics/min_apde.py create mode 100644 metrics/min_brier.py create mode 100644 metrics/min_fde.py create mode 100644 metrics/miss_rate.py create mode 100644 metrics/unittest/__init__.py create mode 100644 metrics/unittest/test_ade.py create mode 100644 metrics/unittest/test_apde.py create mode 100644 metrics/unittest/test_brier.py create mode 100644 metrics/unittest/test_collision_rate.py create mode 100644 metrics/unittest/test_fde.py create mode 100644 metrics/unittest/test_mr.py create mode 100644 metrics/unittest/test_nll.py create mode 100644 metrics/unittest/test_utils.py create mode 100644 metrics/utils.py create mode 100644 models/prototype.py create mode 100644 preamble.py create mode 100644 preprocess.sh create mode 100644 preprocessing/__init__.py create mode 100644 preprocessing/arguments.py create mode 100644 preprocessing/configs/highD.json create mode 100644 preprocessing/configs/inD.json create mode 100644 preprocessing/configs/rounD.json create mode 100644 preprocessing/configs/uniD.json create mode 100644 preprocessing/preprocess_highway.py create mode 100644 preprocessing/preprocess_urban.py create mode 100644 preprocessing/utils/__init__.py create mode 100644 preprocessing/utils/common.py create mode 100644 preprocessing/utils/highway_graph.py create mode 100644 preprocessing/utils/highway_utils.py create mode 100644 preprocessing/utils/lanelet_graph.py create mode 100644 train.py create mode 100644 train.sh diff --git a/.github/workflows/apptainer-image.yml b/.github/workflows/apptainer-image.yml new file mode 100644 index 0000000..bff3730 --- /dev/null +++ b/.github/workflows/apptainer-image.yml @@ -0,0 +1,41 @@ +name: Apptainer CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + + runs-on: ubuntu-22.04 + + steps: + - uses: actions/checkout@v4 + - uses: eWaterCycle/setup-apptainer@v2 + with: + apptainer-version: 1.1.2 + + - name: Clean disk space + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + sudo apt-get clean + + - name: Build the Apptainer image + run: | + cd container + apptainer build dronalize.sif apptainer.def + + - name: Download Data + run: | + wget -O synthD.zip "https://liuonline-my.sharepoint.com/:u:/g/personal/thewe60_liu_se/EfK1PKrQ3X9LgOd_8TSw1g0BINzDadmTxHF_RHKg_31dGw?e=SEKX9X&download=1" + unzip synthD.zip -d ./data + + - name: Run PyTorch Training Loop + run: | + apptainer run container/dronalize.sif python train.py --config ci diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml new file mode 100644 index 0000000..95a5b4a --- /dev/null +++ b/.github/workflows/conda.yml @@ -0,0 +1,37 @@ +name: Conda CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash -el {0} + steps: + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: dronalize + environment-file: container/environment.yml + python-version: 3.11 + auto-activate-base: false + + - name: Download Data + run: | + wget -O synthD.zip "https://liuonline-my.sharepoint.com/:u:/g/personal/thewe60_liu_se/EfK1PKrQ3X9LgOd_8TSw1g0BINzDadmTxHF_RHKg_31dGw?e=SEKX9X&download=1" + unzip synthD.zip -d ./data + + - name: Run PyTorch Training Loop + run: | + python train.py --config ci + + - name: Run Unit Tests + run: | + python -m pip install pytest + pytest diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml new file mode 100644 index 0000000..6cd1adf --- /dev/null +++ b/.github/workflows/docker-image.yml @@ -0,0 +1,31 @@ +name: Docker CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-22.04 + env: + IMAGE_TAG: ${{ github.run_id }} + steps: + - uses: actions/checkout@v4 + - name: Build the Docker image + run: | + cd container + docker build -f Dockerfile . -t dronalize:${IMAGE_TAG} + + - name: Download Data + run: | + wget -O synthD.zip "https://liuonline-my.sharepoint.com/:u:/g/personal/thewe60_liu_se/EfK1PKrQ3X9LgOd_8TSw1g0BINzDadmTxHF_RHKg_31dGw?e=SEKX9X&download=1" + unzip synthD.zip -d ./data + + - name: Run PyTorch Training Loop + run: | + docker run -v "$(pwd)":/app -w /app dronalize:${IMAGE_TAG} python train.py --config ci + + diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 0000000..da1e770 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,31 @@ +name: Linting + +on: + push: + branches: [ main ] + paths: + - '**.py' + pull_request: + branches: [ main ] + paths: + - '**.py' + +jobs: + mypy: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r .github/workflows/requirements.txt + pip install mypy + - name: Run mypy + run: mypy --ignore-missing-imports . diff --git a/.github/workflows/requirements.txt b/.github/workflows/requirements.txt new file mode 100644 index 0000000..df47412 --- /dev/null +++ b/.github/workflows/requirements.txt @@ -0,0 +1,13 @@ +torch +torchvision +torch_geometric +lightning +torchmetrics +pandas +scikit-learn +matplotlib +tqdm +utm +osmium +networkx +pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..975f7cf --- /dev/null +++ b/.gitignore @@ -0,0 +1,425 @@ +### Folders ### +data/ +lightning_logs/ +wandb/ + +### Apptainer ### +*.sif + +### Slurm ### +slurm*.out + +### Files ### +*.csv +*.zip + +### JupyterNotebooks ### +Untitled* + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +bin/ +share/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +### VirtualEnv ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..70812b7 --- /dev/null +++ b/README.md @@ -0,0 +1,553 @@ + +
+Dronalize + + +______________________________________________________________________ + +[![paper](https://img.shields.io/badge/Paper-arXiv-8A2BE2.svg)](https://arxiv.org/abs/2405.00604) +[![python](https://img.shields.io/badge/Python-3.10%20%7C%203.11-blue.svg)](https://www.python.org/) +[![pytorch](https://img.shields.io/badge/PyTorch-2.2-blue.svg)](https://pytorch.org/) +[![contributions](https://img.shields.io/badge/Contributions-welcome-297D1E)](#contributing) +[![license](https://img.shields.io/badge/License-Apache%202.0-2F2F2F.svg)](LICENSE) +
+[![Docker Status](https://github.com/westny/dronalize/actions/workflows/docker-image.yml/badge.svg)](.github/workflows/docker-image.yml) +[![Apptainer Status](https://github.com/westny/dronalize/actions/workflows/apptainer-image.yml/badge.svg)](.github/workflows/apptainer-image.yml) +[![Conda Status](https://github.com/westny/dronalize/actions/workflows/conda.yml/badge.svg)](.github/workflows/conda.yml) +[![Linting Status](https://github.com/westny/dronalize/actions/workflows/mypy.yml/badge.svg)](.github/workflows/mypy.yml) + +
+ +**Dronalize** is a toolbox designed to alleviate the development efforts of researchers working with the **D**rone datasets from [leveLXData](https://levelxdata.com/) on behavior prediction problems. +It includes tools for data preprocessing, visualization, and evaluation, as well as a model development pipeline for data-driven motion forecasting. +
The toolbox relies heavily on [Pytorch logoPyTorch](https://pytorch.org/docs/stable/index.html), [PyG logoPyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/), and [Lightning logoPyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) for its functionality. + +*** + +#### +*All code in this repository is developed by the authors of the paper and is not an official product of [leveLXData](https://levelxdata.com/). +All inquiries regarding the datasets should be directed to them.* + +*** + +- [Installation](#installation) +- [Usage](#usage) +- [Datasets](#datasets) +- [Related Work](#related-work) +- [Contributing](#contributing) +- [Cite](#cite) + +# Installation +There are several alternatives to installation, depending on your needs and preferences. +Our recommendation and personal preference is to use containers for reproducibility and consistency across different environments. +We have provided both an Apptainer and a Dockerfile for this purpose. +Both recipes use the `mamba` package manager for creating the environments. +In both cases, they utilize an `environment.yml` file that could be used to create a local conda environment if desired. + +### Apptainer logo +[Apptainer](https://apptainer.org/docs/user/main/index.html) is a lightweight containerization tool that we prefer for its simplicity and ease of use. +Once installed, you can build the container by running the following command: + +```bash +apptainer build dronalize.sif /path/to/definition_file +``` + +where `/path/to/definition_file` is the path to the `apptainer.def` file in the repository. +Once built, it is very easy to run the container as it only requires a few extra arguments. +For example, to start the container and execute the `train.py` script, you can run the following command from the repository root directory: + +```bash +apptainer run /path/to/dronalize.sif python train.py +``` + +If you have CUDA installed and want to use GPU acceleration, you can add the `--nv` flag to the `run` command. + +```bash +apptainer run --nv /path/to/dronalize.sif python train.py +``` + +### Docker logo +If you prefer to use [Docker](https://www.docker.com/get-started/), you can build the container by running the following command from the container root directory: + +```bash +docker build -t dronalize . +``` + +This will create a Docker image named `dronalize` with all the necessary dependencies. +To run the container, you can use the following command: + +```bash +docker run -it dronalize +``` + +Note that training using docker requires mounting the data directory to the container. +Example of how this is done from the repository root directory: + +```bash +docker run -v "$(pwd)":/app -w /app dronalize python train.py +``` + +To use GPU acceleration, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) and run the container with the `--gpus all` flag. + +```bash +docker run --gpus all -v "$(pwd)":/app -w /app dronalize python train.py +``` + +### Conda logo +If you prefer to not use containers, you can create a [conda](https://conda.io/projects/conda/en/latest/index.html) environment using the `environment.yml` file. +To create the environment, run the following command: + +```bash +conda env create -f /path/to/environment.yml +``` +or if using [mamba](https://mamba.readthedocs.io/en/latest/) + +```bash +mamba env create -f /path/to/environment.yml +``` + +This will create a new conda environment named `dronalize` with all the necessary dependencies. +Once the environment is created, you can activate it by running: + +```bash +conda activate dronalize +``` + +The environment is now ready to use, and you can run the scripts in the repository. + +
+ +# Usage +The **Dronalize** toolbox is designed for two main purposes: data preprocessing and evaluation of trajectory prediction models. +It was developed with the intention of being used in conjunction with [PyTorch logoPyTorch](https://pytorch.org/docs/stable/index.html); in particular, the [Lightning logoLightning](https://lightning.ai/docs/pytorch/stable/) framework. + +### Preprocessing + +> Before running the preprocessing scripts, make sure to unzip the downloaded datasets and place them in a directory of your choice. +> By default, the scripts expect the datasets to be located in the `../datasets` directory, but this can be changed by specifying the `--path` argument. +> Make sure the unzipped folders are named `highD`, `rounD`, and `inD`, respectively. + +All functionality related to data preprocessing is contained in the `preprocessing` module. +Since the **D**rone datasets have minor differences in their structure, there are two separate scripts for preprocessing depending on the dataset used. +For example, to preprocess the `inD` or `rounD` datasets, you can run the following command (replace `dataset_name` with the respective dataset name): + +```bash +python -m preprocessing.preprocess_urban.py --dataset 'dataset_name' --path 'path/to/datasets' +``` + +while for the `highD` dataset, you should run: + +```bash +python -m preprocessing.preprocess_highway.py --dataset 'highD' --path 'path/to/datasets' +``` + +By default, these script will save the preprocessed data in the `data` directory, this can be changed by specifying the `--output-dir` argument. +There is an option to use threading for faster processing by setting the `--use-threads` flag that we recommend for efficient processing. + +There are additional **default** arguments in the respective configuration files within the `preprocessing/config` directory that should not be changed to facilitate consistency across different studies. +Finally, `preprocess.sh` is a script that can be used to preprocess all datasets in one go using the default arguments that we recommend for consistency. + +```bash +. preprocess.sh +``` + +Using Apptainer, the shell script can be executed as follows: +```bash + apptainer run /path/to/dronalize.sif bash preprocess.sh +``` + +> Depending on the dataset and the number of workers, preprocessing can some time. +> In our experience, preprocessing all datasets takes around 8 hours on a standard workstation with 8 cores and 32 GB of RAM. + + +### Data Loading +In [datamodules](datamodules), you will find the necessary classes for loading the preprocessed data into PyTorch training pipelines. +It includes: +- `DroneDataset`: A `Dataset` class built around `torch_geometric`. Found in: [dataset.py](datamodules/dataset.py) +- `DroneDataModule`: A `DataModule` class, including `Dataloader` built around `lightning.pytorch`. Found in: [dataloader.py](datamodules/dataloader.py) +- `CoordinateTransform` and `CoordinateShift`: Example transformations. Found in: [transforms.py](datamodules/transforms.py) + +> [dataloader.py](datamodules/dataloader.py) is designed to be runnable as a standalone script for quick testing of the data loading pipeline. +It includes a `main` function that can be used to load the data and visualize it for debugging and/or educational purposes. + + +### Modeling +In [models/prototype.py](models/prototype.py), there is a baseline neural network for trajectory prediction. +It is a simple encoder-decoder model that takes as input the past trajectory of a road user and outputs a predicted future trajectory. +It learns interactions between road users by encoding the scene as a graph and uses a GNN to process the data. +The model could be used as a starting point for developing more advanced models, where adding map-aware mechanisms would be a natural next step. + +```python +# prototype.py +import torch +import torch.nn as nn +import torch_geometric.nn as ptg +from torch_geometric.data import HeteroData + + +class Net(nn.Module): + def __init__(self, config: dict) -> None: + super().__init__() + num_inputs = config["num_inputs"] + num_outputs = config["num_outputs"] + num_hidden = config["num_hidden"] + self.ph = config["pred_hrz"] + + self.embed = nn.Linear(num_inputs, num_hidden) + self.encoder = nn.GRU(num_hidden, num_hidden, batch_first=True) + self.interaction = ptg.GATv2Conv(num_hidden, num_hidden, concat=False) + self.decoder = nn.GRU(num_hidden, num_hidden, batch_first=True) + self.output = nn.Linear(num_hidden, num_outputs) + + def forward(self, data: HeteroData) -> torch.Tensor: + edge_index = data['agent']['edge_index'] + x = torch.cat([data['agent']['inp_pos'], + data['agent']['inp_vel'], + data['agent']['inp_yaw']], dim=-1) + + x = self.embed(x) + _, h = self.encoder(x) + x = h[-1] + + x = self.interaction(x, edge_index) + x = x.unsqueeze(1).repeat(1, self.ph, 1) + x, _ = self.decoder(x) + + pred = self.output(x) + + return pred +``` + +### Model Training +The toolbox includes a training script, [train.py](train.py), that can be used to train your models on the preprocessed data. +The script is designed to be run from the repository root directory and includes several arguments that can be used to configure the training process. +By default, it uses configuration files in `.json` format found in the [configs](configs) directory, detailing the required modules and hyperparameters for training. +Additional runtime arguments, such as the number of workers, GPU acceleration, debug mode, and model checkpointing, can be specified when running the script (see [arguments.py](arguments.py) for more information). + +The training script is designed to be used with PyTorch Lightning; besides using the custom data modules previously mentioned, it also requires a `LightningModule` that defines the model and training loop. +In [base.py](base.py), you will find a base class that can be modified to build your own `LightningModule`. +In its current form, it can be used to train and evaluate the baseline model. +It also details how to use the proposed evaluation metrics for trajectory prediction. + +An example of how to train the model is shown below: +```bash + [apptainer run --nv path/to/dronalize.sif] python train.py --add-name test --dry-run 0 --use-cuda 1 --n-workers 4 +``` + +We recommend users to modify the default arguments in [arguments.py](arguments.py) to suit their needs. + +> Note that the default logger is set to `wandb` ([weights & biases](https://wandb.ai/)) for logging performance metrics during training. +> It is our preferred tool for tracking experiments, but it can be easily replaced with other logging tools by modifying the `Trainer` in the training script. +> +> See the official [Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for more information on customizing training behavior and how to use the library in general. + +### Metrics +The toolbox includes several evaluation metrics for trajectory prediction, implemented in the [`metrics`](metrics) module. +The metrics are designed to handle both uni- and multi-modal predictions. +Predictions are expected to be in the form of `(batch_size, num_timesteps, 2)` or `(batch_size, num_timesteps, num_modes, 2)`, where `num_modes` is the number of modes in the prediction. +> There is also support for mode-first predictions of shape `(batch_size, num_modes, num_timesteps, 2)` that can be used by setting the `mode_first` flag to `True`. +> Users can of course change the default behavior by directly modifying the metrics. + +Most metrics are also compatible with specifying a `min_criterion` (`FDE`, `ADE`, `MAP`) that is used to select which of the modes to evaluate against the ground-truth target (Default: `FDE`). +Setting `min_criterion` to `MAP` will evaluate the metrics based on the mode with the highest predicted probability. +Note that `MAP` can only be used in conjunction with the optional argument `Prob` of shape `(batch_size, num_modes)` representing the weights of each mode. + +The following metrics are implemented: +- [**Min. Average Displacement Error (minADE)**](metrics/min_ade.py): +- [**Min. Final Displacement Error (minFDE)**](metrics/min_fde.py) +- [**Min. Average Path Displacement Error (minAPDE)**](metrics/min_apde.py) +- [**Miss Rate**](metrics/miss_rate.py) +- [**Collision Rate**](metrics/collision_rate.py) +- [**Min. Brier**](metrics/min_brier.py) +- [**Negative Log-Likelihood (NLL)**](metrics/log_likelihood.py) + +For their mathematical definitions, please refer to the paper. + +
+ +# Datasets + +The toolbox has been developed for use of the *[highD](https://levelxdata.com/highd-dataset/)*, *[rounD](https://levelxdata.com/round-dataset/)*, and *[inD](https://levelxdata.com/ind-dataset/)* datasets. +The datasets contain recorded trajectories from different locations in Germany, including various highways, roundabouts, and intersections. +Their high quality and reliability make them particularly suitable for early-stage research and development. +They are freely available for non-commercial use, which is our targeted audience, but require applying for usage through the links: +- *[highD](https://levelxdata.com/highd-dataset/)* +- *[rounD](https://levelxdata.com/round-dataset/)* +- *[inD](https://levelxdata.com/ind-dataset/)* + +> Several datasets in the leveLXData suite were recently updated (April 2024) that include improvements to the maps, as well as the addition of some new locations. +> This toolbox is designed to work with the updated datasets, and we recommend using the latest versions for the latest features and to avoid having to modify the toolbox. +> +> We found that the toolbox works with the *[uniD](https://levelxdata.com/unid-dataset/)* dataset with minor adjustments, but we have yet evaluated it in detail. +> +> We are working on adding support for the *[exiD](https://levelxdata.com/exid-dataset/)* dataset, that we aim to include in future versions of the toolbox. + +*** + +### *[highD](https://arxiv.org/abs/1810.05642)*: The Highway Drone Dataset + +
+ Abstract +

+ Scenario-based testing for the safety validation of + highly automated vehicles is a promising approach that is being + examined in research and industry. This approach heavily relies + on data from real-world scenarios to derive the necessary + scenario information for testing. Measurement data should be + collected at a reasonable effort, contain naturalistic behavior of + road users and include all data relevant for a description of the + identified scenarios in sufficient quality. However, the current + measurement methods fail to meet at least one of the + requirements. Thus, we propose a novel method to measure data + from an aerial perspective for scenario-based validation + fulfilling the mentioned requirements. Furthermore, we provide + a large-scale naturalistic vehicle trajectory dataset from German + highways called highD. We evaluate the data in terms of + quantity, variety and contained scenarios. Our dataset consists + of 16.5 hours of measurements from six locations with 110 000 + vehicles, a total driven distance of 45 000 km and 5600 recorded + complete lane changes. +

+
+ +
+ Bibtex + + @inproceedings{highDdataset, + title={The highD Dataset: A Drone Dataset of Naturalistic Vehicle Trajectories on German Highways for Validation of Highly Automated Driving Systems}, + author={Krajewski, Robert and Bock, Julian and Kloeker, Laurent and Eckstein, Lutz}, + booktitle={2018 21st International Conference on Intelligent Transportation Systems (ITSC)}, + pages={2118-2125}, + year={2018}, + doi={10.1109/ITSC.2018.8569552} + } +
+ +> #### Dataset Overview +> - Naturalistic trajectory dataset on six different recording locations +> - In total 110 500 vehicles +> - Road users classes: car, trucks + +
+ highD.gif +
+ +*** + +### *[rounD](https://ieeexplore.ieee.org/document/9294728)*: The Roundabouts Drone Dataset + +
+ Abstract +

+ The development and validation of automated vehicles involves a large number of challenges to be overcome. + Due to the high complexity, many classic approaches quickly reach their limits and data-driven methods become necessary. + This creates an unavoidable need for trajectory datasets of road users in all relevant traffic scenarios. + As these trajectories should include naturalistic and diverse behavior, they have to be recorded in public traffic. + Roundabouts are particularly interesting because of the density of interaction between road users, which must be considered by an automated vehicle for behavior planning. + We present a new dataset of road user trajectories at roundabouts in Germany. + Using a camera-equipped drone, traffic at a total of three different roundabouts in Germany was recorded. + The tracks consisting of positions, headings, speeds, accelerations and classes of objects were extracted from recorded videos using deep neural networks. + The dataset contains a total of six hours of recordings with more than 13 746 road users including cars, vans, trucks, buses, pedestrians, bicycles and motorcycles. + In order to make the dataset as accessible as possible for tasks like scenario classification, road user behavior prediction or driver modeling, we provide source code for parsing and visualizing the dataset as well as maps of the recording sites. +

+
+ +
+ Bibtex + + @inproceedings{rounDdataset, + title={The rounD Dataset: A Drone Dataset of Road User Trajectories at Roundabouts in Germany}, + author={Krajewski, Robert and Moers, Tobias and Bock, Julian and Vater, Lennart and Eckstein, Lutz}, + booktitle={2020 IEEE 23rd International Conference on Intelligent Transportation Systems (ITSC)}, + pages={1-6}, + year={2020}, + doi={10.1109/ITSC45102.2020.9294728} + } +
+ +> #### Dataset Overview +> - Naturalistic trajectory dataset on three different recording locations +> - In total ~ 13 740 road users +> - Road users classes: car, van, trailer, truck, bus, pedestrians, bicyclists, motorcyclists + +
+ rounD.gif +
+ +*** + +### *[inD](https://arxiv.org/abs/1911.07602)*: The Intersections Drone Dataset + +
+ Abstract +

+ Automated vehicles rely heavily on data-driven + methods, especially for complex urban environments. Large + datasets of real world measurement data in the form of road + user trajectories are crucial for several tasks like road user + prediction models or scenario-based safety validation. So far, + though, this demand is unmet as no public dataset of urban + road user trajectories is available in an appropriate size, quality + and variety. By contrast, the highway drone dataset (highD) has + recently shown that drones are an efficient method for acquiring + naturalistic road user trajectories. Compared to driving studies + or ground-level infrastructure sensors, one major advantage of + using a drone is the possibility to record naturalistic behavior, + as road users do not notice measurements taking place. Due to + the ideal viewing angle, an entire intersection scenario can be + measured with significantly less occlusion than with sensors at + ground level. Both the class and the trajectory of each road + user can be extracted from the video recordings with high + precision using state-of-the-art deep neural networks. Therefore, + we propose the creation of a comprehensive, large-scale urban + intersection dataset with naturalistic road user behavior using + camera-equipped drones as successor of the highD dataset. The + resulting dataset contains more than 11500 road users including + vehicles, bicyclists and pedestrians at intersections in Germany + and is called inD. The dataset consists of 10 hours of measurement + data from four intersections. +

+
+ +
+ Bibtex + + @inproceedings{inDdataset, + title={The inD Dataset: A Drone Dataset of Naturalistic Road User Trajectories at German Intersections}, + author={Bock, Julian and Krajewski, Robert and Moers, Tobias and Runde, Steffen and Vater, Lennart and Eckstein, Lutz}, + booktitle={2020 IEEE Intelligent Vehicles Symposium (IV)}, + pages={1929-1934}, + year={2020}, + doi={10.1109/IV47402.2020.9304839} + } +
+ +> #### Dataset Overview +> - Naturalistic trajectory dataset on four different recording locations +> - In total ~ 8200 vehicles and ~ 5300 vulnerable road users (VRUs) +> - Road users classes: car, truck/bus, pedestrians, bicyclists + +
+ inD.gif +
+ + +
+ +# Related work +We have been working with the **D**rone datasets in several research projects, resulting in multiple published papers focused on behavior prediction. +If you're interested in learning more about our findings, please refer to the following publications: + +#### [Diffusion-Based Environment-Aware Trajectory Prediction](https://arxiv.org/abs/2403.11643) +- **Authors:** Theodor Westny, Björn Olofsson, and Erik Frisk +- **Published In:** Manuscript submitted for publication + +
+ Abstract +

+ The ability to predict the future trajectories of traffic participants is crucial for the safe and efficient operation of autonomous vehicles. + In this paper, a diffusion-based generative model for multi-agent trajectory prediction is proposed. + The model is capable of capturing the complex interactions between traffic participants and the environment, accurately learning the multimodal nature of the data. + The effectiveness of the approach is assessed on large-scale datasets of real-world traffic scenarios, showing that our model outperforms several well-established methods in terms of prediction accuracy. + By the incorporation of differential motion constraints on the model output, we illustrate that our model is capable of generating a diverse set of realistic future trajectories. + Through the use of an interaction-aware guidance signal, we further demonstrate that the model can be adapted to predict the behavior of less cooperative agents, emphasizing its practical applicability under uncertain traffic conditions. +

+
+ +
+ Bibtex + + @article{westny2024diffusion, + title={Diffusion-Based Environment-Aware Trajectory Prediction}, + author={Westny, Theodor and Olofsson, Bj{\"o}rn and Frisk, Erik}, + journal={arXiv preprint arXiv:2403.11643}, + year={2024} + } +
+ +#### [MTP-GO: Graph-Based Probabilistic Multi-Agent Trajectory Prediction with Neural ODEs](https://arxiv.org/abs/2302.00735) +- **Authors:** Theodor Westny, Joel Oskarsson, Björn Olofsson, and Erik Frisk +- **Published In:** 2023 IEEE Transactions on Intelligent Vehicles, Vol. 8, No. 9 + +
+ Abstract +

+ Enabling resilient autonomous motion planning requires robust predictions of surrounding road users' future behavior. + In response to this need and the associated challenges, we introduce our model titled MTP-GO. + The model encodes the scene using temporal graph neural networks to produce the inputs to an underlying motion model. + The motion model is implemented using neural ordinary differential equations where the state-transition functions are learned with the rest of the model. + Multimodal probabilistic predictions are obtained by combining the concept of mixture density networks and Kalman filtering. + The results illustrate the predictive capabilities of the proposed model across various data sets, outperforming several state-of-the-art methods on a number of metrics. +

+
+ +
+ Bibtex + + @article{westny2023mtp, + title="{MTP-GO}: Graph-Based Probabilistic Multi-Agent Trajectory Prediction with Neural {ODEs}", + author={Westny, Theodor and Oskarsson, Joel and Olofsson, Bj{\"o}rn and Frisk, Erik}, + journal={IEEE Transactions on Intelligent Vehicles}, + year={2023}, + volume={8}, + number={9}, + pages={4223-4236}, + doi={10.1109/TIV.2023.3282308}} + } +
+ + +#### [Evaluation of Differentially Constrained Motion Models for Graph-Based Trajectory Prediction](https://arxiv.org/abs/2304.05116) +- **Authors:** Theodor Westny, Joel Oskarsson, Björn Olofsson, and Erik Frisk +- **Published In:** In 2023 IEEE Intelligent Vehicles Symposium (IV) + +
+ Abstract +

+ Given their flexibility and encouraging performance, deep-learning models are becoming standard for motion prediction in autonomous driving. + However, with great flexibility comes a lack of interpretability and possible violations of physical constraints. + Accompanying these data-driven methods with differentially-constrained motion models to provide physically feasible trajectories is a promising future direction. + The foundation for this work is a previously introduced graph-neural-network-based model, MTP-GO. + The neural network learns to compute the inputs to an underlying motion model to provide physically feasible trajectories. + This research investigates the performance of various motion models in combination with numerical solvers for the prediction task. + The study shows that simpler models, such as low-order integrator models, are preferred over more complex, e.g., kinematic models, to achieve accurate predictions. + Further, the numerical solver can have a substantial impact on performance, advising against commonly used first-order methods like Euler forward. + Instead, a second-order method like Heun’s can greatly improve predictions. +

+
+ +
+ Bibtex + + @inproceedings{westny2023eval, + title={Evaluation of Differentially Constrained Motion Models for Graph-Based Trajectory Prediction}, + author={Westny, Theodor and Oskarsson, Joel and Olofsson, Bj{\"o}rn and Frisk, Erik}, + booktitle={IEEE Intelligent Vehicles Symposium (IV)}, + pages={}, + year={2023}, + doi={10.1109/IV55152.2023.10186615} + } +
+ +## Contributing +We welcome contributions to the toolbox, and we encourage you to submit pull requests with new features, bug fixes, or improvements. +Any form of collaboration is appreciated, and we are open to suggestions for new features or changes to the existing codebase. +Please direct your inquiries to the authors of the paper. + +## Cite +If you use the toolbox in your research, please consider citing the paper: + +``` +@article{westny2024dronalize, + title={A Preprocessing and Evaluation Toolbox for Trajectory Prediction Research on the Drone Datasets}, + author={Westny, Theodor and Olofsson, Bj{\"o}rn and Frisk, Erik}, + journal={arXiv preprint arXiv:2405.00604}, + year={2024} +} +``` + +Feel free [email us](mailto:theodor.westny@liu.se) if you have any questions or notice any issues with the toolbox. +If you have any suggestions for improvements or new features, we would be happy to hear from you. + +## License +This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. diff --git a/arguments.py b/arguments.py new file mode 100644 index 0000000..a58d1cc --- /dev/null +++ b/arguments.py @@ -0,0 +1,73 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser, ArgumentTypeError + + +def str_to_bool(value: bool | str) -> bool: + """Used for boolean arguments in argparse; avoiding `store_true` and `store_false`.""" + true_vals = ("yes", "true", "t", "y", "1") + false_vals = ("no", "false", "f", "n", "0") + if isinstance(value, bool): + return value + if value.lower() in true_vals: + return True + if value.lower() in false_vals: + return False + raise ArgumentTypeError('Boolean value expected.') + + +parser = ArgumentParser(description='Dronalize learning arguments') + + +# Program arguments +parser.add_argument('--seed', type=int, default=42, + help='random seed (default: 42)') +parser.add_argument('--use-logger', type=str_to_bool, default=False, + const=True, nargs="?", + help='if logger should be used (default: False)') +parser.add_argument('--use-cuda', type=str_to_bool, default=False, + const=True, nargs="?", + help='if cuda exists and should be used (default: False)') +parser.add_argument('--num-workers', type=int, default=1, + help='number of workers in dataloader (default: 1)') +parser.add_argument('--pin-memory', type=str_to_bool, default=True, + const=True, nargs="?", + help='if pin memory should be used (default: True)') +parser.add_argument('--persistent-workers', type=str_to_bool, default=True, + const=True, nargs="?", + help='if persistent workers should be used (default: True)') +parser.add_argument('--store-model', type=str_to_bool, default=False, + const=True, nargs="?", + help='if checkpoints should be stored (default: False)') +parser.add_argument('--store-samples', type=str_to_bool, default=False, + const=True, nargs="?", + help='if samples should be stored (default: False)') +parser.add_argument('--overwrite', type=str_to_bool, default=False, + const=True, nargs="?", + help='overwrite if model exists (default: False)') +parser.add_argument('--add-name', type=str, default="", + help='additional string to add to save name') +parser.add_argument('--dry-run', type=str_to_bool, default=True, + const=True, nargs="?", + help='verify the code and the model (default: True)') +parser.add_argument('--small-ds', type=str_to_bool, default=False, + const=True, nargs="?", + help='Use tiny versions of dataset for fast testing (default: False)') +parser.add_argument('--config', type=str, default="example", + help='config file path for experiment (default: example)') +parser.add_argument('--root', type=str, default="", + help='root path for dataset (default: "")') + +args = parser.parse_args() diff --git a/base.py b/base.py new file mode 100644 index 0000000..7730dea --- /dev/null +++ b/base.py @@ -0,0 +1,110 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import lightning.pytorch as pl + +from torch import nn +from torch_geometric.data import HeteroData +from torch_geometric.nn import radius, knn_graph +from torch_geometric.utils import to_undirected + +from metrics import MinADE, MinFDE, MinAPDE, MissRate, CollisionRate + + +class LitModel(pl.LightningModule): + def __init__(self, model: nn.Module, config: dict, **kwargs) -> None: + super().__init__() + self.model = model + self.dataset = config["dataset"] + self.max_epochs = config["epochs"] + self.learning_rate = config["lr"] + + self.save_hyperparameters(ignore=['model']) + + self.min_ade = MinADE() + self.min_fde = MinFDE() + self.min_apde = MinAPDE() + self.mr = MissRate() + self.cr = CollisionRate() + + def post_process(self, data: HeteroData) -> HeteroData: + pos = data['agent']['inp_pos'][:, -1] + map_pos = data['map_point']['position'] + + agent_batch = data['agent']['batch'] + map_batch = data['map_point']['batch'] + + edge_index_a2a = knn_graph(x=pos, k=8, batch=agent_batch, loop=True) + edge_index_a2a = to_undirected(edge_index_a2a) + + edge_index_m2a = radius(x=pos, y=map_pos, r=20, batch_x=agent_batch, + batch_y=map_batch, max_num_neighbors=8) + + data['agent']['edge_index'] = edge_index_a2a + data['map', 'to', 'agent']['edge_index'] = edge_index_m2a + return data + + def forward(self, data: HeteroData) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + data = self.post_process(data) + valid_mask = data['agent']['valid_mask'] + trg = data['agent']['trg_pos'] + + pred = self.model(data) + + num_valid_steps = valid_mask.sum(-1) + + norm = torch.linalg.norm(pred - trg, dim=-1) + + masked_norm = norm * valid_mask + + scored_agents = num_valid_steps > 0 + + summed_loss = masked_norm[scored_agents].sum(-1) / num_valid_steps[scored_agents] + + loss = summed_loss.mean() + return loss, pred, trg + + def training_step(self, data: HeteroData) -> torch.Tensor: + loss, _, trg = self(data) + + self.log("train_loss", loss, on_step=False, on_epoch=True, + batch_size=trg.size(0), prog_bar=True) + return loss + + def validation_step(self, data: HeteroData) -> None: + ma_mask = data['agent']['ma_mask'] + ptr = data['agent']['ptr'] + + loss, pred, trg = self(data) + + self.min_ade.update(pred, trg, mask=ma_mask) + self.min_fde.update(pred, trg, mask=ma_mask) + self.min_apde.update(pred, trg, mask=ma_mask) + self.mr.update(pred, trg, mask=ma_mask) + self.cr.update(pred, trg, ptr, mask=ma_mask) + + metric_dict = {"val_loss": loss, + "val_min_ade": self.min_ade, + "val_min_fde": self.min_fde, + "val_min_apde": self.min_apde, + "val_mr": self.mr, + "val_cr": self.cr} + + self.log_dict(metric_dict, on_step=False, on_epoch=True, + batch_size=trg.size(0), prog_bar=True) + + def configure_optimizers(self) -> torch.optim.Optimizer: + optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) + return optimizer diff --git a/configs/rounD/example.json b/configs/rounD/example.json new file mode 100644 index 0000000..61a35a4 --- /dev/null +++ b/configs/rounD/example.json @@ -0,0 +1,29 @@ +{ + "experiment_name": "Example setup", + "task": "motion forecasting", + "dataset": "rounD", + "model": { + "class": "Net", + "module": "models.prototype", + "num_inputs": 5, + "num_hidden": 32, + "num_outputs": 2, + "pred_hrz": 25 + }, + "litmodule": { + "class": "LitModel", + "module": "base" + }, + "datamodule": { + "class": "DroneDataModule", + "module": "datamodules.dataloader", + "batch_size": 128, + "root": "data/", + "name": "rounD" + }, + "training": { + "dataset": "rounD", + "epochs": 100, + "lr": 5e-4 + } +} \ No newline at end of file diff --git a/configs/synthD/ci.json b/configs/synthD/ci.json new file mode 100644 index 0000000..7bb67a6 --- /dev/null +++ b/configs/synthD/ci.json @@ -0,0 +1,29 @@ +{ + "experiment_name": "Example setup", + "task": "motion forecasting", + "dataset": "synthD", + "model": { + "class": "Net", + "module": "models.prototype", + "num_inputs": 5, + "num_hidden": 32, + "num_outputs": 2, + "pred_hrz": 25 + }, + "litmodule": { + "class": "LitModel", + "module": "base" + }, + "datamodule": { + "class": "DroneDataModule", + "module": "datamodules.dataloader", + "batch_size": 16, + "root": "data/", + "name": "synthD" + }, + "training": { + "dataset": "synthD", + "epochs": 100, + "lr": 5e-4 + } +} diff --git a/container/Dockerfile b/container/Dockerfile new file mode 100644 index 0000000..869222a --- /dev/null +++ b/container/Dockerfile @@ -0,0 +1,37 @@ +# Use an official CUDA base image from NVIDIA +FROM nvidia/cuda:12.1.0-base-ubuntu22.04 + +# Set environment variables +ENV PATH /opt/mambaforge/bin:$PATH +ENV PYTHONNOUSERSITE 1 + +# Run updates and install necessary packages +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + nano \ + wget \ + curl \ + ca-certificates && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install Mambaforge +RUN curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" && \ + bash Mambaforge-$(uname)-$(uname -m).sh -b -p /opt/mambaforge && \ + rm Mambaforge-$(uname)-$(uname -m).sh + +# Copy your environment.yml into the Docker image +COPY environment.yml /opt/environment.yml + +# Create the Conda environment using Mamba +RUN /opt/mambaforge/bin/mamba env create -f /opt/environment.yml + +# Clean up conda packages to reduce the container size +RUN /opt/mambaforge/bin/mamba clean -a --yes + +# Make RUN commands use the new environment: +SHELL ["mamba", "run", "-n", "dronalize", "/bin/bash", "-c"] + +# The code to run when container is started +ENTRYPOINT ["mamba", "run", "--no-capture-output", "-n", "dronalize"] +CMD ["python", "--version"] diff --git a/container/apptainer.def b/container/apptainer.def new file mode 100644 index 0000000..119b3dd --- /dev/null +++ b/container/apptainer.def @@ -0,0 +1,59 @@ +Bootstrap: docker +From: ubuntu:22.04 + +%files + environment.yml /opt/environment.yml + +%environment +export PATH="/opt/mambaforge/bin:$PATH" +export PYTHONNOUSERSITE=1 + +%post + +# Install necessary system packages +apt-get update && apt-get install -y --no-install-recommends \ + git \ + nano \ + wget \ + curl \ + ca-certificates && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install Mambaforge +cd /tmp +curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" +bash Mambaforge-$(uname)-$(uname -m).sh -fp /opt/mambaforge -b +rm Mambaforge*sh + +export PATH=/opt/mambaforge/bin:$PATH + +# Activate mamba and create the environment +. /opt/mambaforge/etc/profile.d/conda.sh +mamba env create -f /opt/environment.yml + +# Clean up conda packages to reduce the container size +mamba clean -a --yes + +%runscript +. /opt/mambaforge/etc/profile.d/conda.sh +conda activate dronalize +exec "$@" + +%help +This is a container for the Dronalize project. It contains all the necessary dependencies to run the project. +To run the project, you can use the following command: +``` +apptainer run dronalize.sif +``` +where `` is the command you want to run and `` are the arguments for the command. + +For example, to run the train script, you can use the following command: +``` +apptainer run dronalize.sif python train.py +``` + +To enable GPU support, you can use the following command: +``` +apptainer run --nv dronalize.sif python train.py +``` \ No newline at end of file diff --git a/container/environment.yml b/container/environment.yml new file mode 100644 index 0000000..d61e29f --- /dev/null +++ b/container/environment.yml @@ -0,0 +1,27 @@ +name: dronalize +channels: + - conda-forge + - pytorch + - nvidia + - pyg +dependencies: + - python=3.11 + - pip + - pytorch=2.2 + - torchvision + - pytorch-cuda=12.1 + - lightning + - wandb + - scikit-learn + - tqdm + - seaborn + - matplotlib + - pandas + - utm + - pyarrow + - pyg + - pytorch-cluster + - pytorch-scatter + - pip: + - osmium + - networkx diff --git a/datamodules/__init__.py b/datamodules/__init__.py new file mode 100644 index 0000000..7d0090b --- /dev/null +++ b/datamodules/__init__.py @@ -0,0 +1,3 @@ +from datamodules.dataloader import DroneDataModule +from datamodules.dataset import DroneDataset +from datamodules.transforms import CoordinateTransform, CoordinateShift diff --git a/datamodules/dataloader.py b/datamodules/dataloader.py new file mode 100644 index 0000000..4126d0c --- /dev/null +++ b/datamodules/dataloader.py @@ -0,0 +1,165 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from argparse import Namespace + +import torch +import numpy as np +import lightning.pytorch as pl +import matplotlib.pyplot as plt + +from matplotlib import colors +from matplotlib.collections import LineCollection + +from lightning.pytorch import LightningDataModule +from torch_geometric.utils import subgraph +from torch_geometric.data import Dataset +from torch_geometric.loader import DataLoader +from datamodules.dataset import DroneDataset +from datamodules.transforms import CoordinateShift + + +class DroneDataModule(LightningDataModule): + train: Dataset = None + val: Dataset = None + test: Dataset = None + + def __init__(self, + config: dict, + args: Namespace) -> None: + super().__init__() + self.root = config["root"] + self.dataset = config["name"] + self.batch_size = config["batch_size"] + self.transform = CoordinateShift() + + self.small_data = args.small_ds + self.num_workers = args.num_workers + self.pin_memory = args.pin_memory + self.persistent_workers = args.persistent_workers + + def setup(self, stage: Optional[str] = None) -> None: + self.train = DroneDataset(root=self.root, dataset=self.dataset, split='train', + transform=self.transform, small_data=self.small_data) + self.val = DroneDataset(root=self.root, dataset=self.dataset, split='val', + transform=self.transform, small_data=self.small_data) + self.test = DroneDataset(root=self.root, dataset=self.dataset, split='test', + transform=self.transform, small_data=self.small_data) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.val, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + + +if __name__ == "__main__": + pl.seed_everything(42) + + + def get_segments(pos, color): + linefade = colors.to_rgb(color) + (0.0,) + myfade = colors.LinearSegmentedColormap.from_list('my', [linefade, color]) + alphas = np.clip(np.exp(np.linspace(0, 1, pos.shape[0] - 1)) - 0.6, 0, 1) + tmp = pos[:, :2][:, None, :] + segments = np.hstack((tmp[:-1], tmp[1:])) + return segments, alphas, myfade + + + config = {'root': '../data', 'name': 'rounD', 'batch_size': 32} + args = Namespace(small_ds=False, num_workers=0, pin_memory=False, persistent_workers=False) + dm = DroneDataModule(config, args) + dm.setup() + + gen = iter(dm.train_dataloader()) + data = next(gen) + + BATCH_IDX = 24 # 24 is used to create Fig. 1 in the paper + + batch = data['agent']['batch'] == BATCH_IDX + pos = data['agent']['inp_pos'][batch] + heading = data['agent']['inp_yaw'][batch] + pos_eq_zero = pos == 0 + pos_eq_zero[0] = False + pos[pos_eq_zero] = float("nan") + + gt = data['agent']['trg_pos'][batch] + gt[gt == 0] = float("nan") + + valid_mask = data['agent']['valid_mask'][batch] + ma_mask = data['agent']['ma_mask'][batch] + ma_idx = torch.where(ma_mask[:, 0])[0] + + map_batch = data['map_point']['batch'] == BATCH_IDX + map_pos = data['map_point']['position'][map_batch] + map_type = data['map_point']['type'][map_batch] + map_edge_index = data['map_point', 'to', 'map_point']['edge_index'] + map_edge_type = data['map_point', 'to', 'map_point']['type'] + + map_edge_index, map_edge_type = subgraph(map_batch, map_edge_index, + map_edge_type, relabel_nodes=True) + + # + for i in range(map_edge_index.shape[1]): + if map_edge_type[i] == 2: + edge = map_edge_index[:, i] + plt.plot(map_pos[edge, 0], map_pos[edge, 1], color='gray', lw=1, + zorder=1, alpha=.9, linestyle='solid') + + elif map_edge_type[i] == 1: + edge = map_edge_index[:, i] + plt.plot(map_pos[edge, 0], map_pos[edge, 1], color='darkgray', lw=0.5, + zorder=0, alpha=.6, linestyle=(0, (5, 10))) + + ax = plt.gca() + + COLOR = 'tab:red' + for i in range(pos.shape[0]): + if i == 0: + COLOR = 'tab:blue' + elif i in ma_idx: + COLOR = 'tab:green' + else: + COLOR = 'tab:red' + + segments, alphas, myfade = get_segments(pos[i], COLOR) + lc = LineCollection(segments, array=alphas, cmap=myfade, lw=5, zorder=0) + line = ax.add_collection(lc) + plt.plot(gt[i, :, 0], gt[i, :, 1], c=COLOR, marker='.', markersize=10, lw=2, alpha=0.3) + + ax.set_aspect('equal') + ax.set_xlim(-50, 200) + ax.set_ylim(-30, 35) + plt.axis('off') + plt.tight_layout() + plt.show() + + print(data) diff --git a/datamodules/dataset.py b/datamodules/dataset.py new file mode 100644 index 0000000..8f0c3d0 --- /dev/null +++ b/datamodules/dataset.py @@ -0,0 +1,53 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +from typing import Optional, Callable +from torch_geometric.data import Dataset +from torch_geometric.data import HeteroData + + +class DroneDataset(Dataset): + def __init__(self, + root: str, + dataset: str, + split: str, + transform: Optional[Callable] = None, + small_data: bool = False) -> None: + super().__init__(root=root, transform=transform, pre_transform=None, pre_filter=None) + assert split in ['train', 'val', 'test'], 'Split must be one of [train, val, test]' + + self.root = root + self.dataset = dataset + self.split = split + self.path = os.path.join(self.root, self.dataset, self.split) + self.files = os.listdir(self.path) + + if small_data: + self.files = self.files[:100] + + self._num_samples = len(self.files) + + def len(self) -> int: + return self._num_samples + + def get(self, idx: int) -> HeteroData: + with open(os.path.join(self.path, self.files[idx]), 'rb') as f: + return HeteroData(pickle.load(f)) + + +if __name__ == "__main__": + ds = DroneDataset(root='data', dataset='highD', split='test') + print(len(ds)) diff --git a/datamodules/transforms.py b/datamodules/transforms.py new file mode 100644 index 0000000..8143e13 --- /dev/null +++ b/datamodules/transforms.py @@ -0,0 +1,102 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch_geometric.data import HeteroData +from torch_geometric.transforms import BaseTransform + + +class CoordinateTransform(BaseTransform): + """ + Transform the coordinates of the agents and map points + to be relative to the last position of the TA. + """ + + def __call__(self, data: HeteroData) -> HeteroData: + hist_pos = data['agent']['inp_pos'] + hist_vel = data['agent']['inp_vel'] + hist_ori = data['agent']['inp_yaw'] + + fut_pos = data['agent']['trg_pos'] + fut_vel = data['agent']['trg_vel'] + fut_ori = data['agent']['trg_yaw'] + + map_pos = data['map_point']['position'] + + ta_index = data['agent']['ta_index'] + ta_pos = hist_pos[ta_index] + ta_ori = hist_ori[ta_index] + + # Get the last observed states + origin = ta_pos[-1].unsqueeze(0) + ori = ta_ori[-1] + + rot_mat_t = torch.tensor([[torch.cos(ori), -torch.sin(ori)], + [torch.sin(ori), torch.cos(ori)]]) + # rot_mat_t = torch.tensor([[torch.cos(ori), torch.sin(ori)], + # [-torch.sin(ori), torch.cos(ori)]]) + + hist_mask = hist_pos != 0 + fut_mask = fut_pos != 0 + + n_hist_pos = (hist_pos - origin) @ rot_mat_t * hist_mask + n_hist_vel = hist_vel @ rot_mat_t * hist_mask + n_hist_ori = torch.atan2(torch.sin(hist_ori - ori), torch.cos(hist_ori - ori)) + + n_fut_pos = (fut_pos - origin) @ rot_mat_t * fut_mask + n_fut_vel = fut_vel @ rot_mat_t * fut_mask + n_fut_ori = torch.atan2(torch.sin(fut_ori - ori), torch.cos(fut_ori - ori)) + + n_map_pos = (map_pos - origin) @ rot_mat_t + + data['agent']['inp_pos'] = n_hist_pos + data['agent']['inp_vel'] = n_hist_vel + data['agent']['inp_yaw'] = n_hist_ori + + data['agent']['trg_pos'] = n_fut_pos + data['agent']['trg_vel'] = n_fut_vel + data['agent']['trg_yaw'] = n_fut_ori + + data['map_point']['position'] = n_map_pos + + return data + + +class CoordinateShift(BaseTransform): + """ + Shifts the origin of the global coordinate system to be in the last position of the TA. + """ + + def __call__(self, data: HeteroData) -> HeteroData: + hist_pos = data['agent']['inp_pos'] + fut_pos = data['agent']['trg_pos'] + map_pos = data['map_point']['position'] + + ta_index = data['agent']['ta_index'] + ta_pos = hist_pos[ta_index] + + origin = ta_pos[-1].unsqueeze(0) + + hist_mask = hist_pos != 0 + fut_mask = fut_pos != 0 + + n_hist_pos = (hist_pos - origin) * hist_mask + n_fut_pos = (fut_pos - origin) * fut_mask + n_map_pos = map_pos - origin + + data['agent']['inp_pos'] = n_hist_pos + data['agent']['trg_pos'] = n_fut_pos + data['map_point']['position'] = n_map_pos + + return data diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..ed0e636 --- /dev/null +++ b/metrics/__init__.py @@ -0,0 +1,8 @@ +from metrics.min_ade import MinADE +from metrics.min_fde import MinFDE +from metrics.miss_rate import MissRate +from metrics.min_brier import MinBrier +from metrics.log_likelihood import NegativeLogLikelihood +from metrics.min_apde import MinAPDE +from metrics.collision_rate import CollisionRate +from metrics.utils import filter_prediction diff --git a/metrics/collision_rate.py b/metrics/collision_rate.py new file mode 100644 index 0000000..1e740c0 --- /dev/null +++ b/metrics/collision_rate.py @@ -0,0 +1,81 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import comb +from typing import Optional +import torch +from torchmetrics import Metric +from metrics.utils import filter_prediction + + +class CollisionRate(Metric): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + ptr: torch.Tensor, + prob: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + best_idx: Optional[torch.Tensor] = None, + collision_criterion: str = 'FDE', + collision_threshold: float = 1.0, + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) or (N, T, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: ptr: The pointer tensor to indicate which agents are in the same scene. (batch_size) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: best_idx: The index of the best prediction. (N,) (to avoid recomputing it) + :param: collision_criterion: Either 'FDE', 'ADE', or 'MAP'. + :param: collision_threshold: The collision threshold in meters. + :param: mode_first: Whether the mode is the first dimension. (default: False) + """ + + assert pred.dim() > 2, "The prediction tensor must have at least 3 dimensions." + + if pred.dim() == 4: + pred, _ = filter_prediction(pred, trg, mask, prob, collision_criterion, + best_idx, mode_first=mode_first) + + seq_len = pred.size(1) + + # Compute the collision rate for each scenario + for i in range(len(ptr) - 1): + ptr_from = ptr[i] + ptr_to = ptr[i + 1] + + # Get the scenario + scenario = pred[ptr_from:ptr_to] + n = scenario.size(0) + + # Compute the number of possible collisions + self.count += seq_len * comb(n, 2) # T * (n * (n - 1)) // 2 + for t in range(seq_len): + dists = torch.cdist(scenario[:, t], scenario[:, t], p=2) # (n, n) + + # Find the collisions and filter out the self-collisions + collisions = (dists < collision_threshold) & (dists != 0.0) + self.sum += collisions.sum().item() / 2 + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/log_likelihood.py b/metrics/log_likelihood.py new file mode 100644 index 0000000..358b5e5 --- /dev/null +++ b/metrics/log_likelihood.py @@ -0,0 +1,123 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Any +import torch +import torch.distributions as tdist +from torchmetrics import Metric + + +class NegativeLogLikelihood(Metric): + dist: Any + + def __init__(self, + dist: str = "mvn", + **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + self.dist = self.get_distribution_initializer(dist) + + @staticmethod + def get_distribution_initializer(dist_name: str) -> Any: + if dist_name == "mvn": + return tdist.MultivariateNormal + if dist_name == "normal": + return tdist.Normal + if dist_name == "laplace": + return tdist.Laplace + raise ValueError(f"Invalid distribution name: {dist_name}") + + @staticmethod + def handle_mode_first(pred: torch.Tensor, scale: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if pred.dim() == 4: + return pred.transpose(1, 2), scale.transpose(1, 2) + warnings.warn("'mode_first' is set to True but the predictions" + " are not multi-modal. Ignoring the flag.") + return pred, scale + + def create_distribution(self, pred, scale, is_tril): + if self.dist.__name__ == "MultivariateNormal": + assert scale.size(-1) == scale.size(-2), "Covariance matrix must be square." + if not is_tril: + scale = torch.linalg.cholesky(scale) + return self.dist(loc=pred, scale_tril=scale) + return self.dist(loc=pred, scale=scale) + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + scale: torch.Tensor, + prob: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + logits: bool = False, + is_tril: bool = False, + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) or (N, T, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: scale: The scale of the predictions. (N, T, M, 2, (2)) or (N, T, 2, (2)) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: logits: Whether the probabilities are logits. + :param: is_tril: Whether the scale is a lower triangular matrix. + :param: mode_first: Whether the mode is the first dimension. + """ + + if mode_first: + # (N, M, T, 2) -> (N, T, M, 2) + pred, scale = self.handle_mode_first(pred, scale) + + batch_size, seq_len = pred.size()[:2] + + distribution = self.create_distribution(pred, scale, is_tril) + + if pred.dim() == 4: + if prob is None: + prob = torch.ones(batch_size, pred.shape[2], device=pred.device) / pred.shape[2] + if logits: + prob *= 0.0 + prob = prob.unsqueeze(1).expand(-1, seq_len, -1) # (N, T, M) + + mix = tdist.Categorical(logits=prob) if logits else tdist.Categorical(probs=prob) + if self.dist.__name__ != "MultivariateNormal": + distribution = tdist.Independent(distribution, 1) + distribution = tdist.MixtureSameFamily(mix, distribution) + + # Compute the negative log-likelihood + neg_log_prob = distribution.log_prob(trg).neg() # (N, T) + + if mask is not None: + neg_log_prob = neg_log_prob * mask + valid_time_steps = mask.sum(dim=-1) + scored_agents = valid_time_steps > 0 + neg_log_prob = neg_log_prob[scored_agents] + valid_time_steps = valid_time_steps[scored_agents] + batch_size = int(scored_agents.sum().item()) + else: + valid_time_steps = torch.ones_like(neg_log_prob).sum(-1) # (N,) + + nll = neg_log_prob.sum(-1) / valid_time_steps # (N,) + + self.sum += nll.sum() + self.count += batch_size + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/min_ade.py b/metrics/min_ade.py new file mode 100644 index 0000000..74609f6 --- /dev/null +++ b/metrics/min_ade.py @@ -0,0 +1,71 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from torchmetrics import Metric +from metrics.utils import filter_prediction + + +class MinADE(Metric): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + prob: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + best_idx: Optional[torch.Tensor] = None, + min_criterion: str = 'FDE', + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) or (N, T, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: best_idx: The index of the best prediction. (N,) (to avoid recomputing it) + :param: min_criterion: Either 'FDE', 'ADE', or 'MAP'. + :param: mode_first: Whether the mode is the first dimension. (default: False) + """ + + if pred.dim() == 4: + pred, _ = filter_prediction(pred, trg, mask, prob, min_criterion, + best_idx, mode_first=mode_first) + + batch_size = pred.size(0) + + norm = torch.linalg.norm(pred - trg, dim=-1) # (N, T) + if mask is not None: + num_valid_steps = mask.sum(dim=-1) # (N,) + scored_agents = num_valid_steps > 0 + norm = norm * mask # (N, T) + norm = norm[scored_agents] + num_valid_steps = num_valid_steps[scored_agents] + batch_size = int(scored_agents.sum().item()) + else: + num_valid_steps = torch.ones_like(norm).sum(dim=-1) # (N,) + + ade = norm.sum(dim=-1) / num_valid_steps # (N,) + self.sum += ade.sum() + self.count += batch_size + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/min_apde.py b/metrics/min_apde.py new file mode 100644 index 0000000..1e2e0cf --- /dev/null +++ b/metrics/min_apde.py @@ -0,0 +1,79 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from torchmetrics import Metric +from metrics.utils import filter_prediction + + +class MinAPDE(Metric): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + prob: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + best_idx: Optional[torch.Tensor] = None, + min_criterion: str = 'FDE', + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) or (N, T, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: best_idx: The index of the best prediction. (N,) (to avoid recomputing it) + :param: min_criterion: Either 'FDE', 'ADE', or 'MAP'. + :param: mode_first: Whether the mode is the first dimension. (default: False) + """ + + if pred.dim() == 4: + pred, _ = filter_prediction(pred, trg, mask, prob, min_criterion, + best_idx, mode_first=mode_first) + + batch_size, seq_len = pred.size()[:2] + + cdist = torch.cdist(pred, trg, p=2) # (N, T, T) + + if mask is not None: + mask_exp = (mask.unsqueeze(1).repeat(1, seq_len, 1) & + mask.unsqueeze(-1).repeat(1, 1, seq_len)) # (N, T, T) + cdist[~mask_exp] = float('1e9') + path_dist, _ = cdist.min(dim=-1) # (N, T) + num_valid_steps = mask.sum(dim=-1) + path_dist = path_dist * mask + + scored_agents = num_valid_steps > 0 + path_dist = path_dist[scored_agents] + num_valid_steps = num_valid_steps[scored_agents] + batch_size = int(scored_agents.sum().item()) + else: + path_dist, _ = cdist.min(dim=-1) # (N, T) + num_valid_steps = torch.ones_like(path_dist).sum(-1) # (N,) + + apde = path_dist.sum(-1) / num_valid_steps # (N,) + + self.sum += apde.sum() + self.count += batch_size + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/min_brier.py b/metrics/min_brier.py new file mode 100644 index 0000000..a526985 --- /dev/null +++ b/metrics/min_brier.py @@ -0,0 +1,86 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from torchmetrics import Metric +from metrics.utils import filter_prediction + + +class MinBrier(Metric): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + prob: torch.Tensor, + mask: Optional[torch.Tensor] = None, + best_idx: Optional[torch.Tensor] = None, + logit: bool = False, + min_criterion: str = 'FDE', + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: best_idx: The index of the best prediction. (N,) (to avoid recomputing it) + :param: logit: Whether the probabilities are logits. + :param: min_criterion: Either 'FDE', 'ADE', or 'MAP'. + :param: mode_first: Whether the mode is the first dimension. (default: False) + """ + assert prob is not None, ("Probabilistic criterion requires" + " the probability of the predictions.") + assert pred.dim() == 4, "The predictions must be multi-modal." + + pred, best_idx = filter_prediction(pred, trg, mask, prob, min_criterion, + best_idx, mode_first=mode_first) + + batch_size, seq_len = pred.size()[:2] + + prob = prob[torch.arange(batch_size), best_idx] # (N,) + + if mask is not None: + mask_reversed = 1 * mask.flip(dims=[-1]) + last_idx = seq_len - 1 - mask_reversed.argmax(dim=-1) + + pred = pred[torch.arange(batch_size), last_idx] # (N, 2) + trg = trg[torch.arange(batch_size), last_idx] # (N, 2) + + scored_agents = mask.sum(dim=-1) > 0 + pred = pred[scored_agents] + trg = trg[scored_agents] + prob = prob[scored_agents] + batch_size = int(scored_agents.sum().item()) + else: + pred = pred[:, -1] # (N, 2) + trg = trg[:, -1] # (N, 2) + + if logit: + prob = torch.sigmoid(prob) + + brier = (1.0 - prob) * torch.linalg.norm(pred - trg, dim=-1) # (N,) + + self.sum += brier.sum() + self.count += batch_size + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/min_fde.py b/metrics/min_fde.py new file mode 100644 index 0000000..fc31c84 --- /dev/null +++ b/metrics/min_fde.py @@ -0,0 +1,76 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from torchmetrics import Metric +from metrics.utils import filter_prediction + + +class MinFDE(Metric): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + prob: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + best_idx: Optional[torch.Tensor] = None, + min_criterion: str = 'FDE', + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) or (N, T, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: best_idx: The index of the best prediction. (N,) (to avoid recomputing it) + :param: min_criterion: Either 'FDE', 'ADE', or 'MAP'. + :param: mode_first: Whether the mode is the first dimension. (default: False) + """ + + if pred.dim() == 4: + pred, _ = filter_prediction(pred, trg, mask, prob, min_criterion, + best_idx, mode_first=mode_first) + + batch_size, seq_len = pred.size()[:2] + + if mask is not None: + mask_reversed = 1 * mask.flip(dims=[-1]) + last_idx = seq_len - 1 - mask_reversed.argmax(dim=-1) + + pred = pred[torch.arange(batch_size), last_idx] # (N, 2) + trg = trg[torch.arange(batch_size), last_idx] # (N, 2) + + scored_agents = mask.sum(dim=-1) > 0 + pred = pred[scored_agents] + trg = trg[scored_agents] + batch_size = int(scored_agents.sum().item()) + else: + pred = pred[:, -1] # (N, 2) + trg = trg[:, -1] # (N, 2) + + fde = torch.linalg.norm(pred - trg, dim=-1) # (N,) + + self.sum += fde.sum() + self.count += batch_size + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/miss_rate.py b/metrics/miss_rate.py new file mode 100644 index 0000000..0b92b00 --- /dev/null +++ b/metrics/miss_rate.py @@ -0,0 +1,80 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from torchmetrics import Metric +from metrics.utils import filter_prediction + + +class MissRate(Metric): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, + pred: torch.Tensor, + trg: torch.Tensor, + mask: Optional[torch.Tensor] = None, + prob: Optional[torch.Tensor] = None, + best_idx: Optional[torch.Tensor] = None, + miss_criterion: str = 'FDE', + miss_threshold: float = 2.0, + mode_first: bool = False) -> None: + """ + Update the metric state. + :param: pred: The predicted trajectory. (N, T, M, 2) or (N, T, 2) + :param: trg: The ground-truth target trajectory. (N, T, 2) + :param: prob: The probability of the predictions. (N, M) + :param: mask: The mask for valid positions. (N, T) + :param: best_idx: The index of the best prediction. (N,) (to avoid recomputing it) + :param: min_criterion: Either 'FDE', 'ADE', or 'MAP'. + :param: miss_threshold: The threshold for a missed prediction. (default: 2.0) + :param: mode_first: Whether the mode is the first dimension. (default: False) + """ + + if pred.dim() == 4: + pred, _ = filter_prediction(pred, trg, mask, prob, miss_criterion, + best_idx, mode_first=mode_first) + + batch_size, seq_len = pred.size()[:2] + + if mask is not None: + mask_reversed = 1 * mask.flip(dims=[-1]) + last_idx = seq_len - 1 - mask_reversed.argmax(dim=-1) + + pred = pred[torch.arange(batch_size), last_idx] # (N, 2) + trg = trg[torch.arange(batch_size), last_idx] # (N, 2) + + scored_agents = mask.sum(dim=-1) > 0 + pred = pred[scored_agents] + trg = trg[scored_agents] + batch_size = int(scored_agents.sum().item()) + else: + pred = pred[:, -1] # (N, 2) + trg = trg[:, -1] # (N, 2) + + norm = torch.linalg.norm(pred - trg, dim=-1) # (N,) + + mr = norm > miss_threshold # (N,) + + self.sum += mr.sum() + self.count += batch_size + + def compute(self) -> torch.Tensor: + """ + Compute the final metric. + """ + return self.sum / self.count diff --git a/metrics/unittest/__init__.py b/metrics/unittest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/unittest/test_ade.py b/metrics/unittest/test_ade.py new file mode 100644 index 0000000..ba51ca0 --- /dev/null +++ b/metrics/unittest/test_ade.py @@ -0,0 +1,26 @@ +import torch +from metrics.min_ade import MinADE + + +def test_min_ade_multimodal() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_ade = MinADE() + for msk in [None, mask]: + min_ade.update(pred, trg, msk) + min_ade.compute() + + +def test_min_ade_unimodal() -> None: + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_ade = MinADE() + for msk in [None, mask]: + min_ade.update(pred, trg, msk) + min_ade.compute() diff --git a/metrics/unittest/test_apde.py b/metrics/unittest/test_apde.py new file mode 100644 index 0000000..d8fcb4a --- /dev/null +++ b/metrics/unittest/test_apde.py @@ -0,0 +1,26 @@ +import torch +from metrics.min_apde import MinAPDE + + +def test_min_apde_multimodal() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_apde = MinAPDE() + for msk in [None, mask]: + min_apde.update(pred, trg, msk) + min_apde.compute() + + +def test_min_apde_unimodal() -> None: + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_apde = MinAPDE() + for msk in [None, mask]: + min_apde.update(pred, trg, msk) + min_apde.compute() diff --git a/metrics/unittest/test_brier.py b/metrics/unittest/test_brier.py new file mode 100644 index 0000000..829a61f --- /dev/null +++ b/metrics/unittest/test_brier.py @@ -0,0 +1,39 @@ +import pytest +import torch +from metrics.min_brier import MinBrier + + +def test_min_brier_multimodal() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + prob = torch.randn(batch_size, num_modes) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_brier = MinBrier() + for msk in [None, mask]: + min_brier.update(pred, trg, prob, msk) + min_brier.compute() + + +def test_min_brier_with_none_probability(): + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) + prob = None + + min_brier = MinBrier() + with pytest.raises(AssertionError, match="Probabilistic criterion requires" + " the probability of the predictions."): + min_brier.update(pred, trg, prob) + + +def test_min_brier_with_unimodal_prediction(): + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) # Assuming this should be multi-modal + prob = torch.randn(batch_size, num_modes) + + min_brier = MinBrier() + with pytest.raises(AssertionError, match="The predictions must be multi-modal."): + min_brier.update(pred, trg, prob) diff --git a/metrics/unittest/test_collision_rate.py b/metrics/unittest/test_collision_rate.py new file mode 100644 index 0000000..2d2a216 --- /dev/null +++ b/metrics/unittest/test_collision_rate.py @@ -0,0 +1,45 @@ +import pytest +import torch +from metrics.collision_rate import CollisionRate + + +def test_collision_rate_multimodal() -> None: + n_scenarios = 10 + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + + ptr = torch.randint(1, batch_size, (n_scenarios - 1,)).unique().sort()[0] + ptr = torch.cat([torch.tensor([0]), ptr, torch.tensor([batch_size])]) + + cr = CollisionRate() + cr.update(pred, trg, ptr) + cr.compute() + + +def test_collision_rate_unimodal() -> None: + n_scenarios = 10 + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) + + ptr = torch.randint(1, batch_size, (n_scenarios - 1,)).unique().sort()[0] + ptr = torch.cat([torch.tensor([0]), ptr, torch.tensor([batch_size])]) + + cr = CollisionRate() + cr.update(pred, trg, ptr) + cr.compute() + + +def test_collision_rate_dimension() -> None: + n_scenarios = 10 + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(seq_len, num_dims) + pred = torch.randn(seq_len, num_dims) + + ptr = torch.randint(1, batch_size, (n_scenarios - 1,)).unique().sort()[0] + ptr = torch.cat([torch.tensor([0]), ptr, torch.tensor([batch_size])]) + + cr = CollisionRate() + with pytest.raises(AssertionError): + cr.update(pred, trg, ptr) diff --git a/metrics/unittest/test_fde.py b/metrics/unittest/test_fde.py new file mode 100644 index 0000000..73db9ee --- /dev/null +++ b/metrics/unittest/test_fde.py @@ -0,0 +1,26 @@ +import torch +from metrics.min_fde import MinFDE + + +def test_min_fde_multimodal() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_fde = MinFDE() + for msk in [None, mask]: + min_fde.update(pred, trg, msk) + min_fde.compute() + + +def test_min_fde_unimodal() -> None: + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + min_fde = MinFDE() + for msk in [None, mask]: + min_fde.update(pred, trg, msk) + min_fde.compute() diff --git a/metrics/unittest/test_mr.py b/metrics/unittest/test_mr.py new file mode 100644 index 0000000..dc91369 --- /dev/null +++ b/metrics/unittest/test_mr.py @@ -0,0 +1,26 @@ +import torch +from metrics.miss_rate import MissRate + + +def test_miss_rate_multimodal() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + mr = MissRate() + for msk in [None, mask]: + mr.update(pred, trg, msk) + mr.compute() + + +def test_miss_rate_unimodal() -> None: + batch_size, seq_len, num_dims = 32, 25, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + pred = torch.randn(batch_size, seq_len, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + mr = MissRate() + for msk in [None, mask]: + mr.update(pred, trg, msk) + mr.compute() diff --git a/metrics/unittest/test_nll.py b/metrics/unittest/test_nll.py new file mode 100644 index 0000000..726db60 --- /dev/null +++ b/metrics/unittest/test_nll.py @@ -0,0 +1,67 @@ +import pytest +import torch +from metrics.log_likelihood import NegativeLogLikelihood + + +@pytest.mark.parametrize("dist", ["normal", "mvn", "laplace"]) +def test_nll_valid_distribution(dist: str) -> None: + # Test that no exceptions are raised for valid distributions + try: + _ = NegativeLogLikelihood(dist) + except AssertionError: + pytest.fail(f"AssertionError raised for valid distribution {dist}") + + +@pytest.mark.parametrize("dist", ["invalid", "unrecognized", ""]) +def test_nll_invalid_distribution(dist: str) -> None: + # Test that an AssertionError is raised for invalid distributions + with pytest.raises(ValueError, match=f"Invalid distribution name: {dist}"): + _ = NegativeLogLikelihood(dist) + + +def test_nll_mvn_invalid_covariance() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + trg = torch.randn(batch_size, seq_len, num_dims) + logits = torch.randn(batch_size, num_modes) + prob = torch.softmax(logits, dim=-1) + + # Not square + scale = torch.ones((batch_size, seq_len, num_modes, num_dims)) + mask = torch.randint(0, 2, (batch_size, seq_len)) + + nll = NegativeLogLikelihood("mvn") + with pytest.raises(AssertionError, match="Covariance matrix must be square."): + nll.update(pred, trg, scale, prob, mask) + + +def test_nll_mvn_valid_scale() -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + trg = torch.randn(batch_size, seq_len, num_dims) + logits = torch.randn(batch_size, num_modes) + prob = torch.softmax(logits, dim=-1) + scale = torch.ones_like(pred).diag_embed() * 0.1 + mask = torch.randint(0, 2, (batch_size, seq_len)) + + nll = NegativeLogLikelihood("mvn") + for msk in [None, mask]: + nll.update(pred, trg, scale, prob, msk) + result = nll.compute() + assert result is not None + + +@pytest.mark.parametrize("dist", ["normal", "laplace"]) +def test_nll_independent(dist: str) -> None: + batch_size, seq_len, num_modes, num_dims = 32, 25, 6, 2 + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + trg = torch.randn(batch_size, seq_len, num_dims) + logit = torch.randn(batch_size, num_modes) + scale = torch.ones_like(pred) * 0.1 + mask = torch.randint(0, 2, (batch_size, seq_len)) + + nll = NegativeLogLikelihood(dist) + for lg in [None, logit]: + nll.update(pred, trg, scale, lg, mask, logits=True) + result = nll.compute() + assert result is not None diff --git a/metrics/unittest/test_utils.py b/metrics/unittest/test_utils.py new file mode 100644 index 0000000..8f1c6da --- /dev/null +++ b/metrics/unittest/test_utils.py @@ -0,0 +1,70 @@ +import pytest +import torch +from metrics.utils import filter_prediction + + +def test_filter_prediction_pre_indexed() -> None: + batch_size, seq_len, num_modes, num_dims = 10, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + best_idx = torch.randint(0, num_modes, (batch_size,)) + pred_, best_idx_ = filter_prediction(pred, trg, best_idx=best_idx) + assert pred_.size() == (batch_size, seq_len, num_dims) + assert (best_idx == best_idx_).all() + + +def test_filter_prediction_criterion() -> None: + batch_size, seq_len, num_modes, num_dims = 10, 25, 6, 2 + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + trg = torch.randn(batch_size, seq_len, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + prob = torch.randn(batch_size, num_modes) + + for min_criterion in ['FDE', 'ADE', 'MAP']: + pred_, best_idx_ = filter_prediction(pred, trg, mask, prob, + min_criterion, mode_first=False) + assert pred_.size() == (batch_size, seq_len, num_dims) + assert best_idx_.size() == (batch_size,) + + invalid_criterion = 'invalid' + with pytest.raises(ValueError) as exc: + filter_prediction(pred, trg, min_criterion=f'{invalid_criterion}') + + assert exc.value == f"Invalid criterion: {invalid_criterion}" + + +def test_filter_prediction_mode_consistency() -> None: + batch_size, seq_len, num_modes, num_dims = 10, 25, 6, 2 + trg = torch.randn(batch_size, seq_len, num_dims) + mask = torch.randint(0, 2, (batch_size, seq_len)) + prob = torch.randn(batch_size, num_modes) + + for min_criterion in ['FDE', 'ADE', 'MAP']: + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + pred_list = [] + best_idx_list = [] + for mode_first in [False, True]: + if mode_first: + pred_ = pred.transpose(1, 2).clone() + else: + pred_ = pred.clone() + + pred_, best_idx_ = filter_prediction(pred_, trg, mask, prob, + min_criterion, mode_first=mode_first) + + pred_list.append(pred_) + best_idx_list.append(best_idx_) + + assert torch.allclose(pred_list[0], pred_list[1]) + assert torch.allclose(best_idx_list[0], best_idx_list[1]) + + +def test_filter_prediction_dimension() -> None: + batch_size, seq_len, num_modes, num_dims = 10, 25, 6, 4 + pred = torch.randn(batch_size, seq_len, num_modes, num_dims) + trg = torch.randn(batch_size, seq_len, num_dims) + with pytest.warns(UserWarning): + pred_, best_idx_ = filter_prediction(pred, trg) + assert pred_.size() == (batch_size, seq_len, 2) + assert best_idx_.size() == (batch_size,) diff --git a/metrics/utils.py b/metrics/utils.py new file mode 100644 index 0000000..872054c --- /dev/null +++ b/metrics/utils.py @@ -0,0 +1,84 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional +import torch + + +def filter_prediction(pred: torch.Tensor, + trg: torch.Tensor, + mask: Optional[torch.Tensor] = None, + prob: Optional[torch.Tensor] = None, + min_criterion: str = 'FDE', + best_idx: Optional[torch.Tensor] = None, + mode_first: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: + if mode_first: + # (N, M, T, 2) -> (N, T, M, 2) + pred = pred.transpose(1, 2) + + if pred.size(-1) > 2 or trg.size(-1) > 2: + warnings.warn("The last dimension of the prediction or target tensors" + " is greater than 2. Only the first two dimensions will be considered.") + pred = pred[..., :2] + trg = trg[..., :2] + + batch_size, seq_len = pred.size()[:2] + + if best_idx is not None: + pred = pred[torch.arange(batch_size), :, best_idx] # (N, T, 2) + return pred, best_idx + + if min_criterion == "FDE": + if mask is not None: + mask_reversed = 1 * mask.flip(dims=[-1]) # (N, T) + last_idx = seq_len - 1 - mask_reversed.argmax(dim=-1) # (N,) + + last_pred = pred[torch.arange(batch_size), last_idx] # (N, M, 2) + last_trg = trg[torch.arange(batch_size), last_idx] # (N, 2) + else: + last_pred = pred[:, -1] + last_trg = trg[:, -1] + + best_idx = torch.linalg.norm(last_pred - last_trg.unsqueeze(1), + dim=-1).argmin(dim=-1) # (N,) + + pred = pred[torch.arange(batch_size), :, best_idx] # (N, T, 2) + + elif min_criterion == "ADE": + if mask is not None: + multi_mask = mask.unsqueeze(-1).unsqueeze(-1) # (N, T, 1, 1) + masked_pred = pred * multi_mask # (N, T, M, 2) + masked_trg = trg.unsqueeze(2) * multi_mask # (N, T, 1, 2) + else: + masked_pred = pred # (N, T, M, 2) + masked_trg = trg.unsqueeze(2) # (N, T, 1, 2) + + norm = torch.linalg.norm(masked_pred - masked_trg, dim=-1) # (N, T, M) + + best_idx = norm.sum(dim=1).argmin(dim=-1) # (N,) + pred = pred[torch.arange(batch_size), :, best_idx] # (N, T, 2) + + elif min_criterion == "MAP": + assert prob is not None, ("Probabilistic criterion requires" + " the probability of the predictions.") + + best_idx = prob.argmax(dim=-1) # (N,) + pred = pred[torch.arange(batch_size), :, best_idx] # (N, T, 2) + + else: + raise ValueError(f"Invalid criterion: {min_criterion}") + + return pred, best_idx diff --git a/models/prototype.py b/models/prototype.py new file mode 100644 index 0000000..a7791fd --- /dev/null +++ b/models/prototype.py @@ -0,0 +1,55 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch import nn +import torch_geometric.nn as ptg +from torch_geometric.data import HeteroData + + +class Net(nn.Module): + def __init__(self, config: dict) -> None: + super().__init__() + num_inputs = config["num_inputs"] + num_outputs = config["num_outputs"] + num_hidden = config["num_hidden"] + self.ph = config["pred_hrz"] + + self.embed = nn.Linear(num_inputs, num_hidden) + self.encoder = nn.GRU(num_hidden, num_hidden, batch_first=True) + self.interaction = ptg.GATv2Conv(num_hidden, num_hidden, concat=False) + self.decoder = nn.GRU(num_hidden, num_hidden, batch_first=True) + self.output = nn.Linear(num_hidden, num_outputs) + + def forward(self, data: HeteroData) -> torch.Tensor: + edge_index = data['agent']['edge_index'] + x = torch.cat([data['agent']['inp_pos'], + data['agent']['inp_vel'], + data['agent']['inp_yaw']], dim=-1) + + # map_to_agent_edge_index = data['map', 'to', 'agent']['edge_index'] + # map_pos = data['map_point']['position'] + + x = self.embed(x) + _, h = self.encoder(x) + x = h[-1] + + x = self.interaction(x, edge_index) + x = x.unsqueeze(1).repeat(1, self.ph, 1) + x, _ = self.decoder(x) + + pred = self.output(x) + + return pred diff --git a/preamble.py b/preamble.py new file mode 100644 index 0000000..4ea3363 --- /dev/null +++ b/preamble.py @@ -0,0 +1,53 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import importlib +from typing import Callable +from pathlib import Path + + +def load_config(config: str) -> dict: + # check if file contains ".json" extension + if not config.endswith(".json"): + config += ".json" + + # check if file exists in any of the config subdirectories + config_path = Path("configs") + + # get all subdirectories + subdirs = [d for d in config_path.iterdir() if d.is_dir()] + + # get all files in subdirectories + files = [f for d in subdirs for f in d.iterdir() if f.is_file()] + + # check if config is any of the files + if not any(config in f.name for f in files): + raise FileNotFoundError(f"Config file {config} not found.") + + config = [str(f) for f in files if config in f.name][0] + + with open(config, 'r', encoding='utf-8') as openfile: + conf = json.load(openfile) + return conf + + +def import_module(module_name: str) -> object: + return importlib.import_module(module_name) + + +def import_from_module(module_name: str, class_name: str) -> Callable: + module = import_module(module_name) + return getattr(module, class_name) diff --git a/preprocess.sh b/preprocess.sh new file mode 100644 index 0000000..ccac3a2 --- /dev/null +++ b/preprocess.sh @@ -0,0 +1,4 @@ +python -m preprocessing.preprocess_highway --dataset 'highD' --debug 0 --use-threads 1 +python -m preprocessing.preprocess_urban --dataset 'rounD' --debug 0 --use-threads 1 +python -m preprocessing.preprocess_urban --dataset 'inD' --debug 0 --use-threads 1 +# python -m preprocessing.preprocess_urban --dataset 'uniD' --debug 1 --use-threads 1 diff --git a/preprocessing/__init__.py b/preprocessing/__init__.py new file mode 100644 index 0000000..5735660 --- /dev/null +++ b/preprocessing/__init__.py @@ -0,0 +1 @@ +from preprocessing.arguments import args diff --git a/preprocessing/arguments.py b/preprocessing/arguments.py new file mode 100644 index 0000000..71d94f8 --- /dev/null +++ b/preprocessing/arguments.py @@ -0,0 +1,47 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser, ArgumentTypeError + + +def str_to_bool(value: bool | str) -> bool: + """Used for boolean arguments in argparse; avoiding `store_true` and `store_false`.""" + true_vals = ("yes", "true", "t", "y", "1") + false_vals = ("no", "false", "f", "n", "0") + if isinstance(value, bool): + return value + if value.lower() in true_vals: + return True + if value.lower() in false_vals: + return False + raise ArgumentTypeError('Boolean value expected.') + + +parser = ArgumentParser(description='Preprocessing arguments') + +# Program arguments +parser.add_argument('--path', type=str, default="../datasets", + help='path to dataset') +parser.add_argument('--dataset', type=str, default="rounD", + help='name of dataset') +parser.add_argument('--output-dir', type=str, default="data", + help='output directory for processed data') +parser.add_argument('--add-name', type=str, default="", + help='additional string to add to output-dir save name') +parser.add_argument('--use-threads', type=str_to_bool, default=False, + const=True, nargs="?", help='if multiprocessing should be used') +parser.add_argument('--debug', type=str_to_bool, default=False, + const=True, nargs="?", help='debugging mode') + +args = parser.parse_args() diff --git a/preprocessing/configs/highD.json b/preprocessing/configs/highD.json new file mode 100644 index 0000000..8ee9c8c --- /dev/null +++ b/preprocessing/configs/highD.json @@ -0,0 +1,16 @@ +{ + "dataset": "highD", + "seed": 42, + "sample_freq": 25, + "input_len": 2, + "output_len": 5, + "downsample": 5, + "skip_lc_samples": 12, + "skip_lk_samples": 25, + "n_inputs": 7, + "n_outputs": 7, + "lane_graph": { + "spacing": 10.0, + "buffer": 2.0 + } +} diff --git a/preprocessing/configs/inD.json b/preprocessing/configs/inD.json new file mode 100644 index 0000000..e487410 --- /dev/null +++ b/preprocessing/configs/inD.json @@ -0,0 +1,211 @@ +{ + "dataset": "inD", + "seed": 42, + "sample_freq": 25, + "input_len": 3, + "output_len": 5, + "downsample": 5, + "skip_samples": 25, + "n_inputs": 7, + "n_outputs": 7, + "recordings": { + "00": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "01": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "02": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "03": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "04": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "05": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "06": { + "x0": 143.255269808385, + "y0": -57.91170481615564, + "location": "04_aseag", + "include": true + }, + "07": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz", + "include": true + }, + "08": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz", + "include": true + }, + "09": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz", + "include": true + }, + "10": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz", + "include": true + }, + "11": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "12": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "13": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "14": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "15": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "16": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "17": { + "x0": 55.72110867398384, + "y0": -32.74837088734138, + "location": "01_bendplatz_construction", + "include": true + }, + "18": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "19": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "20": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "21": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "22": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "23": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "24": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "25": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "26": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "27": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "28": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "29": { + "x0": 47.4118205383659, + "y0": -28.8381176470473, + "location": "02_frankenburg", + "include": true + }, + "30": { + "x0": 40.080060675120016, + "y0": -25.416623842759034, + "location": "03_heckstrasse", + "include": true + }, + "31": { + "x0": 40.080060675120016, + "y0": -25.416623842759034, + "location": "03_heckstrasse", + "include": true + }, + "32": { + "x0": 40.080060675120016, + "y0": -25.416623842759034, + "location": "03_heckstrasse", + "include": true + } + } +} \ No newline at end of file diff --git a/preprocessing/configs/rounD.json b/preprocessing/configs/rounD.json new file mode 100644 index 0000000..10d594e --- /dev/null +++ b/preprocessing/configs/rounD.json @@ -0,0 +1,157 @@ +{ + "dataset": "rounD", + "seed": 42, + "sample_freq": 25, + "input_len": 3, + "output_len": 5, + "downsample": 5, + "skip_samples": 12, + "n_inputs": 7, + "n_outputs": 7, + "recordings": { + "00": { + "x0": 115.51669730710512, + "y0": -70.6429033531912, + "location": "1_kackertstrasse", + "include": true + }, + "01": { + "x0": 137.8338032894461, + "y0": -61.07768929146573, + "location": "2_thiergarten", + "include": true + }, + "02": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "03": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "04": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "05": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "06": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "07": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "08": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "09": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "10": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "11": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "12": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "13": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "14": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "15": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "16": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "17": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "18": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "19": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "20": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "21": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "22": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + }, + "23": { + "x0": 80.97640635242064, + "y0": -46.93989929086365, + "location": "0_neuweiler", + "include": true + } + } +} \ No newline at end of file diff --git a/preprocessing/configs/uniD.json b/preprocessing/configs/uniD.json new file mode 100644 index 0000000..fbab07e --- /dev/null +++ b/preprocessing/configs/uniD.json @@ -0,0 +1,91 @@ +{ + "dataset": "uniD", + "seed": 42, + "sample_freq": 25, + "input_len": 3, + "output_len": 5, + "downsample": 5, + "skip_samples": 25, + "n_inputs": 7, + "n_outputs": 7, + "recordings": { + "00": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "01": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "02": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "03": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "04": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "05": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "06": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "07": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "08": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "09": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "10": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "11": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + }, + "12": { + "x0": 121.46225158193216, + "y0": -86.64134747714613, + "location": "0_superc", + "include": true + } + } +} diff --git a/preprocessing/preprocess_highway.py b/preprocessing/preprocess_highway.py new file mode 100644 index 0000000..e6c2985 --- /dev/null +++ b/preprocessing/preprocess_highway.py @@ -0,0 +1,415 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import sys +import time +import json +import pickle +import warnings +from typing import Any +from multiprocessing import Pool, Value, Lock + +import torch +import pandas as pd +import numpy as np +from tqdm import tqdm + +from preprocessing.arguments import args +from preprocessing.utils.highway_graph import get_highway_graph +from preprocessing.utils.highway_utils import * +from preprocessing.utils.common import * + +worker_counter: Any +worker_lock: Any + + +def process_id(id0: int, + rec_id: str, + out_dir: str, + fr_dict: dict, + tr_meta: pd.DataFrame, + tr: pd.DataFrame, + ln_graph: dict, + current_set: str = 'train', + dataset: str = 'highD', + fz: int = 25, + input_len: int = 2, + output_len: int = 5, + n_inputs: int = 7, + n_outputs: int = 7, + ds_factor: int = 5, + skip: int = 12, + debug: bool = False, + ) -> None: + """ + Extracts the data for a given set of frames and saves it to a pickle file. + :param id0: The trackId of the target vehicle + :param rec_id: The ID of the recording + :param out_dir: Output directory + :param current_set: The current set (train, val, test) + :param fr_dict: The frames to extract + :param tr_meta: The meta-data of the tracks + :param tr: The trajectory data + :param ln_graph: The lane graph + :param fz: The sampling frequency + :param input_len: The length of the input sequence + :param output_len: The length of the output sequence + :param n_inputs: The number of input features + :param n_outputs: The number of output features + :param ds_factor: The down-sampling factor + :param skip: The number of frames to skip + :param dataset: The dataset name + :param debug: Debug mode + :return: None + """ + + not_set = get_other_sets(current_set) + + if not_set is None: + not_set = ['val', 'test'] + df = tr[tr.trackId == id0] + frames = df.frame.to_numpy() + + # Remove frames that are not in the current set + frames = update_frames(frames, fr_dict[not_set[0]], fr_dict[not_set[1]]) + + if len(frames) < fz * (input_len + output_len) + 1: + return None + + driving_dir = int(tr_meta[tr_meta.trackId == id0].drivingDirection.iloc[0]) + + # First, we filter out the frames where the target vehicle is performing a lane keep + # that way we can sample more frames for lane changes + lk_frames = [] + lc_frames = [] + for frame in frames[::skip]: + prediction_frame = frame + fz * input_len + final_frame = prediction_frame + fz * output_len + if final_frame not in frames: + break + ta_intent = get_maneuver(tr, prediction_frame - 1, [id0], prop='maneuver')[0] + if ta_intent == 3: + lk_frames.append(frame) + else: + lc_frames.append(frame) + + n_lc = len(lc_frames) + n_lk = len(lk_frames) + + # Our goal is to not sample more lane keep frames than lane change frames + if n_lc > 0: + keep_lk = min(n_lc, n_lk) + else: + keep_lk = min(n_lk, 5) + + # The stride is selected such that we retain 'keep_lk' lane keep frames + k = max(math.ceil(n_lk / (keep_lk + 1)), 1) + + # 'Slicing' the lane keep frames assures that we sample + # from all parts of the trajectory + lk_frames = lk_frames[::k] + + # Combine the lane change and lane keep frames + updated_frames = lc_frames + lk_frames + + for frame in updated_frames: + prediction_frame = frame + fz * input_len + final_frame = prediction_frame + fz * output_len + + sas = get_neighbors(tr, prediction_frame - 1, id0, driving_dir) + sa_ids = pd.unique(sas.trackId) + n_sas = len(sa_ids) + + agent_ids = [id0, *sa_ids] + + # Retrieve meta information + intention = get_maneuver(tr, prediction_frame - 1, agent_ids, prop='maneuver') + agent_type = class_list_to_int_list(get_meta_property(tr_meta, agent_ids, prop='class')) + + # Convert to tensors + intention_tensor = torch.tensor(intention).long() + agent_type_tensor = torch.tensor(agent_type).long() + + input_array = np.empty((n_sas + 1, fz * input_len, n_inputs)) + target_array = np.empty((n_sas + 1, fz * output_len, n_outputs)) + + for j, v_id in enumerate(agent_ids): + input_array[j] = get_features(tr, frame, prediction_frame - 1, n_inputs, v_id) + target_array[j] = get_features(tr, prediction_frame, final_frame - 1, n_outputs, v_id) + + # Down-sample the data + if ds_factor > 1: + input_array = decimate_nan(input_array, pad_order='front', ds_factor=ds_factor, fz=fz) + target_array = decimate_nan(target_array, pad_order='back', ds_factor=ds_factor, fz=fz) + + # Convert to tensors + input_tensor = torch.from_numpy(input_array).float() + target_tensor = torch.from_numpy(target_array).float() + + # Create masks + three_sec = 3 * fz / ds_factor + + input_mask, valid_mask, sa_mask, ma_mask = \ + get_masks(input_tensor, target_tensor, int(three_sec)) + + # make nans into zeros + input_tensor[torch.isnan(input_tensor)] = 0. + target_tensor[torch.isnan(target_tensor)] = 0. + + agent = {'num_nodes': n_sas + 1, + 'ta_index': 0, + 'ids': agent_ids, + 'type': agent_type_tensor, + 'inp_pos': input_tensor[..., :2], + 'inp_vel': input_tensor[..., 2:4], + 'inp_acc': input_tensor[..., 4:6], + 'inp_yaw': input_tensor[..., 6:], + 'trg_pos': target_tensor[..., :2], + 'trg_vel': target_tensor[..., 2:4], + 'trg_acc': target_tensor[..., 4:6], + 'trg_yaw': target_tensor[..., 6:], + 'intention': intention_tensor, + 'input_mask': input_mask, + 'valid_mask': valid_mask, + 'sa_mask': sa_mask, + 'ma_mask': ma_mask} + + data: dict[str, Any] = {'rec_id': rec_id, 'agent': agent} + # data['x_min'] = None + # data['x_max'] = None + data.update(ln_graph['upper_map'] if driving_dir == 1 else ln_graph['lower_map']) + + if not debug: + with worker_lock: + save_name = f"{dataset}_{current_set}_{worker_counter.value}" + worker_counter.value += 1 + with open(f"{out_dir}/{current_set}/{save_name}.pkl", "wb") as file: + pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) + return None + + +def process_ids(current_set: str, + rec_id: str, + out_dir: str, + fr_dict: dict, + tr_meta: pd.DataFrame, + tr: pd.DataFrame, + ln_graph: dict + ) -> None: + """ + Extracts the data for a given set of frames and saves it to a pickle file. + :param current_set: The current set (train, val, test) + :param rec_id: The recording ID + :param out_dir: Output directory + :param fr_dict: The frames to extract + :param tr_meta: The meta-data of the tracks + :param tr: The trajectory data + :param ln_graph: The lane graph + """ + assert current_set in ['train', 'val', 'test'], 'current_set must be one of [train, val, test]' + + fz = config["sample_freq"] + ds = config["dataset"] + input_len = config["input_len"] + output_len = config["output_len"] + n_inputs = config["n_inputs"] + n_outputs = config["n_outputs"] + ds_factor = config["downsample"] + skip_lc = config["skip_lc_samples"] + skip_lk = config["skip_lk_samples"] + + debug = args.debug + + outer_lc_args = (ds, fz, input_len, output_len, n_inputs, + n_outputs, ds_factor, skip_lc, debug) + outer_lk_args = (ds, fz, input_len, output_len, n_inputs, + n_outputs, ds_factor, skip_lk, debug) + + # Check if there are any saved samples in the current set directory + set_dir = f"{output_dir}/{current_set}" + if len(os.listdir(set_dir)) > 0: + # get the highest save_id + save_ids = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(set_dir)] + save_id = max(save_ids) + 1 + else: + save_id = 0 + + save_id_counter = Value('i', save_id) + save_lock = Lock() + + ta_ids = list(tr_meta[tr_meta['class'].isin(['car', 'truck'])].trackId) + frame_range = fr_dict[current_set] + ta_set = {ta_id for ta_id in ta_ids + if any(tr[tr.trackId == ta_id].frame.isin(frame_range))} + + # Get the ids of all the TAs that perform lane changes + lc_ta_ids = {ta_id for ta_id in ta_set + if int(tr_meta[tr_meta.trackId == ta_id].numLaneChanges.iloc[0]) > 0} + + # Compute the ids of all the TAs that perform lane keeping + lk_ta_ids = ta_set - lc_ta_ids + + frac = max(len(lc_ta_ids) // 10, 5) + + # Remove some of the lane keeping data + lk_ta_ids = set(np.random.choice(list(lk_ta_ids), frac, replace=False)) + + lc_arguments = [ + (ta_id, rec_id, out_dir, fr_dict, tr_meta, tr, + ln_graph, current_set, *outer_lc_args) for ta_id in lc_ta_ids + ] + lk_arguments = [ + (ta_id, rec_id, out_dir, fr_dict, tr_meta, tr, + ln_graph, current_set, *outer_lk_args) for ta_id in lk_ta_ids + ] + + arguments = lc_arguments + lk_arguments + + if args.use_threads: + n_workers = 1 + cpu_count = os.cpu_count() + if cpu_count is None: + warnings.warn("Could not determine the number of CPU cores. Using 1 thread.") + elif cpu_count <= 2: + warnings.warn("The number of CPU cores is too low. Using 1 thread.") + else: + n_workers = cpu_count + + with Pool(n_workers, initializer=init_worker, + initargs=(save_id_counter, save_lock)) as pool: + with tqdm(total=len(arguments), desc=f"{current_set.capitalize()}", + position=1, leave=False) as pbar: + for _ in pool.imap_unordered(worker_function, arguments): + pbar.update() + else: + for arg in tqdm(arguments, desc=f"{current_set.capitalize()}", position=1, leave=False): + process_id(*arg) + + +def init_worker(counter, lock): + # Attach the counter and lock to the worker + global worker_counter, worker_lock + worker_counter, worker_lock = counter, lock + + +def worker_function(arg: tuple) -> None: + # Wrapper function to call extract_by_frame with multiple arguments + return process_id(*arg) + + +def erase_previous_line(double_jump: bool = False): + """Erase the previous line in the terminal.""" + sys.stdout.write('\x1b[1A') # Move the cursor up one line + sys.stdout.write('\x1b[2K') # Clear the entire line + if double_jump: + sys.stdout.write('\x1b[1A') + + +if __name__ == "__main__": + if args.debug: + print("DEBUG MODE: ON \n") + + # worker_counter: Any + # worker_lock: Any + + output_dir = create_directories(args) + print(f"Output directory: {output_dir} \n") + + with open("preprocessing/configs/" + args.dataset + ".json", + "r", encoding="utf-8") as conf_file: + config = json.load(conf_file) + + random_seed = config["seed"] + np.random.seed(random_seed) + + rec_ids = [f"{i:02}" for i in range(1, 60 + 1)] + + try: + for r_id in tqdm(rec_ids, desc="Main process: ", position=0, leave=True): + print(f"Preprocessing started for recording {r_id}...") + + # Construct the base directory path for your data + base_dir = os.path.join(args.path, args.dataset, "data") + + # Use os.path.join for each specific file + rec_meta_path = os.path.join(base_dir, f"{r_id}_recordingMeta.csv") + tracks_meta_path = os.path.join(base_dir, f"{r_id}_tracksMeta.csv") + tracks_path = os.path.join(base_dir, f"{r_id}_tracks.csv") + + # Read the CSV files + rec_meta = pd.read_csv(rec_meta_path) + tracks_meta = pd.read_csv(tracks_meta_path) + tracks = pd.read_csv(tracks_path) + + # For the lanelet file, construct the path similarly + upper_map, lower_map, x_min, x_max = \ + get_highway_graph(rec_meta, tracks, + spacing=config["lane_graph"]["spacing"], + buffer=config["lane_graph"]["buffer"]) + lane_graph = {'upper_map': upper_map, 'lower_map': lower_map} + + # Perform some initial renaming + if 'trackId' not in tracks_meta.columns: + tracks_meta.rename(columns={'id': 'trackId'}, inplace=True) + tracks.rename(columns={'id': 'trackId'}, inplace=True) + if 'vx' not in tracks.columns: + tracks.rename(columns={'xVelocity': 'vx'}, inplace=True) + tracks.rename(columns={'yVelocity': 'vy'}, inplace=True) + tracks.rename(columns={'xAcceleration': 'ax'}, inplace=True) + tracks.rename(columns={'yAcceleration': 'ay'}, inplace=True) + if "x" not in tracks.columns: + tracks.rename(columns={'xCenter': 'x'}, inplace=True) + tracks.rename(columns={'yCenter': 'y'}, inplace=True) + + # Make class lowercase in tracks_meta + tracks_meta['class'] = tracks_meta['class'].str.lower() + + tracks = align_origin_w_centroid(tracks_meta, tracks, debug=args.debug) + tracks = add_driving_direction(tracks_meta, tracks) + tracks = add_maneuver(tracks_meta, tracks, debug=args.debug) + tracks = update_signs(rec_meta, tracks_meta, tracks, debug=args.debug) + tracks = add_heading_feat(tracks, debug=args.debug) + + # Determine train, val, test split (by frames) + train_frames, val_frames, test_frames = \ + get_frame_split(tracks_meta.finalFrame.array[-1], + seed=random_seed) + frame_dict = {'train': train_frames, 'val': val_frames, 'test': test_frames} + + shared_args = (r_id, output_dir, frame_dict, tracks_meta, tracks, lane_graph) + + tasks = [ + ('train',) + shared_args, + ('val',) + shared_args, + ('test',) + shared_args + ] + + # Erase preprocessing message + erase_previous_line() + + # Print and immediately erase a "done" message (as an example) + print("Preprocessing completed.") + time.sleep(1) # Just to let the user see the message + erase_previous_line(True) + + for task in tasks: + process_ids(*task) + + except KeyboardInterrupt: + print("Interrupted.") + + finally: + print("Finished.") diff --git a/preprocessing/preprocess_urban.py b/preprocessing/preprocess_urban.py new file mode 100644 index 0000000..2941157 --- /dev/null +++ b/preprocessing/preprocess_urban.py @@ -0,0 +1,398 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import json +import pickle +import warnings +from typing import Any +from multiprocessing import Pool, Value, Lock + +import torch +import pandas as pd +import numpy as np +from tqdm import tqdm + +from preprocessing.arguments import args +from preprocessing.utils.lanelet_graph import get_lanelet_graph +from preprocessing.utils.common import * + +worker_counter: Any +worker_lock: Any + + +def process_id( + id0: int, + rec_id: str, + out_dir: str, + fr_dict: dict, + tr_meta: pd.DataFrame, + tr: pd.DataFrame, + ln_graph: dict, + current_set: str = "train", + dataset: str = "rounD", + fz: int = 25, + input_len: int = 3, + output_len: int = 5, + n_inputs: int = 7, + n_outputs: int = 7, + ds_factor: int = 5, + skip: int = 12, + debug: bool = False, +) -> None: + """ + Extracts the data for a given set of frames and saves it to a pickle file. + :param id0: The trackId of the target vehicle + :param rec_id: The ID of the recording + :param out_dir: Output directory + :param current_set: The current set (train, val, test) + :param fr_dict: The frames to extract + :param tr_meta: The meta-data of the tracks + :param tr: The trajectory data + :param ln_graph: The lane graph + :param fz: The sampling frequency + :param input_len: The length of the input sequence + :param output_len: The length of the output sequence + :param n_inputs: The number of input features + :param n_outputs: The number of output features + :param ds_factor: The down-sampling factor + :param skip: The number of frames to skip + :param dataset: The dataset name + :param debug: Debug mode + :return: None + """ + + not_set = get_other_sets(current_set) + + if not_set is None: + not_set = ["val", "test"] + df = tr[tr.trackId == id0] + frames = df.frame.to_numpy() + + # Remove frames that are not in the current set + frames = update_frames(frames, fr_dict[not_set[0]], fr_dict[not_set[1]]) + + if len(frames) < fz * (input_len + output_len) + 1: + return None + for frame in frames[0:-1:skip]: # Skip every 0.5 second + prediction_frame = frame + fz * input_len + final_frame = prediction_frame + fz * output_len + if final_frame not in frames: + break + + sas = get_neighbors(tr, prediction_frame - 1, id0) + sa_ids = pd.unique(sas.trackId) + n_sas = len(sa_ids) + + agent_ids = [id0, *sa_ids] + + # Retrieve meta information + agent_type = class_list_to_int_list(get_meta_property(tr_meta, agent_ids, prop="class")) + + # Convert to tensors + agent_type_tensor = torch.tensor(agent_type).long() + + input_array = np.empty((n_sas + 1, fz * input_len, n_inputs)) + target_array = np.empty((n_sas + 1, fz * output_len, n_outputs)) + + for j, v_id in enumerate(agent_ids): + input_array[j] = get_features(tr, frame, prediction_frame - 1, n_inputs, v_id) + target_array[j] = get_features(tr, prediction_frame, final_frame - 1, n_outputs, v_id) + + # Down-sample the data + if ds_factor > 1: + input_array = decimate_nan(input_array, pad_order="front", ds_factor=ds_factor, fz=fz) + target_array = decimate_nan(target_array, pad_order="back", ds_factor=ds_factor, fz=fz) + + # Convert to tensors + input_tensor = torch.from_numpy(input_array).float() + target_tensor = torch.from_numpy(target_array).float() + + # Create masks + three_sec = 3 * fz / ds_factor + + # Detect static_ids (we don't want to score on parked vehicles) + non_scored_ids = [] + for j in range(len(agent_ids)): + a = len(input_tensor[j, :, 0].unique()) + b = len(target_tensor[j, :, 0].unique()) + if a + b <= 3: + non_scored_ids.append(j) + if 0 in non_scored_ids: + continue + + non_scored_ids_tensor = torch.tensor(non_scored_ids) + + input_mask, valid_mask, sa_mask, ma_mask = get_masks( + input_tensor, target_tensor, int(three_sec), non_scored_ids_tensor + ) + + # make nans into zeros + input_tensor[torch.isnan(input_tensor)] = 0.0 + target_tensor[torch.isnan(target_tensor)] = 0.0 + + agent = { + "num_nodes": n_sas + 1, + "ta_index": 0, + "ids": agent_ids, + "type": agent_type_tensor, + "inp_pos": input_tensor[..., :2], + "inp_vel": input_tensor[..., 2:4], + "inp_acc": input_tensor[..., 4:6], + "inp_yaw": input_tensor[..., 6:], + "trg_pos": target_tensor[..., :2], + "trg_vel": target_tensor[..., 2:4], + "trg_acc": target_tensor[..., 4:6], + "trg_yaw": target_tensor[..., 6:], + "input_mask": input_mask, + "valid_mask": valid_mask, + "sa_mask": sa_mask, + "ma_mask": ma_mask, + } + + data: dict[str, Any] = {'rec_id': rec_id, 'agent': agent} + data.update(ln_graph) + + if not debug: + with worker_lock: + save_name = f"{dataset}_{current_set}_{worker_counter.value}" + worker_counter.value += 1 + + with open(f"{out_dir}/{current_set}/{save_name}.pkl", "wb") as file: + pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) + return None + +def process_ids( + current_set: str, + rec_id: str, + out_dir: str, + fr_dict: dict, + tr_meta: pd.DataFrame, + tr: pd.DataFrame, + ln_graph: dict +) -> None: + """ + Extracts the data for a given set of frames and saves it to a pickle file. + :param current_set: The current set (train, val, test) + :param rec_id: The recording ID + :param out_dir: Output directory + :param fr_dict: The frames to extract + :param tr_meta: The meta-data of the tracks + :param tr: The trajectory data + :param ln_graph: The lanelet graph + """ + assert current_set in ["train", "val", "test"], "current_set must be one of [train, val, test]" + + fz = config["sample_freq"] + ds = config["dataset"] + input_len = config["input_len"] + output_len = config["output_len"] + n_inputs = config["n_inputs"] + n_outputs = config["n_outputs"] + ds_factor = config["downsample"] + skip = config["skip_samples"] + debug = args.debug + + outer_args = (ds, fz, input_len, output_len, n_inputs, n_outputs, ds_factor, skip, debug) + + # Check if there are any saved samples in the current set directory + set_dir = f"{output_dir}/{current_set}" + if len(os.listdir(set_dir)) > 0: + # get the highest save_id + save_ids = [int(f.split("_")[-1].split(".")[0]) for f in os.listdir(set_dir)] + save_id = max(save_ids) + 1 + else: + save_id = 0 + + save_id_counter = Value("i", save_id) + save_lock = Lock() + + ta_ids = list( + tr_meta[ + tr_meta["class"].isin(["car", "van", "truck", "truck_bus", "bus", + "motorcycle", "bicycle", "pedestrian"]) + ].trackId + ) + frame_range = fr_dict[current_set] + ta_ids_set = {ta_id for ta_id in ta_ids if + any(tr[tr.trackId == ta_id].frame.isin(frame_range))} + + parked_vehicles = set( + tr_meta[(tr_meta.initialFrame == 0) & + (tr_meta.finalFrame == tr_meta.finalFrame.max())].trackId.values + ) + + ta_ids = list(ta_ids_set - parked_vehicles) + + arguments = [ + (ta_id, rec_id, out_dir, fr_dict, tr_meta, + tr, ln_graph, current_set, *outer_args) for ta_id in ta_ids + ] + + if args.use_threads: + n_workers = 1 + cpu_count = os.cpu_count() + if cpu_count is None: + warnings.warn("Could not determine the number of CPU cores. Using 1 thread.") + elif cpu_count <= 2: + warnings.warn("The number of CPU cores is too low. Using 1 thread.") + else: + n_workers = cpu_count + + with Pool(n_workers, initializer=init_worker, + initargs=(save_id_counter, save_lock)) as pool: + with tqdm(total=len(ta_ids), desc=f"{current_set.capitalize()}", + position=1, leave=False) as pbar: + for _ in pool.imap_unordered(worker_function, arguments): + pbar.update() + else: + for arg in tqdm(arguments, desc=f"{current_set.capitalize()}", position=1, leave=False): + process_id(*arg) + + +def init_worker(counter, lock): + # Attach the counter and lock to the worker + global worker_counter, worker_lock + worker_counter, worker_lock = counter, lock + + +def worker_function(arg: tuple) -> None: + # Wrapper function to call extract_by_frame with multiple arguments + return process_id(*arg) + + +def erase_previous_line(double_jump: bool = False): + """Erase the previous line in the terminal.""" + sys.stdout.write("\x1b[1A") # Move the cursor up one line + sys.stdout.write("\x1b[2K") # Clear the entire line + if double_jump: + sys.stdout.write("\x1b[1A") + + +if __name__ == "__main__": + if args.debug: + print("DEBUG MODE: ON \n") + + output_dir = create_directories(args) + print(f"Output directory: {output_dir} \n") + + with open("preprocessing/configs/" + args.dataset + ".json", "r", + encoding="utf-8") as conf_file: + config = json.load(conf_file) + + random_seed = config["seed"] + np.random.seed(random_seed) + + rec_ids = [] + recordings = config["recordings"] + for key, value in recordings.items(): + if value["include"]: + rec_ids.append(key) + + if args.dataset == "inD": + temp_path = os.path.join(args.path, args.dataset, "maps", "lanelets") + dirs = os.listdir(temp_path) + + # check if "01_bendplatz_constuction" is in the directory + if "01_bendplatz_constuction" in dirs: + # fix spelling error in directory + os.rename( + os.path.join(temp_path, "01_bendplatz_constuction"), + os.path.join(temp_path, "01_bendplatz_construction"), + ) + + elif args.dataset == "uniD": + temp_path = os.path.join(args.path, args.dataset, "maps") + dirs = os.listdir(temp_path) + + # check if lanelet directory is named "lanelet" instead of "lanelets" + if "lanelet" in dirs: + # update name in directory for consistency + os.rename(os.path.join(temp_path, "lanelet"), os.path.join(temp_path, "lanelets")) + + try: + for r_id in tqdm(rec_ids, desc="Main process: ", position=0, leave=True): + print(f"Preprocessing started for recording {r_id}...") + + # Get the approximate geographical center of the scene + p0 = (recordings[r_id]["x0"], recordings[r_id]["y0"]) + + # Construct the base directory path for your data + base_dir = os.path.join(args.path, args.dataset, "data") + + # Use os.path.join for each specific file + rec_meta_path = os.path.join(base_dir, f"{r_id}_recordingMeta.csv") + tracks_meta_path = os.path.join(base_dir, f"{r_id}_tracksMeta.csv") + tracks_path = os.path.join(base_dir, f"{r_id}_tracks.csv") + + # Read the CSV files + rec_meta = pd.read_csv(rec_meta_path) + tracks_meta = pd.read_csv(tracks_meta_path) + tracks = pd.read_csv(tracks_path) + + # For the lanelet file, construct the path similarly + location = recordings[r_id]["location"] + path_to_lanelet = os.path.join(args.path, args.dataset, "maps", "lanelets", location) + osm_file = os.listdir(path_to_lanelet)[0] + lanelet_path = os.path.join(path_to_lanelet, osm_file) + + lane_graph = get_lanelet_graph(rec_meta, lanelet_path, p0[0], p0[1], return_torch=True) + + # Perform some initial renaming + if "vx" not in tracks.columns: + tracks.rename(columns={"xVelocity": "vx"}, inplace=True) + tracks.rename(columns={"yVelocity": "vy"}, inplace=True) + tracks.rename(columns={"xAcceleration": "ax"}, inplace=True) + tracks.rename(columns={"yAcceleration": "ay"}, inplace=True) + if "psi" not in tracks.columns: + tracks.rename(columns={"heading": "psi"}, inplace=True) + # convert all psi values to radians and wrap to pi + radians = np.deg2rad(tracks.psi) + tracks.psi = np.arctan2(np.sin(radians), np.cos(radians)) + if "x" not in tracks.columns: + tracks.rename(columns={"xCenter": "x"}, inplace=True) + tracks.rename(columns={"yCenter": "y"}, inplace=True) + tracks.x = tracks.x - p0[0] + tracks.y = tracks.y - p0[1] + + # Make class lowercase in tracks_meta + tracks_meta["class"] = tracks_meta["class"].str.lower() + + # Determine train, val, test split (by frames) + train_frames, val_frames, test_frames = \ + get_frame_split(tracks_meta.finalFrame.array[-1], seed=random_seed) + frame_dict = {"train": train_frames, "val": val_frames, "test": test_frames} + + shared_args = (r_id, output_dir, frame_dict, tracks_meta, tracks, lane_graph) + + tasks = [("train",) + shared_args, ("val",) + shared_args, ("test",) + shared_args] + + # Erase preprocessing message + erase_previous_line() + + # Print and immediately erase a "done" message (as an example) + print("Preprocessing completed.") + time.sleep(1) # Just to let the user see the message + erase_previous_line(True) + + for task in tasks: + process_ids(*task) + + except KeyboardInterrupt: + print("Interrupted.") + + finally: + print("Finished.") diff --git a/preprocessing/utils/__init__.py b/preprocessing/utils/__init__.py new file mode 100644 index 0000000..6643008 --- /dev/null +++ b/preprocessing/utils/__init__.py @@ -0,0 +1,5 @@ +from preprocessing.utils.common import * +from preprocessing.utils.lanelet_graph import get_lanelet_graph +from preprocessing.utils.highway_graph import get_highway_graph +from preprocessing.utils.highway_utils import * +from preprocessing.utils.exit_utils import * diff --git a/preprocessing/utils/common.py b/preprocessing/utils/common.py new file mode 100644 index 0000000..633446e --- /dev/null +++ b/preprocessing/utils/common.py @@ -0,0 +1,296 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import shutil +from typing import Optional +from argparse import Namespace + +import torch +import numpy as np +from pandas import DataFrame +from scipy.signal import decimate +from sklearn.model_selection import train_test_split + + +def create_directories(args: Namespace) -> str: + """Create directories for processed data.""" + + data_dir = args.dataset + args.add_name + output_dir = args.output_dir + "/" + data_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + elif args.debug: + return output_dir + else: + print(f"Folder {output_dir} already exists.") + inp = input("Would you like to overwrite it? (y/n) \n").lower() + if inp == "y": + print("Overwriting folder...") + # Clear folder + try: + shutil.rmtree(output_dir) + except FileNotFoundError as e: + print(f"Failed to delete {output_dir}. Reason: {e}") + else: + os.makedirs(output_dir) + else: + print("Exiting...") + sys.exit() + + # create subdirectories + if not os.path.exists(output_dir + "/train"): + os.makedirs(output_dir + "/train") + if not os.path.exists(output_dir + "/val"): + os.makedirs(output_dir + "/val") + if not os.path.exists(output_dir + "/test"): + os.makedirs(output_dir + "/test") + + return output_dir + + +def get_frame_split(n_frames: int, seed: int = 42) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Split the frames into train, validation, and test sets.""" + + all_frames = list(range(1, n_frames + 1)) + + # Divide all frames into ten lists of equal length + frame_lists = np.array_split(all_frames, 10) + + # Split the lists into train (80%), validation (10%), and test (10%) sets + train, valtest = train_test_split(frame_lists, test_size=0.2, random_state=seed) + val, test = train_test_split(valtest, test_size=0.5, random_state=seed) + + # Sort the order of the arrays in the list and concatenate them + train = np.concatenate(sorted(train, key=lambda x: x[0])) + val = np.concatenate(sorted(val, key=lambda x: x[0])) + test = np.concatenate(sorted(test, key=lambda x: x[0])) + return train, val, test + + +def update_frames(agent_frames: np.ndarray, + alt_set1: np.ndarray, alt_set2: np.ndarray) -> np.ndarray: + """ + Remove frames from the agent_frames that are in alt_set1 or alt_set2 + :param agent_frames: array of frames that are in the agent set + :param alt_set1: array of frames that are in an alternative set + :param alt_set2: array of frames that are in another alternative set + :return: + """ + assert agent_frames[-1] > agent_frames[0], "The frames are not in the correct order" + idx_b = np.isin(agent_frames, alt_set1) + idx_c = np.isin(agent_frames, alt_set2) + idx = np.logical_or(idx_b, idx_c) + return agent_frames[~idx] + + +def class_list_to_int_list(class_list: list[str]) -> list[int]: + """ + Convert a list of class names to a list of integers + :param class_list: + :return: + """ + class_to_int = { + 'car': 0, + 'van': 0, + 'trailer': 0, + 'truck': 1, + 'truck_bus': 2, + 'bus': 2, + 'motorcycle': 3, + 'bicycle': 4, + 'pedestrian': 5, + 'animal': 6, + } + return [class_to_int[c] for c in class_list] + + +def get_other_sets(current_set: str) -> Optional[list[str]]: + """ + Get the other sets (train, val, test) based on the current set. + :param current_set: The current set + :return: The other sets + """ + match current_set: + case 'train': + return ['val', 'test'] + case 'val': + return ['train', 'test'] + case 'test': + return ['train', 'val'] + case _: + return None + + +def get_neighbors(df: DataFrame, frame: int, id0: int, + driving_dir: Optional[int] = None) -> DataFrame: + """ + Get the vehicles (except id0) present at a given frame. + """ + if driving_dir is None: + df1 = df[(df.frame == frame) & (df.trackId != id0)] + else: + df1 = df[(df.frame == frame) & (df.trackId != id0) & (df.drivingDirection == driving_dir)] + return df1 + + +def get_meta_property(tracks_meta: DataFrame, agent_ids: list, prop: str = 'class') -> list[str]: + """ + Get a meta property of the agent from the tracks_meta DataFrame + """ + prp = [tracks_meta[tracks_meta.trackId == v_id][prop].values[0] for v_id in agent_ids] + return prp + + +def get_maneuver(tracks: DataFrame, frame: int, + agent_ids: list, prop='maneuver') -> list[int]: + """ + Get the maneuver of the agents at a given frame + """ + prp = [tracks[(tracks.trackId == v_id) & (tracks.frame == frame)][prop].values[0] + for v_id in agent_ids] + return prp + + +def get_features(df: DataFrame, + frame_start: int, + frame_end: int, + n_features: int, + track_id: int = -1) -> np.ndarray: + """ + Get the features of the agent with id track_id in the frame range [frame_start, frame_end] + """ + + return_array = np.empty((frame_end - frame_start + 1, n_features)) + return_array[:] = np.NaN + + if track_id != -1: + dfx = df[(df.frame >= frame_start) & (df.frame <= frame_end) & (df.trackId == track_id)] + else: + dfx = df[(df.frame >= frame_start) & (df.frame <= frame_end)] + try: + first_frame = dfx.frame.values[0] + except IndexError: + return return_array + frame_offset = first_frame - frame_start + + features = dfx[['x', 'y', 'vx', 'vy', 'ax', 'ay', 'psi']].to_numpy() + + return_array[frame_offset:frame_offset + features.shape[0], :] = features + + return return_array + + +def decimate_nan(x: np.ndarray, + pad_order: str = 'front', + ds_factor: int = 5, + fz: int = 25, + max_s: float = 1.0, + filter_order: int = 7) -> np.ndarray: + decimation_bound = max_s * fz + target_ds_len = int(x.shape[1] / ds_factor) + y = np.zeros((x.shape[0], target_ds_len, x.shape[2])) + + not_nan_idx = ~np.isnan(x[..., 0]) + decimation_check = not_nan_idx.sum(axis=1) >= decimation_bound + decimation_idx = np.where(decimation_check)[0] + lazy_idx = np.where(~decimation_check)[0] + + for idx in decimation_idx: + arr = x[idx] + + # slice out the non-nan values + arr = arr[not_nan_idx[idx]][::-1] + + # decimate the array + arr = decimate(arr, ds_factor, n=filter_order, axis=0, zero_phase=True)[::-1] + + # pad the array with NaNs to the target length along the first axis + if pad_order == 'front': + n_pad = ((target_ds_len - arr.shape[0], 0), (0, 0)) + arr = np.pad(arr, n_pad, mode='constant', constant_values=np.NaN) + elif pad_order == 'back': + n_pad = ((0, target_ds_len - arr.shape[0]), (0, 0)) + arr = np.pad(arr, n_pad, mode='constant', constant_values=np.NaN) + else: + raise ValueError("pad_order should be either 'front' or 'back'") + + y[idx] = arr + + for idx in lazy_idx: + y[idx] = x[idx, -1:0:-ds_factor][::-1] + + return y + + +def get_masks(x: torch.Tensor, + y: torch.Tensor, + ma_frames: int = 15, + non_scored_ids: Optional[torch.Tensor] = None, + k_max: int = 8) -> list[torch.Tensor]: + """ + Get masks for the input and output tensors. + There are four different masks that are created: + + 1. input_mask: Mask for the input tensor (contains True for valid values (not NaNs)) + + 2. valid_mask: Mask for the output tensor (contains True for valid values (not NaNs)). + This mask is recommended to be used for the loss calculation during training. + + 3. ta_mask: Mask for the target agent (contains True for the target agent). + This mask is used to identify the target agent in the single-target task. + It is recommended to be used for quantitative metrics calculation for validation and testing. + + 4. ma_mask: Mask for the multi-agent task (contains True for the surrounding agents). + This mask is used to identify the surrounding agents in the multi-target task. + It is recommended to be used for quantitative metrics calculation for validation and testing. + For the multi-target task, the 8 closest agents to the target agent are selected. + The requirement is that the agent should be visible for at least 3 seconds in to the future. + + :param x: The input tensor + :param y: The output tensor + :param ma_frames: The number of frames for the multi-agent mask (should represent 3 seconds) + :param non_scored_ids: The ids of the agents that should not be scored + :param k_max: The maximum number of agents to consider for the multi-agent mask + :return: input_mask, valid_mask, tv_mask, mv_mask + """ + input_mask = torch.isnan(x).sum(dim=2) == 0 + valid_mask = torch.isnan(y).sum(dim=2) == 0 + + # Target agent mask. This will always have adequate number of samples by design + ta_mask = torch.zeros_like(valid_mask) + ta_mask[0] = True + + # Multi-agent mask + + # Only interested in indices where valid mask is True for at least 3 s + long_idx = torch.where(valid_mask.sum(dim=1) >= ma_frames)[0] + + # Remove non-scored agents from long_idx (e.g., parked vehicles) + if non_scored_ids is not None: + long_idx = long_idx[~torch.isin(long_idx, non_scored_ids)] + + # (find the (max) 8 closest agents) + dist = torch.norm(x[long_idx, -1, :2] - x[0, -1, :2], dim=1) + max_k = min(k_max, len(long_idx)) # topk fails if k > input_tensor.shape[0] + _, indices = torch.topk(dist, max_k, largest=False) + + intersection = long_idx[indices] + + ma_mask = torch.zeros_like(valid_mask) + ma_mask[intersection] = True + ma_mask = ma_mask & valid_mask # only keep indices where valid mask is True + + return [input_mask, valid_mask, ta_mask, ma_mask] diff --git a/preprocessing/utils/highway_graph.py b/preprocessing/utils/highway_graph.py new file mode 100644 index 0000000..0e5b088 --- /dev/null +++ b/preprocessing/utils/highway_graph.py @@ -0,0 +1,160 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import numpy as np +import pandas as pd + + +def section_graph(lane_markings: list[float], + road_x: tuple[float, float], + spacing: float = 3.0, + direction: int = 1) -> dict: + """ + Create a graph representation of a section of a highway. + :param lane_markings: list of y-coordinates of the lane markings + :param road_x: tuple of the min. and max. x-coordinates of the road + :param spacing: spacing between the points + :param direction: direction of the road: + 1 positive right (lower section), -1 positive left (upper section) + :return: + """ + + pos = [] + node_type = [] + + # We would like to have a FLU coordinate system. + # We move the origin to the bottom right corner of the road + if direction == 1: + y0 = lane_markings[0] + else: + y0 = lane_markings[-1] + + n_markings = len(lane_markings) + + x_l = np.arange(road_x[0], road_x[1], spacing) + for j, lmi in enumerate(lane_markings): + y_l = np.abs(np.ones(x_l.shape) * lmi - y0) + xi = np.stack((x_l, y_l), axis=1) + pos.append(xi) + if j in (0, n_markings - 1): + node_cls = np.ones_like(x_l) * 2 # 2 is the type for road boundary nodes + else: + node_cls = np.ones_like(x_l) # 1 is the type for lane line nodes + node_type.append(node_cls) + + pos_arr = np.concatenate(pos, axis=0) + node_attr = np.concatenate(node_type, axis=0) + + map_data: dict = { + 'map_point': {}, + ('map_point', 'to', 'map_point'): {} + } + + map_data['map_point']['num_nodes'] = pos_arr.shape[0] + map_data['map_point']['type'] = torch.from_numpy(node_attr).long() + # map_data['map_point']['y0'] = y0 + # map_data['map_point']['driving_dir'] = direction + map_data['map_point']['position'] = torch.from_numpy(pos_arr).float() + + nodes_per_lane = len(x_l) + node_idx = 0 + edge_index = [] + edge_attr = [] + for j in range(n_markings): + edges = np.array([[i, i + 1] for i in range(nodes_per_lane - 1)] + + [[i + 1, i] for i in range(nodes_per_lane - 1)]).T + node_idx + edge_index.append(edges) + node_idx += nodes_per_lane + if j in (0, n_markings - 1): + edge_cls = np.ones((len(edges[0]), 1)) * 2 # 2 is the type for road boundaries + else: + edge_cls = np.ones((len(edges[0]), 1)) # 1 is the type for lane lines + edge_attr.append(edge_cls) + + edge_index = np.concatenate(edge_index, axis=1) + edge_attr = np.concatenate(edge_attr, axis=0) + + map_data['map_point', 'to', 'map_point']['edge_index'] = torch.from_numpy(edge_index).long() + map_data['map_point', 'to', 'map_point']['type'] = torch.from_numpy(edge_attr).float() + + return map_data + + +def get_highway_graph(rec_meta: pd.DataFrame, + tracks: pd.DataFrame, + spacing: float = 3.0, + buffer: float = 10.0) -> tuple[dict, dict, float, float]: + """ + Get the graph representation of the highway from the lane markings. + :param rec_meta: meta dataframe of the recording (used to get the lane markings) + :param tracks: trajectory dataframe of the recording (used to get the range of x values) + :param spacing: spacing between the lane graph nodes + :param buffer: buffer to add to the range of x values + :return: + """ + + ulm = [float(l) for l in list(rec_meta['upperLaneMarkings'])[0].split(';')] + llm = [float(l) for l in list(rec_meta['lowerLaneMarkings'])[0].split(';')] + + x_min = tracks.x.min() + x_max = tracks.x.max() + + norm_max = int(x_max - x_min) + buffer + norm_min = - buffer + + # make sure the range is divisible by spacing + norm_max = norm_max + (spacing - norm_max % spacing) + norm_min = norm_min - (norm_min % spacing) + + data_ulm = section_graph(ulm, (norm_min, norm_max), spacing, direction=-1) + data_llm = section_graph(llm, (norm_min, norm_max), spacing, direction=1) + return data_ulm, data_llm, x_min, x_max + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + i = 4 + rec_idx = f"0{str(i)}" if i < 10 else str(i) + ROOT = "../../data_sets/highD/data" + + recording_meta = pd.read_csv(f"{ROOT}/{rec_idx}_recordingMeta.csv") + tracks_csv = pd.read_csv(f"{ROOT}/{rec_idx}_tracks.csv") + + data_upper, data_lower, *_ = get_highway_graph(recording_meta, tracks_csv, spacing=10) + + plot_data = data_upper + + # plt.figure(figsize=(20, 5)) + + # plot upper lane markings using edge_index + for i in range(plot_data['map_point', 'to', 'map_point']['edge_index'].shape[1]): + source = plot_data['map_point', 'to', 'map_point']['edge_index'][0, i] + target = plot_data['map_point', 'to', 'map_point']['edge_index'][1, i] + source_pos = plot_data['map_point']['position'][source] + target_pos = plot_data['map_point']['position'][target] + COLOR = 'k' if plot_data['map_point', 'to', 'map_point']['type'][i] == 2 else 'grey' + plt.plot([source_pos[0], target_pos[0]], + [source_pos[1], target_pos[1]], color=COLOR, zorder=1) + + # plot all points + for i in range(plot_data['map_point']['position'].shape[0]): + COLOR = 'r' if plot_data['map_point']['type'][i] == 2 else 'b' + plt.scatter(plot_data['map_point']['position'][i, 0], + plot_data['map_point']['position'][i, 1], color=COLOR, s=5, zorder=2) + + # plt.axis('equal') + plt.show() diff --git a/preprocessing/utils/highway_utils.py b/preprocessing/utils/highway_utils.py new file mode 100644 index 0000000..cb634ec --- /dev/null +++ b/preprocessing/utils/highway_utils.py @@ -0,0 +1,269 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from pandas import DataFrame, unique + + +def align_origin_w_centroid(tracks_meta: DataFrame, + tracks: DataFrame, debug: bool = False) -> DataFrame: + """ + The coordinates are given wrt the upper left corner of the bounding box + this function modifies the dataframe such that the coordinates are align + with the center of the bounding box + """ + if debug: + return tracks + + ids = tracks_meta.trackId + driving_dirs = tracks_meta.drivingDirection + for i, dd in zip(ids, driving_dirs): + tracks.loc[tracks.trackId == i, 'y'] += tracks.loc[tracks.trackId == i, 'height'] / 2 + if dd == 2: + tracks.loc[tracks.trackId == i, 'x'] += tracks.loc[tracks.trackId == i, 'width'] / 2 + return tracks + + +def add_heading_feat(tracks: DataFrame, debug: bool = False) -> DataFrame: + """ + Add heading as a feature to the tracks dataframe + """ + tracks['psi'] = np.empty(len(tracks)) + + if debug: + return tracks + + t_ids = unique(tracks.trackId) + for t_id in t_ids: + vy = tracks.loc[tracks.trackId == t_id, 'vy'].to_numpy() + vx = tracks.loc[tracks.trackId == t_id, 'vx'].to_numpy() + psi = np.arctan2(vy, vx) + tracks.loc[tracks['trackId'] == t_id, ['psi']] = psi + return tracks + + +def add_maneuver(tracks_meta: DataFrame, tracks: DataFrame, + fz: int = 25, debug: bool = False) -> DataFrame: + """ + Add maneuver as a feature to the tracks dataframe. + + There are 7 different maneuvers: + 0: left lane change within the next (1) second + 1: left lane change within the next 3 seconds + 2: left lane change within the next 5 seconds + 3: no lane change (lane keep) + 4: right lane change within the next (1) second + 5: right lane change within the next 3 seconds + 6: right lane change within the next 5 seconds + """ + + lane_change_left = [0, 1, 2] + lane_change_right = [4, 5, 6] + + tracks['maneuver'] = np.ones(len(tracks), dtype=int) * 3 + + if debug: + return tracks + + t_ids = unique(tracks.trackId) + for t_id in t_ids: + if int(tracks_meta[tracks_meta.trackId == t_id].numLaneChanges.iloc[0]) > 0: + df = tracks[tracks.trackId == t_id] + dr_dir = df.drivingDirection.values[0] + frames = df.frame.to_numpy() + lanes = df.laneId.to_numpy() + event_indices = [i for i in range(1, len(lanes)) if lanes[i] != lanes[i - 1]] + + for event_index in event_indices: + five_seconds_prior = range(max(0, event_index - 5 * fz + 1), event_index) + three_seconds_prior = range(max(0, event_index - 3 * fz + 1), event_index) + one_second_prior = range(max(0, event_index - fz + 1), event_index) + + five_second_frames = (frames[i] for i in five_seconds_prior) + three_second_frames = (frames[i] for i in three_seconds_prior) + one_second_frames = (frames[i] for i in one_second_prior) + + delta_lane = lanes[event_index] - lanes[event_index - 1] + if dr_dir == 1: + # Traveling from right to left. Lane index increases + maneuvers = lane_change_left if delta_lane > 0 else lane_change_right + else: + # Traveling from left to right. Lane index increases + maneuvers = lane_change_right if delta_lane > 0 else lane_change_left + + tracks.loc[(tracks['trackId'] == t_id) & + (tracks['frame'].isin(five_second_frames)), ['maneuver']] = maneuvers[2] + tracks.loc[(tracks['trackId'] == t_id) & + (tracks['frame'].isin(three_second_frames)), ['maneuver']] = maneuvers[1] + tracks.loc[(tracks['trackId'] == t_id) & + (tracks['frame'].isin(one_second_frames)), ['maneuver']] = maneuvers[0] + + return tracks + + +def add_driving_direction(tracks_meta: DataFrame, tracks: DataFrame) -> DataFrame: + """ + Add driving direction (1 or 2) as a feature to the tracks dataframe. + If driving direction is 1, the vehicle is driving from right to left (negative x). + If driving direction is 2, the vehicle is driving from left to right (positive x). + """ + + tracks['drivingDirection'] = np.empty(len(tracks)) + t_ids = unique(tracks.trackId) + for t_id in t_ids: + driving_direction = tracks_meta[tracks_meta.trackId == t_id].drivingDirection.values[0] + tracks.loc[tracks['trackId'] == t_id, ['drivingDirection']] = driving_direction + return tracks + + +def add_displacement_feat(rec_meta: DataFrame, tracks_meta: DataFrame, + tracks: DataFrame, + debug: bool = False) -> DataFrame: + """ + Add roadDisplacement and laneDisplacement as features to the tracks dataframe. + These features are used to determine the relative position + of the vehicle with respect to the road and the lane. + They could potentially replace the lane graph and be added + to the input features of the model. + """ + + ulm = [float(l) for l in list(rec_meta['upperLaneMarkings'])[0].split(';')] + llm = [float(l) for l in list(rec_meta['lowerLaneMarkings'])[0].split(';')] + + def compute_road_w(): + upper_l = ulm[-1] - ulm[0] + lower_l = llm[-1] - llm[0] + return upper_l, lower_l + + def compute_lane_w(): + upper_l = np.mean(np.diff(ulm)) + lower_l = np.mean(np.diff(llm)) + return np.mean([upper_l, lower_l]) + + def get_road_edge_markings(): + return ulm[0], llm[0] + + def get_lane_markings(): + combined = ulm + llm + return np.array(combined) + + def get_dyl(y, dd, lm, lw): + dy = 2 * (y - lm) / lw - 1 + if dd == 2: + dy *= (-1) + return dy + + def get_dy(y, dd, curr_lane_id, lm, lw): + dy = 2 * (y - lm[curr_lane_id - 2]) / lw - 1 + if dd == 2: + dy *= (-1) + return dy + + tracks['roadDisplacement'] = np.empty(len(tracks)) + tracks['laneDisplacement'] = np.empty(len(tracks)) + + if debug: + return tracks + + ur, lr = get_road_edge_markings() + ruw, rlw = compute_road_w() + + lm = get_lane_markings() + lw = compute_lane_w() + t_ids = unique(tracks.trackId) + for t_id in t_ids: + driving_dir = int(tracks_meta[tracks_meta.trackId == t_id].drivingDirection.iloc[0]) + lane_ids = tracks.loc[tracks.trackId == t_id, 'laneId'].to_numpy() + y = tracks.loc[tracks.trackId == t_id, 'y'].to_numpy() + d_y = get_dy(y, driving_dir, lane_ids, lm, lw) + + marking, width = (ur, ruw) if driving_dir == 1 else (lr, rlw) + d_y_r = get_dyl(y, driving_dir, marking, width) + + tracks.loc[tracks['trackId'] == t_id, ['laneDisplacement']] = d_y + tracks.loc[tracks['trackId'] == t_id, ['roadDisplacement']] = d_y_r + return tracks + + +def get_disp_features(df: DataFrame, frame_start: int, frame_end: int, track_id=-1) -> np.ndarray: + return_array = np.empty((frame_end - frame_start + 1, 2)) + return_array[:] = np.NaN + + if track_id != -1: + dfx = df[(df.frame >= frame_start) & (df.frame <= frame_end) & (df.trackId == track_id)] + else: + dfx = df[(df.frame >= frame_start) & (df.frame <= frame_end)] + try: + first_frame = dfx.frame.values[0] + except IndexError: + return return_array + frame_offset = first_frame - frame_start + + features = dfx[['roadDisplacement', 'laneDisplacement']].to_numpy() + + return_array[frame_offset:frame_offset + features.shape[0], :] = features + + return return_array + + +def update_signs(rec_meta: DataFrame, tracks_meta: DataFrame, + tracks: DataFrame, debug: bool = False) -> DataFrame: + """ + We are looking to unify the coordinate system under a + FLU (frontward-leftward-upward) coordinate system (ISO standard): + Forward motion = positive x + Leftward motion = positive y + (Upward motion = positive z) + + This requires updating the tracks differently depending + on how the vehicles are moving (driving direction). + To find the origin of the FLU coordinate system, + we utilize the lower and upper lane markings. + Longitudinal motion is updated based on the driving direction. + + """ + if debug: + return tracks + + ulm = [float(x) for x in list(rec_meta['upperLaneMarkings'])[0].split(';')] + llm = [float(x) for x in list(rec_meta['lowerLaneMarkings'])[0].split(';')] + + # subtract x_min from all tracks (to make everything start/end at 0) + x_min = tracks.x.min() + tracks.x -= x_min + + x_max = tracks.x.max() + + t_ids = unique(tracks.trackId) + for t_id in t_ids: + driving_dir = tracks_meta[tracks_meta.trackId == t_id].drivingDirection.values[0] + + if driving_dir == 1: + tracks.loc[(tracks['trackId'] == t_id), ['y']] = \ + tracks.loc[(tracks['trackId'] == t_id), ['y']] - ulm[0] + tracks.loc[(tracks['trackId'] == t_id), ['x']] = \ + -tracks.loc[(tracks['trackId'] == t_id), ['x']] + x_max + tracks.loc[(tracks['trackId'] == t_id), ['vx']] = \ + -tracks.loc[(tracks['trackId'] == t_id), ['vx']] + tracks.loc[(tracks['trackId'] == t_id), ['ax']] = \ + -tracks.loc[(tracks['trackId'] == t_id), ['ax']] + else: + tracks.loc[(tracks['trackId'] == t_id), ['y']] = \ + llm[-1] - tracks.loc[(tracks['trackId'] == t_id), ['y']] + tracks.loc[(tracks['trackId'] == t_id), ['vy']] = \ + -tracks.loc[(tracks['trackId'] == t_id), ['vy']] + tracks.loc[(tracks['trackId'] == t_id), ['ay']] = \ + -tracks.loc[(tracks['trackId'] == t_id), ['ay']] + + return tracks diff --git a/preprocessing/utils/lanelet_graph.py b/preprocessing/utils/lanelet_graph.py new file mode 100644 index 0000000..3665065 --- /dev/null +++ b/preprocessing/utils/lanelet_graph.py @@ -0,0 +1,285 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +import utm +import torch +import pandas as pd +import osmium as osm +import networkx as nx + +edge_types = { + ("fence", "wall", "road_border"): "road_border", + ("curbstone",): "curb" +} + +type_conversion = { + None: 1, + "road_border": 2, + "curb": 3 +} + + +class OSMHandler(osm.SimpleHandler): + def __init__(self, + utm_x0: float, + utm_y0: float, + map_x0: float = 0., + map_y0: float = 0.) -> None: + osm.SimpleHandler.__init__(self) + self.graph = nx.Graph() + self.nodes: dict = {} + self.ways: dict = {} + self.relations: dict = {} + self.utm_x0 = utm_x0 + self.utm_y0 = utm_y0 + self.map_x0 = map_x0 + self.map_y0 = map_y0 + + self.node_idx = 0 + + def coordinate_shift(self, lat: float, lon: float) -> tuple[float, float]: + utm_x, utm_y, *_ = utm.from_latlon(lat, lon) + x, y = utm_x - (self.utm_x0 + self.map_x0), utm_y - (self.utm_y0 + self.map_y0) + return x, y + + def node(self, n: osm.Node) -> None: + x, y = self.coordinate_shift(n.location.lat, n.location.lon) + + self.graph.add_node(self.node_idx, pos=(x, y), type=None) + self.nodes[n.id] = self.node_idx + self.node_idx += 1 + + def way(self, w: osm.Way) -> None: + self.ways[w.id] = {"nodes": [n.ref for n in w.nodes], + "tags": {tag.k: tag.v for tag in w.tags}} + # Add edges to the graph + nodes = list(w.nodes) + for i in range(len(nodes) - 1): + from_node = self.nodes[nodes[i].ref] + to_node = self.nodes[nodes[i + 1].ref] + + if from_node in self.graph and to_node in self.graph: + self.graph.add_edge(from_node, to_node, type=None) + + for tag in w.tags: + if tag.k == 'type': + for key, value in edge_types.items(): + if tag.v in key: + for i in range(len(nodes) - 1): + this_node = self.nodes[nodes[i].ref] + next_node = self.nodes[nodes[i + 1].ref] + if this_node in self.graph: + self.graph.nodes[this_node]['type'] = value + if next_node in self.graph: + # check if edge exists + if not self.graph.has_edge(this_node, next_node): + self.graph.add_edge(this_node, next_node, type=value) + else: + # update type + self.graph.edges[this_node, next_node]['type'] = value + + def relation(self, r: osm.Relation) -> None: + self.relations[r.id] = {"members": [(m.type, m.ref, m.role) for m in r.members], + "tags": {tag.k: tag.v for tag in r.tags}} + + +# @overload +# def get_urban_graph(rec_meta: pd.DataFrame, +# lanelet_file: str, +# map_x0: float = 0., +# map_y0: float = 0., +# return_torch: bool = True) -> dict: +# ... +# +# +# @overload +# def get_urban_graph(rec_meta: pd.DataFrame, +# lanelet_file: str, +# map_x0: float = 0., +# map_y0: float = 0., +# return_torch: bool = False) -> OSMHandler: +# ... + + +def get_lanelet_graph(rec_meta: pd.DataFrame, + lanelet_file: str, + map_x0: float = 0., + map_y0: float = 0., + return_torch: bool = False) -> Any: # OSMHandler | dict: + meta = rec_meta + x_utm_origin = meta.xUtmOrigin.values[0] + y_utm_origin = meta.yUtmOrigin.values[0] + + osm_handler = OSMHandler(x_utm_origin, y_utm_origin, map_x0, map_y0) + osm_handler.apply_file(lanelet_file) + + if return_torch: + return get_torch_graph(osm_handler.graph) + + return osm_handler + + +def get_torch_graph(graph: nx.Graph) -> dict: + graph = graph.to_directed(as_view=False) # this naming convention is the opposite from PyG + num_nodes = graph.number_of_nodes() + + pos = nx.get_node_attributes(graph, 'pos') + position = torch.tensor([pos[i] for i in range(num_nodes)], dtype=torch.float) + + node_types = nx.get_node_attributes(graph, 'type') + node_types = torch.tensor([type_conversion[node_types[i]] + for i in range(num_nodes)], dtype=torch.long) + + e_types = nx.get_edge_attributes(graph, 'type') + edge_index = torch.tensor(list(graph.edges), dtype=torch.long).t().contiguous() + edge_attr = torch.tensor([type_conversion[e_types[(u, v)]] + for u, v in graph.edges], dtype=torch.long) + + map_data: dict = { + 'map_point': {}, + ('map_point', 'to', 'map_point'): {} + } + + map_data['map_point']['num_nodes'] = num_nodes + map_data['map_point']['type'] = node_types + # map_data['map_point']['driving_dir'] = 0 + # map_data['map_point']['y0'] = 0.0 + map_data['map_point']['position'] = position + + map_data['map_point', 'to', 'map_point']['edge_index'] = edge_index + map_data['map_point', 'to', 'map_point']['type'] = edge_attr[:, None] + + return map_data + + +ind_lanelet_mapping = { + tuple(f"{i:02}" for i in range(7, 10 + 1)): "01_bendplatz", + tuple(f"{i:02}" for i in range(11, 17 + 1)): "01_bendplatz_construction", + tuple(f"{i:02}" for i in range(18, 29 + 1)): "02_frankenburg", + tuple(f"{i:02}" for i in range(30, 32 + 1)): "03_heckstrasse", + tuple(f"{i:02}" for i in range(0, 6 + 1)): "04_aseag" +} + +round_lanelet_mapping = { + ("00",): "1_kackertstrasse", + ("01",): "2_thiergarten", + tuple(f"{i:02}" for i in range(2, 23 + 1)): "0_neuweiler", +} + +unid_lanelet_mapping = { + tuple(f"{i:02}" for i in range(0, 12 + 1)): "0_superc", +} + +ds_mapping = { + "rounD": round_lanelet_mapping, + "inD": ind_lanelet_mapping, + "uniD": unid_lanelet_mapping + +} + +if __name__ == "__main__": + import os + import matplotlib.pyplot as plt + from matplotlib.transforms import Bbox + + ROOT = "../../../data_sets" + + IDX = 30 + DS = "inD" + STR_IDX = '0' + str(IDX) if IDX < 10 else str(IDX) + + meta_file_pth = f"data/{STR_IDX}_recordingMeta.csv" + meta_file_pth = os.path.join(ROOT, DS, meta_file_pth) + + meta = pd.read_csv(meta_file_pth) + + PATH = None + mapping = ds_mapping[DS] + + for key, path in mapping.items(): + if STR_IDX in key: + PATH = path + break + + if PATH is None: + raise ValueError(f"Index {IDX} not found in mapping") + + path_to_lanelet = os.path.join(ROOT, DS, "maps", "lanelets", PATH) + files = os.listdir(path_to_lanelet) + path = os.path.join(path_to_lanelet, files[0]) + + osm_handler = get_lanelet_graph(meta, path, return_torch=False) + + # # Get all 'relation' unique types + # unique_types = set() + # unique_subtypes = set() + # for way in osm_handler.relations.values(): + # for k, v in way["tags"].items(): + # if k == "type": + # unique_types.add(v) + # else: + # unique_subtypes.add(v) + # print(unique_types) + # print(unique_subtypes) + + pos = nx.get_node_attributes(osm_handler.graph, 'pos') + + # get all road_border nodes + road_border_nodes = [n for n, attr in osm_handler.graph.nodes(data=True) + if attr['type'] == 'road_border'] + curb_nodes = [n for n, attr in osm_handler.graph.nodes(data=True) + if attr['type'] == 'curb'] + + # get all edges of type road_border + road_border_edges = [(u, v) for u, v, attr in osm_handler.graph.edges(data=True) + if attr['type'] == 'road_border'] + curb_edges = [(u, v) for u, v, attr in osm_handler.graph.edges(data=True) + if attr['type'] == 'curb'] + lane_edges = [(u, v) for u, v, attr in osm_handler.graph.edges(data=True) + if attr['type'] is None] + + # get all non road_border nodes + other = [n for n, attr in osm_handler.graph.nodes(data=True) if attr['type'] is None] + + fig = plt.figure() + + NS = 0.8 + + nx.draw_networkx_nodes(osm_handler.graph, pos=pos, + nodelist=other, node_color="tab:blue", node_size=NS) + nx.draw_networkx_nodes(osm_handler.graph, pos=pos, + nodelist=curb_nodes, node_color="tab:green", node_size=NS) + nx.draw_networkx_nodes(osm_handler.graph, pos=pos, + nodelist=road_border_nodes, node_color="tab:red", node_size=NS) + + nx.draw_networkx_edges(osm_handler.graph, pos=pos, + edgelist=road_border_edges, edge_color='k', width=0.5) + nx.draw_networkx_edges(osm_handler.graph, pos=pos, + edgelist=curb_edges, edge_color='grey', width=0.5) + nx.draw_networkx_edges(osm_handler.graph, pos=pos, + edgelist=lane_edges, edge_color='grey', width=0.5, + style='dashed') + + plt.gca().set_aspect('equal', adjustable='box') + plt.axis('off') + plt.tight_layout() + + bbox = fig.bbox_inches + bbox_points = [[bbox.x0 + 0.5, bbox.y0 + 1], [bbox.x1 - 0.5, bbox.y1 - 0.8]] + bbox = Bbox(bbox_points) + + # plt.savefig("in_graph.pdf", bbox_inches=bbox, pad_inches=0) + + plt.show() diff --git a/train.py b/train.py new file mode 100644 index 0000000..1dca5ed --- /dev/null +++ b/train.py @@ -0,0 +1,122 @@ +# Copyright 2024, Theodor Westny. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import pathlib +import warnings +import torch + +from torch.multiprocessing import set_sharing_strategy +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.loggers import Logger, WandbLogger +from lightning.pytorch.strategies import Strategy, DDPStrategy + +from arguments import args +from preamble import load_config, import_from_module + +torch.set_float32_matmul_precision('medium') +warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") +warnings.filterwarnings("ignore", ".*Checkpoint directory*") + +set_sharing_strategy('file_system') + +# Load configuration and import modules +config = load_config(args.config) +TorchModel = import_from_module(config["model"]["module"], config["model"]["class"]) +LitDataModule = import_from_module(config["datamodule"]["module"], config["datamodule"]["class"]) +LitModel = import_from_module(config["litmodule"]["module"], config["litmodule"]["class"]) + + +def main(save_name: str) -> None: + ds = config["dataset"] + ckpt_path = pathlib.Path(f"saved_models/{ds}/{save_name}.ckpt") + + # Check if checkpoint exists and the overwrite flag is not set + if ckpt_path.exists() and not args.overwrite: + ckpt = str(ckpt_path) + else: + ckpt = None + + # Setup callbacks list for training + callback_list: list[Callback] = [] + if args.store_model: + ckpt_cb = ModelCheckpoint( + dirpath=str(ckpt_path.parent), # Using parent directory of the checkpoint + filename=save_name + "_{epoch:02d}", + ) + + ckpt_cb_best = ModelCheckpoint( + dirpath=str(ckpt_path.parent), + filename=save_name, + monitor="val_loss", + mode="min" + ) + + callback_list += [ckpt_cb, ckpt_cb_best] + + # Determine the number of devices, strategy and accelerator + strategy: str | Strategy + if torch.cuda.is_available() and args.use_cuda: + devices = -1 if torch.cuda.device_count() > 1 else 1 + strategy = DDPStrategy(find_unused_parameters=True, + gradient_as_bucket_view=True) if devices == -1 else 'auto' + accelerator = "auto" + else: + devices, strategy, accelerator = 1, 'auto', "cpu" + + # Setup logger + logger: bool | Logger + if args.dry_run: + logger = False + args.small_ds = True + elif not args.use_logger: + logger = False + else: + run_name = f"{save_name}_{time.strftime('%d-%m_%H:%M:%S')}" + logger = WandbLogger(project="dronalize", name=run_name) + + # Setup model, datamodule and trainer + net = TorchModel(config["model"]) + model = LitModel(net, config["training"]) + + if args.root: + config["datamodule"]["root"] = args.root + + datamodule = LitDataModule(config["datamodule"], args) + + trainer = Trainer(max_epochs=config["training"]["epochs"], + logger=logger, + devices=devices, + strategy=strategy, + accelerator=accelerator, + callbacks=callback_list, + fast_dev_run=args.dry_run, + enable_checkpointing=args.store_model) + + # Start training + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt) + + +if __name__ == "__main__": + seed_everything(args.seed, workers=True) + + ds_name = config["dataset"] + full_save_name = f"Example{args.add_name}-{ds_name}" + print('----------------------------------------------------') + print(f'\nGetting ready to train model: {full_save_name} \n') + print('----------------------------------------------------') + + main(full_save_name) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..29f20cd --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +apptainer run --nv apptainer/dronalize.sif python train.py --add-name Dev --dry-run 1 --use-cuda 1 --n-workers 4