Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Support relation_match in push rules.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed May 16, 2022
1 parent 250fcd5 commit faf6afb
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 3 deletions.
4 changes: 4 additions & 0 deletions synapse/push/clientformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def format_push_rules_for_user(
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart

sender_type = c.pop("sender_type", None)
if sender_type == "user_id":
c["sender"] = user.to_string()

rulearray = rules["global"][template_name]

template_rule = _rule_to_template(r)
Expand Down
39 changes: 39 additions & 0 deletions synapse/push/push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def matches(
return _sender_notification_permission(
self._event, condition, self._sender_power_level, self._power_levels
)
elif condition["kind"] == "org.matrix.msc3772.relation_match":
return self._relation_match(condition, user_id)
else:
return True

Expand Down Expand Up @@ -258,6 +260,43 @@ def _contains_display_name(self, display_name: Optional[str]) -> bool:

return bool(r.search(body))

def _relation_match(self, condition: dict, user_id: str) -> bool:
"""
Check an "relation_match" push rule condition.
Args:
condition: The "event_match" push rule condition to match.
user_id: The user's MXID.
Returns:
True if the condition matches the event, False otherwise.
"""
rel_type_pattern = condition.get("rel_type")
sender_pattern = condition.get("sender")
if sender_pattern is None:
sender_type = condition.get("sender_type")
if sender_type == "user_id":
sender_pattern = user_id
type_pattern = condition.get("type")

if not rel_type_pattern and not sender_pattern and not type_pattern:
logger.warning("relation_match condition with nothing to match")
return False

# If any other relations matches, return True.
for relation in self._relations:
if rel_type_pattern and not _glob_matches(rel_type_pattern, relation[0]):
continue
if sender_pattern and not _glob_matches(sender_pattern, relation[1]):
continue
if type_pattern and not _glob_matches(type_pattern, relation[2]):
continue
# All values must have matched.
return True

# No relations matched.
return False


# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
Expand Down
63 changes: 60 additions & 3 deletions tests/push/test_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Union
from typing import Dict, Optional, Set, Tuple, Union

import frozendict

Expand All @@ -26,7 +26,9 @@


class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
def _get_evaluator(
self, content: JsonDict, relations: Optional[Set[Tuple[str, str, str]]] = None
) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
Expand All @@ -42,7 +44,11 @@ def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
sender_power_level = 0
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels, set()
event,
room_member_count,
sender_power_level,
power_levels,
relations or set(),
)

def test_display_name(self) -> None:
Expand Down Expand Up @@ -276,3 +282,54 @@ def test_tweaks_for_actions(self) -> None:
push_rule_evaluator.tweaks_for_actions(actions),
{"sound": "default", "highlight": True},
)

def test_relation_match(self) -> None:
"""Test the relation_match push rule kind."""
evaluator = self._get_evaluator(
{}, {("m.annotation", "@user:test", "m.reaction")}
)

# Check just relation type.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))

# Check relation type and sender.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"sender": "@user:test",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"sender": "@other:test",
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))

# Check relation type and event type.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"type": "m.reaction",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))

# Check just sender.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"sender": "@user:test",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
condition = {
"kind": "org.matrix.msc3772.relation_match",
"sender": "@other:test",
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))

# Check glob.
condition = {"kind": "org.matrix.msc3772.relation_match", "sender": "@*:test"}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))

0 comments on commit faf6afb

Please sign in to comment.