From 98d5a941cc67a0ba1c647f746ca0bf091c6f076d Mon Sep 17 00:00:00 2001 From: allaffa Date: Wed, 27 Nov 2024 00:51:04 -0500 Subject: [PATCH] model_type overwrite in config moved to right location in the code --- examples/qm9/qm9.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/qm9/qm9.py b/examples/qm9/qm9.py index ee72cddd..2e589215 100644 --- a/examples/qm9/qm9.py +++ b/examples/qm9/qm9.py @@ -45,6 +45,10 @@ def main(model_type=None): with open(filename, "r") as f: config = json.load(f) + # If a model type is provided, update the configuration accordingly. + if model_type: + config["NeuralNetwork"]["Architecture"]["model_type"] = model_type + verbosity = config["Verbosity"]["level"] var_config = config["NeuralNetwork"]["Variables_of_interest"] @@ -73,10 +77,6 @@ def main(model_type=None): config, train_loader, val_loader, test_loader ) - # If a model type is provided, update the configuration accordingly. - if model_type: - config["NeuralNetwork"]["Architecture"]["model_type"] = model_type - model = hydragnn.models.create_model_config( config=config["NeuralNetwork"], verbosity=verbosity,