diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4bd14f8d..279478ef 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -252,6 +252,8 @@ "_check_soft_dependencies", "_check_python_version", "_check_estimator_deps", + "_normalize_requirement", + "_raise_at_severity", ), "skbase.utils.random_state": ( "check_random_state", diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index cb4f4dd2..226c2912 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -2,25 +2,27 @@ """Utility to check soft dependency imports, and raise warnings or errors.""" import sys import warnings -from importlib import import_module +from functools import lru_cache +from importlib.metadata import PackageNotFoundError, version +from importlib.util import find_spec from inspect import isclass from typing import List from packaging.requirements import InvalidRequirement, Requirement -from packaging.specifiers import InvalidSpecifier, SpecifierSet - -from skbase.utils.stdout_mute import StdoutMute +from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet +from packaging.version import InvalidVersion, Version __author__: List[str] = ["fkiraly", "mloning"] +# todo 0.10.0: remove suppress_import_stdout argument def _check_soft_dependencies( *packages, package_import_alias=None, severity="error", obj=None, msg=None, - suppress_import_stdout=False, + suppress_import_stdout="deprecated", ): """Check if required soft dependencies are installed and raise error or warning. @@ -32,40 +34,68 @@ def _check_soft_dependencies( For instance, the PEP 440 compatible package name such as "pandas"; or a package requirement specifier string such as "pandas>1.2.3". arg can be str, kwargs tuple, or tuple/list of str, following calls are valid: - `_check_soft_dependencies("package1")` - `_check_soft_dependencies("package1", "package2")` - `_check_soft_dependencies(("package1", "package2"))` - `_check_soft_dependencies(["package1", "package2"])` + + * ``_check_soft_dependencies("package1")`` + * ``_check_soft_dependencies("package1", "package2")`` + * ``_check_soft_dependencies(("package1", "package2"))`` + * ``_check_soft_dependencies(["package1", "package2"])`` + package_import_alias : dict with str keys and values, optional, default=empty - key-value pairs are package name, import name - import name is str used in python import, i.e., from import_name import ... - should be provided if import name differs from package name + key-value pairs are package name, import name. + import name is str used in python import, i.e., ``from import_name import ...``, + should be provided if import names differ from package name. + For example, ``{"scikit-learn": "sklearn"}`` for the well-known package. + The argument is used as a lookup and can cover more packages + than passed in ``packages``, so a global dictionary of known + aliases can be passed. + severity : str, "error" (default), "warning", "none" - behaviour for raising errors or warnings - "error" - raises a `ModuleNotFoundError` if one of packages is not installed - "warning" - raises a warning if one of packages is not installed - function returns False if one of packages is not installed, otherwise True - "none" - does not raise exception or warning - function returns False if one of packages is not installed, otherwise True + behaviour for raising errors or warnings: + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed. + The function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning. + The function returns False if one of packages is not installed, otherwise True + obj : python class, object, str, or None, default=None if self is passed here when _check_soft_dependencies is called within __init__, or a class is passed when it is called at the start of a single-class module, the error message is more informative and will refer to the class/object; if str is passed, will be used as name of the class/object or module + msg : str, or None, default=None if str, will override the error message or warning shown with msg - suppress_import_stdout : bool, optional. Default=False - whether to suppress stdout printout upon import. Raises ------ + InvalidRequirement + if package requirement strings are not PEP 440 compatible ModuleNotFoundError error with informative message, asking to install required soft dependencies + TypeError, ValueError + on invalid arguments Returns ------- boolean - whether all packages are installed, only if no exception is raised """ + # todo 0.10.0: remove this warning + if suppress_import_stdout != "deprecated": + warnings.warn( + "In skbase _check_soft_dependencies, the suppress_import_stdout argument " + "is deprecated and no longer has any effect. " + "The argument will be removed in version 0.10.0, so users of the " + "_check_soft_dependencies utility should not pass this argument anymore. " + "The _check_soft_dependencies utility also no longer causes imports, " + "hence no stdout " + "output is created from imports, for any setting of the " + "suppress_import_stdout argument. If you wish to import packages " + "and make use of stdout prints, import the package directly instead.", + DeprecationWarning, + stacklevel=2, + ) + if len(packages) == 1 and isinstance(packages[0], (tuple, list)): packages = packages[0] if not all(isinstance(x, str) for x in packages): @@ -112,6 +142,7 @@ def _check_soft_dependencies( for package in packages: try: req = Requirement(package) + req = _normalize_requirement(req) except InvalidRequirement: msg_version = ( f"wrong format for package requirement string, " @@ -129,15 +160,13 @@ def _check_soft_dependencies( package_import_name = package_import_alias[package_name] else: package_import_name = package_name - # attempt import - if not possible, we know we need to raise warning/exception - try: - with StdoutMute(active=suppress_import_stdout): - pkg_ref = import_module(package_import_name) - # if package cannot be imported, make the user aware of installation requirement - except ModuleNotFoundError as e: - if msg is None: + + pkg_env_version = _get_pkg_version(package_name, package_import_name) + + # if package not present, make the user aware of installation reqs + if pkg_env_version is None: + if obj is None and msg is None: msg = ( - f"{e}. " f"{class_name} requires package {package!r} to be present " f"in the python environment, but {package!r} was not found. " ) @@ -153,24 +182,15 @@ def _check_soft_dependencies( # if msg is not None, none of the above is executed, # so if msg is passed it overrides the default messages - if severity == "error": - raise ModuleNotFoundError(msg) from e - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - return False - elif severity == "none": - return False - else: - raise RuntimeError( - "Error in calling _check_soft_dependencies, severity " - 'argument must be "error", "warning", or "none",' - f"found {severity!r}." - ) from e + _raise_at_severity( + msg, + severity=severity, + caller="_check_soft_dependencies", + ) + return False # now we check compatibility with the version specifier if non-empty if package_version_req != SpecifierSet(""): - pkg_env_version = pkg_ref.__version__ - msg = ( f"{class_name} requires package {package!r} to be present " f"in the python environment, with version {package_version_req}, " @@ -184,23 +204,61 @@ def _check_soft_dependencies( # raise error/warning or return False if version is incompatible if pkg_env_version not in package_version_req: - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - elif severity == "none": - return False - else: - raise RuntimeError( - "Error in calling _check_soft_dependencies, severity argument" - f' must be "error", "warning", or "none", found {severity!r}.' - ) + _raise_at_severity( + msg, + severity=severity, + caller="_check_soft_dependencies", + ) + return False # if package can be imported and no version issue was caught for any string, # then obj is compatible with the requirements and we should return True return True +@lru_cache +def _get_pkg_version(package_name, package_import_name=None): + """Check whether package is available in environment, and return its version if yes. + + Returns ``Version`` object from ``lru_cache``, this should not be mutated. + + Parameters + ---------- + package_name : str, optional, default=None + name of package to check, e.g., "pandas" or "sklearn". + This is the pypi package name, not the import name, e.g., + ``scikit-learn``, not ``sklearn``. + package_import_name : str, optional, default=None + name of package to check for import, e.g., "pandas" or "sklearn". + Note: this is the import name, not the pypi package name, e.g., + ``sklearn``, not ``scikit-learn``. + If not given, ``package_name`` is used as ``package_import_name``, + i.e., it is assumed that the import name is the same as the package name. + + Returns + ------- + None, if package is not found at import ``package_import_name``; + ``importlib`` ``Version`` of package, if found at import ``package_import_name`` + """ + if package_import_name is None: + package_import_name = package_name + + # optimized branching to check presence of import + # and presence of package distribution + # first we check import, then we check distribution + # because try/except consumes more runtime + pkg_spec = find_spec(package_import_name) + if pkg_spec is not None: + try: + pkg_env_version = Version(version(package_name)) + except (InvalidVersion, PackageNotFoundError): + pkg_env_version = None + else: + pkg_env_version = None + + return pkg_env_version + + def _check_python_version(obj, package=None, msg=None, severity="error"): """Check if system python version is compatible with requirements of obj. @@ -263,18 +321,8 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): f" This is due to python version requirements of the {package} package." ) - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - elif severity == "none": - return False - else: - raise RuntimeError( - "Error in calling _check_python_version, severity " - f'argument must be "error", "warning", or "none", found {severity!r}.' - ) - return True + _raise_at_severity(msg, severity=severity, caller="_check_python_version") + return False def _check_estimator_deps(obj, msg=None, severity="error"): @@ -339,3 +387,89 @@ def _check_estimator_deps(obj, msg=None, severity="error"): compatible = compatible and pkg_deps_ok return compatible + + +def _normalize_requirement(req): + """Normalize packaging Requirement by removing build metadata from versions. + + Parameters + ---------- + req : packaging.requirements.Requirement + requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar") + + Returns + ------- + normalized_req : packaging.requirements.Requirement + normalized requirement object with build metadata removed from versions, + e.g., Requirement("pandas>1.2.3") + """ + # Process each specifier in the requirement + normalized_specs = [] + for spec in req.specifier: + # Parse the version and remove the build metadata + spec_v = Version(spec.version) + version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}" + + # Create a new specifier without the build metadata + normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}") + normalized_specs.append(normalized_spec) + + # Reconstruct the specifier set + normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs)) + + # Create a new Requirement object with the normalized specifiers + normalized_req = Requirement(f"{req.name}{normalized_specifier_set}") + + return normalized_req + + +def _raise_at_severity( + msg, + severity, + exception_type=None, + warning_type=None, + stacklevel=2, + caller="_raise_at_severity", +): + """Raise exception or warning or take no action, based on severity. + + Parameters + ---------- + msg : str + message to raise or warn + severity : str, "error", "warning", or "none" + behaviour for raising errors or warnings + exception_type : Exception, default=ModuleNotFoundError + exception type to raise if severity="severity" + warning_type : warning, default=Warning + warning type to raise if severity="warning" + stacklevel : int, default=2 + stacklevel for warnings, if severity="warning" + caller : str, default="_raise_at_severity" + caller name, used in exception if severity not in ["error", "warning", "none"] + + Returns + ------- + None + + Raises + ------ + exception : exception_type, if severity="error" + warning : warning+type, if severity="warning" + ValueError : if severity not in ["error", "warning", "none"] + """ + if exception_type is None: + exception_type = ModuleNotFoundError + + if severity == "error": + raise exception_type(msg) + elif severity == "warning": + warnings.warn(msg, category=warning_type, stacklevel=stacklevel) + elif severity == "none": + return None + else: + raise ValueError( + f"Error in calling {caller}, severity " + f'argument must be "error", "warning", or "none", found {severity!r}.' + ) + return None