Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert search code to async/await. #7460

Merged
merged 1 commit into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7460.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert the search handler to async/await.
44 changes: 20 additions & 24 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from unpaddedbase64 import decode_base64, encode_base64

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
Expand All @@ -39,8 +37,7 @@ def __init__(self, hs):
self.state_store = self.storage.state
self.auth = hs.get_auth()

@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
async def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.

We do so by checking the m.room.create event of the room for a
Expand All @@ -60,7 +57,7 @@ def get_old_rooms_from_upgraded_room(self, room_id):
historical_room_ids = []

# The initial room must have been known for us to get this far
predecessor = yield self.store.get_room_predecessor(room_id)
predecessor = await self.store.get_room_predecessor(room_id)

while True:
if not predecessor:
Expand All @@ -75,7 +72,7 @@ def get_old_rooms_from_upgraded_room(self, room_id):

# Don't add it to the list until we have checked that we are in the room
try:
next_predecessor_room = yield self.store.get_room_predecessor(
next_predecessor_room = await self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
Expand All @@ -89,8 +86,7 @@ def get_old_rooms_from_upgraded_room(self, room_id):

return historical_room_ids

@defer.inlineCallbacks
def search(self, user, content, batch=None):
async def search(self, user, content, batch=None):
"""Performs a full text search for a user.

Args:
Expand Down Expand Up @@ -179,7 +175,7 @@ def search(self, user, content, batch=None):
search_filter = Filter(filter_dict)

# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
Expand All @@ -192,7 +188,7 @@ def search(self, user, content, batch=None):
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids

# Prevent any historical events from being filtered
Expand Down Expand Up @@ -223,7 +219,7 @@ def search(self, user, content, batch=None):
count = None

if order_by == "rank":
search_result = yield self.store.search_msgs(room_ids, search_term, keys)
search_result = await self.store.search_msgs(room_ids, search_term, keys)

count = search_result["count"]

Expand All @@ -238,7 +234,7 @@ def search(self, user, content, batch=None):

filtered_events = search_filter.filter([r["event"] for r in results])

events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

Expand Down Expand Up @@ -267,7 +263,7 @@ def search(self, user, content, batch=None):
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
Expand All @@ -288,7 +284,7 @@ def search(self, user, content, batch=None):

filtered_events = search_filter.filter([r["event"] for r in results])

events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

Expand Down Expand Up @@ -343,11 +339,11 @@ def search(self, user, content, batch=None):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token()
now_token = await self.hs.get_event_sources().get_current_token()

contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)

Expand All @@ -357,11 +353,11 @@ def search(self, user, content, batch=None):
len(res["events_after"]),
)

res["events_before"] = yield filter_events_for_client(
res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
)

res["events_after"] = yield filter_events_for_client(
res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
)

Expand Down Expand Up @@ -390,7 +386,7 @@ def search(self, user, content, batch=None):
[(EventTypes.Member, sender) for sender in senders]
)

state = yield self.state_store.get_state_for_event(
state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)

Expand All @@ -412,18 +408,18 @@ def search(self, user, content, batch=None):
time_now = self.clock.time_msec()

for context in contexts.values():
context["events_before"] = yield self._event_serializer.serialize_events(
context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"], time_now
)
context["events_after"] = yield self._event_serializer.serialize_events(
context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"], time_now
)

state_results = {}
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())

state_results.values()
Expand All @@ -437,7 +433,7 @@ def search(self, user, content, batch=None):
{
"rank": rank_map[e.event_id],
"result": (
yield self._event_serializer.serialize_event(e, time_now)
await self._event_serializer.serialize_event(e, time_now)
),
"context": contexts.get(e.event_id, {}),
}
Expand All @@ -452,7 +448,7 @@ def search(self, user, content, batch=None):
if state_results:
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
s[room_id] = await self._event_serializer.serialize_events(
state, time_now
)

Expand Down