Skip to content

Commit

Permalink
Forces test (ORNL#283)
Browse files Browse the repository at this point in the history
* force tests, which required model arg, and some typo fixing in Lennard Jones

* Add PNAPlus since it uses positions as well

* formatting
  • Loading branch information
RylieWeaver committed Sep 25, 2024
1 parent 886007e commit db705ca
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
3 changes: 3 additions & 0 deletions examples/LennardJones/LJ.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"int_emb_size": 32,
"out_emb_size": 16,
"basis_emb_size": 8,
"num_gaussians": 10,
"num_filters": 8,
"num_before_skip": 1,
"num_after_skip": 1,
"envelope_exponent": 5,
Expand Down Expand Up @@ -55,6 +57,7 @@
"Training": {
"num_epoch": 15,
"batch_size": 64,
"perc_train": 0.7,
"patience": 20,
"early_stopping": true,
"Optimizer": {
Expand Down
4 changes: 2 additions & 2 deletions examples/LennardJones/LJ_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def __init__(self, dirpath, config, dist=False, sampling=None):

for file in rx:
filepath = os.path.join(dirpath, file)
self.dataset.append(self.transform_inumpyut_to_data_object_base(filepath))
self.dataset.append(self.transform_input_to_data_object_base(filepath))

def transform_inumpyut_to_data_object_base(self, filepath):
def transform_input_to_data_object_base(self, filepath):

# Using readline()
file = open(filepath, "r")
Expand Down
8 changes: 7 additions & 1 deletion examples/LennardJones/LennardJones.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
help="preprocess only (no training)",
)
parser.add_argument("--inputfile", help="input file", type=str, default="LJ.json")
parser.add_argument("--model_type", help="model type", type=str, default=None)
parser.add_argument("--mae", action="store_true", help="do mae calculation")
parser.add_argument("--ddstore", action="store_true", help="ddstore dataset")
parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None)
Expand Down Expand Up @@ -98,6 +99,11 @@
# Configurable run choices (JSON file that accompanies this example script).
with open(input_filename, "r") as f:
config = json.load(f)
config["NeuralNetwork"]["Architecture"]["model_type"] = (
args.model_type
if args.model_type
else config["NeuralNetwork"]["Architecture"]["model_type"]
)
verbosity = config["Verbosity"]["level"]
config["NeuralNetwork"]["Variables_of_interest"][
"graph_feature_names"
Expand Down Expand Up @@ -159,7 +165,7 @@
## This is a local split
trainset, valset, testset = split_dataset(
dataset=total,
perc_train=0.9,
perc_train=config["NeuralNetwork"]["Training"]["perc_train"],
stratify_splitting=False,
)
print("Local splitting: ", len(total), len(trainset), len(valset), len(testset))
Expand Down
27 changes: 27 additions & 0 deletions tests/test_forces_equivariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
##############################################################################
# Copyright (c) 2024, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################

import os
import pytest

import subprocess


@pytest.mark.parametrize("example", ["LennardJones"])
@pytest.mark.parametrize("model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus"])
@pytest.mark.mpi_skip()
def pytest_examples(example, model_type):
path = os.path.join(os.path.dirname(__file__), "..", "examples", example)
file_path = os.path.join(path, example + ".py") # Assuming different model scripts
return_code = subprocess.call(["python", file_path, "--model_type", model_type])

# Check the file ran without error.
assert return_code == 0

0 comments on commit db705ca

Please sign in to comment.