jAM (JAX implementation of Action Matching)
Action Matching is a method for learning the time evolution of distributions from samples. That is, suppose we observe the time evolution of some random variable
The key idea is to learn such function
Method | Link |
---|---|
Action Matching | |
Entropic (Stochastic) Action Matching | |
Unbalanced Action Matching (with reweighting) |
Run all the code using main.py
with different config
and mode
.
config
flag takes the path to the config file.mode
flag takes one of the following values: "train", "eval", "fid_stats". All the modes require a config file. Mind thedata.uniform_dequantization
flag when evaluating statistics on the dataset for FID evaluation.workdir
is the path to the working directory for storing states.
For instance, on the clusters with slurm, you would run the code like this inside your sbatch scripts
python main.py --config configs/am/cifar/generation.py \
--workdir $PWD/checkpoint/${SLURM_JOB_ID} \
--mode 'train'