Skip to content
This repository has been archived by the owner on Apr 12, 2024. It is now read-only.

Add a bulk user info endpoint and deprecate the old one #46

Merged
merged 15 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/46.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a bulk version of the User Info API. Deprecate the single-use version.
16 changes: 15 additions & 1 deletion synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from six.moves import urllib

Expand Down Expand Up @@ -1021,6 +1021,20 @@ def get_room_complexity(self, destination, room_id):

return self.client.get_json(destination=destination, path=path)

def get_info_of_users(self, destination: str, user_ids: List[str]):
"""
Args:
destination: The remote server
user_ids: A list of user IDs to query info about

Returns:
Deferred[List]: A dictionary of User ID to information about that user.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
data = {"user_ids": user_ids}

return self.client.post_json(destination=destination, path=path, data=data)


def _create_path(federation_prefix, path, *args):
"""
Expand Down
53 changes: 53 additions & 0 deletions synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
Expand Down Expand Up @@ -849,6 +850,57 @@ async def on_POST(self, origin, content, query):
return 200, data


class FederationUserInfoServlet(BaseFederationServlet):
"""
Return information about a set of users.

This API returns expiration and deactivation information about a set of
users. Requested users not local to this homeserver will be ignored.

Example request:
POST /users/info

{
"user_ids": [
"@alice:example.com",
"@bob:example.com"
]
}

Example response
{
"@alice:example.com": {
"expired": false,
"deactivated": true
}
}
"""

PATH = "/users/info"
PREFIX = FEDERATION_UNSTABLE_PREFIX

def __init__(self, handler, authenticator, ratelimiter, server_name):
super(FederationUserInfoServlet, self).__init__(
handler, authenticator, ratelimiter, server_name
)
self.handler = handler

async def on_POST(self, origin, content, query):
assert_params_in_dict(content, required=["user_ids"])

user_ids = content.get("user_ids", [])

if not isinstance(user_ids, list):
raise SynapseError(
400,
"'user_ids' must be a list of user ID strings",
errcode=Codes.INVALID_PARAM,
)

data = await self.handler.store.get_info_for_users(user_ids)
return 200, data


class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"

Expand Down Expand Up @@ -1410,6 +1462,7 @@ async def on_GET(self, origin, content, query, room_id):
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]

