Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NablaDFT dependency #13

Closed
shenoynikhil opened this issue Oct 21, 2023 · 1 comment
Closed

NablaDFT dependency #13

shenoynikhil opened this issue Oct 21, 2023 · 1 comment
Assignees

Comments

@shenoynikhil
Copy link
Collaborator

shenoynikhil commented Oct 21, 2023

https://github.com/OpenDrugDiscovery/openQDC/blob/a86f884aa124d5a5d8745d030ac24bf129ef4fa5/src/openqdc/datasets/nabladft.py#L7

Currently this line raises an error if nabladft package isn't present. In case people only want to download the preprocessed files, it is not necessary to have this package. I would recommend having a try catch around this import statement to make it optional,

try:
    from nablaDFT.dataset import HamiltonianDatabase
except ImportError:
    print('nablaDFT package is necessary to process raw files. Skip if directly downloading preprocessed files.')
@FNTwin
Copy link
Collaborator

FNTwin commented Oct 23, 2023

We can use some openff import utils for this

import importlib
from functools import wraps
from typing import Any, Callable, TypeVar

F = TypeVar("F", bound=Callable[..., Any])


class MissingOptionalDependencyError(BaseException):
    """
    An exception raised when an optional dependency is required
    but cannot be found.

    Attributes
    ----------
    library_name
        The name of the missing library.
    """

    def __init__(self, library_name: str):
        """

        Parameters
        ----------
        library_name
            The name of the missing library.
        license_issue
            Whether the library was importable but was unusable due
            to a missing license.
        """

        message = f"The required {library_name} module could not be imported."

        super(MissingOptionalDependencyError, self).__init__(message)

        self.library_name = library_name


def has_package(package_name: str) -> bool:
    """
    Helper function to generically check if a Python package is installed.
    Intended to be used to check for optional dependencies.

    Parameters
    ----------
    package_name : str
        The name of the Python package to check the availability of

    Returns
    -------
    package_available : bool
        Boolean indicator if the package is available or not

    Examples
    --------
    >>> has_numpy = has_package('numpy')
    >>> has_numpy
    True
    >>> has_foo = has_package('other_non_installed_package')
    >>> has_foo
    False
    """
    try:
        importlib.import_module(package_name)
    except ModuleNotFoundError:
        return False
    return True


def requires_package(package_name: str) -> Callable[..., Any]:
    """
    Helper function to denote that a funciton requires some optional
    dependency. A function decorated with this decorator will raise
    `MissingOptionalDependencyError` if the package is not found by
    `importlib.import_module()`.

    Parameters
    ----------
    package_name : str
        The name of the module to be imported.

    Raises
    ------
    MissingOptionalDependencyError

    """

    def inner_decorator(function: F) -> F:
        @wraps(function)
        def wrapper(*args, **kwargs):
            import importlib

            try:
                importlib.import_module(package_name)
            except ImportError:
                raise MissingOptionalDependencyError(library_name=package_name)
            except Exception as e:
                raise e

            return function(*args, **kwargs)

        return wrapper

    return inner_decorator

So the read_raw_entries would be just wrapped from the decorator and we will have some re-usable utils

   @requires_package("nablaDFT.dataset.HamiltonianDatabase")
   def read_raw_entries(self):
        raw_path = p_join(self.root, "dataset_full.db")
        train = nablaDFT.dataset.HamiltonianDatabase(raw_path)
        n, c = len(train), 20
        step_size = int(np.ceil(n / os.cpu_count()))

        fn = lambda i: read_chunk_from_db(raw_path, i * step_size, min((i + 1) * step_size, n))
        samples = dm.parallelized(
            fn, list(range(c)), n_jobs=c, progress=False, scheduler="threads"
        )  # don't use more than 1 job

        return sum(samples, [])

@FNTwin FNTwin mentioned this issue Oct 27, 2023
3 tasks
@FNTwin FNTwin closed this as completed Oct 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants