Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fixup types
Browse files Browse the repository at this point in the history
  • Loading branch information
Half-Shot committed Oct 1, 2020
1 parent d910534 commit 97d1739
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 28 deletions.
27 changes: 19 additions & 8 deletions synapse/appservice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.events import EventBase
from synapse.types import GroupID, UserID, get_domain_from_id
from synapse.util.caches.descriptors import _CacheContext, cached

if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
Expand All @@ -35,7 +36,13 @@ class ApplicationServiceState:
class AppServiceTransaction:
"""Represents an application service transaction."""

def __init__(self, service, id, events, ephemeral=None):
def __init__(
self,
service: ApplicationService,
id: int,
events: List[EventBase],
ephemeral=None,
):
self.service = service
self.id = id
self.events = events
Expand Down Expand Up @@ -198,9 +205,11 @@ async def _matches_user(self, event, store):
return does_match

@cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(self, room_id, store, cache_context):
async def matches_user_in_member_list(
self, room_id: str, store: DataStore, cache_context: _CacheContext
):
member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
room_id
)

# check joined member events
Expand Down Expand Up @@ -246,15 +255,17 @@ async def is_interested(self, event, store=None) -> bool:
return False

@cached(num_args=1, cache_context=True)
async def is_interested_in_presence(self, user_id, store, cache_context):
async def is_interested_in_presence(
self, user_id: UserID, store: DataStore, cache_context: _CacheContext
):
# Find all the rooms the sender is in
if self.is_interested_in_user(user_id.to_string()):
return True
room_ids = await store.get_rooms_for_user(user_id.to_string())

# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:
if await self.matches_user_in_member_list(room_id, store, cache_context):
if await self.matches_user_in_member_list(room_id, store):
return True
return False

Expand Down
16 changes: 11 additions & 5 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, List, Optional

from prometheus_client import Counter

from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
Expand Down Expand Up @@ -201,7 +202,13 @@ async def _get() -> Optional[JsonDict]:
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)

async def push_bulk(self, service, events, ephemeral=None, txn_id=None):
async def push_bulk(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[Any] = None,
txn_id: Optional[int] = None,
):
if service.url is None:
return True

Expand All @@ -211,10 +218,9 @@ async def push_bulk(self, service, events, ephemeral=None, txn_id=None):
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
txn_id = str(0)
txn_id = str(txn_id)
txn_id = 0

uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
body = {"events": events}
if ephemeral:
body["de.sorunome.msc2409.ephemeral"] = ephemeral
Expand Down
31 changes: 22 additions & 9 deletions synapse/appservice/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@
components.
"""
import logging
from typing import Any, List, Optional

from synapse.appservice import ApplicationServiceState
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process

Expand Down Expand Up @@ -82,10 +84,12 @@ async def start(self):
for service in services:
self.txn_ctrl.start_recoverer(service)

def submit_event_for_as(self, service, event):
def submit_event_for_as(self, service: ApplicationService, event: EventBase):
self.queuer.enqueue(service, event)

def submit_ephemeral_events_for_as(self, service, events):
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[Any]
):
self.queuer.enqueue_ephemeral(service, events)


Expand All @@ -99,7 +103,7 @@ class _ServiceQueuer:

def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}

# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
Expand All @@ -118,7 +122,7 @@ def enqueue(self, service, event):
"as-sender-%s" % (service.id), self._send_request, service
)

def enqueue_ephemeral(self, service, events):
def enqueue_ephemeral(self, service: ApplicationService, events: List[Any]):
self.queued_ephemeral.setdefault(service.id, []).extend(events)

# start a sender for this appservice if we don't already have one
Expand All @@ -130,7 +134,9 @@ def enqueue_ephemeral(self, service, events):
"as-sender-%s" % (service.id), self._send_request, service
)

async def _send_request(self, service, ephemeral=None):
async def _send_request(
self, service: ApplicationService, ephemeral: Optional[Any] = None
):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
Expand Down Expand Up @@ -175,9 +181,16 @@ def __init__(self, clock, store, as_api):
# for UTs
self.RECOVERER_CLASS = _Recoverer

async def send(self, service, events, ephemeral=None):
async def send(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[Any] = None,
):
try:
txn = await self.store.create_appservice_txn(service=service, events=events, ephemeral=ephemeral)
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral
)
service_is_up = await self.is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
Expand Down Expand Up @@ -221,7 +234,7 @@ def start_recoverer(self, service):
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))

async def is_service_up(self, service):
async def is_service_up(self, service: ApplicationService):
state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None

Expand Down
5 changes: 4 additions & 1 deletion synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import logging

from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
Expand Down Expand Up @@ -140,7 +141,9 @@ async def get_new_events(self, from_key, room_ids, **kwargs):

return (events, to_key)

async def get_new_events_as(self, from_key, service, **kwargs):
async def get_new_events_as(
self, from_key: int, service: ApplicationService, **kwargs
):
from_key = int(from_key)
to_key = self.get_current_key()

Expand Down
17 changes: 13 additions & 4 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# limitations under the License.
import logging
import re
from typing import Any, List, Optional

from synapse.appservice import AppServiceTransaction
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
Expand Down Expand Up @@ -172,7 +174,12 @@ async def set_appservice_state(self, service, state) -> None:
"application_services_state", {"as_id": service.id}, {"state": state}
)

async def create_appservice_txn(self, service, events, ephemeral=None):
async def create_appservice_txn(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[Any] = None,
):
"""Atomically creates a new transaction for this application service
with the given list of events.
Expand Down Expand Up @@ -353,7 +360,9 @@ def get_new_events_for_appservice_txn(txn):

return upper_bound, events

async def get_type_stream_id_for_appservice(self, service, type: str) -> int:
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
def get_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
Expand All @@ -371,7 +380,7 @@ def get_type_stream_id_for_appservice_txn(txn):
)

async def set_type_stream_id_for_appservice(
self, service, type: str, pos: int
self, service: ApplicationService, type: str, pos: int
) -> None:
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
Expand Down
4 changes: 3 additions & 1 deletion synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def f(txn):
return results

@cached(num_args=2,)
async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
async def _get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
):
def f(txn):
if from_key:
sql = """
Expand Down

0 comments on commit 97d1739

Please sign in to comment.