OPENID_SERVLET_CLASSES = (
Expand Down
113 changes: 77 additions & 36 deletions synapse/rest/client/v2_alpha/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# limitations under the License.

import logging
from typing import Dict

from signedjson.sign import sign_json

from twisted.internet import defer

from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.types import UserID

from ._base import client_patterns
Expand Down Expand Up @@ -92,45 +95,43 @@ async def on_POST(self, request):
return 200, results


class UserInfoServlet(RestServlet):
class SingleUserInfoServlet(RestServlet):
"""
Deprecated and replaced by `/users/info`

GET /user/{user_id}/info HTTP/1.1
"""

PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")

def __init__(self, hs):
super(UserInfoServlet, self).__init__()
super(SingleUserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
self.transport_layer = hs.get_federation_transport_client()
registry = hs.get_federation_registry()

if not registry.query_handlers.get("user_info"):
registry.register_query_handler("user_info", self._on_federation_query)

@defer.inlineCallbacks
def on_GET(self, request, user_id):
async def on_GET(self, request, user_id):
# Ensure the user is authenticated
yield self.auth.get_user_by_req(request, allow_guest=False)
await self.auth.get_user_by_req(request)

user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# Attempt to make a federation request to the server that owns this user
args = {"user_id": user_id}
res = yield self.transport_layer.make_query(
res = await self.transport_layer.make_query(
user.domain, "user_info", args, retry_on_dns_fail=True
)
defer.returnValue((200, res))
return 200, res

res = yield self._get_user_info(user_id)
defer.returnValue((200, res))
user_id_to_info = await self.store.get_info_for_users([user_id])
return 200, user_id_to_info[user_id]

@defer.inlineCallbacks
def _on_federation_query(self, args):
async def _on_federation_query(self, args):
"""Called when a request for user information appears over federation

Args:
Expand All @@ -147,32 +148,72 @@ def _on_federation_query(self, args):
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")

res = yield self._get_user_info(user_id)
defer.returnValue(res)
user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
return user_ids_to_info_dict[user_id]

@defer.inlineCallbacks
def _get_user_info(self, user_id):
"""Retrieve information about a given user

Args:
user_id (str): The User ID of a given user on this homeserver
class UserInfoServlet(RestServlet):
"""Bulk version of `/user/{user_id}/info` endpoint

Returns:
Deferred[dict]: Deactivation and expiration information for a given user
"""
# Check whether user is deactivated
is_deactivated = yield self.store.get_user_deactivated_status(user_id)
GET /users/info HTTP/1.1

# Check whether user is expired
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
is_expired = (
expiration_ts is not None and self.clock.time_msec() >= expiration_ts
)
Returns a dictionary of user_id to info dictionary. Supports remote users
"""

PATTERNS = client_patterns("/users/info$", unstable=True, releases=())

def __init__(self, hs):
super(UserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.transport_layer = hs.get_federation_transport_client()

async def on_POST(self, request):
# Ensure the user is authenticated
await self.auth.get_user_by_req(request)

# Extract the user_ids from the request
body = parse_json_object_from_request(request)
assert_params_in_dict(body, required=["user_ids"])

user_ids = body["user_ids"]
babolivier marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(user_ids, list):
raise SynapseError(
400,
"'user_ids' must be a list of user ID strings",
errcode=Codes.INVALID_PARAM,
)

# Separate local and remote users
local_user_ids = set()
remote_server_to_user_ids = {} # type: Dict[str, set]
for user_id in user_ids:
user = UserID.from_string(user_id)

if self.hs.is_mine(user):
local_user_ids.add(user_id)
else:
remote_server_to_user_ids.setdefault(user.domain, set())
remote_server_to_user_ids[user.domain].add(user_id)

# Retrieve info of all local users
user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)

# Request info of each remote user from their remote homeserver
for server_name, user_id_set in remote_server_to_user_ids.items():
# Make a request to the given server about their own users
res = await self.transport_layer.get_info_of_users(
server_name, list(user_id_set)
)

for user_id, info in res:
user_id_to_info_dict[user_id] = info

res = {"expired": is_expired, "deactivated": is_deactivated}
defer.returnValue(res)
return 200, user_id_to_info_dict


def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
SingleUserInfoServlet(hs).register(http_server)
UserInfoServlet(hs).register(http_server)
50 changes: 50 additions & 0 deletions synapse/storage/data_stores/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import re
from typing import List

from six import iterkeys

Expand Down Expand Up @@ -295,6 +296,55 @@ def delete_account_validity_for_user(self, user_id):
desc="delete_account_validity_for_user",
)

@defer.inlineCallbacks
def get_info_for_users(
self, user_ids: List[str],
):
"""Return the user info for a given set of users

Args:
user_ids: A list of users to return information about

Returns:
Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
a dict with the following keys:
* expired - whether this is an expired user
* deactivated - whether this is a deactivated user
"""
# Get information of all our local users
def _get_info_for_users_txn(txn):
rows = []

for user_id in user_ids:
sql = """
SELECT u.name, u.deactivated, av.expiration_ts_ms
FROM users as u
LEFT JOIN account_validity as av
ON av.user_id = u.name
WHERE u.name = ?
"""

txn.execute(sql, (user_id,))
row = txn.fetchone()
if row:
rows.append(row)

return rows

info_rows = yield self.db.runInteraction(
"get_info_for_users", _get_info_for_users_txn
)

return {
user_id: {
"expired": (
expiration is not None and self.clock.time_msec() >= expiration
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
),
"deactivated": deactivated == 1,
}
for user_id, deactivated, expiration in info_rows
}

async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.

Expand Down
Loading