Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
prtos committed Oct 26, 2023
1 parent 00d2904 commit 42e8db8
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/openqdc/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import click
import numpy as np
from loguru import logger

from openqdc import datasets

options = [
datasets.ANI1,
datasets.ANI1CCX,
datasets.ANI1X,
datasets.COMP6,
datasets.DESS,
datasets.GDML,
datasets.GEOM,
datasets.ISO17,
datasets.Molecule3D,
datasets.NablaDFT,
datasets.OrbnetDenali,
datasets.PCQM_B3LYP,
datasets.PCQM_PM6,
datasets.QM7X,
datasets.QMugs,
datasets.SN2RXN,
datasets.SolvatedPeptides,
datasets.Spice,
datasets.TMQM,
datasets.Transition1X,
datasets.WaterClusters,
]

options_map = {d.__name__: d for d in options}


@click.command()
@click.option("--dataset", "-d", type=str, default="ani1", help="Dataset name or index.")
def preprocess(dataset):
if dataset not in options_map:
dataset_id = int(dataset)

data_class = options[dataset_id]
data_class().preprocess()
data = data_class()
logger.info(f"Preprocessing {data.__name__}")

n = len(data)
for i in np.random.choice(n, 3, replace=False):
x = data[i]
print(x.name, x.subset, end=" ")
for k in x:
if x[k] is not None:
print(k, x[k].shape, end=" ")
print()


if __name__ == "__main__":
preprocess()
2 changes: 2 additions & 0 deletions tests/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_open_qdc():
import openQDC # noqa

0 comments on commit 42e8db8

Please sign in to comment.