From f16b9da39f553f5a6296ffd9952df87b1c8f5aec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 14:17:38 +0100 Subject: [PATCH 01/10] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 94 ++++++++++++++++------ 1 file changed, 69 insertions(+), 25 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index cb4f4dd2..17a949d8 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -2,14 +2,14 @@ """Utility to check soft dependency imports, and raise warnings or errors.""" import sys import warnings -from importlib import import_module +from importlib.metadata import PackageNotFoundError, version +from importlib.util import find_spec from inspect import isclass from typing import List from packaging.requirements import InvalidRequirement, Requirement -from packaging.specifiers import InvalidSpecifier, SpecifierSet - -from skbase.utils.stdout_mute import StdoutMute +from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet +from packaging.version import InvalidVersion, Version __author__: List[str] = ["fkiraly", "mloning"] @@ -129,32 +129,44 @@ def _check_soft_dependencies( package_import_name = package_import_alias[package_name] else: package_import_name = package_name - # attempt import - if not possible, we know we need to raise warning/exception - try: - with StdoutMute(active=suppress_import_stdout): - pkg_ref = import_module(package_import_name) - # if package cannot be imported, make the user aware of installation requirement - except ModuleNotFoundError as e: - if msg is None: + + # optimized branching to check presence of import + # and presence of package distribution + # first we check import, then we check distribution + # because try/except consumes more runtime + pkg_spec = find_spec(package_import_name) + if pkg_spec is not None: + try: + pkg_env_version = Version(version(package_name)) + except (InvalidVersion, PackageNotFoundError): + pkg_spec = None + + # if package not present, make the user aware of installation reqs + if pkg_spec is None: + if obj is None and msg is None: msg = ( - f"{e}. " - f"{class_name} requires package {package!r} to be present " - f"in the python environment, but {package!r} was not found. " + f"'{package}' not found. " + f"'{package}' is a soft dependency and not included in the " + f"base sktime installation. Please run: `pip install {package}` to " + f"install the {package} package. " + f"To install all soft dependencies, run: `pip install " + f"sktime[all_extras]`" ) - if obj is not None: - msg = msg + ( - f"{package!r} is a dependency of {class_name} and required " - f"to construct it. " - ) - msg = msg + ( - f"Please run: `pip install {package}` to " + elif msg is None: # obj is not None, msg is None + msg = ( + f"{class_name} requires package '{package}' to be present " + f"in the python environment, but '{package}' was not found. " + f"'{package}' is a soft dependency and not included in the base " + f"sktime installation. Please run: `pip install {package}` to " f"install the {package} package. " + f"To install all soft dependencies, run: `pip install " + f"sktime[all_extras]`" ) # if msg is not None, none of the above is executed, # so if msg is passed it overrides the default messages if severity == "error": - raise ModuleNotFoundError(msg) from e + raise ModuleNotFoundError(msg) elif severity == "warning": warnings.warn(msg, stacklevel=2) return False @@ -165,12 +177,10 @@ def _check_soft_dependencies( "Error in calling _check_soft_dependencies, severity " 'argument must be "error", "warning", or "none",' f"found {severity!r}." - ) from e + ) # now we check compatibility with the version specifier if non-empty if package_version_req != SpecifierSet(""): - pkg_env_version = pkg_ref.__version__ - msg = ( f"{class_name} requires package {package!r} to be present " f"in the python environment, with version {package_version_req}, " @@ -339,3 +349,37 @@ def _check_estimator_deps(obj, msg=None, severity="error"): compatible = compatible and pkg_deps_ok return compatible + + +def _normalize_requirement(req): + """Normalize packaging Requirement by removing build metadata from versions. + + Parameters + ---------- + req : packaging.requirements.Requirement + requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar") + + Returns + ------- + normalized_req : packaging.requirements.Requirement + normalized requirement object with build metadata removed from versions, + e.g., Requirement("pandas>1.2.3") + """ + # Process each specifier in the requirement + normalized_specs = [] + for spec in req.specifier: + # Parse the version and remove the build metadata + spec_v = Version(spec.version) + version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}" + + # Create a new specifier without the build metadata + normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}") + normalized_specs.append(normalized_spec) + + # Reconstruct the specifier set + normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs)) + + # Create a new Requirement object with the normalized specifiers + normalized_req = Requirement(f"{req.name}{normalized_specifier_set}") + + return normalized_req From 5b5291efa453684400123dc190b7a27e259969f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 14:35:32 +0100 Subject: [PATCH 02/10] deprecation --- skbase/utils/dependencies/_dependencies.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 17a949d8..ca82e0da 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -14,13 +14,14 @@ __author__: List[str] = ["fkiraly", "mloning"] +# todo 0.10.0: remove suppress_import_stdout argument def _check_soft_dependencies( *packages, package_import_alias=None, severity="error", obj=None, msg=None, - suppress_import_stdout=False, + suppress_import_stdout="deprecated", ): """Check if required soft dependencies are installed and raise error or warning. @@ -54,8 +55,6 @@ def _check_soft_dependencies( if str is passed, will be used as name of the class/object or module msg : str, or None, default=None if str, will override the error message or warning shown with msg - suppress_import_stdout : bool, optional. Default=False - whether to suppress stdout printout upon import. Raises ------ @@ -66,6 +65,22 @@ def _check_soft_dependencies( ------- boolean - whether all packages are installed, only if no exception is raised """ + # todo 0.10.0: remove this warning + if suppress_import_stdout != "deprecated": + warnings.warn( + "In skbase _check_soft_dependencies, the suppress_import_stdout argument " + "is deprecated and no longer has any effect. " + "The argument will be removed in version 0.10.0, so users of the " + "_check_soft_dependencies utility should not pass this argument anymore. " + "The _check_soft_dependencies utility also no longer causes imports, " + "hence no stdout " + "output is created from imports, for any setting of the " + "suppress_import_stdout argument. If you wish to import packages " + "and make use of stdout prints, import the package directly instead.", + DeprecationWarning, + stacklevel=2, + ) + if len(packages) == 1 and isinstance(packages[0], (tuple, list)): packages = packages[0] if not all(isinstance(x, str) for x in packages): From bb81f58ceb783110606459f1937055b37a59ab3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 14:46:33 +0100 Subject: [PATCH 03/10] revert error msg --- skbase/utils/dependencies/_dependencies.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index ca82e0da..72118149 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -160,12 +160,17 @@ def _check_soft_dependencies( if pkg_spec is None: if obj is None and msg is None: msg = ( - f"'{package}' not found. " - f"'{package}' is a soft dependency and not included in the " - f"base sktime installation. Please run: `pip install {package}` to " + f"{class_name} requires package {package!r} to be present " + f"in the python environment, but {package!r} was not found. " + ) + if obj is not None: + msg = msg + ( + f"{package!r} is a dependency of {class_name} and required " + f"to construct it. " + ) + msg = msg + ( + f"Please run: `pip install {package}` to " f"install the {package} package. " - f"To install all soft dependencies, run: `pip install " - f"sktime[all_extras]`" ) elif msg is None: # obj is not None, msg is None msg = ( From 23aafba0bbd79dc88a4b1cf5ddc9e761ddb45fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 14:47:28 +0100 Subject: [PATCH 04/10] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 72118149..180b2f05 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -172,16 +172,6 @@ def _check_soft_dependencies( f"Please run: `pip install {package}` to " f"install the {package} package. " ) - elif msg is None: # obj is not None, msg is None - msg = ( - f"{class_name} requires package '{package}' to be present " - f"in the python environment, but '{package}' was not found. " - f"'{package}' is a soft dependency and not included in the base " - f"sktime installation. Please run: `pip install {package}` to " - f"install the {package} package. " - f"To install all soft dependencies, run: `pip install " - f"sktime[all_extras]`" - ) # if msg is not None, none of the above is executed, # so if msg is passed it overrides the default messages From 3fdecf1bdee484e9e0c57a4e0362648a67ed9711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 14:48:34 +0100 Subject: [PATCH 05/10] add normalization --- skbase/utils/dependencies/_dependencies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 180b2f05..e123e44c 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -127,6 +127,7 @@ def _check_soft_dependencies( for package in packages: try: req = Requirement(package) + req = _normalize_requirement(req) except InvalidRequirement: msg_version = ( f"wrong format for package requirement string, " From 6b50dc9e75bf8d5e1312298756e381519644be7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 15:13:39 +0100 Subject: [PATCH 06/10] refactor --- skbase/utils/dependencies/_dependencies.py | 101 +++++++++++++++------ 1 file changed, 75 insertions(+), 26 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index e123e44c..003b7033 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -2,6 +2,7 @@ """Utility to check soft dependency imports, and raise warnings or errors.""" import sys import warnings +from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec from inspect import isclass @@ -33,33 +34,47 @@ def _check_soft_dependencies( For instance, the PEP 440 compatible package name such as "pandas"; or a package requirement specifier string such as "pandas>1.2.3". arg can be str, kwargs tuple, or tuple/list of str, following calls are valid: - `_check_soft_dependencies("package1")` - `_check_soft_dependencies("package1", "package2")` - `_check_soft_dependencies(("package1", "package2"))` - `_check_soft_dependencies(["package1", "package2"])` + + * ``_check_soft_dependencies("package1")`` + * ``_check_soft_dependencies("package1", "package2")`` + * ``_check_soft_dependencies(("package1", "package2"))`` + * ``_check_soft_dependencies(["package1", "package2"])`` + package_import_alias : dict with str keys and values, optional, default=empty - key-value pairs are package name, import name - import name is str used in python import, i.e., from import_name import ... - should be provided if import name differs from package name + key-value pairs are package name, import name. + import name is str used in python import, i.e., ``from import_name import ...``, + should be provided if import names differ from package name. + For example, ``{"scikit-learn": "sklearn"}`` for the well-known package. + The argument is used as a lookup and can cover more packages + than passed in ``packages``, so a global dictionary of known + aliases can be passed. + severity : str, "error" (default), "warning", "none" - behaviour for raising errors or warnings - "error" - raises a `ModuleNotFoundError` if one of packages is not installed - "warning" - raises a warning if one of packages is not installed - function returns False if one of packages is not installed, otherwise True - "none" - does not raise exception or warning - function returns False if one of packages is not installed, otherwise True + behaviour for raising errors or warnings: + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed. + The function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning. + The function returns False if one of packages is not installed, otherwise True + obj : python class, object, str, or None, default=None if self is passed here when _check_soft_dependencies is called within __init__, or a class is passed when it is called at the start of a single-class module, the error message is more informative and will refer to the class/object; if str is passed, will be used as name of the class/object or module + msg : str, or None, default=None if str, will override the error message or warning shown with msg Raises ------ + InvalidRequirement + if package requirement strings are not PEP 440 compatible ModuleNotFoundError error with informative message, asking to install required soft dependencies + TypeError, ValueError + on invalid arguments Returns ------- @@ -146,19 +161,10 @@ def _check_soft_dependencies( else: package_import_name = package_name - # optimized branching to check presence of import - # and presence of package distribution - # first we check import, then we check distribution - # because try/except consumes more runtime - pkg_spec = find_spec(package_import_name) - if pkg_spec is not None: - try: - pkg_env_version = Version(version(package_name)) - except (InvalidVersion, PackageNotFoundError): - pkg_spec = None + pkg_env_version = _get_pkg_version(package_name, package_import_name) # if package not present, make the user aware of installation reqs - if pkg_spec is None: + if pkg_env_version is None: if obj is None and msg is None: msg = ( f"{class_name} requires package {package!r} to be present " @@ -184,7 +190,7 @@ def _check_soft_dependencies( elif severity == "none": return False else: - raise RuntimeError( + raise ValueError( "Error in calling _check_soft_dependencies, severity " 'argument must be "error", "warning", or "none",' f"found {severity!r}." @@ -212,7 +218,7 @@ def _check_soft_dependencies( elif severity == "none": return False else: - raise RuntimeError( + raise ValueError( "Error in calling _check_soft_dependencies, severity argument" f' must be "error", "warning", or "none", found {severity!r}.' ) @@ -222,6 +228,49 @@ def _check_soft_dependencies( return True +@lru_cache +def _get_pkg_version(package_name, package_import_name=None): + """Check whether package is available in environment, and return its version if yes. + + Returns ``Version`` object from ``lru_cache``, this should not be mutated. + + Parameters + ---------- + package_name : str, optional, default=None + name of package to check, e.g., "pandas" or "sklearn". + This is the pypi package name, not the import name, e.g., + ``scikit-learn``, not ``sklearn``. + package_import_name : str, optional, default=None + name of package to check for import, e.g., "pandas" or "sklearn". + Note: this is the import name, not the pypi package name, e.g., + ``sklearn``, not ``scikit-learn``. + If not given, ``package_name`` is used as ``package_import_name``, + i.e., it is assumed that the import name is the same as the package name. + + Returns + ------- + None, if package is not found at import ``package_import_name``; + ``importlib`` ``Version`` of package, if found at import ``package_import_name`` + """ + if package_import_name is None: + package_import_name = package_name + + # optimized branching to check presence of import + # and presence of package distribution + # first we check import, then we check distribution + # because try/except consumes more runtime + pkg_spec = find_spec(package_import_name) + if pkg_spec is not None: + try: + pkg_env_version = Version(version(package_name)) + except (InvalidVersion, PackageNotFoundError): + pkg_env_version = None + else: + pkg_env_version = None + + return pkg_env_version + + def _check_python_version(obj, package=None, msg=None, severity="error"): """Check if system python version is compatible with requirements of obj. From ea1c0f5a549d7c7d55f0fe1ff80d876d014049eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 15:16:12 +0100 Subject: [PATCH 07/10] conftest --- skbase/tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4bd14f8d..664e92c2 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -252,6 +252,8 @@ "_check_soft_dependencies", "_check_python_version", "_check_estimator_deps", + "_get_pkg_version", + "_normalize_requirement", ), "skbase.utils.random_state": ( "check_random_state", From 83eaa7efb573b04bfda963ce465607bc799e5309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 29 Jun 2024 15:20:17 +0100 Subject: [PATCH 08/10] Update conftest.py --- skbase/tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 664e92c2..cbcb2096 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -252,7 +252,6 @@ "_check_soft_dependencies", "_check_python_version", "_check_estimator_deps", - "_get_pkg_version", "_normalize_requirement", ), "skbase.utils.random_state": ( From 5dece13ebd55b7c46dbfa310600e181122bb64af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 4 Jul 2024 16:01:39 +0100 Subject: [PATCH 09/10] further refactor --- skbase/utils/dependencies/_dependencies.py | 102 +++++++++++++-------- 1 file changed, 66 insertions(+), 36 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 003b7033..226c2912 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -182,19 +182,12 @@ def _check_soft_dependencies( # if msg is not None, none of the above is executed, # so if msg is passed it overrides the default messages - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - return False - elif severity == "none": - return False - else: - raise ValueError( - "Error in calling _check_soft_dependencies, severity " - 'argument must be "error", "warning", or "none",' - f"found {severity!r}." - ) + _raise_at_severity( + msg, + severity=severity, + caller="_check_soft_dependencies", + ) + return False # now we check compatibility with the version specifier if non-empty if package_version_req != SpecifierSet(""): @@ -211,17 +204,12 @@ def _check_soft_dependencies( # raise error/warning or return False if version is incompatible if pkg_env_version not in package_version_req: - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - elif severity == "none": - return False - else: - raise ValueError( - "Error in calling _check_soft_dependencies, severity argument" - f' must be "error", "warning", or "none", found {severity!r}.' - ) + _raise_at_severity( + msg, + severity=severity, + caller="_check_soft_dependencies", + ) + return False # if package can be imported and no version issue was caught for any string, # then obj is compatible with the requirements and we should return True @@ -333,18 +321,8 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): f" This is due to python version requirements of the {package} package." ) - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - elif severity == "none": - return False - else: - raise RuntimeError( - "Error in calling _check_python_version, severity " - f'argument must be "error", "warning", or "none", found {severity!r}.' - ) - return True + _raise_at_severity(msg, severity=severity, caller="_check_python_version") + return False def _check_estimator_deps(obj, msg=None, severity="error"): @@ -443,3 +421,55 @@ def _normalize_requirement(req): normalized_req = Requirement(f"{req.name}{normalized_specifier_set}") return normalized_req + + +def _raise_at_severity( + msg, + severity, + exception_type=None, + warning_type=None, + stacklevel=2, + caller="_raise_at_severity", +): + """Raise exception or warning or take no action, based on severity. + + Parameters + ---------- + msg : str + message to raise or warn + severity : str, "error", "warning", or "none" + behaviour for raising errors or warnings + exception_type : Exception, default=ModuleNotFoundError + exception type to raise if severity="severity" + warning_type : warning, default=Warning + warning type to raise if severity="warning" + stacklevel : int, default=2 + stacklevel for warnings, if severity="warning" + caller : str, default="_raise_at_severity" + caller name, used in exception if severity not in ["error", "warning", "none"] + + Returns + ------- + None + + Raises + ------ + exception : exception_type, if severity="error" + warning : warning+type, if severity="warning" + ValueError : if severity not in ["error", "warning", "none"] + """ + if exception_type is None: + exception_type = ModuleNotFoundError + + if severity == "error": + raise exception_type(msg) + elif severity == "warning": + warnings.warn(msg, category=warning_type, stacklevel=stacklevel) + elif severity == "none": + return None + else: + raise ValueError( + f"Error in calling {caller}, severity " + f'argument must be "error", "warning", or "none", found {severity!r}.' + ) + return None From a05f658d007de665ed6672d65153803abd7ffb51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 9 Jul 2024 20:55:50 +0100 Subject: [PATCH 10/10] Update conftest.py --- skbase/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index cbcb2096..279478ef 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -253,6 +253,7 @@ "_check_python_version", "_check_estimator_deps", "_normalize_requirement", + "_raise_at_severity", ), "skbase.utils.random_state": ( "check_random_state",