diff --git a/README.md b/README.md index 96943fd..8248583 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,12 @@ This repository is based on the official implementation of [Discrete Point Flow] # Environment -This repository requires: +We provide all necessary requirements in form of a `environment.yml`. -- pytorch -- ... - -We further provide all necessary requirements in for of a `requirements.txt`. +For our evaluation we rely on the efficient implementation of the EMD metric provided by [PointFlow](https://github.com/stevenygd/PointFlow). +To this end, we refer to the installation instructions provided there. +Alternatively, the precompiled code can be downloaded [here](https://drive.google.com/drive/folders/1jFo6gSuQNjVq-8oB0iZ2YkMRFVP7t8GC?usp=sharing), +which needs to be unzipped and placed in `lib/metrics/` and is expected to work with the provided `environment.yml`. # Datasets @@ -38,14 +38,12 @@ Since the preprocessing takes up to a week, we provide the preprocessed datasets # Pretrained models All pretrained models including the corresponding config files can be downloaded [here](https://drive.google.com/drive/folders/1fkVBVqxy2_zTevwd3WdnROPreYke-zuU?usp=sharing). -To use the models, you need to download the models and put the files in the root directory `./`. -Then, specify the `path2data` storing preprocessed data and path2save directory storing all saved -checkpoints. - +To use the models during evaluation, specify your path to the preprocessed data `path2data` in the configs of the pretrained models. + # Training All training configurations can be found in `configs/`. Prior to training/evaluation remember to set -`path2data` in the resp. config file accordingly. +`path2data` in the resp. config file accordingly. Note, `path2save` specifies the logging directory and defaults to `./results`. ## Generative modeling diff --git a/configs/config_generative_modeling_airplane.yaml b/configs/config_generative_modeling_airplane.yaml index 612178e..356ac89 100644 --- a/configs/config_generative_modeling_airplane.yaml +++ b/configs/config_generative_modeling_airplane.yaml @@ -41,11 +41,11 @@ num_workers: 8 p_decoder_base_type: free p_decoder_base_var: -3.9551 p_decoder_n_features: 64 -p_decoder_n_flows: 1 #21 +p_decoder_n_flows: 21 p_latent_space_size: 3 p_prior_n_layers: 1 params_reduce_mode: depth_and_feature -path2data: /usr/local/google/home/postels/research/data/ShapeNet # your/path/to/data +path2data: your/path/to/data path2save: ./results pc_enc_init_n_channels: 3 pc_enc_init_n_features: 64 diff --git a/scripts/run_evaluate_ae.sh b/scripts/run_evaluate_ae.sh index d4d26dc..edbe37d 100644 --- a/scripts/run_evaluate_ae.sh +++ b/scripts/run_evaluate_ae.sh @@ -1 +1 @@ -python evaluate_ae.py ./configs/config_autoencoding.yaml path_to_trained_model test 2048 2048 autoencoding --weights_type learned_weights --reps 1 --f1_threshold_lst 0.0001 --cd --f1 --emd +python evaluate_ae.py path_to_experiment_root name_of_trained_model test 2048 2048 autoencoding --weights_type learned_weights --reps 1 --f1_threshold_lst 0.0001 --cd --f1 --emd diff --git a/scripts/run_evaluate_gen.sh b/scripts/run_evaluate_gen.sh index 9c080e5..ce82943 100644 --- a/scripts/run_evaluate_gen.sh +++ b/scripts/run_evaluate_gen.sh @@ -1 +1 @@ -python evaluate_ae.py ./configs/config_generative_modeling_airplane.yaml path_to_trained_model test 2048 2048 generating --weights_type learned_weights --reps 10 --f1_threshold_lst 0.0001 --cd --emd +python evaluate_ae.py path_to_experiment_root name_of_trained_model test 2048 2048 generating --weights_type learned_weights --reps 10 --f1_threshold_lst 0.0001 --cd --emd diff --git a/scripts/run_evaluate_svr.sh b/scripts/run_evaluate_svr.sh index 5c9badf..f7ac153 100644 --- a/scripts/run_evaluate_svr.sh +++ b/scripts/run_evaluate_svr.sh @@ -1 +1 @@ -python evaluate_ae.py ./configs/config_SVR.yaml path_to_trained_model test 2500 2500 reconstruction --weights_type learned_weights --reps 1 --f1_threshold_lst 0.001 --cd --f1 --emd --unit_scale_evaluation +python evaluate_ae.py path_to_experiment_root name_of_trained_model test 2500 2500 reconstruction --weights_type learned_weights --reps 1 --f1_threshold_lst 0.001 --cd --f1 --emd --unit_scale_evaluation