From f0a499b1f539ef1314b0ed34ac6f4810c2e60e4e Mon Sep 17 00:00:00 2001 From: FNTwin Date: Mon, 3 Jun 2024 15:10:58 +0000 Subject: [PATCH] Moved sanitization in cli --- openqdc/cli.py | 26 +++++++---- openqdc/datasets/interaction/__init__.py | 16 +++---- openqdc/datasets/potential/__init__.py | 58 ++++++++++++------------ 3 files changed, 53 insertions(+), 47 deletions(-) diff --git a/openqdc/cli.py b/openqdc/cli.py index faae4ce..1d98509 100644 --- a/openqdc/cli.py +++ b/openqdc/cli.py @@ -16,8 +16,15 @@ app = typer.Typer(help="OpenQDC CLI") +def sanitize(dictionary): + return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()} + + +SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS) + + def exist_dataset(dataset): - if dataset not in AVAILABLE_DATASETS: + if dataset not in sanitize(AVAILABLE_DATASETS): logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.") return False return True @@ -57,10 +64,10 @@ def download( """ for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): - if AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite: + if SANITIZED_AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite: logger.info(f"{dataset} is already cached. Skipping download") else: - AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir) + SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir) @app.command() @@ -115,18 +122,17 @@ def fetch( openqdc fetch Spice """ if datasets[0].lower() == "all": - dataset_names = AVAILABLE_DATASETS + dataset_names = list(sanitize(AVAILABLE_DATASETS).keys()) elif datasets[0].lower() == "potential": - dataset_names = AVAILABLE_POTENTIAL_DATASETS + dataset_names = list(sanitize(AVAILABLE_POTENTIAL_DATASETS).keys()) elif datasets[0].lower() == "interaction": - dataset_names = AVAILABLE_INTERACTION_DATASETS + dataset_names = list(sanitize(AVAILABLE_INTERACTION_DATASETS).keys()) else: dataset_names = datasets - for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)): if exist_dataset(dataset): try: - AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite) + SANITIZED_AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite) except Exception as e: logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}") @@ -152,9 +158,9 @@ def preprocess( """ for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): - logger.info(f"Preprocessing {AVAILABLE_DATASETS[dataset].__name__}") + logger.info(f"Preprocessing {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}") try: - AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite) + SANITIZED_AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite) except Exception as e: logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?") raise e diff --git a/openqdc/datasets/interaction/__init__.py b/openqdc/datasets/interaction/__init__.py index b415273..ab0b212 100644 --- a/openqdc/datasets/interaction/__init__.py +++ b/openqdc/datasets/interaction/__init__.py @@ -6,12 +6,12 @@ from .x40 import X40 AVAILABLE_INTERACTION_DATASETS = { - "des5m": DES5M, - "des370k": DES370K, - "dess66": DESS66, - "dess66x8": DESS66x8, - "l7": L7, - "metcalf": Metcalf, - "splinter": Splinter, - "x40": X40, + "DES5M": DES5M, + "DES370K": DES370K, + "DESS66": DESS66, + "DESS66x8": DESS66x8, + "L7": L7, + "Metcalf": Metcalf, + "Splinter": Splinter, + "X40": X40, } diff --git a/openqdc/datasets/potential/__init__.py b/openqdc/datasets/potential/__init__.py index 5e473a8..591ddc4 100644 --- a/openqdc/datasets/potential/__init__.py +++ b/openqdc/datasets/potential/__init__.py @@ -21,33 +21,33 @@ from .waterclusters3_30 import WaterClusters AVAILABLE_POTENTIAL_DATASETS = { - "ani1": ANI1, - "ani1ccx": ANI1CCX, - "ani1ccxv2": ANI1CCX_V2, - "ani1x": ANI1X, - "comp6": COMP6, - "gdml": GDML, - "geom": GEOM, - "iso17": ISO17, - "molecule3d": Molecule3D, - "nabladft": NablaDFT, - "orbnetdenali": OrbnetDenali, - "pcqmb3lyp": PCQM_B3LYP, - "pcqmpm6": PCQM_PM6, - "qm7x": QM7X, - "qm7xv2": QM7X_V2, - "qmugs": QMugs, - "qmugsv2": QMugs_V2, - "sn2rxn": SN2RXN, - "solvatedpeptides": SolvatedPeptides, - "spice": Spice, - "spicev2": SpiceV2, - "spicevl2": SpiceVL2, - "tmqm": TMQM, - "transition1x": Transition1X, - "watercluster": WaterClusters, - "multixcqm9": MultixcQM9, - "multixcqm9v2": MultixcQM9_V2, - "revmd17": RevMD17, - "md22": MD22, + "ANI1": ANI1, + "ANI1CCX": ANI1CCX, + "ANI1CCX_V2": ANI1CCX_V2, + "ANI1X": ANI1X, + "COMP6": COMP6, + "GDML": GDML, + "GEOM": GEOM, + "ISO17": ISO17, + "Molecule3D": Molecule3D, + "NablaDFT": NablaDFT, + "OrbnetDenali": OrbnetDenali, + "PCQM_B3LYP": PCQM_B3LYP, + "PCQM_PM6": PCQM_PM6, + "QM7X": QM7X, + "QM7X_V2": QM7X_V2, + "QMugs": QMugs, + "QMugs_V2": QMugs_V2, + "SN2RXN": SN2RXN, + "SolvatedPeptides": SolvatedPeptides, + "Spice": Spice, + "SpiceV2": SpiceV2, + "SpiceVL2": SpiceVL2, + "TMQM": TMQM, + "Transition1X": Transition1X, + "WaterClusters": WaterClusters, + "MultixcQM9": MultixcQM9, + "MultixcQM9_V2": MultixcQM9_V2, + "RevMD17": RevMD17, + "MD22": MD22, }