diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9772ed1a0e7c..3ead2c9a5b71 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -157,7 +157,7 @@ def search(self, user, content, batch=None): if order_by == "rank": search_result = yield self.store.search_msgs( - room_ids, search_term, keys + user.to_string(), room_ids, search_term, keys ) count = search_result["count"] @@ -205,7 +205,7 @@ def search(self, user, content, batch=None): while len(room_events) < search_filter.limit() and i < 5: i += 1 search_result = yield self.store.search_rooms( - room_ids, search_term, keys, search_filter.limit() * 2, + user.to_string(), room_ids, search_term, keys, search_filter.limit() * 2, pagination_token=pagination_token, ) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 479b04c63697..011f6338e3bb 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -243,12 +243,24 @@ def reindex_search_txn(txn): defer.returnValue(num_rows) @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): + def _find_starred_events(self, user_id, room_ids): + starred = [] + for room_id in room_ids: + account_data = yield self.get_account_data_for_room(user_id, room_id) + room_starred = account_data.get("m.room.starred_events", None) + if room_starred: + starred.extend(room_starred["starred"]) + + defer.returnValue(starred) + + @defer.inlineCallbacks + def search_msgs(self, user_id, room_ids, search_term, keys): """Performs a full text search over events with given keys. Args: + user_id (str): User id of searcher room_ids (list): List of room ids to search in - search_term (str): Search term to search for + search_term (str): Search term to search for, may contain expressions keys (list): List of keys to search in, currently supports "content.body", "content.name", "content.topic" @@ -257,13 +269,33 @@ def search_msgs(self, room_ids, search_term, keys): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + query_expr_map = _parse_query(search_term) + + search_query = _query_words_to_clauses(self.database_engine, query_expr_map["words"]) args = [] + bypass_room_id_filtering = False + bypass_words_matching = False + if "starred" in query_expr_map["tags"]: + bypass_room_id_filtering = True + event_ids = yield self._find_starred_events(user_id, room_ids) + if not event_ids: + defer.returnValue({ + "results": [], + "count": 0, + "highlights": [], + }) + clauses.append( + "event_id IN (%s)" % (",".join(["?"] * len(event_ids)),) + ) + args.extend(event_ids) + if not search_query: + bypass_words_matching = True + # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. - if len(room_ids) < 500: + if not bypass_room_id_filtering and len(room_ids) < 500: clauses.append( "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) ) @@ -281,52 +313,58 @@ def search_msgs(self, room_ids, search_term, keys): count_args = args count_clauses = clauses - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) - args = [search_query, search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) - count_args = [search_query] + count_args - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) - args = [search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) - count_args = [search_term] + count_args + if bypass_words_matching: + sql = "SELECT room_id, event_id FROM event_search WHERE " + count_sql = "SELECT room_id, count(*) as count FROM event_search WHERE " else: - # This should be unreachable. - raise Exception("Unrecognized database engine") + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + " room_id, event_id" + " FROM event_search WHERE " + ) + clause.append("vector @@ to_tsquery('english', ?)") + args = [search_query] + args + [search_query] + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search WHERE " + ) + count_clauses.append("vector @@ to_tsquery('english', ?)") + count_args.append(search_query) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " FROM event_search WHERE " + ) + clauses.append("value MATCH ?") + args.append(search_query) + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search WHERE " + ) + count_clauses.append("value MATCH ?") + count_args.append(search_term) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") - for clause in clauses: - sql += " AND " + clause + sql += " AND ".join(clauses) - for clause in count_clauses: - count_sql += " AND " + clause + count_sql += " AND ".join(count_clauses) + + if not bypass_words_matching: + sql += " ORDER BY rank DESC" # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. - sql += " ORDER BY rank DESC LIMIT 500" + sql += " LIMIT 500" results = yield self._execute( "search_msgs", self.cursor_to_dict, sql, *args ) - results = filter(lambda row: row["room_id"] in room_ids, results) + if not bypass_room_id_filtering: + results = filter(lambda row: row["room_id"] in room_ids, results) events = yield self._get_events([r["event_id"] for r in results]) @@ -336,7 +374,7 @@ def search_msgs(self, room_ids, search_term, keys): } highlights = None - if isinstance(self.database_engine, PostgresEngine): + if not bypass_words_matching and isinstance(self.database_engine, PostgresEngine): highlights = yield self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -351,7 +389,7 @@ def search_msgs(self, room_ids, search_term, keys): "results": [ { "event": event_map[r["event_id"]], - "rank": r["rank"], + "rank": event_map[r["event_id"]]["origin_server_ts"] if bypass_words_matching else r["rank"], } for r in results if r["event_id"] in event_map @@ -361,12 +399,13 @@ def search_msgs(self, room_ids, search_term, keys): }) @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + def search_rooms(self, user_id, room_ids, search_term, keys, limit, pagination_token=None): """Performs a full text search over events with given keys. Args: + user_id (str): User id of searcher room_id (list): The room_ids to search in - search_term (str): Search term to search for + search_term (str): Search term to search for, may contain expressions keys (list): List of keys to search in, currently supports "content.body", "content.name", "content.topic" pagination_token (str): A pagination token previously returned @@ -376,13 +415,33 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + query_expr_map = _parse_query(search_term) + + search_query = _query_words_to_clauses(self.database_engine, query_expr_map["words"]) args = [] + bypass_room_id_filtering = False + bypass_words_matching = False + if "starred" in query_expr_map["tags"]: + bypass_room_id_filtering = True + event_ids = yield self._find_starred_events(user_id, room_ids) + if not event_ids: + defer.returnValue({ + "results": [], + "count": 0, + "highlights": [], + }) + clauses.append( + "event_id IN (%s)" % (",".join(["?"] * len(event_ids)),) + ) + args.extend(event_ids) + if not search_query: + bypass_words_matching = True + # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. - if len(room_ids) < 500: + if not bypass_room_id_filtering and len(room_ids) < 500: clauses.append( "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) ) @@ -416,49 +475,65 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None ) args.extend([origin_server_ts, origin_server_ts, stream]) - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) - args = [search_query, search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) - count_args = [search_query] + count_args - elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. - # https://sqlite.org/optoverview.html#crossjoin - # - # We want to use the full text search index on event_search to - # extract all possible matches first, then lookup those matches - # in the events table to get the topological ordering. We need - # to use the indexes in this order because sqlite refuses to - # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " - ) - args = [search_query] + args + if bypass_words_matching: + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT origin_server_ts, stream_ordering, room_id, event_id" + " FROM event_search WHERE " + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT room_id, event_id, origin_server_ts, stream_ordering" + " FROM (SELECT key, event_id FROM event_search)" + " CROSS JOIN events USING (event_id)" + " WHERE " + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) - count_args = [search_term] + count_args + count_sql = "SELECT room_id, count(*) as count FROM event_search WHERE " else: - # This should be unreachable. - raise Exception("Unrecognized database engine") + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + " origin_server_ts, stream_ordering, room_id, event_id" + " FROM event_search WHERE " + ) + clauses.append("vector @@ to_tsquery('english', ?)") + args = [search_query] + args + [search_query] + + count_sql = "SELECT room_id, count(*) as count FROM event_search WHERE " + count_clauses.append("vector @@ to_tsquery('english', ?)") + count_args.append(search_query) + elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. + # https://sqlite.org/optoverview.html#crossjoin + # + # We want to use the full text search index on event_search to + # extract all possible matches first, then lookup those matches + # in the events table to get the topological ordering. We need + # to use the indexes in this order because sqlite refuses to + # MATCH unless it uses the full text search index + sql = ( + "SELECT rank(matchinfo) as rank, room_id, event_id," + " origin_server_ts, stream_ordering" + " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" + " FROM event_search" + " WHERE value MATCH ?" + " )" + " CROSS JOIN events USING (event_id)" + " WHERE " + ) + args = [search_query] + args + + count_sql = "SELECT room_id, count(*) as count FROM event_search WHERE " + + count_clauses.append("value MATCH ?") + count_args.append(search_term) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") sql += " AND ".join(clauses) count_sql += " AND ".join(count_clauses) @@ -481,7 +556,8 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None "search_rooms", self.cursor_to_dict, sql, *args ) - results = filter(lambda row: row["room_id"] in room_ids, results) + if not bypass_room_id_filtering: + results = filter(lambda row: row["room_id"] in room_ids, results) events = yield self._get_events([r["event_id"] for r in results]) @@ -491,7 +567,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None } highlights = None - if isinstance(self.database_engine, PostgresEngine): + if not bypass_words_matching and isinstance(self.database_engine, PostgresEngine): highlights = yield self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -506,7 +582,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None "results": [ { "event": event_map[r["event_id"]], - "rank": r["rank"], + "rank": event_map[r["event_id"]]["origin_server_ts"] if bypass_words_matching else r["rank"], "pagination_token": "%s,%s" % ( r["origin_server_ts"], r["stream_ordering"] ), @@ -589,20 +665,44 @@ def _to_postgres_options(options_dict): ) -def _parse_query(database_engine, search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. +def _parse_query(search_term): + """Parse search query string from the user and return a query expressions map. + The query string may contain: + - plain words + - tag expressions: e.g. ':starred :read-latter' + - key value pairs: e.g. 'before: 17/01/22 after: 17/02/01' + The return map contains: + - "words": list of plain words + - "tags": list of tags + - "criteria": list of search criteria + """ + + exprs = search_term.split() + expr_map = { "words": [], "tags": [], "criteria": {} } + + for expr in exprs: + kv = expr.split(":") + if len(kv) == 1: + expr_map["words"].append(kv[0]) + elif not kv[0]: + expr_map["tags"].append(kv[1]) + else: + expr_map["criteria"][kv[0]] = kv[1] + + return expr_map + + +def _query_words_to_clauses(database_engine, words): + """Takes a list of plain unicode string words from the user and converts it + into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something that is supported by default. """ - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) + return " & ".join(w + ":*" for w in words) elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) + return " & ".join(w + "*" for w in words) else: # This should be unreachable. raise Exception("Unrecognized database engine")