diff --git a/doc/changelog.d/3705.added.md b/doc/changelog.d/3705.added.md new file mode 100644 index 0000000000..9bc4540620 --- /dev/null +++ b/doc/changelog.d/3705.added.md @@ -0,0 +1 @@ +feat: speed up `requires_package` using caching \ No newline at end of file diff --git a/src/ansys/mapdl/core/misc.py b/src/ansys/mapdl/core/misc.py index 63b2e35b61..e9bdc3cce4 100644 --- a/src/ansys/mapdl/core/misc.py +++ b/src/ansys/mapdl/core/misc.py @@ -22,7 +22,7 @@ """Module for miscellaneous functions and methods""" from enum import Enum -from functools import wraps +from functools import cache, wraps import importlib import inspect import os @@ -415,6 +415,16 @@ def write_array(filename: Union[str, bytes], array: np.ndarray) -> None: np.savetxt(filename, array, fmt="%20.12f") +@cache +def is_package_installed_cached(package_name): + try: + importlib.import_module(package_name) + return True + + except ModuleNotFoundError: + return False + + def requires_package(package_name: str, softerror: bool = False) -> Callable: """ Decorator check whether a package is installed or not. @@ -430,11 +440,11 @@ def requires_package(package_name: str, softerror: bool = False) -> Callable: def decorator(function): @wraps(function) def wrapper(self, *args, **kwargs): - try: - importlib.import_module(package_name) + + if is_package_installed_cached(package_name): return function(self, *args, **kwargs) - except ModuleNotFoundError: + else: msg = ( f"To use the method '{function.__name__}', " f"the package '{package_name}' is required.\n" diff --git a/tests/common.py b/tests/common.py index 142e0a681a..9a7581f843 100644 --- a/tests/common.py +++ b/tests/common.py @@ -26,6 +26,7 @@ import subprocess import time from typing import Dict, List +from warnings import warn import psutil diff --git a/tests/test_mapdl.py b/tests/test_mapdl.py index 7d902e099a..5621b8dcf5 100644 --- a/tests/test_mapdl.py +++ b/tests/test_mapdl.py @@ -2836,3 +2836,14 @@ def test_none_on_selecting(mapdl, cleared, func): assert len(selfunc("all")) > 0 assert len(selfunc(None)) == 0 + + +def test_requires_package_speed(): + from ansys.mapdl.core.misc import requires_package + + @requires_package("pyvista") + def my_func(i): + return i + 1 + + for i in range(1_000_000): + my_func(i)