Skip to content
This repository has been archived by the owner on Apr 12, 2024. It is now read-only.

Commit

Permalink
Add a bulk user info endpoint and deprecate the old one (#46)
Browse files Browse the repository at this point in the history
The current `/user/<user_id>/info` API was useful in that it could be used by any user to lookup whether another user was deactivate or expired. However, it was impractical as it only allowed for a single lookup at once. Clients trying to use this API were met with speed issues as they tried to query this information for all users in a room.

This PR adds an equivalent CS and Federation API that takes a list of user IDs, and returning a mapping from user ID to info dictionary.

Note that the federation in this PR was a bit trickier than in the original #12 as we can no longer use a federation query, as those don't allow for JSON bodies - which we require to pass a list of user IDs. Instead we do the whole thing of adding a method to transport/client and transport/server.

This PR also adds unittests. The earlier PR used Sytest, presumably for testing across federation, but as this is Synapse-specific that felt a little gross. Unit tests for the deprecated endpoint have not been added.
  • Loading branch information
anoadragon453 authored Jun 19, 2020
1 parent 6708163 commit 53949a9
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 38 deletions.
1 change: 1 addition & 0 deletions changelog.d/46.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a bulk version of the User Info API. Deprecate the single-use version.
16 changes: 15 additions & 1 deletion synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from six.moves import urllib

Expand Down Expand Up @@ -1021,6 +1021,20 @@ def get_room_complexity(self, destination, room_id):

return self.client.get_json(destination=destination, path=path)

def get_info_of_users(self, destination: str, user_ids: List[str]):
"""
Args:
destination: The remote server
user_ids: A list of user IDs to query info about
Returns:
Deferred[List]: A dictionary of User ID to information about that user.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
data = {"user_ids": user_ids}

return self.client.post_json(destination=destination, path=path, data=data)


def _create_path(federation_prefix, path, *args):
"""
Expand Down
53 changes: 53 additions & 0 deletions synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
Expand Down Expand Up @@ -849,6 +850,57 @@ async def on_POST(self, origin, content, query):
return 200, data


class FederationUserInfoServlet(BaseFederationServlet):
"""
Return information about a set of users.
This API returns expiration and deactivation information about a set of
users. Requested users not local to this homeserver will be ignored.
Example request:
POST /users/info
{
"user_ids": [
"@alice:example.com",
"@bob:example.com"
]
}
Example response
{
"@alice:example.com": {
"expired": false,
"deactivated": true
}
}
"""

PATH = "/users/info"
PREFIX = FEDERATION_UNSTABLE_PREFIX

def __init__(self, handler, authenticator, ratelimiter, server_name):
super(FederationUserInfoServlet, self).__init__(
handler, authenticator, ratelimiter, server_name
)
self.handler = handler

async def on_POST(self, origin, content, query):
assert_params_in_dict(content, required=["user_ids"])

user_ids = content.get("user_ids", [])

if not isinstance(user_ids, list):
raise SynapseError(
400,
"'user_ids' must be a list of user ID strings",
errcode=Codes.INVALID_PARAM,
)

data = await self.handler.store.get_info_for_users(user_ids)
return 200, data


class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"

Expand Down Expand Up @@ -1410,6 +1462,7 @@ async def on_GET(self, origin, content, query, room_id):
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]

