diff --git a/src/openqdc/utils/preprocess.py b/src/openqdc/utils/preprocess.py new file mode 100644 index 0000000..1142dca --- /dev/null +++ b/src/openqdc/utils/preprocess.py @@ -0,0 +1,56 @@ +import click +import numpy as np +from loguru import logger + +from openqdc import datasets + +options = [ + datasets.ANI1, + datasets.ANI1CCX, + datasets.ANI1X, + datasets.COMP6, + datasets.DESS, + datasets.GDML, + datasets.GEOM, + datasets.ISO17, + datasets.Molecule3D, + datasets.NablaDFT, + datasets.OrbnetDenali, + datasets.PCQM_B3LYP, + datasets.PCQM_PM6, + datasets.QM7X, + datasets.QMugs, + datasets.SN2RXN, + datasets.SolvatedPeptides, + datasets.Spice, + datasets.TMQM, + datasets.Transition1X, + datasets.WaterClusters, +] + +options_map = {d.__name__: d for d in options} + + +@click.command() +@click.option("--dataset", "-d", type=str, default="ani1", help="Dataset name or index.") +def preprocess(dataset): + if dataset not in options_map: + dataset_id = int(dataset) + + data_class = options[dataset_id] + data_class().preprocess() + data = data_class() + logger.info(f"Preprocessing {data.__name__}") + + n = len(data) + for i in np.random.choice(n, 3, replace=False): + x = data[i] + print(x.name, x.subset, end=" ") + for k in x: + if x[k] is not None: + print(k, x[k].shape, end=" ") + print() + + +if __name__ == "__main__": + preprocess() diff --git a/tests/test_import.py b/tests/test_import.py new file mode 100644 index 0000000..0736e30 --- /dev/null +++ b/tests/test_import.py @@ -0,0 +1,2 @@ +def test_open_qdc(): + import openQDC # noqa