Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge commit '208e1d3eb' into anoa/dinsic_release_1_21_x
Browse files Browse the repository at this point in the history
* commit '208e1d3eb':
  Fix typing for `@cached` wrapped functions (#8240)
  Remove useless changelog about reverting a #8239.
  Revert pinning of setuptools (#8239)
  Fix typing for SyncHandler (#8237)
  wrap `_get_e2e_device_keys_and_signatures_txn` in a non-txn method (#8231)
  Add an overload for simple_select_one_onecol_txn. (#8235)
  • Loading branch information
anoadragon453 committed Oct 20, 2020
2 parents 2df215d + 208e1d3 commit 255860b
Show file tree
Hide file tree
Showing 17 changed files with 200 additions and 53 deletions.
2 changes: 1 addition & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ mkdir -p ~/synapse
virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions
pip install --upgrade setuptools
pip install matrix-synapse
```

Expand Down
1 change: 0 additions & 1 deletion changelog.d/8212.bugfix

This file was deleted.

1 change: 1 addition & 0 deletions changelog.d/8231.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
1 change: 1 addition & 0 deletions changelog.d/8235.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `StreamStore`.
1 change: 1 addition & 0 deletions changelog.d/8237.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints in `SyncHandler`.
1 change: 1 addition & 0 deletions changelog.d/8240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints for functions decorated with `@cached`.
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent
check_untyped_defs = True
show_error_codes = True
Expand Down Expand Up @@ -51,6 +51,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,
Expand Down
85 changes: 85 additions & 0 deletions scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This is a mypy plugin for Synpase to deal with some of the funky typing that
can crop up, e.g the cache descriptors.
"""

from typing import Callable, Optional

from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType


class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
):
return cached_function_method_signature
return None


def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
"""

# First we mark this as a bound function signature.
signature = bind_self(ctx.default_signature)

# Secondly, we remove any "cache_context" args.
#
# Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world.
context_arg_index = None
for idx, name in enumerate(signature.arg_names):
if name == "cache_context":
context_arg_index = idx
break

if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)

arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)

arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)

signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)

return signature


def plugin(version: str):
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
# However, since we pin the version of mypy Synapse uses in CI, we don't
# really care.
return SynapsePlugin
10 changes: 5 additions & 5 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,11 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
if not prevs - seen:
return

latest = await self.store.get_latest_event_ids_in_room(room_id)
latest_list = await self.store.get_latest_event_ids_in_room(room_id)

# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest = set(latest_list)
latest |= seen

logger.info(
Expand Down Expand Up @@ -784,7 +784,7 @@ async def _process_received_pdu(
# keys across all devices.
current_keys = [
key
for device in cached_devices
for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]

Expand Down Expand Up @@ -2129,8 +2129,8 @@ async def _check_for_soft_fail(
if backfilled or event.internal_metadata.is_outlier():
return

extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())

if extrem_ids == prev_event_ids:
Expand Down
12 changes: 7 additions & 5 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import itertools
import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple

import attr
from prometheus_client import Counter
Expand Down Expand Up @@ -44,6 +44,9 @@
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

# Debug logger for https://github.com/matrix-org/synapse/issues/4422
Expand Down Expand Up @@ -244,7 +247,7 @@ def __nonzero__(self) -> bool:


class SyncHandler(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
Expand Down Expand Up @@ -717,9 +720,8 @@ async def compute_summary(
]

missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()

for s in missing_hero_state:
for s in missing_hero_state.values():
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s

Expand Down Expand Up @@ -1771,7 +1773,7 @@ async def _generate_room_entry(
ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[List[JsonDict]],
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
Expand Down
4 changes: 0 additions & 4 deletions synapse/python_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
# setuptools is required by a variety of dependencies, unfortunately version
# 50.0 is incompatible with older Python versions, see
# https://github.com/pypa/setuptools/issues/2352
"setuptools!=50.0",
]

CONDITIONAL_REQUIREMENTS = {
Expand Down
24 changes: 24 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,30 @@ async def simple_select_one_onecol(
allow_none=allow_none,
)

@overload
@classmethod
def simple_select_one_onecol_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
) -> Any:
...

@overload
@classmethod
def simple_select_one_onecol_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
) -> Optional[Any]:
...

@classmethod
def simple_select_one_onecol_txn(
cls,
Expand Down
4 changes: 1 addition & 3 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ async def _get_device_update_edus_by_remote(
List of objects representing an device update EDU
"""
devices = (
await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
Expand Down
52 changes: 38 additions & 14 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
"""The type returned by get_e2e_device_keys_and_signatures"""

display_name = attr.ib(type=Optional[str])

Expand All @@ -60,11 +60,7 @@ async def get_e2e_device_keys_for_federation_query(
"""
now_stream_id = self.get_device_stream_token()

devices = await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
[(user_id, None)],
)
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])

if devices:
user_devices = devices[user_id]
Expand Down Expand Up @@ -108,11 +104,7 @@ async def get_e2e_device_keys_for_cs_api(
if not query_list:
return {}

results = await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
)
results = await self.get_e2e_device_keys_and_signatures(query_list)

# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
Expand All @@ -135,12 +127,45 @@ async def get_e2e_device_keys_for_cs_api(
return rv

@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
async def get_e2e_device_keys_and_signatures(
self,
query_list: List[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
to indicate "all devices for this user"
include_all_devices: whether to return devices without device keys
include_deleted_devices: whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data.
"""
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)

result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
include_all_devices,
include_deleted_devices,
)

log_kv(result)
return result

def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
query_clauses = []
query_params = []
signature_query_clauses = []
Expand Down Expand Up @@ -230,7 +255,6 @@ def _get_e2e_device_keys_and_signatures_txn(
)
signing_user_signatures[signing_key_id] = signature

log_kv(result)
return result

async def get_e2e_one_time_keys(
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ async def get_invite_for_local_user_in_room(
return None

async def get_rooms_for_local_user_where_membership_is(
self, user_id: str, membership_list: List[str]
) -> Optional[List[RoomsForUser]]:
self, user_id: str, membership_list: Collection[str]
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Expand All @@ -314,7 +314,7 @@ async def get_rooms_for_local_user_where_membership_is(
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
return None
return []

rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)

tags_by_room = {}
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
Expand Down Expand Up @@ -123,7 +123,7 @@ def get_tag_content(txn, tag_ids):

async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, List[str]]:
) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
Expand Down
Loading

0 comments on commit 255860b

Please sign in to comment.