Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use state before join to determine if we _should_perform_remote_join #13270

Merged
merged 7 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13270.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.40 where a user invited to a restricted room would be briefly unable to join.
2 changes: 1 addition & 1 deletion synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def build(
The signed and hashed event.
"""
if auth_event_ids is None:
state_ids = await self._state.get_current_state_ids(
state_ids = await self._state.compute_state_after_events(
self.room_id, prev_event_ids
)
auth_event_ids = self._event_auth_handler.compute_auth_events(
Expand Down
35 changes: 20 additions & 15 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,14 +755,14 @@ async def update_membership_locked(

latest_event_ids = await self.store.get_prev_events_for_room(room_id)

current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
state_before_join = await self.state_handler.compute_state_after_events(
room_id, latest_event_ids
)

# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
Expand Down Expand Up @@ -813,11 +813,11 @@ async def update_membership_locked(
if action == "kick":
raise AuthError(403, "The target user is not in the room")

is_host_in_room = await self._is_host_in_room(current_state_ids)
is_host_in_room = await self._is_host_in_room(state_before_join)

if effective_membership_state == Membership.JOIN:
if requester.is_guest:
guest_can_join = await self._can_guest_join(current_state_ids)
guest_can_join = await self._can_guest_join(state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
Expand Down Expand Up @@ -855,7 +855,12 @@ async def update_membership_locked(

# Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
target.to_string(),
room_id,
remote_room_hosts,
content,
is_host_in_room,
state_before_join,
)
if remote_join:
if ratelimit:
Expand Down Expand Up @@ -995,6 +1000,7 @@ async def _should_perform_remote_join(
remote_room_hosts: List[str],
content: JsonDict,
is_host_in_room: bool,
state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
Expand All @@ -1014,6 +1020,8 @@ async def _should_perform_remote_join(
content: The content to use as the event body of the join. This may
be modified.
is_host_in_room: True if the host is in the room.
state_before_join: The state before the join event (i.e. the resolution of
the states after its parent events).

Returns:
A tuple of:
Expand All @@ -1030,20 +1038,17 @@ async def _should_perform_remote_join(
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id
)

# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
current_state_ids, room_version
state_before_join, room_version
):
return False, []

# If the user is invited to the room or already joined, the join
# event can always be issued locally.
prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
Expand All @@ -1058,10 +1063,10 @@ async def _should_perform_remote_join(
#
# If not, generate a new list of remote hosts based on which
# can issue invites.
event_map = await self.store.get_events(current_state_ids.values())
event_map = await self.store.get_events(state_before_join.values())
current_state = {
state_key: event_map[event_id]
for state_key, event_id in current_state_ids.items()
for state_key, event_id in state_before_join.items()
}
allowed_servers = get_servers_from_users(
get_users_which_can_issue_invite(current_state)
Expand All @@ -1075,7 +1080,7 @@ async def _should_perform_remote_join(

# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
current_state_ids, room_version, user_id, prev_member_event
state_before_join, room_version, user_id, prev_member_event
)

# If this is going to be a local join, additional information must
Expand All @@ -1085,7 +1090,7 @@ async def _should_perform_remote_join(
EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
current_state_ids,
state_before_join,
)

return False, []
Expand Down
21 changes: 13 additions & 8 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,27 @@ def __init__(self, hs: "HomeServer"):
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
)

async def get_current_state_ids(
async def compute_state_after_events(
self,
room_id: str,
latest_event_ids: Collection[str],
event_ids: Collection[str],
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
"""Fetch the state after each of the given event IDs. Resolve them and return.

This is typically used where `event_ids` is a collection of forward extremities
in a room, intended to become the `prev_events` of a new event E. If so, the
return value of this function represents the state before E.

Args:
room_id:
latest_event_ids: The forward extremities to resolve.
room_id: the room_id containing the given events.
event_ids: the events whose state should be fetched and resolved.

Returns:
the state dict, mapping from (event_type, state_key) -> event_id
the state dict (a mapping from (event_type, state_key) -> event_id) which
holds the resolution of the states after the given event IDs.
"""
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
logger.debug("calling resolve_state_groups from compute_state_after_events")
ret = await self.resolve_state_groups_for_events(room_id, event_ids)
return await ret.get_state(self._state_storage_controller, StateFilter.all())

async def get_current_users_in_room(
Expand Down