OPENID_SERVLET_CLASSES = (
Expand Down
113 changes: 77 additions & 36 deletions synapse/rest/client/v2_alpha/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# limitations under the License.

import logging
from typing import Dict

from signedjson.sign import sign_json

from twisted.internet import defer

from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.types import UserID

from ._base import client_patterns
Expand Down Expand Up @@ -92,45 +95,43 @@ async def on_POST(self, request):
return 200, results


class UserInfoServlet(RestServlet):
class SingleUserInfoServlet(RestServlet):
"""
Deprecated and replaced by `/users/info`
GET /user/{user_id}/info HTTP/1.1
"""

PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")

def __init__(self, hs):
super(UserInfoServlet, self).__init__()
super(SingleUserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
self.transport_layer = hs.get_federation_transport_client()
registry = hs.get_federation_registry()

if not registry.query_handlers.get("user_info"):
registry.register_query_handler("user_info", self._on_federation_query)

@defer.inlineCallbacks
def on_GET(self, request, user_id):
async def on_GET(self, request, user_id):
# Ensure the user is authenticated
yield self.auth.get_user_by_req(request, allow_guest=False)
await self.auth.get_user_by_req(request)

user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# Attempt to make a federation request to the server that owns this user
args = {"user_id": user_id}
res = yield self.transport_layer.make_query(
res = await self.transport_layer.make_query(
user.domain, "user_info", args, retry_on_dns_fail=True
)
defer.returnValue((200, res))
return 200, res

res = yield self._get_user_info(user_id)
defer.returnValue((200, res))
user_id_to_info = await self.store.get_info_for_users([user_id])
return 200, user_id_to_info[user_id]

@defer.inlineCallbacks
def _on_federation_query(self, args):
async def _on_federation_query(self, args):
"""Called when a request for user information appears over federation
Args:
Expand All @@ -147,32 +148,72 @@ def _on_federation_query(self, args):
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")

res = yield self._get_user_info(user_id)
defer.returnValue(res)
user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
return user_ids_to_info_dict[user_id]

@defer.inlineCallbacks
def _get_user_info(self, user_id):
"""Retrieve information about a given user

Args:
user_id (str): The User ID of a given user on this homeserver
class UserInfoServlet(RestServlet):
"""Bulk version of `/user/{user_id}/info` endpoint
Returns:
Deferred[dict]: Deactivation and expiration information for a given user
"""
# Check whether user is deactivated
is_deactivated = yield self.store.get_user_deactivated_status(user_id)
GET /users/info HTTP/1.1
# Check whether user is expired
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
is_expired = (
expiration_ts is not None and self.clock.time_msec() >= expiration_ts
)
Returns a dictionary of user_id to info dictionary. Supports remote users
"""

PATTERNS = client_patterns("/users/info$", unstable=True, releases=())

def __init__(self, hs):
super(UserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.transport_layer = hs.get_federation_transport_client()

async def on_POST(self, request):
# Ensure the user is authenticated
await self.auth.get_user_by_req(request)

# Extract the user_ids from the request
body = parse_json_object_from_request(request)
assert_params_in_dict(body, required=["user_ids"])

user_ids = body["user_ids"]
if not isinstance(user_ids, list):
raise SynapseError(
400,
"'user_ids' must be a list of user ID strings",
errcode=Codes.INVALID_PARAM,
)

# Separate local and remote users
local_user_ids = set()
remote_server_to_user_ids = {} # type: Dict[str, set]
for user_id in user_ids:
user = UserID.from_string(user_id)

if self.hs.is_mine(user):
local_user_ids.add(user_id)
else:
remote_server_to_user_ids.setdefault(user.domain, set())
remote_server_to_user_ids[user.domain].add(user_id)

# Retrieve info of all local users
user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)

# Request info of each remote user from their remote homeserver
for server_name, user_id_set in remote_server_to_user_ids.items():
# Make a request to the given server about their own users
res = await self.transport_layer.get_info_of_users(
server_name, list(user_id_set)
)

for user_id, info in res:
user_id_to_info_dict[user_id] = info

res = {"expired": is_expired, "deactivated": is_deactivated}
defer.returnValue(res)
return 200, user_id_to_info_dict


def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
SingleUserInfoServlet(hs).register(http_server)
UserInfoServlet(hs).register(http_server)
50 changes: 50 additions & 0 deletions synapse/storage/data_stores/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import re
from typing import List

from six import iterkeys

Expand Down Expand Up @@ -304,6 +305,55 @@ def delete_account_validity_for_user(self, user_id):
desc="delete_account_validity_for_user",
)

@defer.inlineCallbacks
def get_info_for_users(
self, user_ids: List[str],
):
"""Return the user info for a given set of users
Args:
user_ids: A list of users to return information about
Returns:
Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
a dict with the following keys:
* expired - whether this is an expired user
* deactivated - whether this is a deactivated user
"""
# Get information of all our local users
def _get_info_for_users_txn(txn):
rows = []

for user_id in user_ids:
sql = """
SELECT u.name, u.deactivated, av.expiration_ts_ms
FROM users as u
LEFT JOIN account_validity as av
ON av.user_id = u.name
WHERE u.name = ?
"""

txn.execute(sql, (user_id,))
row = txn.fetchone()
if row:
rows.append(row)

return rows

info_rows = yield self.db.runInteraction(
"get_info_for_users", _get_info_for_users_txn
)

return {
user_id: {
"expired": (
expiration is not None and self.clock.time_msec() >= expiration
),
"deactivated": deactivated == 1,
}
for user_id, deactivated, expiration in info_rows
}

async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
Expand Down
Loading

0 comments on commit 53949a9

Please sign in to comment.