Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

[WIP] External sharded cache #12955

Closed
wants to merge 11 commits into from
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ jaeger-client = { version = ">=4.0.0", optional = true }
pyjwt = { version = ">=1.6.4", optional = true }
txredisapi = { version = ">=1.4.7", optional = true }
hiredis = { version = "*", optional = true }
jump-consistent-hash = { version = ">=3.2.0", optional = true }
Pympler = { version = "*", optional = true }
parameterized = { version = ">=0.7.4", optional = true }
idna = { version = ">=2.5", optional = true }
Expand All @@ -199,7 +200,7 @@ opentracing = ["jaeger-client", "opentracing"]
jwt = ["pyjwt"]
# hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.)
redis = ["txredisapi", "hiredis"]
redis = ["txredisapi", "hiredis", "jump-consistent-hash"]
# Required to use experimental `caches.track_memory_usage` config option.
cache_memory = ["pympler"]
test = ["parameterized", "idna"]
Expand Down Expand Up @@ -233,7 +234,7 @@ all = [
# jwt
"pyjwt",
# redis
"txredisapi", "hiredis",
"txredisapi", "hiredis", "jump-consistent-hash",
# cache_memory
"pympler",
# omitted:
Expand Down
2 changes: 2 additions & 0 deletions stubs/jump.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def hash(key: int, buckets: int) -> int:
...
2 changes: 2 additions & 0 deletions stubs/txredisapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class RedisProtocol(protocol.Protocol):
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> "Deferred[None]": ...
def mset(self, values: dict[str, Any]) -> "Deferred[Any]": ...
def get(self, key: str) -> "Deferred[Any]": ...
def mget(self, keys: list[str]) -> "Deferred[Any]": ...

class SubscriberProtocol(RedisProtocol):
def __init__(self, *args: object, **kwargs: object): ...
Expand Down
18 changes: 18 additions & 0 deletions synapse/config/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.redis_port = redis_config.get("port", 6379)
self.redis_password = redis_config.get("password")

cache_shard_config = redis_config.get("cache_shards")
if cache_shard_config:
self.cache_shard_hosts = cache_shard_config.get("hosts", [])
self.cache_shard_expire = cache_shard_config.get("expire_caches", False)
self.cache_shard_ttl = cache_shard_config.get("cache_entry_ttl", False)

def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Configuration for Redis when using workers. This *must* be enabled when
Expand All @@ -54,4 +60,16 @@ def generate_config_section(self, **kwargs: Any) -> str:
# Optional password if configured on the Redis instance
#
#password: <secret_password>

# Optional one or more Redis hosts to use for long term sharedd caches.
Fizzadar marked this conversation as resolved.
Show resolved Hide resolved
# Should be configured to automatically expire records when out of
# memory, and not be the same instance as used for replication.
#
#cache_shards:
# enabled: false
# expire_caches: false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not currently supported but could be fairly easily using a Redis transaction block and setting each key one by one.

# cache_entry_ttl: 30m
# hosts:
# - host: localhost
# port: 6379
"""
191 changes: 191 additions & 0 deletions synapse/replication/tcp/external_sharded_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2022 Beeper
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import binascii
import logging
import marshal
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union

import jump
from prometheus_client import Counter, Histogram

from twisted.internet import defer

from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.replication.tcp.redis import lazyConnection
from synapse.util import unwrapFirstError

if TYPE_CHECKING:
from synapse.server import HomeServer

set_counter = Counter(
"synapse_external_sharded_cache_set",
"Number of times we set a cache",
labelnames=["cache_name"],
)

get_counter = Counter(
"synapse_external_sharded_cache_get",
"Number of times we get a cache",
labelnames=["cache_name", "hit"],
)

response_timer = Histogram(
"synapse_external_sharded_cache_response_time_seconds",
"Time taken to get a response from Redis for a cache get/set request",
labelnames=["method"],
buckets=(
0.001,
0.002,
0.005,
0.01,
0.02,
0.05,
),
)


logger = logging.getLogger(__name__)


class ExternalShardedCache:
"""A cache backed by an external Redis. Does nothing if no Redis is
configured.
"""

def __init__(self, hs: "HomeServer"):
self._redis_shards = []

if hs.config.redis.redis_enabled and hs.config.redis.cache_shard_hosts:
for shard in hs.config.redis.cache_shard_hosts:
logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
shard["host"],
shard["port"],
)
self._redis_shards.append(
lazyConnection(
hs=hs,
host=shard["host"],
port=shard["port"],
reconnect=True,
),
)

def _get_redis_key(self, cache_name: str, key: str) -> str:
return "sharded_cache_v1:%s:%s" % (cache_name, key)

def _get_redis_shard_id(self, redis_key: str) -> int:
key = binascii.crc32(redis_key.encode()) & 0xFFFFFFFF
idx = jump.hash(key, len(self._redis_shards))
return idx

def is_enabled(self) -> bool:
"""Whether the external cache is used or not.

It's safe to use the cache when this returns false, the methods will
just no-op, but the function is useful to avoid doing unnecessary work.
"""
return bool(self._redis_shards)

async def mset(
self,
cache_name: str,
values: dict[str, Any],
) -> None:
"""Add the key/value combinations to the named cache, with the expiry time given."""

if not self.is_enabled():
return

set_counter.labels(cache_name).inc(len(values))

logger.debug("Caching %s: %r", cache_name, values)

shard_id_to_encoded_values: dict[int, dict[str, Any]] = defaultdict(dict)

for key, value in values.items():
redis_key = self._get_redis_key(cache_name, key)
shard_id = self._get_redis_shard_id(redis_key)
shard_id_to_encoded_values[shard_id][redis_key] = marshal.dumps(value)

with opentracing.start_active_span(
"ExternalShardedCache.set",
tags={opentracing.SynapseTags.CACHE_NAME: cache_name},
):
with response_timer.labels("set").time():
deferreds = [
self._redis_shards[shard_id].mset(values)
for shard_id, values in shard_id_to_encoded_values.items()
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
).addErrback(unwrapFirstError)

async def set(self, cache_name: str, key: str, value: Any) -> None:
await self.mset(cache_name, {key: value})

async def _mget_shard(
self, shard_id: int, key_mapping: dict[str, str]
) -> dict[str, Any]:
results = await self._redis_shards[shard_id].mget(list(key_mapping.values()))
original_keys = list(key_mapping.keys())
mapped_results: dict[str, Any] = {}
for i, result in enumerate(results):
if result:
result = marshal.loads(result)
mapped_results[original_keys[i]] = result
return mapped_results

async def mget(self, cache_name: str, keys: Iterable[str]) -> dict[str, Any]:
"""Look up a key/value combinations in the named cache."""

if not self.is_enabled():
return None

shard_id_to_key_mapping: dict[int, dict[str, str]] = defaultdict(dict)

for key in keys:
redis_key = self._get_redis_key(cache_name, key)
shard_id = self._get_redis_shard_id(redis_key)
shard_id_to_key_mapping[shard_id][key] = redis_key

with opentracing.start_active_span(
"ExternalShardedCache.get",
tags={opentracing.SynapseTags.CACHE_NAME: cache_name},
):
with response_timer.labels("get").time():
deferreds = [
defer.ensureDeferred(self._mget_shard(shard_id, keys))
for shard_id, keys in shard_id_to_key_mapping.items()
]
results: Union[
list, list[dict[str, Any]]
] = await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) # type: ignore
).addErrback(
unwrapFirstError
)

combined_results: dict[str, Any] = {}
for result in results:
combined_results.update(result)

logger.debug("Got cache result %s %s: %r", cache_name, keys, combined_results)

get_counter.labels(cache_name, result is not None).inc()

return combined_results
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.external_sharded_cache import ExternalShardedCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
Expand Down Expand Up @@ -771,6 +772,10 @@ def get_event_auth_handler(self) -> EventAuthHandler:
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)

@cache_in_self
def get_external_sharded_cache(self) -> ExternalShardedCache:
return ExternalShardedCache(self)

@cache_in_self
def get_account_handler(self) -> AccountHandler:
return AccountHandler(self)
Expand Down
44 changes: 35 additions & 9 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,15 +830,41 @@ async def _get_joined_profiles_from_event_ids(
Map from event ID to `user_id` and ProfileInfo (or None if not join event).
"""

rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_profiles_from_event_ids",
)
sharded_cache = self.hs.get_external_sharded_cache()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all feels ripe for some kind of wrapper - I had a brief look at cachedList but I don't think it should live there because it's already pretty complex! Maybe some kind of util wrapper? Or out of scope for this PR?

sharded_cache_enabled = sharded_cache.is_enabled()

missing = []
rows = []

if sharded_cache_enabled:
event_id_to_row = await sharded_cache.mget(
"_get_joined_profile_from_event_id", event_ids
)
for event_id, row in event_id_to_row.items():
if row:
rows.append(row)
else:
missing.append(event_id)
else:
missing = list(event_ids)

if missing:
missing_rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_profiles_from_event_ids",
)
rows += missing_rows

if sharded_cache_enabled and missing_rows:
await sharded_cache.mset(
"_get_joined_profile_from_event_id",
{row["event_id"]: row for row in missing_rows},
)

return {
row["event_id"]: (
Expand Down