Skip to content

Commit

Permalink
edge_dim added as argument to GATv2Conv (#310)
Browse files Browse the repository at this point in the history
* edge_dim added as argument to GATv2Conv

* GAT added in edge_models

* added model_type as optional imput argument to qm9 and md17 examples

* edge_dim passed into GATv2Conv stack inside create method

* architectural arguments added to vectoroutput CI test

* qnum_samples variable moved outside main function scope in qm9 example

* SAGE, GIN, anFC removed from examples where there are edge features

* Correct management of node degree for on GPUs

* split examples test based on whether thee model needs to use data.pos or not

* model_type overwrite in config moved to right location in the code

* comment for allowed stacks in LJ force_grad

* Add MACE to test

* black formatting

---------

Co-authored-by: Rylie Weaver <rylieweaver9@gmail.com>
  • Loading branch information
allaffa and RylieWeaver authored Nov 27, 2024
1 parent ca29030 commit b935c88
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 189 deletions.
12 changes: 12 additions & 0 deletions examples/md17/md17.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
"periodic_boundary_conditions": false,
"hidden_dim": 5,
"num_conv_layers": 6,
"int_emb_size": 32,
"out_emb_size": 16,
"basis_emb_size": 8,
"num_gaussians": 10,
"num_filters": 8,
"num_before_skip": 1,
"num_after_skip": 1,
"envelope_exponent": 5,
"max_ell": 1,
"node_max_ell": 1,
"num_radial": 5,
"num_spherical": 2,
"output_heads": {
"graph":{
"num_sharedlayers": 2,
Expand Down
182 changes: 98 additions & 84 deletions examples/md17/md17.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import os, json

import os
import json
import torch

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import torch_geometric
import argparse

# deprecated in torch_geometric 2.0
try:
from torch_geometric.loader import DataLoader
except:
except ImportError:
from torch_geometric.data import DataLoader

import hydragnn


# Update each sample prior to loading.
def md17_pre_transform(data):
def md17_pre_transform(data, compute_edges):
# Set descriptor as element type.
data.x = data.z.float().view(-1, 1)
# Only predict energy (index 0 of 2 properties) for this run.
Expand All @@ -33,78 +30,95 @@ def md17_pre_filter(data):
return torch.rand(1) < 0.25


# Set this path for output.
try:
os.environ["SERIALIZED_DATA_PATH"]
except:
os.environ["SERIALIZED_DATA_PATH"] = os.getcwd()

# Configurable run choices (JSON file that accompanies this example script).
filename = os.path.join(os.path.dirname(__file__), "md17.json")
with open(filename, "r") as f:
config = json.load(f)
verbosity = config["Verbosity"]["level"]
arch_config = config["NeuralNetwork"]["Architecture"]
var_config = config["NeuralNetwork"]["Variables_of_interest"]

# Always initialize for multi-rank training.
world_size, world_rank = hydragnn.utils.distributed.setup_ddp()

log_name = "md17_test"
# Enable print to log file.
hydragnn.utils.print.print_utils.setup_log(log_name)

# Use built-in torch_geometric datasets.
# Filter function above used to run quick example.
# NOTE: data is moved to the device in the pre-transform.
# NOTE: transforms/filters will NOT be re-run unless the qm9/processed/ directory is removed.
compute_edges = hydragnn.preprocess.get_radius_graph_config(arch_config)

# Fix for MD17 datasets
torch_geometric.datasets.MD17.file_names["uracil"] = "md17_uracil.npz"

dataset = torch_geometric.datasets.MD17(
root="dataset/md17",
name="uracil",
pre_transform=md17_pre_transform,
pre_filter=md17_pre_filter,
)
train, val, test = hydragnn.preprocess.split_dataset(
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False
)
(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"]
)

config = hydragnn.utils.input_config_parsing.update_config(
config, train_loader, val_loader, test_loader
)

model = hydragnn.models.create_model_config(
config=config["NeuralNetwork"],
verbosity=verbosity,
)
model = hydragnn.utils.distributed.get_distributed_model(model, verbosity)

learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"]
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001
)

# Run training with the given model and md17 dataset.
writer = hydragnn.utils.model.model.get_summary_writer(log_name)
hydragnn.utils.input_config_parsing.save_config(config, log_name)

hydragnn.train.train_validate_test(
model,
optimizer,
train_loader,
val_loader,
test_loader,
writer,
scheduler,
config["NeuralNetwork"],
log_name,
verbosity,
)
def main(model_type=None):
# FIX random seed
random_state = 0
torch.manual_seed(random_state)

# Set this path for output.
os.environ.setdefault("SERIALIZED_DATA_PATH", os.getcwd())

# Configurable run choices (JSON file that accompanies this example script).
filename = os.path.join(os.path.dirname(__file__), "md17.json")
with open(filename, "r") as f:
config = json.load(f)

verbosity = config["Verbosity"]["level"]
arch_config = config["NeuralNetwork"]["Architecture"]

# If a model type is provided, update the configuration
if model_type:
config["NeuralNetwork"]["Architecture"]["model_type"] = model_type

# Always initialize for multi-rank training.
world_size, world_rank = hydragnn.utils.distributed.setup_ddp()

log_name = f"md17_test_{model_type}" if model_type else "md17_test"
# Enable print to log file.
hydragnn.utils.print.print_utils.setup_log(log_name)

# Preprocess configurations for edge computation
compute_edges = hydragnn.preprocess.get_radius_graph_config(arch_config)

# Fix for MD17 datasets
torch_geometric.datasets.MD17.file_names["uracil"] = "md17_uracil.npz"

dataset = torch_geometric.datasets.MD17(
root="dataset/md17",
name="uracil",
pre_transform=lambda data: md17_pre_transform(data, compute_edges),
pre_filter=md17_pre_filter,
)
train, val, test = hydragnn.preprocess.split_dataset(
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False
)
(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders(
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"]
)

config = hydragnn.utils.input_config_parsing.update_config(
config, train_loader, val_loader, test_loader
)

model = hydragnn.models.create_model_config(
config=config["NeuralNetwork"],
verbosity=verbosity,
)
model = hydragnn.utils.distributed.get_distributed_model(model, verbosity)

learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"]
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001
)

# Run training with the given model and md17 dataset.
writer = hydragnn.utils.model.model.get_summary_writer(log_name)
hydragnn.utils.input_config_parsing.save_config(config, log_name)

hydragnn.train.train_validate_test(
model,
optimizer,
train_loader,
val_loader,
test_loader,
writer,
scheduler,
config["NeuralNetwork"],
log_name,
verbosity,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run MD17 example with an optional model type."
)
parser.add_argument(
"--model_type",
type=str,
default=None,
help="Specify the model type for training (default: None).",
)
args = parser.parse_args()
main(model_type=args.model_type)
12 changes: 12 additions & 0 deletions examples/qm9/qm9.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
"periodic_boundary_conditions": false,
"hidden_dim": 5,
"num_conv_layers": 6,
"int_emb_size": 32,
"out_emb_size": 16,
"basis_emb_size": 8,
"num_gaussians": 10,
"num_filters": 8,
"num_before_skip": 1,
"num_after_skip": 1,
"envelope_exponent": 5,
"max_ell": 1,
"node_max_ell": 1,
"num_radial": 5,
"num_spherical": 2,
"output_heads": {
"graph":{
"num_sharedlayers": 2,
Expand Down
Loading

0 comments on commit b935c88

Please sign in to comment.