Skip to content

nZhangx/TrajectoryFlowMatching

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Trajctory Flow Matching

TFM Preprint pytorch lightning hydra license Template

Description

Trajectory Flow Matching (TFM) is a method that leverages the flow matching technique from generative modeling to model time series. This approach offers a simulation-free training process, allowing for efficient fitting of stochastic differential equations to time-series data. Augmented with memory, time interval prediction, and uncertainty prediction, TFM can better model irregularly sampled trajectories with stochastic nature, for example clinical time series.

The idea of TFM lies in using flow matching concept to predict both stochastic uncertainty and the next value in the time series. The prediction is conditioned on past data and conditional variables.

How to run

Initialize environment

# clone project
git clone https://github.com/nZhangx/TrajectoryFlowMatching.git
cd TrajectoryFlowMatching

# [OPTIONAL] create conda environment
conda create -n tfm python=3.10
conda activate tfm

# install requirements
conda env create -f environment.yml

Run experiments

Under src, create new a DATA_NAME.yml under conf/data and a MODEL_NAME.yml under conf/model with desired configurations. Then replace data and model definitions in conf/config.yaml with your DATA_NAME and MODEL_NAME. Then run

python src/main.py

Demo

We have included an example of TFM modeling three crossing oscillations in a self-contained Jupyter notebook notebook/3Oscillation.ipynb.

Implemented models

ICU Sepsis ICU Cardiac Arrest ICU GIB ED GIB
NeuralODE 4.776 $\pm$ 0.000 6.153 $\pm$ 0.000 3.170 $\pm$ 0.000 10.859 $\pm$ 0.000
FM baseline ODE 4.671 $\pm$ 0.791 10.207 $\pm$ 1.076 118.439 $\pm$ 17.947 11.923 $\pm$ 1.123
LatentODE-RNN 61.806 $\pm$ 46.573 386.190 $\pm$ 558.140 422.886 $\pm$ 431.954 980.228 $\pm$ 1032.393
TFM-ODE (ours) 0.793 $\pm$ 0.017 2.762 $\pm$ 0.021 2.673 $\pm$ 0.069 8.245 $\pm$ 0.495
NeuralSDE 4.747 $\pm$ 0.000 3.250 $\pm$ 0.024 3.186 $\pm$ 0.000 10.850 $\pm$ 0.043
TFM (ours) 0.796 $\pm$ 0.026 2.755 $\pm$ 0.015 2.596 $\pm$ 0.079 8.613 $\pm$ 0.260

Available datasets

We plan to share the clinical data we used that are from the eICU Collaborative Research Database v2.0 (ICU sepsis and ICU Cardiac Arrest) and the Medical Information Mart for Intensive Care III (MIMIC-III) critical care database (ICU GIB) on Physionet.

How to cite

This repository contains the code to reproduce the main experiments and illustrations of the preprint Trajectory Flow Matching with Applications to Clinical Time Series Modeling. We are excited that it was marked as a spotlight presentation.

If you find this code useful in your research, please cite (expand for BibTeX):

bibtex citation
@article{TFM,
	title        = {Trajectory Flow Matching with Applications to Clinical Time Series Modelling},
	author       = {Zhang, Xi and Pu, Yuan and Kawamura, Yuki and Loza, Andrew and Bengio, Yoshua and Shung, Dennis and Tong, Alexander},
	year         = 2024,
	journal      = {NeurIPS},
}

License

This repo is licensed under the MIT License.

About

Code repository for Trajectory Flow Matching

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published