Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move methods into PushRuleEvaluatorForEvent. #12677

Merged
merged 5 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12677.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor functions to on `PushRuleEvaluatorForEvent`.
32 changes: 2 additions & 30 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,6 @@ async def action_for_event_by_user(
event, len(room_members), sender_power_level, power_levels
)

condition_cache: Dict[str, bool] = {}

# If the event is not a state event check if any users ignore the sender.
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
Expand Down Expand Up @@ -247,8 +245,8 @@ async def action_for_event_by_user(
if "enabled" in rule and not rule["enabled"]:
continue

matches = _condition_checker(
evaluator, rule["conditions"], uid, display_name, condition_cache
matches = evaluator.check_conditions(
rule["conditions"], uid, display_name
)
if matches:
actions = [x for x in rule["actions"] if x != "dont_notify"]
Expand All @@ -267,32 +265,6 @@ async def action_for_event_by_user(
)


def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
display_name: Optional[str],
cache: Dict[str, bool],
) -> bool:
for cond in conditions:
_cache_key = cond.get("_cache_key", None)
if _cache_key:
res = cache.get(_cache_key, None)
if res is False:
return False
elif res is True:
continue

res = evaluator.matches(cond, uid, display_name)
if _cache_key:
cache[_cache_key] = bool(res)

if not res:
return False

return True


MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
Expand Down
70 changes: 66 additions & 4 deletions synapse/push/push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,55 @@ def __init__(
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)

# Maps cache keys to final values.
self._condition_cache: Dict[str, bool] = {}

def check_conditions(
self, conditions: List[dict], uid: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's conditions/user ID/display name match the event.

Args:
conditions: The user's conditions to match.
uid: The user's MXID.
display_name: The display name.

Returns:
True if all conditions match the event, False otherwise.
"""
for cond in conditions:
_cache_key = cond.get("_cache_key", None)
if _cache_key:
res = self._condition_cache.get(_cache_key, None)
if res is False:
return False
elif res is True:
continue

res = self.matches(cond, uid, display_name)
if _cache_key:
self._condition_cache[_cache_key] = bool(res)

if not res:
return False

return True

def matches(
self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's condition/user ID/display name match the event.

Args:
condition: The user's condition to match.
uid: The user's MXID.
display_name: The display name, or None if there is not one.

Returns:
True if the condition matches the event, False otherwise.
"""
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
Expand All @@ -146,6 +192,16 @@ def matches(
return True

def _event_match(self, condition: dict, user_id: str) -> bool:
"""
Check an "event_match" push rule condition.

Args:
condition: The "event_match" push rule condition to match.
user_id: The user's MXID.

Returns:
True if the condition matches the event, False otherwise.
"""
pattern = condition.get("pattern", None)

if not pattern:
Expand All @@ -167,13 +223,22 @@ def _event_match(self, condition: dict, user_id: str) -> bool:

return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._get_value(condition["key"])
haystack = self._value_cache.get(condition["key"], None)
if haystack is None:
return False

return _glob_matches(pattern, haystack)

def _contains_display_name(self, display_name: Optional[str]) -> bool:
"""
Check an "event_match" push rule condition.

Args:
display_name: The display name, or None if there is not one.

Returns:
True if the display name is found in the event body, False otherwise.
"""
if not display_name:
return False

Expand All @@ -191,9 +256,6 @@ def _contains_display_name(self, display_name: Optional[str]) -> bool:

return bool(r.search(body))

def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)


# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
Expand Down