From 7378819870b8f1de3f10c96eb8e053f0381fa247 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20G=C3=B3mez?= Date: Tue, 24 Aug 2021 12:02:21 +0200 Subject: [PATCH] - Fixed alignment in print_cfg - Added reg_loss_weight as parameter in cfg --- nidn/training/run_training.py | 4 +++- nidn/training/utils/validate_config.py | 3 +++ nidn/utils/print_cfg.py | 4 ++-- nidn/utils/resources/default_config.toml | 1 + 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/nidn/training/run_training.py b/nidn/training/run_training.py index aea7722..c12cddc 100644 --- a/nidn/training/run_training.py +++ b/nidn/training/run_training.py @@ -138,7 +138,9 @@ def run_training( loss += spectrum_loss if run_cfg.type == "classification" and run_cfg.use_regularization_loss: - loss += 0.05 * _likelihood_regularization_loss_fn(material_ids, run_cfg.L) + loss += run_cfg.reg_loss_weight * _likelihood_regularization_loss_fn( + material_ids, run_cfg.L + ) # We store the model if it has the lowest loss yet # (this is to avoid losing good results during a run that goes wild) diff --git a/nidn/training/utils/validate_config.py b/nidn/training/utils/validate_config.py index 25a7f3a..aac765d 100644 --- a/nidn/training/utils/validate_config.py +++ b/nidn/training/utils/validate_config.py @@ -32,6 +32,7 @@ def _validate_config(cfg: DotMap): "absorption_loss", "type", "use_regularization_loss", + "reg_loss_weight", "add_noise", "noise_scale", ] @@ -59,6 +60,7 @@ def _validate_config(cfg: DotMap): "imag_max_eps", "siren_omega", "noise_scale", + "reg_loss_weight", ] boolean_keys = ["use_regularization_loss", "add_noise"] string_keys = ["model_type", "type"] @@ -99,6 +101,7 @@ def _validate_config(cfg: DotMap): "iterations", "eps_oversampling", "noise_scale", + "reg_loss_weight", ] for key in positive_value_keys: if not (cfg[key] > 0): diff --git a/nidn/utils/print_cfg.py b/nidn/utils/print_cfg.py index 7d1e456..ac040b7 100644 --- a/nidn/utils/print_cfg.py +++ b/nidn/utils/print_cfg.py @@ -21,8 +21,8 @@ def print_cfg(cfg: DotMap): print() else: if idx % 3 == 2: - print(f"{key:<20}: {value:<20}") + print(f"{key:<23}: {value:<15}|") else: - print(f"{key:<20}: {value:<20}", end="") + print(f"{key:<23}: {value:<15}|", end="") idx += 1 print() diff --git a/nidn/utils/resources/default_config.toml b/nidn/utils/resources/default_config.toml index fc43eaf..ff6cae4 100644 --- a/nidn/utils/resources/default_config.toml +++ b/nidn/utils/resources/default_config.toml @@ -4,6 +4,7 @@ model_type = "siren" iterations = 2000 learning_rate = 6.5e-5 type = "classification" # "classification" or "regression" +reg_loss_weight = 0.05 # weighting of the regularization loss use_regularization_loss = true # only relevant for classification # Loss