diff --git a/examples/qm9/qm9.py b/examples/qm9/qm9.py index ee72cddd..2e589215 100644 --- a/examples/qm9/qm9.py +++ b/examples/qm9/qm9.py @@ -45,6 +45,10 @@ def main(model_type=None): with open(filename, "r") as f: config = json.load(f) + # If a model type is provided, update the configuration accordingly. + if model_type: + config["NeuralNetwork"]["Architecture"]["model_type"] = model_type + verbosity = config["Verbosity"]["level"] var_config = config["NeuralNetwork"]["Variables_of_interest"] @@ -73,10 +77,6 @@ def main(model_type=None): config, train_loader, val_loader, test_loader ) - # If a model type is provided, update the configuration accordingly. - if model_type: - config["NeuralNetwork"]["Architecture"]["model_type"] = model_type - model = hydragnn.models.create_model_config( config=config["NeuralNetwork"], verbosity=verbosity,