diff --git a/panel/io/cache.py b/panel/io/cache.py index 2d17befbea..5a3016dbeb 100644 --- a/panel/io/cache.py +++ b/panel/io/cache.py @@ -350,7 +350,8 @@ def cache( max_items=max_items, ttl=ttl, to_disk=to_disk, - cache_path=cache_path + cache_path=cache_path, + per_session=per_session, ) func_hash = None # noqa @@ -434,7 +435,7 @@ def wrapped_func(*args, **kwargs): func_cache[hash_value] = (ret, time, 0, time) return ret - def clear(session_context=None): + def clear(): global func_hash # clear called before anything is cached. if 'func_hash' not in globals(): @@ -448,10 +449,13 @@ def clear(session_context=None): else: cache = state._memoize_cache.get(func_hash, {}) cache.clear() + wrapped_func.clear = clear if per_session and state.curdoc and state.curdoc.session_context: - state.curdoc.on_session_destroyed(clear) + def server_clear(session_context): + clear() + state.curdoc.on_session_destroyed(server_clear) try: wrapped_func.__dict__.update(func.__dict__) diff --git a/panel/tests/io/test_cache.py b/panel/tests/io/test_cache.py index a11ea773b4..589cd27d71 100644 --- a/panel/tests/io/test_cache.py +++ b/panel/tests/io/test_cache.py @@ -3,10 +3,13 @@ import pathlib import time +from collections import Counter + import numpy as np import pandas as pd import param import pytest +import requests try: import diskcache @@ -15,7 +18,8 @@ diskcache_available = pytest.mark.skipif(diskcache is None, reason="requires diskcache") from panel.io.cache import _find_hash_func, cache -from panel.io.state import set_curdoc +from panel.io.state import set_curdoc, state +from panel.tests.util import serve_and_wait ################ # Test hashing # @@ -219,6 +223,26 @@ def test_per_session_cache(document): assert fn(a=0, b=0) == 0 assert fn(a=0, b=0) == 1 +def test_per_session_cache_server(port): + counts = Counter() + + @cache(per_session=True) + def get_data(): + counts[state.curdoc] += 1 + return "Some data" + + def app(): + get_data() + get_data() + return + + serve_and_wait(app, port=port) + + requests.get(f"http://localhost:{port}/") + requests.get(f"http://localhost:{port}/") + + assert list(counts.values()) == [1, 1] + @pytest.mark.xdist_group("cache") @diskcache_available def test_disk_cache():