Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use ClientPool to prevent race conditions when using pylibmc as memcached package #287

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 71 additions & 16 deletions src/cachelib/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,25 @@ def __init__(
servers: _t.Any = None,
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
threads: int = 1,
blocking: bool = False,
):
BaseCache.__init__(self, default_timeout)

self.pylibmc_used = False

if servers is None or isinstance(servers, (list, tuple)):
if servers is None:
servers = ["127.0.0.1:11211"]
self._client = self.import_preferred_memcache_lib(servers)
self._client = self.import_preferred_memcache_lib(servers, threads)
if self._client is None:
raise RuntimeError("no memcache module found")
else:
# NOTE: servers is actually an already initialized memcache
# client.
self._client = servers

self.blocking = blocking
self.key_prefix = key_prefix

def _normalize_key(self, key: str) -> str:
Expand All @@ -81,7 +87,11 @@ def get(self, key: str) -> _t.Any:
# checks for so long keys can occur because it's tested from user
# submitted data etc we fail silently for getting.
if _test_memcached_key(key):
return self._client.get(key)
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return mc.get(self._normalize_key(key))
else:
return self._client.get(key)

def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
key_mapping = {}
Expand All @@ -90,7 +100,11 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
if _test_memcached_key(key):
key_mapping[encoded_key] = key
_keys = list(key_mapping)
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
d = rv = mc.get_multi(_keys) # type: _t.Dict[str, _t.Any]
else:
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
if self.key_prefix:
rv = {}
for key, value in d.items():
Expand All @@ -104,14 +118,22 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
key = self._normalize_key(key)
timeout = self._normalize_timeout(timeout)
return bool(self._client.add(key, value, timeout))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.add(key, value, timeout))
else:
return bool(self._client.add(key, value, timeout))

def set(
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
) -> _t.Optional[bool]:
key = self._normalize_key(key)
timeout = self._normalize_timeout(timeout)
return bool(self._client.set(key, value, timeout))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.set(key, value, timeout))
else:
return bool(self._client.set(key, value, timeout))

def get_many(self, *keys: str) -> _t.List[_t.Any]:
d = self.get_dict(*keys)
Expand All @@ -126,16 +148,26 @@ def set_many(
new_mapping[key] = value

timeout = self._normalize_timeout(timeout)
failed_keys = self._client.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
failed_keys = mc.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
else:
failed_keys = self._client.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
k_normkey = zip(mapping.keys(), new_mapping.keys()) # noqa: B905
return [k for k, nkey in k_normkey if nkey not in failed_keys]

def delete(self, key: str) -> bool:
key = self._normalize_key(key)
if _test_memcached_key(key):
return bool(self._client.delete(key))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.delete(key))
else:
return bool(self._client.delete(key))
return False

def delete_many(self, *keys: str) -> _t.List[_t.Any]:
Expand All @@ -144,36 +176,59 @@ def delete_many(self, *keys: str) -> _t.List[_t.Any]:
key = self._normalize_key(key)
if _test_memcached_key(key):
new_keys.append(key)
self._client.delete_multi(new_keys)
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
mc.delete_multi(new_keys)
else:
self._client.delete_multi(new_keys)
return [k for k in new_keys if not self.has(k)]

def has(self, key: str) -> bool:
key = self._normalize_key(key)
if _test_memcached_key(key):
return bool(self._client.append(key, ""))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.append(key, ""))
else:
return bool(self._client.append(key, ""))
return False

def clear(self) -> bool:
return bool(self._client.flush_all())
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.flush_all())
else:
return bool(self._client.flush_all())

def inc(self, key: str, delta: int = 1) -> _t.Optional[int]:
key = self._normalize_key(key)
value = (self._client.get(key) or 0) + delta
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
value = (mc.get(key) or 0) + delta
else:
value = (self._client.get(key) or 0) + delta
return value if self.set(key, value) else None

def dec(self, key: str, delta: int = 1) -> _t.Optional[int]:
key = self._normalize_key(key)
value = (self._client.get(key) or 0) - delta
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
value = (mc.get(key) or 0) - delta
else:
value = (self._client.get(key) or 0) - delta
return value if self.set(key, value) else None

def import_preferred_memcache_lib(self, servers: _t.Any) -> _t.Any:
def import_preferred_memcache_lib(self, servers: _t.Any, threads: int) -> _t.Any:
"""Returns an initialized memcache client. Used by the constructor."""
try:
import pylibmc # type: ignore
except ImportError:
pass
else:
return pylibmc.Client(servers)
self.pylibmc_used = True
_client_pool = pylibmc.ClientPool()
_client_pool.fill(pylibmc.Client(servers), threads)
return _client_pool

try:
from google.appengine.api import memcache # type: ignore
Expand Down
6 changes: 5 additions & 1 deletion tests/test_interface_uniformity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
@pytest.fixture(autouse=True)
def create_cache_list(request, tmpdir):
mc = MemcachedCache()
mc._client.flush_all()
if mc.pylibmc_used:
with mc._client.reserve(block=mc.blocking) as client:
client.flush_all()
else:
mc._client.flush_all()
rc = RedisCache(port=6360)
rc._write_client.flushdb()
request.cls.cache_list = [FileSystemCache(tmpdir), mc, rc, SimpleCache()]
Expand Down
6 changes: 5 additions & 1 deletion tests/test_memcached_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
def cache_factory(request):
def _factory(self, *args, **kwargs):
mc = MemcachedCache(*args, **kwargs)
mc._client.flush_all()
if mc.pylibmc_used:
with mc._client.reserve(block=mc.blocking) as client:
client.flush_all()
else:
mc._client.flush_all()
return mc

request.cls.cache_factory = _factory
Expand Down