Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Aug 23, 2024
1 parent 41ecc3b commit 1a6b744
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
14 changes: 7 additions & 7 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import multiprocessing
import pickle
import warnings
from functools import partial
from typing import Callable, Literal
Expand All @@ -11,7 +9,7 @@

from tests._example_pipelines import CacheWarning, ExampleClassOtherModule
from tpcp import Algorithm
from tpcp.caching import global_disk_cache, global_ram_cache, hybrid_cache, remove_any_cache, _is_cached
from tpcp.caching import _is_cached, global_disk_cache, global_ram_cache, hybrid_cache, remove_any_cache


class ExampleClass(Algorithm):
Expand Down Expand Up @@ -53,6 +51,7 @@ def example_class(request):
yield request.param
remove_any_cache(request.param[1])


@pytest.fixture()
def simple_example_class(request):
yield ExampleClassOtherModule
Expand Down Expand Up @@ -91,6 +90,7 @@ def get_cache_method(self, request, joblib_cache):
else:
self.cache_method = partial(global_ram_cache, None)
self.cache_method_name = request.param

def test_caching_twice_same_instance(self, example_class):
config, example_class = example_class
action_name = config.get("action_method_name", "action")
Expand Down Expand Up @@ -175,14 +175,16 @@ def test_double_cache_warning(self, example_class):
config, example_class = example_class
action_name = config.get("action_method_name", "action")
self.cache_method(**config)(example_class)
with pytest.warns(UserWarning, match=f"The action method {action_name} of {example_class.__name__} is already cached"):
with pytest.warns(
UserWarning, match=f"The action method {action_name} of {example_class.__name__} is already cached"
):
self.cache_method(**config)(example_class)

@pytest.mark.parametrize("restore_in_parallel_process", [True, False])
def test_cache_correctly_restored_in_parallel_process(self, simple_example_class, restore_in_parallel_process):
from tpcp.parallel import delayed
from joblib import Parallel

from tpcp.parallel import delayed

self.cache_method(restore_in_parallel_process=restore_in_parallel_process)(simple_example_class)

Expand Down Expand Up @@ -230,8 +232,6 @@ def test_double_cache_error_ram_first(self, joblib_cache, simple_example_class):
global_disk_cache(joblib_cache)(simple_example_class)




class TestHybridCache:
def test_staggered_cache_all_disabled(self):
cached_func = hybrid_cache(joblib.Memory(None), False)(example_func)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_global_parallel_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def reset_parallel_context():
yield
parallel._PARALLEL_CONTEXT_CALLBACKS = {}


def test_simple_callback():
set_config("set")

Expand All @@ -37,7 +38,6 @@ def func():
assert joblib.Parallel(n_jobs=2)(delayed(func)() for _ in range(2)) == ["set", "set"]



def test_doctest():
from tpcp import parallel

Expand Down
3 changes: 1 addition & 2 deletions tpcp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
ValidationError,
)


if TYPE_CHECKING:
from collections.abc import Iterable

Expand Down Expand Up @@ -701,7 +700,7 @@ def _is_builtin_class_instance(obj: Any) -> bool:
return type(obj).__module__ == "builtins"


def clone(algorithm: T, *, safe: bool = False) -> T: # noqa: C901, PLR0911
def clone(algorithm: T, *, safe: bool = False) -> T: # noqa: C901, PLR0911, PLR0912
"""Construct a new algorithm object with the same parameters.
This is a modified version from sklearn and the original was published under a BSD-3 license and the original file
Expand Down
8 changes: 4 additions & 4 deletions tpcp/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ def _handle_double_cached(obj, action_name, cache_type):
def _register_global_parallel_callback(func, name):
def wrapper(_):
func()
def _callback():

def _callback():
return None, wrapper

register_global_parallel_callback(_callback, name=name)


def global_disk_cache(
def global_disk_cache( # noqa: C901
memory: Memory = Memory(None),
*,
cache_only: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -180,7 +180,7 @@ def global_disk_cache(
"""
_global_cache_warning()

def inner(algorithm_object: type[Algorithm]):
def inner(algorithm_object: type[Algorithm]): # noqa: C901
# This only return the first action method, but this is fine for now
# This method is "unbound", as we are working on the class, not an instance
to_cache_action_method_name, action_method_raw = _get_action_method(algorithm_object, action_method_name)
Expand Down Expand Up @@ -276,7 +276,7 @@ def remove_disk_cache(algorithm_object: type[Algorithm]):
return algorithm_object


def global_ram_cache(
def global_ram_cache( # noqa: C901
max_n: Optional[int] = None,
*,
cache_only: Optional[Sequence[str]] = None,
Expand Down

0 comments on commit 1a6b744

Please sign in to comment.