Skip to content

Commit

Permalink
Convert search code to async/await. (matrix-org#7460)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored and phil-flex committed Jun 16, 2020
1 parent 152a791 commit baa12cb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
1 change: 1 addition & 0 deletions changelog.d/7460.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert the search handler to async/await.
44 changes: 20 additions & 24 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from unpaddedbase64 import decode_base64, encode_base64

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
Expand All @@ -39,8 +37,7 @@ def __init__(self, hs):
self.state_store = self.storage.state
self.auth = hs.get_auth()

@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
async def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
Expand All @@ -60,7 +57,7 @@ def get_old_rooms_from_upgraded_room(self, room_id):
historical_room_ids = []

# The initial room must have been known for us to get this far
predecessor = yield self.store.get_room_predecessor(room_id)
predecessor = await self.store.get_room_predecessor(room_id)

while True:
if not predecessor:
Expand All @@ -75,7 +72,7 @@ def get_old_rooms_from_upgraded_room(self, room_id):

# Don't add it to the list until we have checked that we are in the room
try:
next_predecessor_room = yield self.store.get_room_predecessor(
next_predecessor_room = await self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
Expand All @@ -89,8 +86,7 @@ def get_old_rooms_from_upgraded_room(self, room_id):

return historical_room_ids

@defer.inlineCallbacks
def search(self, user, content, batch=None):
async def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
Expand Down Expand Up @@ -179,7 +175,7 @@ def search(self, user, content, batch=None):
search_filter = Filter(filter_dict)

# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
Expand All @@ -192,7 +188,7 @@ def search(self, user, content, batch=None):
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids

# Prevent any historical events from being filtered
Expand Down Expand Up @@ -223,7 +219,7 @@ def search(self, user, content, batch=None):
count = None

if order_by == "rank":
search_result = yield self.store.search_msgs(room_ids, search_term, keys)
search_result = await self.store.search_msgs(room_ids, search_term, keys)

count = search_result["count"]

Expand All @@ -238,7 +234,7 @@ def search(self, user, content, batch=None):

filtered_events = search_filter.filter([r["event"] for r in results])

events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

Expand Down Expand Up @@ -267,7 +263,7 @@ def search(self, user, content, batch=None):
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
Expand All @@ -288,7 +284,7 @@ def search(self, user, content, batch=None):

filtered_events = search_filter.filter([r["event"] for r in results])

events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

Expand Down Expand Up @@ -343,11 +339,11 @@ def search(self, user, content, batch=None):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token()
now_token = await self.hs.get_event_sources().get_current_token()

contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)

Expand All @@ -357,11 +353,11 @@ def search(self, user, content, batch=None):
len(res["events_after"]),
)

res["events_before"] = yield filter_events_for_client(
res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
)

res["events_after"] = yield filter_events_for_client(
res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
)

Expand Down Expand Up @@ -390,7 +386,7 @@ def search(self, user, content, batch=None):
[(EventTypes.Member, sender) for sender in senders]
)

state = yield self.state_store.get_state_for_event(
state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)

Expand All @@ -412,18 +408,18 @@ def search(self, user, content, batch=None):
time_now = self.clock.time_msec()

for context in contexts.values():
context["events_before"] = yield self._event_serializer.serialize_events(
context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"], time_now
)
context["events_after"] = yield self._event_serializer.serialize_events(
context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"], time_now
)

state_results = {}
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())

state_results.values()
Expand All @@ -437,7 +433,7 @@ def search(self, user, content, batch=None):
{
"rank": rank_map[e.event_id],
"result": (
yield self._event_serializer.serialize_event(e, time_now)
await self._event_serializer.serialize_event(e, time_now)
),
"context": contexts.get(e.event_id, {}),
}
Expand All @@ -452,7 +448,7 @@ def search(self, user, content, batch=None):
if state_results:
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
s[room_id] = await self._event_serializer.serialize_events(
state, time_now
)

Expand Down

0 comments on commit baa12cb

Please sign in to comment.