diff --git a/changelog.d/9381.misc b/changelog.d/9381.misc new file mode 100644 index 000000000000..5688166120b9 --- /dev/null +++ b/changelog.d/9381.misc @@ -0,0 +1 @@ +Update the version of black used to 20.8b1. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index ab1e1f1f4c95..67e032244ecc 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -92,7 +92,7 @@ def _domain(self): return self.config["user"].split(":")[1] def do_config(self, line): - """ Show the config for this client: "config" + """Show the config for this client: "config" Edit a key value mapping: "config key value" e.g. "config token 1234" Config variables: user: The username to auth with. @@ -360,7 +360,7 @@ def do_joinalias(self, line): print(e) def do_topic(self, line): - """"topic [set|get] []" + """ "topic [set|get] []" Set the topic for a room: topic set Get the topic for a room: topic get """ @@ -690,7 +690,7 @@ def do_online(self, line): self._do_presence_state(2, line) def _parse(self, line, keys, force_keys=False): - """ Parses the given line. + """Parses the given line. Args: line : The line to parse @@ -721,7 +721,7 @@ def _run_and_pprint( query_params={"access_token": None}, alt_text=None, ): - """ Runs an HTTP request and pretty prints the output. + """Runs an HTTP request and pretty prints the output. Args: method: HTTP method diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 345120b61267..851e80c25bb4 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -23,11 +23,10 @@ class HttpClient: - """ Interface for talking json over http - """ + """Interface for talking json over http""" def put_json(self, url, data): - """ Sends the specifed json data using PUT + """Sends the specifed json data using PUT Args: url (str): The URL to PUT data to. @@ -41,7 +40,7 @@ def put_json(self, url, data): pass def get_json(self, url, args=None): - """ Gets some json from the given host homeserver and path + """Gets some json from the given host homeserver and path Args: url (str): The URL to GET data from. @@ -58,7 +57,7 @@ def get_json(self, url, args=None): class TwistedHttpClient(HttpClient): - """ Wrapper around the twisted HTTP client api. + """Wrapper around the twisted HTTP client api. Attributes: agent (twisted.web.client.Agent): The twisted Agent used to send the @@ -87,8 +86,7 @@ def get_json(self, url, args=None): defer.returnValue(json.loads(body)) def _create_put_request(self, url, json_data, headers_dict={}): - """ Wrapper of _create_request to issue a PUT request - """ + """Wrapper of _create_request to issue a PUT request""" if "Content-Type" not in headers_dict: raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) @@ -98,8 +96,7 @@ def _create_put_request(self, url, json_data, headers_dict={}): ) def _create_get_request(self, url, headers_dict={}): - """ Wrapper of _create_request to issue a GET request - """ + """Wrapper of _create_request to issue a GET request""" return self._create_request("GET", url, headers_dict=headers_dict) @defer.inlineCallbacks @@ -127,8 +124,7 @@ def do_request( @defer.inlineCallbacks def _create_request(self, method, url, producer=None, headers_dict={}): - """ Creates and sends a request to the given url - """ + """Creates and sends a request to the given url""" headers_dict["User-Agent"] = ["Synapse Cmd Client"] retries_left = 5 @@ -185,8 +181,7 @@ def stopProducing(self): class _JsonProducer: - """ Used by the twisted http client to create the HTTP body from json - """ + """Used by the twisted http client to create the HTTP body from json""" def __init__(self, jsn): self.data = jsn diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py index 15a22c3a0ebb..cff73650e6fe 100644 --- a/contrib/experiments/cursesio.py +++ b/contrib/experiments/cursesio.py @@ -63,8 +63,7 @@ def print_log(self, text): self.redraw() def redraw(self): - """ method for redisplaying lines - based on internal list of lines """ + """method for redisplaying lines based on internal list of lines""" self.stdscr.clear() self.paintStatus(self.statusText) diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index d4c35ff2fc19..7fbc7d8fc6fd 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -56,7 +56,7 @@ def excpetion_errback(failure): class InputOutput: - """ This is responsible for basic I/O so that a user can interact with + """This is responsible for basic I/O so that a user can interact with the example app. """ @@ -68,8 +68,7 @@ def set_home_server(self, server): self.server = server def on_line(self, line): - """ This is where we process commands. - """ + """This is where we process commands.""" try: m = re.match(r"^join (\S+)$", line) @@ -133,7 +132,7 @@ def emit(self, record): class Room: - """ Used to store (in memory) the current membership state of a room, and + """Used to store (in memory) the current membership state of a room, and which home servers we should send PDUs associated with the room to. """ @@ -148,8 +147,7 @@ def __init__(self, room_name): self.have_got_metadata = False def add_participant(self, participant): - """ Someone has joined the room - """ + """Someone has joined the room""" self.participants.add(participant) self.invited.discard(participant) @@ -160,14 +158,13 @@ def add_participant(self, participant): self.oldest_server = server def add_invited(self, invitee): - """ Someone has been invited to the room - """ + """Someone has been invited to the room""" self.invited.add(invitee) self.servers.add(origin_from_ucid(invitee)) class HomeServer(ReplicationHandler): - """ A very basic home server implentation that allows people to join a + """A very basic home server implentation that allows people to join a room and then invite other people. """ @@ -181,8 +178,7 @@ def __init__(self, server_name, replication_layer, output): self.output = output def on_receive_pdu(self, pdu): - """ We just received a PDU - """ + """We just received a PDU""" pdu_type = pdu.pdu_type if pdu_type == "sy.room.message": @@ -199,23 +195,20 @@ def on_receive_pdu(self, pdu): ) def _on_message(self, pdu): - """ We received a message - """ + """We received a message""" self.output.print_line( "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"]) ) def _on_join(self, context, joinee): - """ Someone has joined a room, either a remote user or a local user - """ + """Someone has joined a room, either a remote user or a local user""" room = self._get_or_create_room(context) room.add_participant(joinee) self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED")) def _on_invite(self, origin, context, invitee): - """ Someone has been invited - """ + """Someone has been invited""" room = self._get_or_create_room(context) room.add_invited(invitee) @@ -228,8 +221,7 @@ def _on_invite(self, origin, context, invitee): @defer.inlineCallbacks def send_message(self, room_name, sender, body): - """ Send a message to a room! - """ + """Send a message to a room!""" destinations = yield self.get_servers_for_context(room_name) try: @@ -247,8 +239,7 @@ def send_message(self, room_name, sender, body): @defer.inlineCallbacks def join_room(self, room_name, sender, joinee): - """ Join a room! - """ + """Join a room!""" self._on_join(room_name, joinee) destinations = yield self.get_servers_for_context(room_name) @@ -269,8 +260,7 @@ def join_room(self, room_name, sender, joinee): @defer.inlineCallbacks def invite_to_room(self, room_name, sender, invitee): - """ Invite someone to a room! - """ + """Invite someone to a room!""" self._on_invite(self.server_name, room_name, invitee) destinations = yield self.get_servers_for_context(room_name) diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index b3de468687a6..495fd4e10a91 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -193,15 +193,12 @@ def advertiseSsrcs(self): time.sleep(7) print("SSRC spammer started") while self.running: - ssrcMsg = ( - "%(nick)s" - % { - "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), - "nick": self.userId, - "assrc": self.ssrcs["audio"], - "vssrc": self.ssrcs["video"], - } - ) + ssrcMsg = "%(nick)s" % { + "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), + "nick": self.userId, + "assrc": self.ssrcs["audio"], + "vssrc": self.ssrcs["video"], + } res = self.sendIq(ssrcMsg) print("reply from ssrc announce: ", res) time.sleep(10) diff --git a/docs/code_style.md b/docs/code_style.md index f6c825d7d410..190f8ab2de88 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -8,16 +8,16 @@ errors in code. The necessary tools are detailed below. +First install them with: + + pip install -e ".[lint,mypy]" + - **black** The Synapse codebase uses [black](https://pypi.org/project/black/) as an opinionated code formatter, ensuring all comitted code is properly formatted. - First install `black` with: - - pip install --upgrade black - Have `black` auto-format your code (it shouldn't change any functionality) with: @@ -28,10 +28,6 @@ The necessary tools are detailed below. `flake8` is a code checking tool. We require code to pass `flake8` before being merged into the codebase. - Install `flake8` with: - - pip install --upgrade flake8 flake8-comprehensions - Check all application and test code with: flake8 synapse tests @@ -41,10 +37,6 @@ The necessary tools are detailed below. `isort` ensures imports are nicely formatted, and can suggest and auto-fix issues such as double-importing. - Install `isort` with: - - pip install --upgrade isort - Auto-fix imports with: isort -rc synapse tests diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index f7f18805e49a..18df68305b88 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -87,7 +87,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. signature = signature.copy_modified( - arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, + arg_types=arg_types, + arg_names=arg_names, + arg_kinds=arg_kinds, ) return signature diff --git a/setup.py b/setup.py index 99425d52de8c..08ba4eb764b7 100755 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ def exec_file(path_segments): # We pin black so that our tests don't start failing on new releases. CONDITIONAL_REQUIREMENTS["lint"] = [ "isort==5.7.0", - "black==19.10b0", + "black==20.8b1", "flake8-comprehensions", "flake8", ] diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index 7b9fd079d9b1..0eaef0049860 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -89,12 +89,16 @@ class SortedDict(Dict[_KT, _VT]): def __reduce__( self, ) -> Tuple[ - Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], + Type[SortedDict[_KT, _VT]], + Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], ]: ... def __repr__(self) -> str: ... def _check(self) -> None: ... def islice( - self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + self, + start: Optional[int] = ..., + stop: Optional[int] = ..., + reverse=bool, ) -> Iterator[_KT]: ... def bisect_left(self, value: _KT) -> int: ... def bisect_right(self, value: _KT) -> int: ... diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index 8f6086b3ff38..f80a3a72ce04 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -31,7 +31,9 @@ class SortedList(MutableSequence[_T]): DEFAULT_LOAD_FACTOR: int = ... def __init__( - self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., + self, + iterable: Optional[Iterable[_T]] = ..., + key: Optional[_Key[_T]] = ..., ): ... # NB: currently mypy does not honour return type, see mypy #3307 @overload @@ -76,10 +78,18 @@ class SortedList(MutableSequence[_T]): def __len__(self) -> int: ... def reverse(self) -> None: ... def islice( - self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + self, + start: Optional[int] = ..., + stop: Optional[int] = ..., + reverse=bool, ) -> Iterator[_T]: ... def _islice( - self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, + self, + min_pos: int, + min_idx: int, + max_pos: int, + max_idx: int, + reverse: bool, ) -> Iterator[_T]: ... def irange( self, diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 67ecbd32ffe4..89e62b0e367b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -168,7 +168,7 @@ async def get_user_by_req( rights: str = "access", allow_expired: bool = False, ) -> synapse.types.Requester: - """ Get a registered user's ID. + """Get a registered user's ID. Args: request: An HTTP request with an access_token query parameter. @@ -294,9 +294,12 @@ async def _get_appservice_user_id(self, request): return user_id, app_service async def get_user_by_access_token( - self, token: str, rights: str = "access", allow_expired: bool = False, + self, + token: str, + rights: str = "access", + allow_expired: bool = False, ) -> TokenLookupResult: - """ Validate access token and get user_id from it + """Validate access token and get user_id from it Args: token: The access token to get the user by @@ -489,7 +492,7 @@ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: return service async def is_server_admin(self, user: UserID) -> bool: - """ Check if the given user is a local server admin. + """Check if the given user is a local server admin. Args: user: user to check @@ -500,7 +503,10 @@ async def is_server_admin(self, user: UserID) -> bool: return await self.store.is_server_admin(user) def compute_auth_events( - self, event, current_state_ids: StateMap[str], for_verification: bool = False, + self, + event, + current_state_ids: StateMap[str], + for_verification: bool = False, ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 565a8cd76a59..e6ea95ba33f4 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -128,8 +128,7 @@ class UserTypes: class RelationTypes: - """The types of relations known to this server. - """ + """The types of relations known to this server.""" ANNOTATION = "m.annotation" REPLACE = "m.replace" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index cd6670d0a266..2a789ea3e823 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -390,8 +390,7 @@ def error_dict(self): class LimitExceededError(SynapseError): - """A client has sent too many requests and is being throttled. - """ + """A client has sent too many requests and is being throttled.""" def __init__( self, @@ -408,8 +407,7 @@ def error_dict(self): class RoomKeysVersionError(SynapseError): - """A client has tried to upload to a non-current version of the room_keys store - """ + """A client has tried to upload to a non-current version of the room_keys store""" def __init__(self, current_version: str): """ @@ -426,7 +424,9 @@ class UnsupportedRoomVersionError(SynapseError): def __init__(self, msg: str = "Homeserver does not support this room version"): super().__init__( - code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, + code=400, + msg=msg, + errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -461,8 +461,7 @@ def error_dict(self): class PasswordRefusedError(SynapseError): - """A password has been refused, either during password reset/change or registration. - """ + """A password has been refused, either during password reset/change or registration.""" def __init__( self, @@ -470,7 +469,9 @@ def __init__( errcode: str = Codes.WEAK_PASSWORD, ): super().__init__( - code=400, msg=msg, errcode=errcode, + code=400, + msg=msg, + errcode=errcode, ) @@ -493,7 +494,7 @@ def __init__(self, inner_exception, can_retry): def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): - """ Utility method for constructing an error response for client-server + """Utility method for constructing an error response for client-server interactions. Args: @@ -510,7 +511,7 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): class FederationError(RuntimeError): - """ This class is used to inform remote homeservers about erroneous + """This class is used to inform remote homeservers about erroneous PDUs they sent us. FATAL: The remote server could not interpret the source event. diff --git a/synapse/api/presence.py b/synapse/api/presence.py index 18a462f0eeb7..b9a8e294609e 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py @@ -56,8 +56,7 @@ def copy_and_replace(self, **kwargs): @classmethod def default(cls, user_id): - """Returns a default presence state. - """ + """Returns a default presence state.""" return cls( user_id=user_id, state=PresenceState.OFFLINE, diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9840a9d55b1b..43b1f1e94bdf 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -58,7 +58,7 @@ def register_sighup(func, *args, **kwargs): def start_worker_reactor(appname, config, run_command=reactor.run): - """ Run the reactor in the main process + """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor. Pulls configuration from the 'worker' settings in 'config'. @@ -93,7 +93,7 @@ def start_reactor( logger, run_command=reactor.run, ): - """ Run the reactor in the main process + """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor @@ -313,9 +313,7 @@ def run_sighup(*args, **kwargs): refresh_certificate(hs) # Start the tracer - synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs - ) + synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa # It is now safe to start your Synapse. hs.start_listening(listeners) @@ -370,8 +368,7 @@ def setup_sentry(hs): def setup_sdnotify(hs): - """Adds process state hooks to tell systemd what we are up to. - """ + """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. @@ -405,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): class _LimitedHostnameResolver: - """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. - """ + """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.""" def __init__(self, resolver, max_dns_requests_in_flight): self._resolver = resolver diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 516f2464b4f4..6526acb2f285 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -421,8 +421,7 @@ def get_currently_syncing_users_for_replication(self) -> Iterable[str]: ] async def set_state(self, target_user, state, ignore_status_msg=False): - """Set the presence state of the user. - """ + """Set the presence state of the user.""" presence = state["presence"] valid_presence = ( diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 3944780a4235..0bfc5e445f5f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -166,7 +166,10 @@ async def _matches_user( @cached(num_args=1, cache_context=True) async def matches_user_in_member_list( - self, room_id: str, store: "DataStore", cache_context: _CacheContext, + self, + room_id: str, + store: "DataStore", + cache_context: _CacheContext, ) -> bool: """Check if this service is interested a room based upon it's membership diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 11aee50f7a0d..93c2aabcca6c 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -227,7 +227,9 @@ async def push_bulk( try: await self.put_json( - uri=uri, json_body=body, args={"access_token": service.hs_token}, + uri=uri, + json_body=body, + args={"access_token": service.hs_token}, ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 58291afc2231..366c476f807a 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -68,7 +68,7 @@ class ApplicationServiceScheduler: - """ Public facing API for this module. Does the required DI to tie the + """Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. """ diff --git a/synapse/config/_base.py b/synapse/config/_base.py index a851f8801d7e..97399eb9ba94 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -224,7 +224,9 @@ def read_template(self, filename: str) -> jinja2.Template: return self.read_templates([filename])[0] def read_templates( - self, filenames: List[str], custom_template_directory: Optional[str] = None, + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -264,7 +266,10 @@ def read_templates( # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) + env = jinja2.Environment( + loader=loader, + autoescape=jinja2.select_autoescape(), + ) # Update the environment with our custom filters env.filters.update( @@ -825,8 +830,7 @@ class ShardedWorkerHandlingConfig: instances = attr.ib(type=List[str]) def should_handle(self, instance_name: str, key: str) -> bool: - """Whether this instance is responsible for handling the given key. - """ + """Whether this instance is responsible for handling the given key.""" # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True diff --git a/synapse/config/auth.py b/synapse/config/auth.py index 1f4c090cde89..7fa64b821a1d 100644 --- a/synapse/config/auth.py +++ b/synapse/config/auth.py @@ -18,8 +18,7 @@ class AuthConfig(Config): - """Password and login configuration - """ + """Password and login configuration""" section = "auth" diff --git a/synapse/config/database.py b/synapse/config/database.py index 8a18a9ca2a7b..e7889b9c20a3 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -207,8 +207,7 @@ def add_arguments(parser): ) def get_single_database(self) -> DatabaseConnectionConfig: - """Returns the database if there is only one, useful for e.g. tests - """ + """Returns the database if there is only one, useful for e.g. tests""" if not self.databases: raise Exception("More than one database exists") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d4328c46b9b6..52505ac5d2b5 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -289,7 +289,8 @@ def read_config(self, config, **kwargs): self.email_notif_template_html, self.email_notif_template_text, ) = self.read_templates( - [notif_template_html, notif_template_text], template_dir, + [notif_template_html, notif_template_text], + template_dir, ) self.email_notif_for_new_users = email_config.get( @@ -311,7 +312,8 @@ def read_config(self, config, **kwargs): self.account_validity_template_html, self.account_validity_template_text, ) = self.read_templates( - [expiry_template_html, expiry_template_text], template_dir, + [expiry_template_html, expiry_template_text], + template_dir, ) subjects_config = email_config.get("subjects", {}) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 4df3f93c1cd8..e56cf846f516 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -162,7 +162,10 @@ def add_arguments(parser): ) logging_group.add_argument( - "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS, + "-f", + "--log-file", + dest="log_file", + help=argparse.SUPPRESS, ) def generate_files(self, config, config_dir_path): diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index d081f36fa5d3..a27594befc87 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -355,9 +355,10 @@ def _parse_oidc_config_dict( ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) ump_config.setdefault("config", {}) - (user_mapping_provider_class, user_mapping_provider_config,) = load_module( - ump_config, config_path + ("user_mapping_provider",) - ) + ( + user_mapping_provider_class, + user_mapping_provider_config, + ) = load_module(ump_config, config_path + ("user_mapping_provider",)) # Ensure loaded user mapping module has defined all necessary methods required_methods = [ @@ -372,7 +373,11 @@ def _parse_oidc_config_dict( if missing_methods: raise ConfigError( "Class %s is missing required " - "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), + "methods: %s" + % ( + user_mapping_provider_class, + ", ".join(missing_methods), + ), config_path + ("user_mapping_provider", "module"), ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fcaea8fb93c4..52849c325633 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -52,7 +52,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): - """ Takes a list of dictionaries with "width", "height", and "method" keys + """Takes a list of dictionaries with "width", "height", and "method" keys and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate diff --git a/synapse/config/server.py b/synapse/config/server.py index a635b8a7dc69..6f3325ff81c4 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -52,7 +52,12 @@ def _6to4(network: IPNetwork) -> IPNetwork: hex_network = hex(network.first)[2:] hex_network = ("0" * (8 - len(hex_network))) + hex_network return IPNetwork( - "2002:%s:%s::/%d" % (hex_network[:4], hex_network[4:], 16 + network.prefixlen,) + "2002:%s:%s::/%d" + % ( + hex_network[:4], + hex_network[4:], + 16 + network.prefixlen, + ) ) @@ -254,7 +259,8 @@ def read_config(self, config, **kwargs): # Whether to require sharing a room with a user to retrieve their # profile data self.limit_profile_requests_to_users_who_share_rooms = config.get( - "limit_profile_requests_to_users_who_share_rooms", False, + "limit_profile_requests_to_users_who_share_rooms", + False, ) if "restrict_public_rooms_to_local_users" in config and ( @@ -614,7 +620,9 @@ class LimitRemoteRoomsConfig: if manhole: self.listeners.append( ListenerConfig( - port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + port=manhole, + bind_addresses=["127.0.0.1"], + type="manhole", ) ) @@ -650,7 +658,8 @@ class LimitRemoteRoomsConfig: # and letting the client know which email address is bound to an account and # which one isn't. self.request_token_inhibit_3pid_errors = config.get( - "request_token_inhibit_3pid_errors", False, + "request_token_inhibit_3pid_errors", + False, ) # List of users trialing the new experimental default push rules. This setting is diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 07ba217f89b8..243cc681e88d 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -35,8 +35,7 @@ class SsoAttributeRequirement: class SSOConfig(Config): - """SSO Configuration - """ + """SSO Configuration""" section = "sso" diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f10e33f7b8d8..7a0ca16da8b7 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: @attr.s class InstanceLocationConfig: - """The host and port to talk to an instance via HTTP replication. - """ + """The host and port to talk to an instance via HTTP replication.""" host = attr.ib(type=str) port = attr.ib(type=int) @@ -54,13 +53,19 @@ class WriterLocations: ) typing = attr.ib(default="master", type=str) to_device = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter, + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) account_data = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter, + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) receipts = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter, + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) @@ -107,7 +112,9 @@ def read_config(self, config, **kwargs): if manhole: self.worker_listeners.append( ListenerConfig( - port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + port=manhole, + bind_addresses=["127.0.0.1"], + type="manhole", ) ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 56f8dc9caf9e..91ad5b3d3cf0 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -42,7 +42,7 @@ def check( do_sig_check: bool = True, do_size_check: bool = True, ) -> None: - """ Checks if this event is correctly authed. + """Checks if this event is correctly authed. Args: room_version_obj: the version of the room @@ -423,7 +423,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def check_redaction( - room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], + room_version_obj: RoomVersion, + event: EventBase, + auth_events: StateMap[EventBase], ) -> bool: """Check whether the event sender is allowed to redact the target event. @@ -459,7 +461,9 @@ def check_redaction( def _check_power_levels( - room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], + room_version_obj: RoomVersion, + event: EventBase, + auth_events: StateMap[EventBase], ) -> None: user_list = event.content.get("users", {}) # Validate users diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 07df258e6eed..c1c0426f6ea0 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -98,7 +98,9 @@ def is_state(self): return self._state_key is not None async def build( - self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]], + self, + prev_event_ids: List[str], + auth_event_ids: Optional[List[str]], ) -> EventBase: """Transform into a fully signed and hashed event diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index afecafe15c3e..7295df74fed6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -341,8 +341,7 @@ def _encode_state_dict(state_dict): def _decode_state_dict(input): - """Decodes a state dict encoded using `_encode_state_dict` above - """ + """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: return None diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 77fbd3f68a59..02bce8b5c914 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -40,7 +40,8 @@ def __init__(self, hs): if module is not None: self.third_party_rules = module( - config=config, module_api=hs.get_module_api(), + config=config, + module_api=hs.get_module_api(), ) async def check_event_allowed( diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9c22e3381378..7ca5c9940a3a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -34,7 +34,7 @@ def prune_event(event: EventBase) -> EventBase: - """ Returns a pruned version of the given event, which removes all keys we + """Returns a pruned version of the given event, which removes all keys we don't know about or think could potentially be dodgy. This is used when we "redact" an event. We want to remove all fields that diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 40e14512017a..bee81fc01946 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -750,7 +750,11 @@ async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: return resp[1] async def send_invite( - self, destination: str, room_id: str, event_id: str, pdu: EventBase, + self, + destination: str, + room_id: str, + event_id: str, + pdu: EventBase, ) -> EventBase: room_version = await self.store.get_room_version(room_id) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 171d25c9454a..8d4bb621e739 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -85,7 +85,8 @@ ) pdu_process_time = Histogram( - "synapse_federation_server_pdu_process_time", "Time taken to process an event", + "synapse_federation_server_pdu_process_time", + "Time taken to process an event", ) @@ -204,7 +205,7 @@ async def _on_incoming_transaction_inner( async def _handle_incoming_transaction( self, origin: str, transaction: Transaction, request_time: int ) -> Tuple[int, Dict[str, Any]]: - """ Process an incoming transaction and return the HTTP response + """Process an incoming transaction and return the HTTP response Args: origin: the server making the request @@ -373,8 +374,7 @@ async def process_pdus_for_room(room_id: str): return pdu_results async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): - """Process the EDUs in a received transaction. - """ + """Process the EDUs in a received transaction.""" async def _process_edu(edu_dict): received_edus_counter.inc() @@ -437,7 +437,10 @@ async def on_state_ids_request( raise AuthError(403, "Host not in room.") resp = await self._state_ids_resp_cache.wrap( - (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, + (room_id, event_id), + self._on_state_ids_request_compute, + room_id, + event_id, ) return 200, resp @@ -679,7 +682,7 @@ def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: ) async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: - """ Process a PDU received in a federation /send/ transaction. + """Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. (The error will then be logged and sent back to the sender (which @@ -906,13 +909,11 @@ def register_query_handler( self.query_handlers[query_type] = handler def register_instance_for_edu(self, edu_type: str, instance_name: str): - """Register that the EDU handler is on a different instance than master. - """ + """Register that the EDU handler is on a different instance than master.""" self._edu_type_to_instance[edu_type] = [instance_name] def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): - """Register that the EDU handler is on multiple instances. - """ + """Register that the EDU handler is on multiple instances.""" self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 079e2b2fe0ad..ce5fc758f0e6 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -30,8 +30,7 @@ class TransactionActions: - """ Defines persistence actions that relate to handling Transactions. - """ + """Defines persistence actions that relate to handling Transactions.""" def __init__(self, datastore): self.store = datastore @@ -57,8 +56,7 @@ async def have_responded( async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """Persist how we responded to a transaction. - """ + """Persist how we responded to a transaction.""" transaction_id = transaction.transaction_id # type: ignore if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 5f1bf492c1d0..3e993b428b71 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -468,8 +468,7 @@ def add_to_buffer(self, buff): class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu - """Streams EDUs that don't have keys. See KeyedEduRow - """ + """Streams EDUs that don't have keys. See KeyedEduRow""" TypeId = "e" @@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows): # them into the appropriate collection and then send them off. buff = ParsedFederationStreamData( - presence=[], presence_destinations=[], keyed_edus={}, edus={}, + presence=[], + presence_destinations=[], + keyed_edus={}, + edus={}, ) # Parse the rows in the stream and add to the buffer diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 643b26ae6d2d..97fc4d0a82b4 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -328,7 +328,9 @@ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # to allow us to perform catch-up later on if the remote is unreachable # for a while. await self.store.store_destination_rooms_entries( - destinations, pdu.room_id, pdu.internal_metadata.stream_ordering, + destinations, + pdu.room_id, + pdu.internal_metadata.stream_ordering, ) for destination in destinations: @@ -475,7 +477,7 @@ def send_presence_to_destinations( self, states: List[UserPresenceState], destinations: List[str] ) -> None: """Send the given presence states to the given destinations. - destinations (list[str]) + destinations (list[str]) """ if not states or not self.hs.config.use_presence: @@ -616,8 +618,8 @@ async def _wake_destinations_needing_catchup(self): last_processed = None # type: Optional[str] while True: - destinations_to_wake = await self.store.get_catch_up_outstanding_destinations( - last_processed + destinations_to_wake = ( + await self.store.get_catch_up_outstanding_destinations(last_processed) ) if not destinations_to_wake: diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index db8e456fe8dc..deb519f3efb6 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -85,7 +85,8 @@ def __init__( # processing. We have a guard in `attempt_new_transaction` that # ensure we don't start sending stuff. logger.error( - "Create a per destination queue for %s on wrong worker", destination, + "Create a per destination queue for %s on wrong worker", + destination, ) self._should_send_on_this_instance = False @@ -440,8 +441,10 @@ async def _catch_up_transmission_loop(self) -> None: if first_catch_up_check: # first catchup so get last_successful_stream_ordering from database - self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering( - self._destination + self._last_successful_stream_ordering = ( + await self._store.get_destination_last_successful_stream_ordering( + self._destination + ) ) if self._last_successful_stream_ordering is None: @@ -457,7 +460,8 @@ async def _catch_up_transmission_loop(self) -> None: # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( - self._destination, self._last_successful_stream_ordering, + self._destination, + self._last_successful_stream_ordering, ) if not event_ids: diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 3e07f925e00b..763aff296c87 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -65,7 +65,10 @@ def __init__(self, hs: "synapse.server.HomeServer"): @measure_func("_send_new_transaction") async def send_new_transaction( - self, destination: str, pdus: List[EventBase], edus: List[Edu], + self, + destination: str, + pdus: List[EventBase], + edus: List[Edu], ) -> bool: """ Args: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index abe9168c7866..10c4747f9749 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -39,7 +39,7 @@ def __init__(self, hs): @log_function def get_room_state_ids(self, destination, room_id, event_id): - """ Requests all state for a given room from the given server at the + """Requests all state for a given room from the given server at the given event. Returns the state's event_id's Args: @@ -63,7 +63,7 @@ def get_room_state_ids(self, destination, room_id, event_id): @log_function def get_event(self, destination, event_id, timeout=None): - """ Requests the pdu with give id and origin from the given server. + """Requests the pdu with give id and origin from the given server. Args: destination (str): The host name of the remote homeserver we want @@ -84,7 +84,7 @@ def get_event(self, destination, event_id, timeout=None): @log_function def backfill(self, destination, room_id, event_tuples, limit): - """ Requests `limit` previous PDUs in a given context before list of + """Requests `limit` previous PDUs in a given context before list of PDUs. Args: @@ -118,7 +118,7 @@ def backfill(self, destination, room_id, event_tuples, limit): @log_function async def send_transaction(self, transaction, json_data_callback=None): - """ Sends the given Transaction to its destination + """Sends the given Transaction to its destination Args: transaction (Transaction) @@ -551,8 +551,7 @@ async def get_missing_events( @log_function def get_group_profile(self, destination, group_id, requester_user_id): - """Get a group profile - """ + """Get a group profile""" path = _create_v1_path("/groups/%s/profile", group_id) return self.client.get_json( @@ -584,8 +583,7 @@ def update_group_profile(self, destination, group_id, requester_user_id, content @log_function def get_group_summary(self, destination, group_id, requester_user_id): - """Get a group summary - """ + """Get a group summary""" path = _create_v1_path("/groups/%s/summary", group_id) return self.client.get_json( @@ -597,8 +595,7 @@ def get_group_summary(self, destination, group_id, requester_user_id): @log_function def get_rooms_in_group(self, destination, group_id, requester_user_id): - """Get all rooms in a group - """ + """Get all rooms in a group""" path = _create_v1_path("/groups/%s/rooms", group_id) return self.client.get_json( @@ -611,8 +608,7 @@ def get_rooms_in_group(self, destination, group_id, requester_user_id): def add_room_to_group( self, destination, group_id, requester_user_id, room_id, content ): - """Add a room to a group - """ + """Add a room to a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.post_json( @@ -626,8 +622,7 @@ def add_room_to_group( def update_room_in_group( self, destination, group_id, requester_user_id, room_id, config_key, content ): - """Update room in group - """ + """Update room in group""" path = _create_v1_path( "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) @@ -641,8 +636,7 @@ def update_room_in_group( ) def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): - """Remove a room from a group - """ + """Remove a room from a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.delete_json( @@ -654,8 +648,7 @@ def remove_room_from_group(self, destination, group_id, requester_user_id, room_ @log_function def get_users_in_group(self, destination, group_id, requester_user_id): - """Get users in a group - """ + """Get users in a group""" path = _create_v1_path("/groups/%s/users", group_id) return self.client.get_json( @@ -667,8 +660,7 @@ def get_users_in_group(self, destination, group_id, requester_user_id): @log_function def get_invited_users_in_group(self, destination, group_id, requester_user_id): - """Get users that have been invited to a group - """ + """Get users that have been invited to a group""" path = _create_v1_path("/groups/%s/invited_users", group_id) return self.client.get_json( @@ -680,8 +672,7 @@ def get_invited_users_in_group(self, destination, group_id, requester_user_id): @log_function def accept_group_invite(self, destination, group_id, user_id, content): - """Accept a group invite - """ + """Accept a group invite""" path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) return self.client.post_json( @@ -690,8 +681,7 @@ def accept_group_invite(self, destination, group_id, user_id, content): @log_function def join_group(self, destination, group_id, user_id, content): - """Attempts to join a group - """ + """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( @@ -702,8 +692,7 @@ def join_group(self, destination, group_id, user_id, content): def invite_to_group( self, destination, group_id, user_id, requester_user_id, content ): - """Invite a user to a group - """ + """Invite a user to a group""" path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) return self.client.post_json( @@ -730,8 +719,7 @@ def invite_to_group_notification(self, destination, group_id, user_id, content): def remove_user_from_group( self, destination, group_id, requester_user_id, user_id, content ): - """Remove a user from a group - """ + """Remove a user from a group""" path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) return self.client.post_json( @@ -772,8 +760,7 @@ def renew_group_attestation(self, destination, group_id, user_id, content): def update_group_summary_room( self, destination, group_id, user_id, room_id, category_id, content ): - """Update a room entry in a group summary - """ + """Update a room entry in a group summary""" if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", @@ -796,8 +783,7 @@ def update_group_summary_room( def delete_group_summary_room( self, destination, group_id, user_id, room_id, category_id ): - """Delete a room entry in a group summary - """ + """Delete a room entry in a group summary""" if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", @@ -817,8 +803,7 @@ def delete_group_summary_room( @log_function def get_group_categories(self, destination, group_id, requester_user_id): - """Get all categories in a group - """ + """Get all categories in a group""" path = _create_v1_path("/groups/%s/categories", group_id) return self.client.get_json( @@ -830,8 +815,7 @@ def get_group_categories(self, destination, group_id, requester_user_id): @log_function def get_group_category(self, destination, group_id, requester_user_id, category_id): - """Get category info in a group - """ + """Get category info in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.get_json( @@ -845,8 +829,7 @@ def get_group_category(self, destination, group_id, requester_user_id, category_ def update_group_category( self, destination, group_id, requester_user_id, category_id, content ): - """Update a category in a group - """ + """Update a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.post_json( @@ -861,8 +844,7 @@ def update_group_category( def delete_group_category( self, destination, group_id, requester_user_id, category_id ): - """Delete a category in a group - """ + """Delete a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.delete_json( @@ -874,8 +856,7 @@ def delete_group_category( @log_function def get_group_roles(self, destination, group_id, requester_user_id): - """Get all roles in a group - """ + """Get all roles in a group""" path = _create_v1_path("/groups/%s/roles", group_id) return self.client.get_json( @@ -887,8 +868,7 @@ def get_group_roles(self, destination, group_id, requester_user_id): @log_function def get_group_role(self, destination, group_id, requester_user_id, role_id): - """Get a roles info - """ + """Get a roles info""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.get_json( @@ -902,8 +882,7 @@ def get_group_role(self, destination, group_id, requester_user_id, role_id): def update_group_role( self, destination, group_id, requester_user_id, role_id, content ): - """Update a role in a group - """ + """Update a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.post_json( @@ -916,8 +895,7 @@ def update_group_role( @log_function def delete_group_role(self, destination, group_id, requester_user_id, role_id): - """Delete a role in a group - """ + """Delete a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.delete_json( @@ -931,8 +909,7 @@ def delete_group_role(self, destination, group_id, requester_user_id, role_id): def update_group_summary_user( self, destination, group_id, requester_user_id, user_id, role_id, content ): - """Update a users entry in a group - """ + """Update a users entry in a group""" if role_id: path = _create_v1_path( "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id @@ -950,8 +927,7 @@ def update_group_summary_user( @log_function def set_group_join_policy(self, destination, group_id, requester_user_id, content): - """Sets the join policy for a group - """ + """Sets the join policy for a group""" path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) return self.client.put_json( @@ -966,8 +942,7 @@ def set_group_join_policy(self, destination, group_id, requester_user_id, conten def delete_group_summary_user( self, destination, group_id, requester_user_id, user_id, role_id ): - """Delete a users entry in a group - """ + """Delete a users entry in a group""" if role_id: path = _create_v1_path( "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id @@ -983,8 +958,7 @@ def delete_group_summary_user( ) def bulk_get_publicised_groups(self, destination, user_ids): - """Get the groups a list of users are publicising - """ + """Get the groups a list of users are publicising""" path = _create_v1_path("/get_groups_publicised") diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 95c64510a927..0b30efe9935d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -364,7 +364,10 @@ def register(self, server): continue server.register_paths( - method, (pattern,), self._wrap(code), self.__class__.__name__, + method, + (pattern,), + self._wrap(code), + self.__class__.__name__, ) @@ -381,7 +384,7 @@ def __init__(self, handler, server_name, **kwargs): # This is when someone is trying to send us a bunch of data. async def on_PUT(self, origin, content, query, transaction_id): - """ Called on PUT /send// + """Called on PUT /send// Args: request (twisted.web.http.Request): The HTTP request. @@ -855,8 +858,7 @@ async def on_GET(self, origin, content, query): class FederationGroupsProfileServlet(BaseFederationServlet): - """Get/set the basic profile of a group on behalf of a user - """ + """Get/set the basic profile of a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/profile" @@ -895,8 +897,7 @@ async def on_GET(self, origin, content, query, group_id): class FederationGroupsRoomsServlet(BaseFederationServlet): - """Get the rooms in a group on behalf of a user - """ + """Get the rooms in a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/rooms" @@ -911,8 +912,7 @@ async def on_GET(self, origin, content, query, group_id): class FederationGroupsAddRoomsServlet(BaseFederationServlet): - """Add/remove room from group - """ + """Add/remove room from group""" PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" @@ -940,8 +940,7 @@ async def on_DELETE(self, origin, content, query, group_id, room_id): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): - """Update room config in group - """ + """Update room config in group""" PATH = ( "/groups/(?P[^/]*)/room/(?P[^/]*)" @@ -961,8 +960,7 @@ async def on_POST(self, origin, content, query, group_id, room_id, config_key): class FederationGroupsUsersServlet(BaseFederationServlet): - """Get the users in a group on behalf of a user - """ + """Get the users in a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/users" @@ -977,8 +975,7 @@ async def on_GET(self, origin, content, query, group_id): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): - """Get the users that have been invited to a group - """ + """Get the users that have been invited to a group""" PATH = "/groups/(?P[^/]*)/invited_users" @@ -995,8 +992,7 @@ async def on_GET(self, origin, content, query, group_id): class FederationGroupsInviteServlet(BaseFederationServlet): - """Ask a group server to invite someone to the group - """ + """Ask a group server to invite someone to the group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" @@ -1013,8 +1009,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): - """Accept an invitation from the group server - """ + """Accept an invitation from the group server""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" @@ -1028,8 +1023,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsJoinServlet(BaseFederationServlet): - """Attempt to join a group - """ + """Attempt to join a group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" @@ -1043,8 +1037,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsRemoveUserServlet(BaseFederationServlet): - """Leave or kick a user from the group - """ + """Leave or kick a user from the group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" @@ -1061,8 +1054,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsLocalInviteServlet(BaseFederationServlet): - """A group server has invited a local user - """ + """A group server has invited a local user""" PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" @@ -1076,8 +1068,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): - """A group server has removed a local user - """ + """A group server has removed a local user""" PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" @@ -1093,8 +1084,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation - """ + """A group or user's server renews their attestation""" PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" @@ -1156,8 +1146,7 @@ async def on_DELETE(self, origin, content, query, group_id, category_id, room_id class FederationGroupsCategoriesServlet(BaseFederationServlet): - """Get all categories for a group - """ + """Get all categories for a group""" PATH = "/groups/(?P[^/]*)/categories/?" @@ -1172,8 +1161,7 @@ async def on_GET(self, origin, content, query, group_id): class FederationGroupsCategoryServlet(BaseFederationServlet): - """Add/remove/get a category in a group - """ + """Add/remove/get a category in a group""" PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" @@ -1218,8 +1206,7 @@ async def on_DELETE(self, origin, content, query, group_id, category_id): class FederationGroupsRolesServlet(BaseFederationServlet): - """Get roles in a group - """ + """Get roles in a group""" PATH = "/groups/(?P[^/]*)/roles/?" @@ -1234,8 +1221,7 @@ async def on_GET(self, origin, content, query, group_id): class FederationGroupsRoleServlet(BaseFederationServlet): - """Add/remove/get a role in a group - """ + """Add/remove/get a role in a group""" PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" @@ -1325,8 +1311,7 @@ async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): - """Get roles in a group - """ + """Get roles in a group""" PATH = "/get_groups_publicised" @@ -1339,8 +1324,7 @@ async def on_POST(self, origin, content, query): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): - """Sets whether a group is joinable without an invite or knock - """ + """Sets whether a group is joinable without an invite or knock""" PATH = "/groups/(?P[^/]*)/settings/m.join_policy" diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 64d98fc8f675..b662c4262120 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -29,7 +29,7 @@ @attr.s(slots=True) class Edu(JsonEncodedObject): - """ An Edu represents a piece of data sent from one homeserver to another. + """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are not meaningful beyond a given pair of homeservers, and don't have an @@ -63,7 +63,7 @@ def strip_context(self): class Transaction(JsonEncodedObject): - """ A transaction is a list of Pdus and Edus to be sent to a remote home + """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. Example transaction:: @@ -99,7 +99,7 @@ class Transaction(JsonEncodedObject): ] def __init__(self, transaction_id=None, pdus=[], **kwargs): - """ If we include a list of pdus then we decode then as PDU's + """If we include a list of pdus then we decode then as PDU's automatically. """ @@ -111,7 +111,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs): @staticmethod def create_new(pdus, **kwargs): - """ Used to create a new transaction. Will auto fill out + """Used to create a new transaction. Will auto fill out transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 41cf07cc881b..db69bb7c06bf 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -61,8 +61,7 @@ class GroupAttestationSigning: - """Creates and verifies group attestations. - """ + """Creates and verifies group attestations.""" def __init__(self, hs): self.keyring = hs.get_keyring() @@ -125,8 +124,7 @@ def create_attestation(self, group_id, user_id): class GroupAttestionRenewer: - """Responsible for sending and receiving attestation updates. - """ + """Responsible for sending and receiving attestation updates.""" def __init__(self, hs): self.clock = hs.get_clock() @@ -142,8 +140,7 @@ def __init__(self, hs): ) async def on_renew_attestation(self, group_id, user_id, content): - """When a remote updates an attestation - """ + """When a remote updates an attestation""" attestation = content["attestation"] if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): @@ -161,8 +158,7 @@ def _start_renew_attestations(self): return run_as_background_process("renew_attestations", self._renew_attestations) async def _renew_attestations(self): - """Called periodically to check if we need to update any of our attestations - """ + """Called periodically to check if we need to update any of our attestations""" now = self.clock.time_msec() diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 76bf52ea235b..4e8695aa7cb8 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -165,16 +165,14 @@ async def get_group_summary(self, group_id, requester_user_id): } async def get_group_categories(self, group_id, requester_user_id): - """Get all categories in a group (as seen by user) - """ + """Get all categories in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) categories = await self.store.get_group_categories(group_id=group_id) return {"categories": categories} async def get_group_category(self, group_id, requester_user_id, category_id): - """Get a specific category in a group (as seen by user) - """ + """Get a specific category in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = await self.store.get_group_category( @@ -186,24 +184,21 @@ async def get_group_category(self, group_id, requester_user_id, category_id): return res async def get_group_roles(self, group_id, requester_user_id): - """Get all roles in a group (as seen by user) - """ + """Get all roles in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) roles = await self.store.get_group_roles(group_id=group_id) return {"roles": roles} async def get_group_role(self, group_id, requester_user_id, role_id): - """Get a specific role in a group (as seen by user) - """ + """Get a specific role in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = await self.store.get_group_role(group_id=group_id, role_id=role_id) return res async def get_group_profile(self, group_id, requester_user_id): - """Get the group profile as seen by requester_user_id - """ + """Get the group profile as seen by requester_user_id""" await self.check_group_is_ours(group_id, requester_user_id) @@ -350,8 +345,7 @@ def __init__(self, hs): async def update_group_summary_room( self, group_id, requester_user_id, room_id, category_id, content ): - """Add/update a room to the group summary - """ + """Add/update a room to the group summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -375,8 +369,7 @@ async def update_group_summary_room( async def delete_group_summary_room( self, group_id, requester_user_id, room_id, category_id ): - """Remove a room from the summary - """ + """Remove a room from the summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -409,8 +402,7 @@ async def set_group_join_policy(self, group_id, requester_user_id, content): async def update_group_category( self, group_id, requester_user_id, category_id, content ): - """Add/Update a group category - """ + """Add/Update a group category""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -428,8 +420,7 @@ async def update_group_category( return {} async def delete_group_category(self, group_id, requester_user_id, category_id): - """Delete a group category - """ + """Delete a group category""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -441,8 +432,7 @@ async def delete_group_category(self, group_id, requester_user_id, category_id): return {} async def update_group_role(self, group_id, requester_user_id, role_id, content): - """Add/update a role in a group - """ + """Add/update a role in a group""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -458,8 +448,7 @@ async def update_group_role(self, group_id, requester_user_id, role_id, content) return {} async def delete_group_role(self, group_id, requester_user_id, role_id): - """Remove role from group - """ + """Remove role from group""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -471,8 +460,7 @@ async def delete_group_role(self, group_id, requester_user_id, role_id): async def update_group_summary_user( self, group_id, requester_user_id, user_id, role_id, content ): - """Add/update a users entry in the group summary - """ + """Add/update a users entry in the group summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -494,8 +482,7 @@ async def update_group_summary_user( async def delete_group_summary_user( self, group_id, requester_user_id, user_id, role_id ): - """Remove a user from the group summary - """ + """Remove a user from the group summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -507,8 +494,7 @@ async def delete_group_summary_user( return {} async def update_group_profile(self, group_id, requester_user_id, content): - """Update the group profile - """ + """Update the group profile""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -539,8 +525,7 @@ async def update_group_profile(self, group_id, requester_user_id, content): await self.store.update_group_profile(group_id, profile) async def add_room_to_group(self, group_id, requester_user_id, room_id, content): - """Add room to group - """ + """Add room to group""" RoomID.from_string(room_id) # Ensure valid room id await self.check_group_is_ours( @@ -556,8 +541,7 @@ async def add_room_to_group(self, group_id, requester_user_id, room_id, content) async def update_room_in_group( self, group_id, requester_user_id, room_id, config_key, content ): - """Update room in group - """ + """Update room in group""" RoomID.from_string(room_id) # Ensure valid room id await self.check_group_is_ours( @@ -576,8 +560,7 @@ async def update_room_in_group( return {} async def remove_room_from_group(self, group_id, requester_user_id, room_id): - """Remove room from group - """ + """Remove room from group""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -587,8 +570,7 @@ async def remove_room_from_group(self, group_id, requester_user_id, room_id): return {} async def invite_to_group(self, group_id, user_id, requester_user_id, content): - """Invite user to group - """ + """Invite user to group""" group = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id @@ -724,8 +706,7 @@ async def join_group(self, group_id, requester_user_id, content): return {"state": "join", "attestation": local_attestation} async def knock(self, group_id, requester_user_id, content): - """A user requests becoming a member of the group - """ + """A user requests becoming a member of the group""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() @@ -918,8 +899,7 @@ async def _kick_user_from_group(user_id): def _parse_join_policy_from_contents(content): - """Given a content for a request, return the specified join policy or None - """ + """Given a content for a request, return the specified join policy or None""" join_policy_dict = content.get("m.join_policy") if join_policy_dict: @@ -929,8 +909,7 @@ def _parse_join_policy_from_contents(content): def _parse_join_policy_dict(join_policy_dict): - """Given a dict for the "m.join_policy" config return the join policy specified - """ + """Given a dict for the "m.join_policy" config return the join policy specified""" join_policy_type = join_policy_dict.get("type") if not join_policy_type: return "invite" diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 37e63da9b12a..db68c94c50a7 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -203,13 +203,11 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> class ExfiltrationWriter(metaclass=abc.ABCMeta): - """Interface used to specify how to write exported data. - """ + """Interface used to specify how to write exported data.""" @abc.abstractmethod def write_events(self, room_id: str, events: List[EventBase]) -> None: - """Write a batch of events for a room. - """ + """Write a batch of events for a room.""" raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5c6458eb52fc..deab8ff2d032 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -290,7 +290,9 @@ async def _handle_presence( if not interested: continue presence_events, _ = await presence_source.get_new_events( - user=user, service=service, from_key=from_key, + user=user, + service=service, + from_key=from_key, ) time_now = self.clock.time_msec() events.extend( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 648fe91f53b5..9ba9f591d985 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier( # Ensure the identifier has a type if "type" not in identifier: raise SynapseError( - 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + 400, + "'identifier' dict has no key 'type'", + errcode=Codes.MISSING_PARAM, ) return identifier @@ -351,7 +353,11 @@ def get_new_session_data() -> JsonDict: try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, description, get_new_session_data, + flows, + request, + request_body, + description, + get_new_session_data, ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). @@ -379,8 +385,7 @@ def get_new_session_data() -> JsonDict: return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: - """Get a list of the authentication types this user can use - """ + """Get a list of the authentication types this user can use""" ui_auth_types = set() @@ -723,7 +728,9 @@ def _get_params_terms(self) -> dict: } def _auth_dict_for_flows( - self, flows: List[List[str]], session_id: str, + self, + flows: List[List[str]], + session_id: str, ) -> Dict[str, Any]: public_flows = [] for f in flows: @@ -880,7 +887,9 @@ def get_supported_login_types(self) -> Iterable[str]: return self._supported_login_types async def validate_login( - self, login_submission: Dict[str, Any], ratelimit: bool = False, + self, + login_submission: Dict[str, Any], + ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Authenticates the user for the /login API @@ -1023,7 +1032,9 @@ async def validate_login( raise async def _validate_userid_login( - self, username: str, login_submission: Dict[str, Any], + self, + username: str, + login_submission: Dict[str, Any], ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Helper for validate_login @@ -1446,7 +1457,8 @@ def _complete_sso_login( # is considered OK since the newest SSO attributes should be most valid. if extra_attributes: self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( - self._clock.time_msec(), extra_attributes, + self._clock.time_msec(), + extra_attributes, ) # Create a login token @@ -1702,5 +1714,9 @@ async def on_logged_out( # This might return an awaitable, if it does block the log out # until it completes. await maybe_awaitable( - g(user_id=user_id, device_id=device_id, access_token=access_token,) + g( + user_id=user_id, + device_id=device_id, + access_token=access_token, + ) ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 81ed44ac877a..04972f9cf0b6 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -33,8 +33,7 @@ class CasError(Exception): - """Used to catch errors when validating the CAS ticket. - """ + """Used to catch errors when validating the CAS ticket.""" def __init__(self, error, error_description=None): self.error = error @@ -100,7 +99,10 @@ def _build_service_param(self, args: Dict[str, str]) -> str: Returns: The URL to use as a "service" parameter. """ - return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) + return "%s?%s" % ( + self._cas_service_url, + urllib.parse.urlencode(args), + ) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] @@ -296,7 +298,10 @@ async def _handle_cas_response( # first check if we're doing a UIA if session: return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, cas_response.username, session, request, + self.idp_id, + cas_response.username, + session, + request, ) # otherwise, we're handling a login request. @@ -366,7 +371,8 @@ async def grandfather_existing_users() -> Optional[str]: user_id = UserID(localpart, self._hostname).to_string() logger.debug( - "Looking for existing account based on mapped %s", user_id, + "Looking for existing account based on mapped %s", + user_id, ) users = await self._store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index c4a3b26a8486..94f3f3163f11 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -196,8 +196,7 @@ def _start_user_parting(self) -> None: run_as_background_process("user_parter_loop", self._user_parter_loop) async def _user_parter_loop(self) -> None: - """Loop that parts deactivated users from rooms - """ + """Loop that parts deactivated users from rooms""" self._user_parter_running = True logger.info("Starting user parter") try: @@ -214,8 +213,7 @@ async def _user_parter_loop(self) -> None: self._user_parter_running = False async def _part_user(self, user_id: str) -> None: - """Causes the given user_id to leave all the rooms they're joined to - """ + """Causes the given user_id to leave all the rooms they're joined to""" user = UserID.from_string(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 0863154f7aba..df3cdc8fba11 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -86,7 +86,7 @@ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: - """ Retrieve the given device + """Retrieve the given device Args: user_id: The user to get the device from @@ -341,7 +341,7 @@ async def check_device_registered( @trace async def delete_device(self, user_id: str, device_id: str) -> None: - """ Delete the given device + """Delete the given device Args: user_id: The user to delete the device from. @@ -386,7 +386,7 @@ async def delete_all_devices_for_user( await self.delete_devices(user_id, device_ids) async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: - """ Delete several devices + """Delete several devices Args: user_id: The user to delete devices from. @@ -417,7 +417,7 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: await self.notify_device_update(user_id, device_ids) async def update_device(self, user_id: str, device_id: str, content: dict) -> None: - """ Update the given device + """Update the given device Args: user_id: The user to update devices of. @@ -534,7 +534,9 @@ async def store_dehydrated_device( device id of the dehydrated device """ device_id = await self.check_device_registered( - user_id, None, initial_device_display_name, + user_id, + None, + initial_device_display_name, ) old_device_id = await self.store.store_dehydrated_device( user_id, device_id, device_data @@ -803,7 +805,8 @@ async def _maybe_retry_device_resync(self) -> None: try: # Try to resync the current user's devices list. result = await self.user_device_resync( - user_id=user_id, mark_failed_as_stale=False, + user_id=user_id, + mark_failed_as_stale=False, ) # user_device_resync only returns a result if it managed to @@ -813,14 +816,17 @@ async def _maybe_retry_device_resync(self) -> None: # self.store.update_remote_device_list_cache). if result: logger.debug( - "Successfully resynced the device list for %s", user_id, + "Successfully resynced the device list for %s", + user_id, ) except Exception as e: # If there was an issue resyncing this user, e.g. if the remote # server sent a malformed result, just log the error instead of # aborting all the subsequent resyncs. logger.debug( - "Could not resync the device list for %s: %s", user_id, e, + "Could not resync the device list for %s: %s", + user_id, + e, ) finally: # Allow future calls to retry resyncinc out of sync device lists. @@ -855,7 +861,9 @@ async def user_device_resync( return None except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to handle device list update for %s: %s", user_id, e, + "Failed to handle device list update for %s: %s", + user_id, + e, ) if mark_failed_as_stale: @@ -931,7 +939,9 @@ async def user_device_resync( # Handle cross-signing keys. cross_signing_device_ids = await self.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + cross_signing_device_ids diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0c7737e09d7e..1aa7d803b573 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -62,7 +62,8 @@ def __init__(self, hs: "HomeServer"): ) else: hs.get_federation_registry().register_instances_for_edu( - "m.direct_to_device", hs.config.worker.writers.to_device, + "m.direct_to_device", + hs.config.worker.writers.to_device, ) # The handler to call when we think a user's device list might be out of @@ -73,8 +74,8 @@ def __init__(self, hs: "HomeServer"): hs.get_device_handler().device_list_updater.user_device_resync ) else: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8f3a6b35a433..9a946a3cfe5d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -61,8 +61,8 @@ def __init__(self, hs: "HomeServer"): self._is_master = hs.config.worker_app is None if not self._is_master: - self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) else: # Only register this edu handler on master as it requires writing @@ -85,7 +85,7 @@ def __init__(self, hs: "HomeServer"): async def query_devices( self, query_body: JsonDict, timeout: int, from_user_id: str ) -> JsonDict: - """ Handle a device key query from a client + """Handle a device key query from a client { "device_keys": { @@ -391,8 +391,7 @@ async def query_local_devices( async def on_federation_query_client_keys( self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: - """ Handle a device key query from a federated server - """ + """Handle a device key query from a federated server""" device_keys_query = query_body.get( "device_keys", {} ) # type: Dict[str, Optional[List[str]]] @@ -1065,7 +1064,9 @@ async def _get_e2e_cross_signing_verify_key( return key, key_id, verify_key async def _retrieve_cross_signing_keys_for_remote_user( - self, user: UserID, desired_key_type: str, + self, + user: UserID, + desired_key_type: str, ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database @@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: @attr.s(slots=True) class SignatureListItem: - """An item in the signature list as used by upload_signatures_for_device_keys. - """ + """An item in the signature list as used by upload_signatures_for_device_keys.""" signing_key_id = attr.ib(type=str) target_user_id = attr.ib(type=str) @@ -1355,8 +1355,12 @@ async def _handle_signing_key_updates(self, user_id: str) -> None: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = await device_list_updater.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + new_device_ids = ( + await device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, + ) ) device_ids = device_ids + new_device_ids diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 539b4fc32e95..3e23f82cf756 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -57,8 +57,7 @@ async def get_stream( room_id: Optional[str] = None, is_guest: bool = False, ) -> JsonDict: - """Fetches the events stream for a given user. - """ + """Fetches the events stream for a given user.""" if room_id: blocked = await self.store.is_room_blocked(room_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5581e06bb464..2ead626a4d5a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -111,13 +111,13 @@ class _NewEventInfo: class FederationHandler(BaseHandler): """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the homeserver (including auth and state conflict resolutions) + b) converting events that were produced by local clients that may need + to be sent to remote homeservers. + c) doing the necessary dances to invite remote users and join remote + rooms. """ def __init__(self, hs: "HomeServer"): @@ -150,11 +150,11 @@ def __init__(self, hs: "HomeServer"): ) if hs.config.worker_app: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) - self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( - hs + self._maybe_store_room_on_outlier_membership = ( + ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: self._device_list_updater = hs.get_device_handler().device_list_updater @@ -172,7 +172,7 @@ def __init__(self, hs: "HomeServer"): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: - """ Process a PDU received via a federation /send/ transaction, or + """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: @@ -368,7 +368,8 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: # know about for p in prevs - seen: logger.info( - "Requesting state at missing prev_event %s", event_id, + "Requesting state at missing prev_event %s", + event_id, ) with nested_logging_context(p): @@ -388,12 +389,14 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), + state_map = ( + await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) ) # We need to give _process_received_pdu the actual state events @@ -687,9 +690,12 @@ async def _get_events_from_store_or_dest( return fetched_events async def _process_received_pdu( - self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], ): - """ Called when we have a new pdu. We need to do auth checks and put it + """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. Args: @@ -801,7 +807,7 @@ async def _resync_device(self, sender: str) -> None: @log_function async def backfill(self, dest, room_id, limit, extremities): - """ Trigger a backfill request to `dest` for the given `room_id` + """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side has no new events to offer, this will return an empty list. @@ -1204,11 +1210,16 @@ async def get_event(event_id: str): with nested_logging_context(event_id): try: event = await self.federation_client.get_pdu( - [destination], event_id, room_version, outlier=True, + [destination], + event_id, + room_version, + outlier=True, ) if event is None: logger.warning( - "Server %s didn't return event %s", destination, event_id, + "Server %s didn't return event %s", + destination, + event_id, ) return @@ -1235,7 +1246,8 @@ async def get_event(event_id: str): if aid not in event_map ] persisted_events = await self.store.get_events( - auth_events, allow_rejected=True, + auth_events, + allow_rejected=True, ) event_infos = [] @@ -1251,7 +1263,9 @@ async def get_event(event_id: str): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, room_id, event_infos, + destination, + room_id, + event_infos, ) def _sanity_check_event(self, ev): @@ -1287,7 +1301,7 @@ def _sanity_check_event(self, ev): raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): - """ Sends the invite to the remote server for signing. + """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. """ @@ -1310,7 +1324,7 @@ async def on_event_auth(self, event_id: str) -> List[EventBase]: async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict ) -> Tuple[str, int]: - """ Attempts to join the `joinee` to the room `room_id` via the + """Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial @@ -1388,7 +1402,8 @@ async def do_invite_join( # so we can rely on it now. # await self.store.upsert_room_on_join( - room_id=room_id, room_version=room_version_obj, + room_id=room_id, + room_version=room_version_obj, ) max_stream_id = await self._persist_auth_tree( @@ -1458,7 +1473,7 @@ async def _handle_queued_pdus(self, room_queue): async def on_make_join_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_join/ request, so we create a partial + """We've received a /make_join/ request, so we create a partial join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1483,7 +1498,8 @@ async def on_make_join_request( is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( - "Got /make_join request for room %s we are no longer in", room_id, + "Got /make_join request for room %s we are no longer in", + room_id, ) raise NotFoundError("Not an active room on this server") @@ -1517,7 +1533,7 @@ async def on_make_join_request( return event async def on_send_join_request(self, origin, pdu): - """ We have received a join event for a room. Fully process it and + """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ event = pdu @@ -1573,7 +1589,7 @@ async def on_send_join_request(self, origin, pdu): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion ): - """ We've got an invite event. Process and persist it. Sign it. + """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. """ @@ -1700,7 +1716,7 @@ async def _make_and_verify_event( async def on_make_leave_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_leave/ request, so we create a partial + """We've received a /make_leave/ request, so we create a partial leave event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1776,8 +1792,7 @@ async def on_send_leave_request(self, origin, pdu): return None async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) @@ -1803,8 +1818,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase return [] async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2010,7 +2024,11 @@ async def _persist_auth_tree( for e_id in missing_auth_events: m_ev = await self.federation_client.get_pdu( - [origin], e_id, room_version=room_version, outlier=True, timeout=10000, + [origin], + e_id, + room_version=room_version, + outlier=True, + timeout=10000, ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -2160,7 +2178,9 @@ async def _check_for_soft_fail( ) logger.debug( - "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state @@ -2513,7 +2533,7 @@ async def _update_context_for_auth_events( async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: - """ Given a local and remote auth chain, find the differences. This + """Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth Params: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 71f11ef94aad..bfb95e3eee53 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -146,8 +146,7 @@ async def get_group_summary( async def get_users_in_group( self, group_id: str, requester_user_id: str ) -> JsonDict: - """Get users in a group - """ + """Get users in a group""" if self.is_mine_id(group_id): return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id @@ -283,8 +282,7 @@ def __init__(self, hs: "HomeServer"): async def create_group( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Create a group - """ + """Create a group""" logger.info("Asking to create group with ID: %r", group_id) @@ -314,8 +312,7 @@ async def create_group( async def join_group( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Request to join a group - """ + """Request to join a group""" if self.is_mine_id(group_id): await self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None @@ -361,8 +358,7 @@ async def join_group( async def accept_invite( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Accept an invite to a group - """ + """Accept an invite to a group""" if self.is_mine_id(group_id): await self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None @@ -408,8 +404,7 @@ async def accept_invite( async def invite( self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict ) -> JsonDict: - """Invite a user to a group - """ + """Invite a user to a group""" content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = await self.groups_server_handler.invite_to_group( @@ -434,8 +429,7 @@ async def invite( async def on_invite( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """One of our users were invited to a group - """ + """One of our users were invited to a group""" # TODO: Support auto join and rejection if not self.is_mine_id(user_id): @@ -466,8 +460,7 @@ async def on_invite( async def remove_user_from_group( self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict ) -> JsonDict: - """Remove a user from a group - """ + """Remove a user from a group""" if user_id == requester_user_id: token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" @@ -501,8 +494,7 @@ async def remove_user_from_group( async def user_removed_from_group( self, group_id: str, user_id: str, content: JsonDict ) -> None: - """One of our users was removed/kicked from a group - """ + """One of our users was removed/kicked from a group""" # TODO: Check if user in group token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 8fc1e8b91c0a..5f346f6d6d28 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -72,7 +72,10 @@ def __init__(self, hs): ) def ratelimit_request_token_requests( - self, request: SynapseRequest, medium: str, address: str, + self, + request: SynapseRequest, + medium: str, + address: str, ): """Used to ratelimit requests to `/requestToken` by IP and address. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index fbd8df9dcc09..78c3e5a10bb9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -124,7 +124,8 @@ async def _snapshot_all_rooms( joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] receipt = await self.store.get_linearized_receipts_for_rooms( - joined_rooms, to_key=int(now_token.receipt_key), + joined_rooms, + to_key=int(now_token.receipt_key), ) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -169,7 +170,10 @@ async def handle_room(event: RoomsForUser): self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: - room_end_token = RoomStreamToken(None, event.stream_ordering,) + room_end_token = RoomStreamToken( + None, + event.stream_ordering, + ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] ) @@ -284,7 +288,9 @@ async def room_initial_sync( membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True, + room_id, + user_id, + allow_departed_users=True, ) is_peeking = member_event_id is None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a15336bf00a8..c03f6c997b29 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -65,8 +65,7 @@ class MessageHandler: - """Contains some read only APIs to get state about a room - """ + """Contains some read only APIs to get state about a room""" def __init__(self, hs): self.auth = hs.get_auth() @@ -88,9 +87,13 @@ def __init__(self, hs): ) async def get_room_data( - self, user_id: str, room_id: str, event_type: str, state_key: str, + self, + user_id: str, + room_id: str, + event_type: str, + state_key: str, ) -> dict: - """ Get data from a room. + """Get data from a room. Args: user_id @@ -174,7 +177,10 @@ async def get_state_events( raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False, + self.storage, + user_id, + last_events, + filter_send_to_client=False, ) event = last_events[0] @@ -571,7 +577,7 @@ async def create_event( async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester ) -> bool: - """"Determine if an event to be sent is exempt from having to consent + """ "Determine if an event to be sent is exempt from having to consent to the privacy policy Args: @@ -793,9 +799,10 @@ async def create_new_client_event( """ if prev_event_ids is not None: - assert len(prev_event_ids) <= 10, ( - "Attempting to create an event with %i prev_events" - % (len(prev_event_ids),) + assert ( + len(prev_event_ids) <= 10 + ), "Attempting to create an event with %i prev_events" % ( + len(prev_event_ids), ) else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) @@ -821,7 +828,8 @@ async def create_new_client_event( ) if not third_party_result: logger.info( - "Event %s forbidden by third-party rules", event, + "Event %s forbidden by third-party rules", + event, ) raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN @@ -1316,7 +1324,11 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool: # Since this is a dummy-event it is OK if it is sent by a # shadow-banned user. await self.handle_new_client_event( - requester, event, context, ratelimit=False, ignore_shadow_ban=True, + requester, + event, + context, + ratelimit=False, + ignore_shadow_ban=True, ) return True except AuthError: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 5f3e8a77f5df..702bfb8bc95f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -73,8 +73,7 @@ class OidcHandler: - """Handles requests related to the OpenID Connect login flow. - """ + """Handles requests related to the OpenID Connect login flow.""" def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() @@ -216,8 +215,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None: class OidcError(Exception): - """Used to catch errors when calling the token_endpoint - """ + """Used to catch errors when calling the token_endpoint""" def __init__(self, error, error_description=None): self.error = error @@ -252,7 +250,9 @@ def __init__( self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - provider.client_id, provider.client_secret, provider.client_auth_method, + provider.client_id, + provider.client_secret, + provider.client_auth_method, ) # type: ClientAuth self._client_auth_method = provider.client_auth_method @@ -509,7 +509,10 @@ async def _exchange_code(self, code: str) -> Token: # We're not using the SimpleHttpClient util methods as we don't want to # check the HTTP status code and we do the body encoding ourself. response = await self._http_client.request( - method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, ) # This is used in multiple error messages below @@ -966,7 +969,9 @@ def generate_oidc_session_token( A signed macaroon token with the session information. """ macaroon = pymacaroons.Macaroon( - location=self._server_name, identifier="key", key=self._macaroon_secret_key, + location=self._server_name, + identifier="key", + key=self._macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = session") diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5372753707f3..059064a4eb44 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -197,7 +197,8 @@ async def purge_history_for_rooms_in_range( stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) r = await self.store.get_room_event_before_stream_ordering( - room_id, stream_ordering, + room_id, + stream_ordering, ) if not r: logger.warning( @@ -223,7 +224,12 @@ async def purge_history_for_rooms_in_range( # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", self._purge_history, purge_id, room_id, token, True, + "_purge_history", + self._purge_history, + purge_id, + room_id, + token, + True, ) def start_purge_history( @@ -389,7 +395,9 @@ async def get_messages( ) await self.hs.get_federation_handler().maybe_backfill( - room_id, curr_topo, limit=pagin_config.limit, + room_id, + curr_topo, + limit=pagin_config.limit, ) to_room_key = None diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 22d1e9d35c1d..7ba22d511f33 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -635,8 +635,7 @@ async def update_external_syncs_clear(self, process_id): self.external_process_last_updated_ms.pop(process_id, None) async def current_state_for_user(self, user_id): - """Get the current presence state for a user. - """ + """Get the current presence state for a user.""" res = await self.current_state_for_users([user_id]) return res[user_id] @@ -678,8 +677,7 @@ def _push_to_remotes(self, states): self.federation.send_presence(states) async def incoming_presence(self, origin, content): - """Called when we receive a `m.presence` EDU from a remote server. - """ + """Called when we receive a `m.presence` EDU from a remote server.""" if not self._presence_enabled: return @@ -729,8 +727,7 @@ async def incoming_presence(self, origin, content): await self._update_states(updates) async def set_state(self, target_user, state, ignore_status_msg=False): - """Set the presence state of the user. - """ + """Set the presence state of the user.""" status_msg = state.get("status_msg", None) presence = state["presence"] @@ -758,8 +755,7 @@ async def set_state(self, target_user, state, ignore_status_msg=False): await self._update_states([prev_state.copy_and_replace(**new_fields)]) async def is_visible(self, observed_user, observer_user): - """Returns whether a user can see another user's presence. - """ + """Returns whether a user can see another user's presence.""" observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) @@ -953,8 +949,7 @@ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: def should_notify(old_state, new_state): - """Decides if a presence state change should be sent to interested parties. - """ + """Decides if a presence state change should be sent to interested parties.""" if old_state == new_state: return False diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c02b95103199..2f62d84fb510 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -207,7 +207,8 @@ async def set_displayname( # This must be done by the target user himself. if by_admin: requester = create_requester( - target_user, authenticated_entity=requester.authenticated_entity, + target_user, + authenticated_entity=requester.authenticated_entity, ) await self.store.set_profile_displayname( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index cc21fc228474..6a6c52884983 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -49,15 +49,15 @@ def __init__(self, hs: "HomeServer"): ) else: hs.get_federation_registry().register_instances_for_edu( - "m.receipt", hs.config.worker.writers.receipts, + "m.receipt", + hs.config.worker.writers.receipts, ) self.clock = self.hs.get_clock() self.state = hs.get_state_handler() async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: - """Called when we receive an EDU of type m.receipt from a remote HS. - """ + """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] for room_id, room_values in content.items(): for receipt_type, users in room_values.items(): @@ -83,8 +83,7 @@ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None await self._handle_new_receipts(receipts) async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: - """Takes a list of receipts, stores them and informs the notifier. - """ + """Takes a list of receipts, stores them and informs the notifier.""" min_batch_id = None # type: Optional[int] max_batch_id = None # type: Optional[int] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 49b085269bcb..3cda89657ef0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -62,8 +62,8 @@ def __init__(self, hs: "HomeServer"): self._register_device_client = RegisterDeviceReplicationServlet.make_client( hs ) - self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( - hs + self._post_registration_client = ( + ReplicationPostRegisterActionsServlet.make_client(hs) ) else: self.device_handler = hs.get_device_handler() @@ -189,12 +189,15 @@ async def register_user( self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( - threepid, localpart, user_agent_ips or [], + threepid, + localpart, + user_agent_ips or [], ) if result == RegistrationBehaviour.DENY: logger.info( - "Blocked registration of %r", localpart, + "Blocked registration of %r", + localpart, ) # We return a 429 to make it not obvious that they've been # denied. @@ -203,7 +206,8 @@ async def register_user( shadow_banned = result == RegistrationBehaviour.SHADOW_BAN if shadow_banned: logger.info( - "Shadow banning registration of %r", localpart, + "Shadow banning registration of %r", + localpart, ) # do not check_auth_blocking if the call is coming through the Admin API @@ -369,7 +373,9 @@ async def _create_and_join_rooms(self, user_id: str) -> None: config["room_alias_name"] = room_alias.localpart info, _ = await room_creation_handler.create_room( - fake_requester, config=config, ratelimit=False, + fake_requester, + config=config, + ratelimit=False, ) # If the room does not require an invite, but another user @@ -753,7 +759,10 @@ async def _register_email_threepid( return await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) # And we add an email pusher for them by default, but only @@ -805,5 +814,8 @@ async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None: raise await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 591a82f45953..a488df10d678 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -198,7 +198,9 @@ async def _upgrade_room( if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = await self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], room_version=new_version, + creator_id=user_id, + is_public=r["is_public"], + room_version=new_version, ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -236,7 +238,9 @@ async def _upgrade_room( # now send the tombstone await self.event_creation_handler.handle_new_client_event( - requester=requester, event=tombstone_event, context=tombstone_context, + requester=requester, + event=tombstone_event, + context=tombstone_context, ) old_room_state = await tombstone_context.get_current_state_ids() @@ -257,7 +261,10 @@ async def _upgrade_room( # finally, shut down the PLs in the old room, and update them in the new # room. await self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, + old_room_id, + new_room_id, + old_room_state, ) return new_room_id @@ -570,7 +577,7 @@ async def create_room( ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, ) -> Tuple[dict, int]: - """ Creates a new room. + """Creates a new room. Args: requester: @@ -691,7 +698,9 @@ async def create_room( is_public = visibility == "public" room_id = await self._generate_room_id( - creator_id=user_id, is_public=is_public, room_version=room_version, + creator_id=user_id, + is_public=is_public, + room_version=room_version, ) # Check whether this visibility value is blocked by a third party module @@ -884,7 +893,10 @@ async def send(etype: str, content: JsonDict, **kwargs) -> int: _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False, ignore_shadow_ban=True, + creator, + event, + ratelimit=False, + ignore_shadow_ban=True, ) return last_stream_id @@ -984,7 +996,10 @@ async def send(etype: str, content: JsonDict, **kwargs) -> int: return last_sent_stream_id async def _generate_room_id( - self, creator_id: str, is_public: bool, room_version: RoomVersion, + self, + creator_id: str, + is_public: bool, + room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a5da97cfe0b2..1660921306c0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -191,7 +191,10 @@ async def _local_membership_update( # do it up front for efficiency.) if txn_id and requester.access_token_id: existing_event_id = await self.store.get_event_id_from_transaction_id( - room_id, requester.user.to_string(), requester.access_token_id, txn_id, + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, ) if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) @@ -238,7 +241,11 @@ async def _local_membership_update( ) result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit, + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, ) if event.membership == Membership.LEAVE: @@ -583,7 +590,10 @@ async def update_membership_locked( # send the rejection to the inviter's HS (with fallback to # local event) return await self.remote_reject_invite( - invite.event_id, txn_id, requester, content, + invite.event_id, + txn_id, + requester, + content, ) # the inviter was on our server, but has now left. Carry on @@ -1056,8 +1066,7 @@ async def _remote_join( user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" # filter ourselves out of remote_room_hosts: do_invite_join ignores it # and if it is the only entry we'd like to return a 404 rather than a # 500. @@ -1211,7 +1220,10 @@ async def _generate_local_out_of_band_leave( event.internal_metadata.out_of_band_membership = True result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[UserID.from_string(target_user)], + requester, + event, + context, + extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering assert result_event.internal_metadata.stream_ordering @@ -1219,8 +1231,7 @@ async def _generate_local_out_of_band_leave( return result_event.event_id, result_event.internal_metadata.stream_ordering async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room - """ + """Implements RoomMemberHandler._user_left_room""" user_left_room(self.distributor, target, room_id) async def forget(self, user: UserID, room_id: str) -> None: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index f2e88f6a5b5d..108730a7a112 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -44,8 +44,7 @@ async def _remote_join( user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") @@ -80,8 +79,7 @@ async def remote_reject_invite( return ret["event_id"], ret["stream_id"] async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room - """ + """Implements RoomMemberHandler._user_left_room""" await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 78f130e15243..a9645b77d808 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -121,7 +121,8 @@ async def handle_redirect_request( now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( - creation_time=now, ui_auth_session_id=ui_auth_session_id, + creation_time=now, + ui_auth_session_id=ui_auth_session_id, ) for key, value in info["headers"]: @@ -450,7 +451,8 @@ def saml_response_to_user_attributes( mxid_source = saml_response.ava[self._mxid_source_attribute][0] except KeyError: logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + "SAML2 response lacks a '%s' attestation", + self._mxid_source_attribute, ) raise SynapseError( 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index a63fd5248576..514b1f69d8ee 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -327,7 +327,8 @@ async def get_sso_user_by_remote_user_id( # Check if we already have a mapping for this user. previously_registered_user_id = await self._store.get_user_by_external_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # A match was found, return the user ID. @@ -416,7 +417,8 @@ async def complete_sso_login_request( with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # Check for grandfathering of users. @@ -461,7 +463,8 @@ async def complete_sso_login_request( ) async def _call_attribute_mapper( - self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + self, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], ) -> UserAttributes: """Call the attribute mapper function in a loop, until we get a unique userid""" for i in range(self._MAP_USERNAME_RETRIES): @@ -632,7 +635,8 @@ async def complete_sso_ui_auth_request( """ user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) user_id_to_verify = await self._auth_handler.get_session_data( @@ -671,7 +675,8 @@ async def complete_sso_ui_auth_request( # render an error page. html = self._bad_user_template.render( - server_name=self._server_name, user_id_to_verify=user_id_to_verify, + server_name=self._server_name, + user_id_to_verify=user_id_to_verify, ) respond_with_html(request, 200, html) @@ -695,7 +700,9 @@ def get_mapping_session(self, session_id: str) -> UsernameMappingSession: raise SynapseError(400, "unknown session") async def check_username_availability( - self, localpart: str, session_id: str, + self, + localpart: str, + session_id: str, ) -> bool: """Handle an "is username available" callback check @@ -833,7 +840,8 @@ async def register_sso_user(self, request: Request, session_id: str) -> None: ) attributes = UserAttributes( - localpart=session.chosen_localpart, emails=session.emails_to_use, + localpart=session.chosen_localpart, + emails=session.emails_to_use, ) if session.use_display_name: diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index d261d7cd4e84..924281144c62 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -63,8 +63,7 @@ def __init__(self, hs: "HomeServer"): self.clock.call_later(0, self.notify_new_event) def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.stats_enabled or self._is_processing: return diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5c7590f38e4d..4e8ed7b33f62 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -339,8 +339,7 @@ async def current_sync_for_user( since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now. - """ + """Get the sync for client needed to match what the server has now.""" return await self.generate_sync_result(sync_config, since_token, full_state) async def push_rules_for_user(self, user: UserID) -> JsonDict: @@ -564,7 +563,7 @@ async def get_state_at( stream_position: StreamToken, state_filter: StateFilter = StateFilter.all(), ) -> StateMap[str]: - """ Get the room state at a particular stream position + """Get the room state at a particular stream position Args: room_id: room for which to get state @@ -598,7 +597,7 @@ async def compute_summary( state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: - """ Works out a room summary block for this room, summarising the number + """Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds state events to 'state' if needed to describe the heroes. @@ -743,7 +742,7 @@ async def compute_state_delta( now_token: StreamToken, full_state: bool, ) -> MutableStateMap[EventBase]: - """ Works out the difference in state between the start of the timeline + """Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -820,8 +819,10 @@ async def compute_state_delta( ) elif batch.limited: if batch: - state_at_timeline_start = await self.state_store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_at_timeline_start = ( + await self.state_store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: # We can get here if the user has ignored the senders of all @@ -955,8 +956,7 @@ async def generate_sync_result( since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result. - """ + """Generates a sync result.""" # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1030,8 +1030,8 @@ async def generate_sync_result( one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( - user_id, device_id + unused_fallback_key_types = ( + await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) logger.debug("Fetching group data") @@ -1176,8 +1176,10 @@ async def _generate_sync_entry_for_device_list( # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) - user_signatures_changed = await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key + user_signatures_changed = ( + await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) ) users_that_have_changed.update(user_signatures_changed) @@ -1393,8 +1395,10 @@ async def _generate_sync_entry_for_rooms( logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) ) # If there is ignored users account data and it matches the proper type, @@ -1499,8 +1503,7 @@ async def _have_rooms_changed( async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync. - """ + """Gets the the changes that have happened since the last sync.""" user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3f0dfc7a74dc..096d199f4cf1 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -61,7 +61,8 @@ def __init__(self, hs: "HomeServer"): if hs.config.worker.writers.typing != hs.get_instance_name(): hs.get_federation_registry().register_instance_for_edu( - "m.typing", hs.config.worker.writers.typing, + "m.typing", + hs.config.worker.writers.typing, ) # map room IDs to serial numbers @@ -76,8 +77,7 @@ def __init__(self, hs: "HomeServer"): self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self) -> None: - """Reset the typing handler's data caches. - """ + """Reset the typing handler's data caches.""" # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing @@ -149,8 +149,7 @@ async def _push_remote(self, member: RoomMember, typing: bool) -> None: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] ) -> None: - """Should be called whenever we receive updates for typing stream. - """ + """Should be called whenever we receive updates for typing stream.""" if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 8aedf5072e6b..3dfb0a26c2a4 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -97,8 +97,7 @@ async def search_users( return results def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.update_user_directory: return @@ -134,8 +133,7 @@ async def handle_local_profile_change( ) async def handle_user_deactivated(self, user_id: str) -> None: - """Called when a user ID is deactivated - """ + """Called when a user ID is deactivated""" # FIXME(#3714): We should probably do this in the same worker as all # the other changes. await self.store.remove_from_user_dir(user_id) @@ -172,8 +170,7 @@ async def _unsafe_process(self) -> None: await self.store.update_user_directory_stream_pos(max_pos) async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: - """Called with the state deltas to process - """ + """Called with the state deltas to process""" for delta in deltas: typ = delta["type"] state_key = delta["state_key"] diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 4bc3cb53f0f9..c658862fe65f 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -54,8 +54,7 @@ def stopProducing(self): def get_request_user_agent(request: IRequest, default: str = "") -> str: - """Return the last User-Agent header, or the given default. - """ + """Return the last User-Agent header, or the given default.""" # There could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically diff --git a/synapse/http/client.py b/synapse/http/client.py index 37ccf5ab98f7..73b414ccffa0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -398,7 +398,8 @@ async def request( body_producer = None if data is not None: body_producer = QuieterFileBodyProducer( - BytesIO(data), cooperator=self._cooperator, + BytesIO(data), + cooperator=self._cooperator, ) request_deferred = treq.request( @@ -413,7 +414,9 @@ async def request( # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, 60, self.hs.get_reactor(), + request_deferred, + 60, + self.hs.get_reactor(), ) # turn timeouts into RequestTimedOutErrors diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 113fd4713434..2e83fa6773a1 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -195,8 +195,7 @@ def request( @implementer(IAgentEndpointFactory) class MatrixHostnameEndpointFactory: - """Factory for MatrixHostnameEndpoint for parsing to an Agent. - """ + """Factory for MatrixHostnameEndpoint for parsing to an Agent.""" def __init__( self, @@ -261,8 +260,7 @@ def __init__( self._srv_resolver = srv_resolver def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: - """Implements IStreamClientEndpoint interface - """ + """Implements IStreamClientEndpoint interface""" return run_in_background(self._do_connect, protocol_factory) diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index b3b6dbcab045..4def7d763304 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -81,8 +81,7 @@ class WellKnownLookupResult: class WellKnownResolver: - """Handles well-known lookups for matrix servers. - """ + """Handles well-known lookups for matrix servers.""" def __init__( self, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 19293bf6739d..cde42e9f5e7b 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -254,7 +254,8 @@ def __init__(self, hs, tls_client_options_factory): # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, + self.agent, + ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() @@ -652,7 +653,7 @@ async def put_json( backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, ) -> Union[JsonDict, list]: - """ Sends the specified json data using PUT + """Sends the specified json data using PUT Args: destination: The remote server to send the HTTP request to. @@ -740,7 +741,7 @@ async def post_json( ignore_backoff: bool = False, args: Optional[QueryArgs] = None, ) -> Union[JsonDict, list]: - """ Sends the specified json data using POST + """Sends the specified json data using POST Args: destination: The remote server to send the HTTP request to. @@ -799,7 +800,11 @@ async def post_json( _sec_timeout = self.default_timeout body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms, + self.reactor, + _sec_timeout, + request, + response, + start_ms, ) return body @@ -813,7 +818,7 @@ async def get_json( ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, ) -> Union[JsonDict, list]: - """ GETs some json from the given host homeserver and path + """GETs some json from the given host homeserver and path Args: destination: The remote server to send the HTTP request to. @@ -994,7 +999,10 @@ async def get_file( except BodyExceededMaxSize: msg = "Requested file is too large > %r bytes" % (max_size,) logger.warning( - "{%s} [%s] %s", request.txn_id, request.destination, msg, + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, ) raise SynapseError(502, msg, Codes.TOO_LARGE) except Exception as e: diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 7c5defec826e..0ec5d941b8fe 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -213,8 +213,7 @@ def stop(self, time_sec, response_code, sent_bytes): self.update_metrics() def update_metrics(self): - """Updates the in flight metrics with values from this request. - """ + """Updates the in flight metrics with values from this request.""" new_stats = self.start_context.get_resource_usage() diff = new_stats - self._request_stats diff --git a/synapse/http/server.py b/synapse/http/server.py index 8249732b27ac..845db9b78d5f 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -76,8 +76,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: - """Sends a JSON error response to clients. - """ + """Sends a JSON error response to clients.""" if f.check(SynapseError): error_code = f.value.code @@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, error_code, error_dict, send_cors=True, + request, + error_code, + error_dict, + send_cors=True, ) def return_html_error( - f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], + f: failure.Failure, + request: Request, + error_template: Union[str, jinja2.Template], ) -> None: """Sends an HTML error page corresponding to the given failure. @@ -189,8 +193,7 @@ async def wrapped_async_request_handler(self, request): class HttpServer(Protocol): - """ Interface for registering callbacks on a HTTP server - """ + """Interface for registering callbacks on a HTTP server""" def register_paths( self, @@ -199,7 +202,7 @@ def register_paths( callback: ServletCallback, servlet_classname: str, ) -> None: - """ Register a callback that gets fired if we receive a http request + """Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. If the regex contains groups these gets passed to the callback via @@ -235,8 +238,7 @@ def __init__(self, extract_context=False): self._extract_context = extract_context def render(self, request): - """ This gets called by twisted every time someone sends us a request. - """ + """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @@ -287,13 +289,18 @@ async def _async_render(self, request: Request): @abc.abstractmethod def _send_response( - self, request: SynapseRequest, code: int, response_object: Any, + self, + request: SynapseRequest, + code: int, + response_object: Any, ) -> None: raise NotImplementedError() @abc.abstractmethod def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: raise NotImplementedError() @@ -308,10 +315,12 @@ def __init__(self, canonical_json=False, extract_context=False): self.canonical_json = canonical_json def _send_response( - self, request: Request, code: int, response_object: Any, + self, + request: Request, + code: int, + response_object: Any, ): - """Implements _AsyncResource._send_response - """ + """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( request, @@ -322,15 +331,16 @@ def _send_response( ) def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: - """Implements _AsyncResource._send_error_response - """ + """Implements _AsyncResource._send_error_response""" return_json_error(f, request) class JsonResource(DirectServeJsonResource): - """ This implements the HttpServer interface and provides JSON support for + """This implements the HttpServer interface and provides JSON support for Resources. Register callbacks via register_paths() @@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource): ERROR_TEMPLATE = HTML_ERROR_TEMPLATE def _send_response( - self, request: SynapseRequest, code: int, response_object: Any, + self, + request: SynapseRequest, + code: int, + response_object: Any, ): - """Implements _AsyncResource._send_response - """ + """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) html_bytes = response_object @@ -454,10 +466,11 @@ def _send_response( respond_with_html_bytes(request, 200, html_bytes) def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: - """Implements _AsyncResource._send_error_response - """ + """Implements _AsyncResource._send_error_response""" return_html_error(f, request, self.ERROR_TEMPLATE) @@ -534,7 +547,9 @@ class _ByteProducer: min_chunk_size = 1024 def __init__( - self, request: Request, iterator: Iterator[bytes], + self, + request: Request, + iterator: Iterator[bytes], ): self._request = request self._iterator = iterator @@ -654,7 +669,10 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, code: int, json_bytes: bytes, send_cors: bool = False, + request: Request, + code: int, + json_bytes: bytes, + send_cors: bool = False, ): """Sends encoded JSON in response to the given request. @@ -769,7 +787,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: def finish_request(request: Request): - """ Finish writing the response to the request. + """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the response was written but doesn't provide a convenient or reliable way to diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index b361b7cbaf43..0e637f47016f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -258,7 +258,7 @@ def assert_params_in_dict(body, required): class RestServlet: - """ A Synapse REST Servlet. + """A Synapse REST Servlet. An implementing class can either provide its own custom 'register' method, or use the automatic pattern handling provided by the base class. diff --git a/synapse/http/site.py b/synapse/http/site.py index 12ec3f851fd3..4a4fb5ef264b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -249,8 +249,7 @@ def _started_processing(self, servlet_name): ) def _finished_processing(self): - """Log the completion of this request and update the metrics - """ + """Log the completion of this request and update the metrics""" assert self.logcontext is not None usage = self.logcontext.get_resource_usage() @@ -276,7 +275,8 @@ def _finished_processing(self): # authenticated (e.g. and admin is puppetting a user) then we log both. if self.requester.user.to_string() != authenticated_entity: authenticated_entity = "{},{}".format( - authenticated_entity, self.requester.user.to_string(), + authenticated_entity, + self.requester.user.to_string(), ) elif self.requester is not None: # This shouldn't happen, but we log it so we don't lose information @@ -322,8 +322,7 @@ def _finished_processing(self): logger.warning("Failed to stop metrics: %r", e) def _should_log_request(self) -> bool: - """Whether we should log at INFO that we processed the request. - """ + """Whether we should log at INFO that we processed the request.""" if self.path == b"/health": return False diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index fb937b3f2847..f8e9112b56b1 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -174,7 +174,9 @@ def writer(result: Protocol) -> None: # Make a new producer and start it. self._producer = LogProducer( - buffer=self._buffer, transport=result.transport, format=self.format, + buffer=self._buffer, + transport=result.transport, + format=self.format, ) result.transport.registerProducer(self._producer, True) self._producer.resumeProducing() diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 14d9c104c2e0..3e054f615c48 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -60,7 +60,10 @@ def parse_drain_configs( ) # Either use the default formatter or the tersejson one. - if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): + if logging_type in ( + DrainType.CONSOLE_JSON, + DrainType.FILE_JSON, + ): formatter = "json" # type: Optional[str] elif logging_type in ( DrainType.CONSOLE_JSON_TERSE, @@ -131,7 +134,9 @@ def parse_drain_configs( ) -def setup_structured_logging(log_config: dict,) -> dict: +def setup_structured_logging( + log_config: dict, +) -> dict: """ Convert a legacy structured logging configuration (from Synapse < v1.23.0) to one compatible with the new standard library handlers. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index c2db8b45f3f7..78e27bfb00ed 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -338,7 +338,10 @@ def __enter__(self) -> "LoggingContext": if self.previous_context != old_context: logcontext_error( "Expected previous context %r, found %r" - % (self.previous_context, old_context,) + % ( + self.previous_context, + old_context, + ) ) return self @@ -562,7 +565,7 @@ def filter(self, record: logging.LogRecord) -> Literal[True]: class PreserveLoggingContext: """Context manager which replaces the logging context - The previous logging context is restored on exit.""" + The previous logging context is restored on exit.""" __slots__ = ["_old_context", "_new_context"] @@ -585,7 +588,10 @@ def __exit__(self, type, value, traceback) -> None: else: logcontext_error( "Expected logging context %s but found %s" - % (self._new_context, context,) + % ( + self._new_context, + context, + ) ) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 0538350f38b7..10bd4a14614b 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -238,8 +238,7 @@ class _DummyTagNames: @attr.s(slots=True, frozen=True) class _WrappedRustReporter: - """Wrap the reporter to ensure `report_span` never throws. - """ + """Wrap the reporter to ensure `report_span` never throws.""" _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) @@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs): def init_tracer(hs: "HomeServer"): - """Set the whitelists and initialise the JaegerClient tracer - """ + """Set the whitelists and initialise the JaegerClient tracer""" global opentracing if not hs.config.opentracer_enabled: # We don't have a tracer @@ -384,7 +382,7 @@ def whitelisted_homeserver(destination): Args: destination (str) - """ + """ if _homeserver_whitelist: return _homeserver_whitelist.match(destination) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index becf66dd86c8..fd3543ab0428 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args): def log_function(f): - """ Function decorator that logs every call to that function. - """ + """Function decorator that logs every call to that function.""" func_name = f.__name__ @wraps(f) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index cbf0dbb871e7..a8cb49d5b4be 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -155,8 +155,7 @@ def register(self, key, callback): self._registrations.setdefault(key, set()).add(callback) def unregister(self, key, callback): - """Registers that we've exited a block with labels `key`. - """ + """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) @@ -402,7 +401,9 @@ def collect(self): # Total time spent in GC: 0.073 # s.total_gc_time pypy_gc_time = CounterMetricFamily( - "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[], + "pypy_gc_time_seconds_total", + "Total time spent in PyPy GC", + labels=[], ) pypy_gc_time.add_metric([], s.total_gc_time / 1000) yield pypy_gc_time diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 734271e765ae..71320a140223 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -216,7 +216,7 @@ def log_message(self, format, *args): @classmethod def factory(cls, registry): """Returns a dynamic MetricsHandler class tied - to the passed registry. + to the passed registry. """ # This implementation relies on MetricsHandler.registry # (defined above and defaulted to REGISTRY). diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 70e0fa45d9d1..b56986d8e753 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -208,7 +208,8 @@ async def run(): return await maybe_awaitable(func(*args, **kwargs)) except Exception: logger.exception( - "Background process '%s' threw an exception", desc, + "Background process '%s' threw an exception", + desc, ) finally: _background_process_in_flight_count.labels(desc).dec() @@ -249,8 +250,7 @@ def __init__(self, name: str, request: Optional[str] = None): self._proc = _BackgroundProcess(name, self) def start(self, rusage: "Optional[resource._RUsage]"): - """Log context has started running (again). - """ + """Log context has started running (again).""" super().start(rusage) @@ -261,8 +261,7 @@ def start(self, rusage: "Optional[resource._RUsage]"): _background_processes_active_since_last_scrape.add(self._proc) def __exit__(self, type, value, traceback) -> None: - """Log context has finished. - """ + """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 401d57729377..2e3b311c4a08 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -275,7 +275,9 @@ def complete_sso_login( redirect them directly if whitelisted). """ self._auth_handler._complete_sso_login( - registered_user_id, request, client_redirect_url, + registered_user_id, + request, + client_redirect_url, ) async def complete_sso_login_async( @@ -352,7 +354,10 @@ async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBa event, _, ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event( - requester, event_dict, ratelimit=False, ignore_shadow_ban=True, + requester, + event_dict, + ratelimit=False, + ignore_shadow_ban=True, ) return event diff --git a/synapse/notifier.py b/synapse/notifier.py index 0745899b480b..1374aae49051 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -75,7 +75,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int: class _NotificationListener: - """ This represents a single client connection to the events stream. + """This represents a single client connection to the events stream. The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. """ @@ -119,7 +119,10 @@ def __init__( self.notify_deferred = ObservableDeferred(defer.Deferred()) def notify( - self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, + self, + stream_key: str, + stream_id: Union[int, RoomStreamToken], + time_now_ms: int, ): """Notify any listeners for this user of a new event from an event source. @@ -140,7 +143,7 @@ def notify( noify_deferred.callback(self.current_token) def remove(self, notifier: "Notifier"): - """ Remove this listener from all the indexes in the Notifier + """Remove this listener from all the indexes in the Notifier it knows about. """ @@ -186,7 +189,7 @@ class _PendingRoomEventEntry: class Notifier: - """ This class is responsible for notifying any listeners when there are + """This class is responsible for notifying any listeners when there are new events available for it. Primarily used from the /events stream. @@ -265,8 +268,7 @@ def on_new_room_event( max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): - """Unwraps event and calls `on_new_room_event_args`. - """ + """Unwraps event and calls `on_new_room_event_args`.""" self.on_new_room_event_args( event_pos=event_pos, room_id=event.room_id, @@ -341,7 +343,10 @@ def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken if users or rooms: self.on_new_event( - "room_key", max_room_stream_token, users=users, rooms=rooms, + "room_key", + max_room_stream_token, + users=users, + rooms=rooms, ) self._on_updated_room_token(max_room_stream_token) @@ -392,7 +397,7 @@ def on_new_event( users: Collection[Union[str, UserID]] = [], rooms: Collection[str] = [], ): - """ Used to inform listeners that something has happened event wise. + """Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. """ @@ -418,7 +423,9 @@ def on_new_event( # Notify appservices self._notify_app_services_ephemeral( - stream_key, new_token, users, + stream_key, + new_token, + users, ) def on_new_replication_data(self) -> None: @@ -502,7 +509,7 @@ async def get_events_for( is_guest: bool = False, explicit_room_id: str = None, ) -> EventStreamResult: - """ For the given user and rooms, return any new events for them. If + """For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -651,8 +658,7 @@ def notify_replication(self) -> None: cb() def notify_remote_server_up(self, server: str): - """Notify any replication that a remote server has come back up - """ + """Notify any replication that a remote server has come back up""" # We call federation_sender directly rather than registering as a # callback as a) we already have a reference to it and b) it introduces # circular dependencies. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 6317f22d3cea..c016a83909cd 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -144,8 +144,7 @@ async def _get_rules_for_event( @lru_cache() def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": - """Get the current RulesForRoom object for the given room id - """ + """Get the current RulesForRoom object for the given room id""" # It's important that RulesForRoom gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) @@ -252,7 +251,9 @@ async def action_for_event_by_user( # notified for this event. (This will then get handled when we persist # the event) await self.store.add_push_actions_to_staging( - event.event_id, actions_by_user, count_as_unread, + event.event_id, + actions_by_user, + count_as_unread, ) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 4ac1b3174832..5fec2aaf5dc8 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -116,8 +116,7 @@ def _pause_processing(self) -> None: self._is_processing = True def _resume_processing(self) -> None: - """Used by tests to resume processing of events after pausing. - """ + """Used by tests to resume processing of events after pausing.""" assert self._is_processing self._is_processing = False self._start_processing() @@ -157,8 +156,10 @@ async def _unsafe_process(self) -> None: being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( - self.user_id, start, self.max_stream_ordering + unprocessed = ( + await self.store.get_unread_push_actions_for_user_in_range_for_email( + self.user_id, start, self.max_stream_ordering + ) ) soonest_due_at = None # type: Optional[int] @@ -222,12 +223,14 @@ async def save_last_stream_ordering_and_success( self, last_stream_ordering: int ) -> None: self.last_stream_ordering = last_stream_ordering - pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), + pusher_still_exists = ( + await self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), + ) ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -298,7 +301,8 @@ async def sent_notif_update_throttle( current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) self.throttle_params[room_id] = ThrottleParams( - self.clock.time_msec(), new_throttle_ms, + self.clock.time_msec(), + new_throttle_ms, ) assert self.pusher_id is not None await self.store.set_throttle_params( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e048b0d59ee1..b9d3da2e0a56 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -176,8 +176,10 @@ async def _unsafe_process(self) -> None: Never call this directly: use _process which will only allow this to run once per pusher. """ - unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( - self.user_id, self.last_stream_ordering, self.max_stream_ordering + unprocessed = ( + await self.store.get_unread_push_actions_for_user_in_range_for_http( + self.user_id, self.last_stream_ordering, self.max_stream_ordering + ) ) logger.info( @@ -204,12 +206,14 @@ async def _unsafe_process(self) -> None: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), + pusher_still_exists = ( + await self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), + ) ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -290,7 +294,8 @@ async def _process_one(self, push_action: dict) -> bool: # for sanity, we only remove the pushkey if it # was the one we actually sent... logger.warning( - ("Ignoring rejected pushkey %s because we didn't send it"), pk, + ("Ignoring rejected pushkey %s because we didn't send it"), + pk, ) else: logger.info("Pushkey %s was rejected: removing", pk) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index eed16dbfb5c1..ae1145be0e42 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -78,8 +78,7 @@ def __init__(self, hs: "HomeServer"): self.pushers = {} # type: Dict[str, Dict[str, Pusher]] def start(self) -> None: - """Starts the pushers off in a background process. - """ + """Starts the pushers off in a background process.""" if not self._should_start_pushers: logger.info("Not starting pushers because they are disabled in the config") return @@ -297,8 +296,7 @@ async def start_pusher_by_id( return pusher async def _start_pushers(self) -> None: - """Start all the pushers - """ + """Start all the pushers""" pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the @@ -335,7 +333,8 @@ async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: return None except Exception: logger.exception( - "Couldn't start pusher id %i: caught Exception", pusher_config.id, + "Couldn't start pusher id %i: caught Exception", + pusher_config.id, ) return None diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 288727a56605..8a3f113e7640 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -273,7 +273,10 @@ def register(self, http_server): pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], self._check_auth_and_handle, self.__class__.__name__, + method, + [pattern], + self._check_auth_and_handle, + self.__class__.__name__, ) def _check_auth_and_handle(self, request, **kwargs): diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 52d32528ee17..60899b6ad622 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -175,7 +175,11 @@ async def _serialize_payload(user_id, room_id, tag): return {} async def _handle_request(self, request, user_id, room_id, tag): - max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + max_stream_id = await self.handler.remove_tag_from_room( + user_id, + room_id, + tag, + ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 84e002f93448..439881be6716 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -160,7 +160,10 @@ async def _handle_request( # type: ignore # hopefully we're now on the master, so this won't recurse! event_id, stream_id = await self.member_handler.remote_reject_invite( - invite_event_id, txn_id, requester, event_content, + invite_event_id, + txn_id, + requester, + event_content, ) return 200, {"event_id": event_id, "stream_id": stream_id} diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 7b12ec906025..d005f3876717 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -22,8 +22,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): - """Register a new user - """ + """Register a new user""" NAME = "register_user" PATH_ARGS = ("user_id",) @@ -97,8 +96,7 @@ async def _handle_request(self, request, user_id): class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): - """Run any post registration actions - """ + """Run any post registration actions""" NAME = "post_register" PATH_ARGS = ("user_id",) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index ac532ed5887f..0a9da79c32a3 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitrary (often timestamp) - """ + """Sent by either side as a keep alive. The data is arbitrary (often timestamp)""" NAME = "PING" diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index 34fa3ff5b3c4..d89a36f25a59 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -60,8 +60,7 @@ def is_enabled(self) -> bool: return self._redis_connection is not None async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: - """Add the key/value to the named cache, with the expiry time given. - """ + """Add the key/value to the named cache, with the expiry time given.""" if self._redis_connection is None: return @@ -76,13 +75,14 @@ async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> No return await make_deferred_yieldable( self._redis_connection.set( - self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + self._get_redis_key(cache_name, key), + encoded_value, + pexpire=expiry_ms, ) ) async def get(self, cache_name: str, key: str) -> Optional[Any]: - """Look up a key/value in the named cache. - """ + """Look up a key/value in the named cache.""" if self._redis_connection is None: return None diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 8ea8dcd587c5..d1d00c371769 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -303,7 +303,9 @@ def start_replication(self, hs): hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + hs.config.redis.redis_host, + hs.config.redis.redis_port, + self._factory, ) else: client_name = hs.get_instance_name() @@ -313,13 +315,11 @@ def start_replication(self, hs): hs.get_reactor().connectTCP(host, port, self._factory) def get_streams(self) -> Dict[str, Stream]: - """Get a map from stream name to all streams. - """ + """Get a map from stream name to all streams.""" return self._streams def get_streams_to_replicate(self) -> List[Stream]: - """Get a list of streams that this instances replicates. - """ + """Get a list of streams that this instances replicates.""" return self._streams_to_replicate def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): @@ -340,7 +340,10 @@ def send_positions_to_connection(self, conn: AbstractConnection): current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand( - stream.NAME, self._instance_name, current_token, current_token, + stream.NAME, + self._instance_name, + current_token, + current_token, ) ) @@ -592,8 +595,7 @@ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpComma self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): - """Called when we have a new connection. - """ + """Called when we have a new connection.""" self._connections.append(connection) # If we are connected to replication as a client (rather than a server) @@ -620,8 +622,7 @@ def new_connection(self, connection: AbstractConnection): ) def lost_connection(self, connection: AbstractConnection): - """Called when a connection is closed/lost. - """ + """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) if streams: @@ -678,15 +679,13 @@ def send_federation_ack(self, token: int): def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int ): - """Poke the master that a user has started/stopped syncing. - """ + """Poke the master that a user has started/stopped syncing.""" self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) ) def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): - """Poke the master to remove a pusher for a user - """ + """Poke the master to remove a pusher for a user""" cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) @@ -699,8 +698,7 @@ def send_user_ip( device_id: str, last_seen: int, ): - """Tell the master that the user made a request. - """ + """Tell the master that the user made a request.""" cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 804da994eab3..e0b4ad314dfe 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -222,8 +222,7 @@ def send_ping(self): self.send_error("ping timeout") def lineReceived(self, line: bytes): - """Called when we've received a line - """ + """Called when we've received a line""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_line(line) @@ -299,8 +298,7 @@ def close(self): self.on_connection_closed() def send_error(self, error_string, *args): - """Send an error to remote and close the connection. - """ + """Send an error to remote and close the connection.""" self.send_command(ErrorCommand(error_string % args)) self.close() @@ -341,8 +339,7 @@ def send_command(self, cmd, do_buffer=True): self.last_sent_command = self.clock.time_msec() def _queue_command(self, cmd): - """Queue the command until the connection is ready to write to again. - """ + """Queue the command until the connection is ready to write to again.""" logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) @@ -355,8 +352,7 @@ def _queue_command(self, cmd): self.close() def _send_pending_commands(self): - """Send any queued commandes - """ + """Send any queued commandes""" pending = self.pending_commands self.pending_commands = [] for cmd in pending: @@ -380,8 +376,7 @@ def pauseProducing(self): self.state = ConnectionStates.PAUSED def resumeProducing(self): - """The remote has caught up after we started buffering! - """ + """The remote has caught up after we started buffering!""" logger.info("[%s] Resume producing", self.id()) self.state = ConnectionStates.ESTABLISHED self._send_pending_commands() @@ -440,8 +435,7 @@ def id(self): return "%s-%s" % (self.name, self.conn_id) def lineLengthExceeded(self, line): - """Called when we receive a line that is above the maximum line length - """ + """Called when we receive a line that is above the maximum line length""" self.send_error("Line length exceeded") @@ -495,21 +489,18 @@ def on_SERVER(self, cmd): self.send_error("Wrong remote") def replicate(self): - """Send the subscription request to the server - """ + """Send the subscription request to the server""" logger.info("[%s] Subscribing to replication streams", self.id()) self.send_command(ReplicateCommand()) class AbstractConnection(abc.ABC): - """An interface for replication connections. - """ + """An interface for replication connections.""" @abc.abstractmethod def send_command(self, cmd: Command): - """Send the command down the connection - """ + """Send the command down the connection""" pass diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 89f8af0f364b..0e6155cf530a 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -123,8 +123,7 @@ async def _send_subscribe(self): self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str): - """Received a message from redis. - """ + """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) @@ -137,7 +136,8 @@ def _parse_and_dispatch_message(self, message: str): cmd = parse_command_from_line(message) except Exception: logger.exception( - "Failed to parse replication line: %r", message, + "Failed to parse replication line: %r", + message, ) return diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 1d4ceac0f123..2018f9f29ed5 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -36,8 +36,7 @@ class ReplicationStreamProtocolFactory(Factory): - """Factory for new replication connections. - """ + """Factory for new replication connections.""" def __init__(self, hs): self.command_handler = hs.get_tcp_replication() @@ -181,7 +180,8 @@ async def _run_notifier_loop(self): raise logger.debug( - "Sending %d updates", len(updates), + "Sending %d updates", + len(updates), ) if updates: diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 61b282ab2dab..38809b5b7c7c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -183,7 +183,10 @@ async def get_updates_since( return [], upto_token, False updates, upto_token, limited = await self.update_function( - instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, + from_token, + upto_token, + _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited @@ -339,8 +342,7 @@ def __init__(self, hs): class PushRulesStream(Stream): - """A user has changed their push rules - """ + """A user has changed their push rules""" PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str @@ -362,8 +364,7 @@ def _current_token(self, instance_name: str) -> int: class PushersStream(Stream): - """A user has added/changed/removed a pusher - """ + """A user has added/changed/removed a pusher""" PushersStreamRow = namedtuple( "PushersStreamRow", @@ -416,8 +417,7 @@ def __init__(self, hs): class PublicRoomsStream(Stream): - """The public rooms list changed - """ + """The public rooms list changed""" PublicRoomsStreamRow = namedtuple( "PublicRoomsStreamRow", @@ -463,8 +463,7 @@ def __init__(self, hs): class ToDeviceStream(Stream): - """New to_device messages for a client - """ + """New to_device messages for a client""" ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str @@ -481,8 +480,7 @@ def __init__(self, hs): class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room - """ + """Someone added/removed a tag for a room""" TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict @@ -501,8 +499,7 @@ def __init__(self, hs): class AccountDataStream(Stream): - """Global or per room account data was changed - """ + """Global or per room account data was changed""" AccountDataStreamRow = namedtuple( "AccountDataStream", @@ -589,8 +586,7 @@ def __init__(self, hs): class UserSignatureStream(Stream): - """A user has signed their own device with their user-signing key - """ + """A user has signed their own device with their user-signing key""" UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 86a62b71eb87..fa5e37ba7bc5 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -113,8 +113,7 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): class EventsStream(Stream): - """We received a new event, or an event went from being an outlier to not - """ + """We received a new event, or an event went from being an outlier to not""" NAME = "events" diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index d0c86b204a1c..ebc587aa0603 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -22,8 +22,7 @@ class DeleteGroupAdminRestServlet(RestServlet): - """Allows deleting of local groups - """ + """Allows deleting of local groups""" PATTERNS = admin_patterns("/delete_group/(?P[^/]*)") diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 8720b1401fbc..b996862c0525 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -119,8 +119,7 @@ async def on_POST( class ProtectMediaByID(RestServlet): - """Protect local media from being quarantined. - """ + """Protect local media from being quarantined.""" PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") @@ -141,8 +140,7 @@ async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict] class ListMediaInRoom(RestServlet): - """Lists all of the media in a given room. - """ + """Lists all of the media in a given room.""" PATTERNS = admin_patterns("/room/(?P[^/]+)/media") @@ -180,8 +178,7 @@ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: class DeleteMediaByID(RestServlet): - """Delete local media by a given ID. Removes it from this server. - """ + """Delete local media by a given ID. Removes it from this server.""" PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 491f9ca09578..1a3a36f6cf0c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -482,7 +482,8 @@ async def on_POST(self, request, room_identifier): if not admin_user_id: raise SynapseError( - 400, "No local admin user in room", + 400, + "No local admin user in room", ) pl_content = power_levels.content @@ -492,7 +493,8 @@ async def on_POST(self, request, room_identifier): admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - 400, "No local admin user in room", + 400, + "No local admin user in room", ) # Grant the user power equal to the room admin by attempting to send an @@ -502,7 +504,8 @@ async def on_POST(self, request, room_identifier): new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] fake_requester = create_requester( - admin_user_id, authenticated_entity=requester.authenticated_entity, + admin_user_id, + authenticated_entity=requester.authenticated_entity, ) try: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 9350c704b981..998a0ef671cc 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -579,7 +579,7 @@ class ResetPasswordRestServlet(RestServlet): } Returns: 200 OK with empty object if success otherwise an error. - """ + """ PATTERNS = admin_patterns("/reset_password/(?P[^/]*)") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0fb9419e58ad..6e2fbedd99bf 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -310,7 +310,9 @@ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: except jwt.PyJWTError as e: # A JWT error occurred, return some info back to the client. raise LoginError( - 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, + 403, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.FORBIDDEN, ) user = payload.get("sub", None) @@ -375,7 +377,9 @@ async def on_GET( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url, idp_id, + request, + client_redirect_url, + idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 85a66458c5bb..717c5f2b108a 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -60,7 +60,9 @@ async def on_PUT(self, request, user_id): new_name = content["displayname"] except Exception: raise SynapseError( - code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON, + code=400, + msg="Unable to parse name", + errcode=Codes.BAD_JSON, ) await self.profile_handler.set_displayname(user, requester, new_name, is_admin) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 89823fcc39a5..0c148a213db4 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -159,7 +159,9 @@ async def on_GET(self, request): self.notifier.on_new_replication_data() respond_with_html_bytes( - request, 200, PushersRemoveRestServlet.SUCCESS_HTML, + request, + 200, + PushersRemoveRestServlet.SUCCESS_HTML, ) return None diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 90fd98c53e2f..9a1df30c2999 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -362,7 +362,9 @@ async def on_GET(self, request): parse_and_validate_server_name(server) except ValueError: raise SynapseError( - 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, ) try: @@ -413,7 +415,9 @@ async def on_POST(self, request): parse_and_validate_server_name(server) except ValueError: raise SynapseError( - 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, ) try: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a84a2fb3855d..adf1d397282f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -193,7 +193,10 @@ async def on_POST(self, request): requester = await self.auth.get_user_by_req(request) try: params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "modify your account password", + requester, + request, + body, + "modify your account password", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth, but @@ -312,7 +315,10 @@ async def on_POST(self, request): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "deactivate your account", + requester, + request, + body, + "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), @@ -703,7 +709,10 @@ async def on_POST(self, request): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "add a third-party identifier to your account", + requester, + request, + body, + "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 314e01dfe443..3d07aadd39ba 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -83,7 +83,10 @@ async def on_POST(self, request): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "remove device(s) from your account", + requester, + request, + body, + "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -129,7 +132,10 @@ async def on_DELETE(self, request, device_id): raise await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "remove a device from your account", + requester, + request, + body, + "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) @@ -206,7 +212,9 @@ async def on_PUT(self, request: SynapseRequest): if "device_data" not in submission: raise errors.SynapseError( - 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM, + 400, + "device_data missing", + errcode=errors.Codes.MISSING_PARAM, ) elif not isinstance(submission["device_data"], dict): raise errors.SynapseError( @@ -259,11 +267,15 @@ async def on_POST(self, request: SynapseRequest): if "device_id" not in submission: raise errors.SynapseError( - 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM, + 400, + "device_id missing", + errcode=errors.Codes.MISSING_PARAM, ) elif not isinstance(submission["device_id"], str): raise errors.SynapseError( - 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM, + 400, + "device_id must be a string", + errcode=errors.Codes.INVALID_PARAM, ) result = await self.device_handler.rehydrate_device( diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 4fe712b30cc9..7cbfae84266e 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -54,8 +54,7 @@ def wrapper(self, request: Request, group_id: str, *args, **kwargs): class GroupServlet(RestServlet): - """Get the group profile - """ + """Get the group profile""" PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") @@ -94,8 +93,7 @@ async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict] class GroupSummaryServlet(RestServlet): - """Get the full group summary - """ + """Get the full group summary""" PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") @@ -172,8 +170,7 @@ async def on_DELETE( class GroupCategoryServlet(RestServlet): - """Get/add/update/delete a group category - """ + """Get/add/update/delete a group category""" PATTERNS = client_patterns( "/groups/(?P[^/]*)/categories/(?P[^/]+)$" @@ -229,8 +226,7 @@ async def on_DELETE( class GroupCategoriesServlet(RestServlet): - """Get all group categories - """ + """Get all group categories""" PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") @@ -253,8 +249,7 @@ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupRoleServlet(RestServlet): - """Get/add/update/delete a group role - """ + """Get/add/update/delete a group role""" PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") @@ -308,8 +303,7 @@ async def on_DELETE( class GroupRolesServlet(RestServlet): - """Get all group roles - """ + """Get all group roles""" PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") @@ -386,8 +380,7 @@ async def on_DELETE( class GroupRoomServlet(RestServlet): - """Get all rooms in a group - """ + """Get all rooms in a group""" PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") @@ -410,8 +403,7 @@ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupUsersServlet(RestServlet): - """Get all users in a group - """ + """Get all users in a group""" PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") @@ -434,8 +426,7 @@ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupInvitedUsersServlet(RestServlet): - """Get users invited to a group - """ + """Get users invited to a group""" PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") @@ -458,8 +449,7 @@ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupSettingJoinPolicyServlet(RestServlet): - """Set group join policy - """ + """Set group join policy""" PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") @@ -484,8 +474,7 @@ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupCreateServlet(RestServlet): - """Create a group - """ + """Create a group""" PATTERNS = client_patterns("/create_group$") @@ -514,8 +503,7 @@ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: class GroupAdminRoomsServlet(RestServlet): - """Add a room to the group - """ + """Add a room to the group""" PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" @@ -558,8 +546,7 @@ async def on_DELETE( class GroupAdminRoomsConfigServlet(RestServlet): - """Update the config of a room in a group - """ + """Update the config of a room in a group""" PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" @@ -589,8 +576,7 @@ async def on_PUT( class GroupAdminUsersInviteServlet(RestServlet): - """Invite a user to the group - """ + """Invite a user to the group""" PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" @@ -620,8 +606,7 @@ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDi class GroupAdminUsersKickServlet(RestServlet): - """Kick a user from the group - """ + """Kick a user from the group""" PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" @@ -648,8 +633,7 @@ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDi class GroupSelfLeaveServlet(RestServlet): - """Leave a joined group - """ + """Leave a joined group""" PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") @@ -674,8 +658,7 @@ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupSelfJoinServlet(RestServlet): - """Attempt to join a group, or knock - """ + """Attempt to join a group, or knock""" PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") @@ -700,8 +683,7 @@ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupSelfAcceptInviteServlet(RestServlet): - """Accept a group invite - """ + """Accept a group invite""" PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") @@ -726,8 +708,7 @@ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class GroupSelfUpdatePublicityServlet(RestServlet): - """Update whether we publicise a users membership of a group - """ + """Update whether we publicise a users membership of a group""" PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") @@ -750,8 +731,7 @@ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: class PublicisedGroupsForUserServlet(RestServlet): - """Get the list of groups a user is advertising - """ + """Get the list of groups a user is advertising""" PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") @@ -771,8 +751,7 @@ async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: class PublicisedGroupsForUsersServlet(RestServlet): - """Get the list of groups a user is advertising - """ + """Get the list of groups a user is advertising""" PATTERNS = client_patterns("/publicised_groups$") @@ -795,8 +774,7 @@ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: class GroupsForUserServlet(RestServlet): - """Get all groups the logged in user is joined to - """ + """Get all groups the logged in user is joined to""" PATTERNS = client_patterns("/joined_groups$") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index a6134ead8a70..f092e5b3a2de 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -271,7 +271,10 @@ async def on_POST(self, request): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "add a device signing key to your account", + requester, + request, + body, + "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e3d322f2ac3a..8f68d8dfc8fa 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -522,7 +522,10 @@ async def on_POST(self, request): # not this will raise a user-interactive auth error. try: auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, request, body, "register a new account", + self._registration_flows, + request, + body, + "register a new account", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth. @@ -665,7 +668,9 @@ async def _do_appservice_registration(self, username, as_token, body): username, as_token ) return await self._create_registration_details( - user_id, body, is_appservice_ghost=True, + user_id, + body, + is_appservice_ghost=True, ) async def _create_registration_details( diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 18c75738f87d..fe765da23c5b 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -244,7 +244,9 @@ async def on_GET( requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True, + room_id, + requester.user.to_string(), + allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to @@ -322,7 +324,9 @@ async def on_GET(self, request, room_id, parent_id, relation_type, event_type, k requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True, + room_id, + requester.user.to_string(), + allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 3ed219ae43ff..48f4433155fb 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -51,7 +51,8 @@ async def _async_render_GET(self, request: Request) -> None: b" object-src 'self';", ) request.setHeader( - b"Referrer-Policy", b"no-referrer", + b"Referrer-Policy", + b"no-referrer", ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 635bccf77565..a0162d4255a3 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -325,7 +325,10 @@ async def _get_remote_media_impl( # Failed to find the file anywhere, lets download it. try: - media_info = await self._download_remote_file(server_name, media_id,) + media_info = await self._download_remote_file( + server_name, + media_id, + ) except SynapseError: raise except Exception as e: @@ -351,7 +354,11 @@ async def _get_remote_media_impl( responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: + async def _download_remote_file( + self, + server_name: str, + media_id: str, + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -773,7 +780,11 @@ async def _generate_thumbnails( ) except Exception as e: thumbnail_exists = await self.store.get_remote_media_thumbnail( - server_name, media_id, t_width, t_height, t_type, + server_name, + media_id, + t_width, + t_height, + t_type, ) if not thumbnail_exists: raise e @@ -832,7 +843,10 @@ async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: return await self._remove_local_media_from_disk([media_id]) async def delete_old_local_media( - self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True, + self, + before_ts: int, + size_gt: int = 0, + keep_profiles: bool = True, ) -> Tuple[List[str], int]: """ Delete local or remote media from this server by size and timestamp. Removes @@ -849,7 +863,9 @@ async def delete_old_local_media( A tuple of (list of deleted media IDs, total deleted media IDs). """ old_media = await self.store.get_local_media_before( - before_ts, size_gt, keep_profiles, + before_ts, + size_gt, + keep_profiles, ) return await self._remove_local_media_from_disk(old_media) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index aba6d689a800..1057e638beba 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -85,8 +85,7 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str: return fname async def write_to_file(self, source: IO, output: IO): - """Asynchronously write the `source` to `output`. - """ + """Asynchronously write the `source` to `output`.""" await defer_to_thread(self.reactor, _write_file_synchronously, source, output) @contextlib.contextmanager @@ -342,8 +341,7 @@ class ReadableFileWrapper: path = attr.ib(type=str) async def write_chunks_to(self, callback: Callable[[bytes], None]): - """Reads the file in chunks and calls the callback with each chunk. - """ + """Reads the file in chunks and calls the callback with each chunk.""" with open(self.path, "rb") as file: while True: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ae53b1d23ff7..6104ef4e46f6 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -580,8 +580,7 @@ def _start_expire_url_cache_data(self): ) async def _expire_url_cache_data(self) -> None: - """Clean up expired url cache content, media and thumbnails. - """ + """Clean up expired url cache content, media and thumbnails.""" # TODO: Delete from backup media store assert self._worker_run_media_background_jobs diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 8dd01fce769d..665245134614 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -28,7 +28,7 @@ class ResourceLimitsServerNotices: - """ Keeps track of whether the server has reached it's resource limit and + """Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 28544ccb9282..c3d6e80c49f7 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -398,7 +398,7 @@ async def compute_event_context( async def resolve_state_groups_for_events( self, room_id: str, event_ids: Iterable[str] ) -> _StateCacheEntry: - """ Given a list of event_ids this method fetches the state at each + """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: @@ -570,7 +570,9 @@ async def resolve_state_groups( return cache logger.info( - "Resolving state for %s with groups %s", room_id, list(group_names), + "Resolving state for %s with groups %s", + room_id, + list(group_names), ) state_groups_histogram.observe(len(state_groups_ids)) @@ -656,11 +658,15 @@ def _report_metrics(self): return self._report_biggest( - lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter, + lambda i: i.cpu_time, + "CPU time", + _biggest_room_by_cpu_counter, ) self._report_biggest( - lambda i: i.db_time, "DB time", _biggest_room_by_db_counter, + lambda i: i.db_time, + "DB time", + _biggest_room_by_db_counter, ) self._state_res_metrics.clear() diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 85edae053dfe..ce255da6fd07 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -95,7 +95,11 @@ async def resolve_events_with_store( if event.room_id != room_id: raise Exception( "Attempting to state-resolve for room %s with event %s which is in %s" - % (room_id, event.event_id, event.room_id,) + % ( + room_id, + event.event_id, + event.room_id, + ) ) # get the ids of the auth events which allow us to authenticate the @@ -119,7 +123,11 @@ async def resolve_events_with_store( if event.room_id != room_id: raise Exception( "Attempting to state-resolve for room %s with event %s which is in %s" - % (room_id, event.event_id, event.room_id,) + % ( + room_id, + event.event_id, + event.room_id, + ) ) state_map.update(state_map_new) @@ -243,7 +251,7 @@ def _resolve_with_state( def _resolve_state_events( conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] ) -> StateMap[EventBase]: - """ This is where we actually decide which of the conflicted state to + """This is where we actually decide which of the conflicted state to use. We resolve conflicts in the following order: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index e585954bd883..e73a548ee412 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -118,7 +118,11 @@ async def resolve_events_with_store( if event.room_id != room_id: raise Exception( "Attempting to state-resolve for room %s with event %s which is in %s" - % (room_id, event.event_id, event.room_id,) + % ( + room_id, + event.event_id, + event.room_id, + ) ) full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c0d9d1240f16..a3c52695e984 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -43,8 +43,7 @@ class Storage: - """The high level interfaces for talking to various storage layers. - """ + """The high level interfaces for talking to various storage layers.""" def __init__(self, hs: "HomeServer", stores: Databases): # We include the main data store here mainly so that we don't have to diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 29b8ca676a9b..329660cf0faf 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -77,7 +77,7 @@ def total_items_per_ms(self) -> Optional[float]: class BackgroundUpdater: - """ Background updates are updates to the database that run in the + """Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to process and autotuning the batch size. @@ -158,8 +158,7 @@ async def has_completed_background_updates(self) -> bool: return False async def has_completed_background_update(self, update_name: str) -> bool: - """Check if the given background update has finished running. - """ + """Check if the given background update has finished running.""" if self._all_done: return True @@ -198,7 +197,8 @@ def get_background_updates_txn(txn): if not self._current_background_update: all_pending_updates = await self.db_pool.runInteraction( - "background_updates", get_background_updates_txn, + "background_updates", + get_background_updates_txn, ) if not all_pending_updates: # no work left to do diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ae4bf1a54fcf..464692644900 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -85,8 +85,7 @@ def make_pool( reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine ) -> adbapi.ConnectionPool: - """Get the connection pool for the database. - """ + """Get the connection pool for the database.""" # By default enable `cp_reconnect`. We need to fiddle with db_args in case # someone has explicitly set `cp_reconnect`. @@ -432,8 +431,7 @@ def get_chain_id_txn(txn): ) def is_running(self) -> bool: - """Is the database pool currently running - """ + """Is the database pool currently running""" return self._db_pool.running async def _check_safe_to_upsert(self) -> None: @@ -546,7 +544,11 @@ def new_transaction( # This can happen if the database disappears mid # transaction. transaction_logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, + "[TXN OPERROR] {%s} %s %d/%d", + name, + e, + i, + N, ) if i < N: i += 1 @@ -567,7 +569,9 @@ def new_transaction( conn.rollback() except self.engine.module.Error as e1: transaction_logger.warning( - "[TXN EROLL] {%s} %s", name, e1, + "[TXN EROLL] {%s} %s", + name, + e1, ) continue raise @@ -1406,7 +1410,10 @@ def simple_select_one_onecol_txn( @staticmethod def simple_select_onecol_txn( - txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: str, ) -> List[Any]: sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} @@ -1716,7 +1723,11 @@ async def simple_delete_one( desc: description of the transaction, for logging and metrics """ await self.runInteraction( - desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True, + desc, + self.simple_delete_one_txn, + table, + keyvalues, + db_autocommit=True, ) @staticmethod diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 0c243250117e..e84f8b42f734 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -56,7 +56,10 @@ def __init__(self, main_store_class, hs): database_config.databases, ) prepare_database( - db_conn, engine, hs.config, databases=database_config.databases, + db_conn, + engine, + hs.config, + databases=database_config.databases, ) database = DatabasePool(hs, database_config, engine) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index e550cbc86690..03a38422a113 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -73,8 +73,7 @@ def get_app_services(self): return self.services_cache def get_if_app_services_interested_in_user(self, user_id: str) -> bool: - """Check if the user is one associated with an app service (exclusively) - """ + """Check if the user is one associated with an app service (exclusively)""" if self.exclusive_user_regex: return bool(self.exclusive_user_regex.match(user_id)) else: diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index ea1e8fb5808b..6d18e692b0a9 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -280,8 +280,7 @@ def remove(txn): return batch_size async def _devices_last_seen_update(self, progress, batch_size): - """Background update to insert last seen info into devices table - """ + """Background update to insert last seen info into devices table""" last_user_id = progress.get("last_user_id", "") last_device_id = progress.get("last_device_id", "") @@ -363,8 +362,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ + """Removes entries in user IPs older than the configured period.""" if self.user_ips_max_age is None: # Nothing to do @@ -565,7 +563,11 @@ async def get_user_ip_and_agents( results = {} for key in self._batch_row_update: - uid, access_token, ip, = key + ( + uid, + access_token, + ip, + ) = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 659d8f245fe5..d327e9aa0b8c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -315,7 +315,8 @@ async def _get_device_update_edus_by_remote( # make sure we go through the devices in stream order device_ids = sorted( - user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + user_devices.keys(), + key=lambda i: query_map[(user_id, i)][0], ) for device_id in device_ids: @@ -366,8 +367,7 @@ def f(txn): async def mark_as_sent_devices_by_remote( self, destination: str, stream_id: int ) -> None: - """Mark that updates have successfully been sent to the destination. - """ + """Mark that updates have successfully been sent to the destination.""" await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, @@ -681,7 +681,8 @@ async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): return results async def get_user_ids_requiring_device_list_resync( - self, user_ids: Optional[Collection[str]] = None, + self, + user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we should resync the device lists for. If None is given instead of a list, @@ -721,8 +722,7 @@ async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: ) async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: - """Mark that we no longer track device lists for remote user. - """ + """Mark that we no longer track device lists for remote user.""" def _mark_remote_user_device_list_as_unsubscribed_txn(txn): self.db_pool.simple_delete_txn( @@ -902,7 +902,8 @@ def _prune_txn(txn): logger.info("Pruned %d device list outbound pokes", count) await self.db_pool.runInteraction( - "_prune_old_outbound_device_pokes", _prune_txn, + "_prune_old_outbound_device_pokes", + _prune_txn, ) @@ -943,7 +944,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): # clear out duplicate device list outbound pokes self.db_pool.updates.register_background_update_handler( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, + self._remove_duplicate_outbound_pokes, ) # a pair of background updates that were added during the 1.14 release cycle, @@ -1004,17 +1006,23 @@ def _txn(txn): row = None for row in rows: self.db_pool.simple_delete_txn( - txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + txn, + "device_lists_outbound_pokes", + {x: row[x] for x in KEY_COLS}, ) row["sent"] = False self.db_pool.simple_insert_txn( - txn, "device_lists_outbound_pokes", row, + txn, + "device_lists_outbound_pokes", + row, ) if row: self.db_pool.updates._background_update_progress_txn( - txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + txn, + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, + {"last_row": row}, ) return len(rows) @@ -1286,7 +1294,9 @@ def _update_remote_device_list_cache_txn( # we've done a full resync, so we remove the entry that says we need # to resync self.db_pool.simple_delete_txn( - txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, + txn, + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, ) async def add_device_change_to_streams( @@ -1336,7 +1346,9 @@ def _add_device_change_to_stream_txn( stream_ids: List[str], ): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], + self._device_list_stream_cache.entity_has_changed, + user_id, + stream_ids[-1], ) min_stream_id = stream_ids[0] diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index e5060d4c4685..267b948397be 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -85,7 +85,7 @@ async def create_room_alias_association( servers: Iterable[str], creator: Optional[str] = None, ) -> None: - """ Creates an association between a room alias and room_id/servers + """Creates an association between a room alias and room_id/servers Args: room_alias: The alias to create. @@ -160,7 +160,10 @@ def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: return room_id async def update_aliases_for_room( - self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + self, + old_room_id: str, + new_room_id: str, + creator: Optional[str] = None, ) -> None: """Repoint all of the aliases for a given room, to a different room. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 309f1e865b7c..f1e7859d26e7 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -361,7 +361,7 @@ def _add_e2e_one_time_keys(txn): async def count_e2e_one_time_keys( self, user_id: str, device_id: str ) -> Dict[str, int]: - """ Count the number of one time keys the server has for a device + """Count the number of one time keys the server has for a device Returns: A mapping from algorithm to number of keys for that algorithm. """ @@ -494,7 +494,9 @@ async def _get_bare_e2e_cross_signing_keys_bulk( ) def _get_bare_e2e_cross_signing_keys_bulk_txn( - self, txn: Connection, user_ids: List[str], + self, + txn: Connection, + user_ids: List[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if @@ -556,7 +558,10 @@ def _get_bare_e2e_cross_signing_keys_bulk_txn( return result def _get_e2e_cross_signing_signatures_txn( - self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + self, + txn: Connection, + keys: Dict[str, Dict[str, dict]], + from_user_id: str, ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing signatures made by a user on a set of keys. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ddfb13e3ad50..18ddb92fcca5 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -71,7 +71,9 @@ async def get_auth_chain( return await self.get_events_as_list(event_ids) async def get_auth_chain_ids( - self, event_ids: Collection[str], include_given: bool = False, + self, + event_ids: Collection[str], + include_given: bool = False, ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. @@ -273,7 +275,8 @@ def _get_auth_chain_difference_using_cover_index_txn( # origin chain. if origin_sequence_number <= chains.get(origin_chain_id, 0): chains[target_chain_id] = max( - target_sequence_number, chains.get(target_chain_id, 0), + target_sequence_number, + chains.get(target_chain_id, 0), ) seen_chains.add(target_chain_id) @@ -632,8 +635,7 @@ async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: ) async def get_min_depth(self, room_id: str) -> int: - """For the given room, get the minimum depth we have seen for it. - """ + """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -858,12 +860,13 @@ def _delete_old_forward_extrem_cache_txn(txn): ) await self.db_pool.runInteraction( - "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn, ) class EventFederationStore(EventFederationWorkerStore): - """ Responsible for storing and serving up the various graphs associated + """Responsible for storing and serving up the various graphs associated with an event. Including the main event graph and the auth chains for an event. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 438383abe17a..78245ad5bd30 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight): def _deserialize_action(actions, is_highlight): - """Custom deserializer for actions. This allows us to "compress" common actions - """ + """Custom deserializer for actions. This allows us to "compress" common actions""" if actions: return db_to_json(actions) @@ -91,7 +90,10 @@ def __init__(self, database: DatabasePool, db_conn, hs): @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( - self, room_id: str, user_id: str, last_read_event_id: Optional[str], + self, + room_id: str, + user_id: str, + last_read_event_id: Optional[str], ) -> Dict[str, int]: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -120,13 +122,19 @@ async def get_unread_event_push_actions_by_room_for_user( ) def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id, + self, + txn, + room_id, + user_id, + last_read_event_id, ): stream_ordering = None if last_read_event_id is not None: stream_ordering = self.get_stream_id_for_event_txn( - txn, last_read_event_id, allow_none=True, + txn, + last_read_event_id, + allow_none=True, ) if stream_ordering is None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7abfb9112e0b..287606cb4f07 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -399,7 +399,9 @@ def _persist_events_txn( self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) def _persist_event_auth_chain_txn( - self, txn: LoggingTransaction, events: List[EventBase], + self, + txn: LoggingTransaction, + events: List[EventBase], ) -> None: # We only care about state events, so this if there are no state events. @@ -470,7 +472,11 @@ def _persist_event_auth_chain_txn( event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} self._add_chain_cover_index( - txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + txn, + self.db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, ) @classmethod @@ -517,7 +523,10 @@ def _add_chain_cover_index( # simple_select_many, but this case happens rarely and almost always # with a single row.) auth_events = db_pool.simple_select_onecol_txn( - txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", + txn, + "event_auth", + keyvalues={"event_id": event_id}, + retcol="auth_id", ) events_to_calc_chain_id_for.add(event_id) @@ -550,7 +559,9 @@ def _add_chain_cover_index( WHERE """ clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", missing_auth_chains, + txn.database_engine, + "event_id", + missing_auth_chains, ) txn.execute(sql + clause, args) @@ -704,7 +715,8 @@ def _add_chain_cover_index( if chain_map[a_id][0] != chain_id } for start_auth_id, end_auth_id in itertools.permutations( - event_to_auth_chain.get(event_id, []), r=2, + event_to_auth_chain.get(event_id, []), + r=2, ): if chain_links.exists_path_from( chain_map[start_auth_id], chain_map[end_auth_id] @@ -888,8 +900,7 @@ def _persist_transaction_ids_txn( txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], ): - """Persist the mapping from transaction IDs to event IDs (if defined). - """ + """Persist the mapping from transaction IDs to event IDs (if defined).""" to_insert = [] for event, _ in events_and_contexts: @@ -909,7 +920,9 @@ def _persist_transaction_ids_txn( if to_insert: self.db_pool.simple_insert_many_txn( - txn, table="event_txn_id", values=to_insert, + txn, + table="event_txn_id", + values=to_insert, ) def _update_current_state_txn( @@ -941,7 +954,9 @@ def _update_current_state_txn( txn.execute(sql, (stream_id, self._instance_name, room_id)) self.db_pool.simple_delete_txn( - txn, table="current_state_events", keyvalues={"room_id": room_id}, + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, ) else: # We're still in the room, so we update the current state as normal. @@ -1608,8 +1623,7 @@ def _store_event_reference_hashes_txn(self, txn, events): ) def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ + """Store a room member in the database.""" def str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) else None @@ -2001,8 +2015,7 @@ def _update_backward_extremeties(self, txn, events): @attr.s(slots=True) class _LinkMap: - """A helper type for tracking links between chains. - """ + """A helper type for tracking links between chains.""" # Stores the set of links as nested maps: source chain ID -> target chain ID # -> source sequence number -> target sequence number. @@ -2108,7 +2121,9 @@ def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: yield (src_chain, src_seq, target_chain, target_seq) def exists_path_from( - self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], + self, + src_tuple: Tuple[int, int], + target_tuple: Tuple[int, int], ) -> bool: """Checks if there is a path between the source chain ID/sequence and target chain ID/sequence. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 5ca4fa681721..89274e75f778 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -32,8 +32,7 @@ @attr.s(slots=True, frozen=True) class _CalculateChainCover: - """Return value for _calculate_chain_cover_txn. - """ + """Return value for _calculate_chain_cover_txn.""" # The last room_id/depth/stream processed. room_id = attr.ib(type=str) @@ -127,11 +126,13 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) self.db_pool.updates.register_background_update_handler( - "rejected_events_metadata", self._rejected_events_metadata, + "rejected_events_metadata", + self._rejected_events_metadata, ) self.db_pool.updates.register_background_update_handler( - "chain_cover", self._chain_cover_index, + "chain_cover", + self._chain_cover_index, ) async def _background_reindex_fields_sender(self, progress, batch_size): @@ -462,8 +463,7 @@ def _drop_table_txn(txn): return num_handled async def _redactions_received_ts(self, progress, batch_size): - """Handles filling out the `received_ts` column in redactions. - """ + """Handles filling out the `received_ts` column in redactions.""" last_event_id = progress.get("last_event_id", "") def _redactions_received_ts_txn(txn): @@ -518,8 +518,7 @@ def _redactions_received_ts_txn(txn): return count async def _event_fix_redactions_bytes(self, progress, batch_size): - """Undoes hex encoded censored redacted event JSON. - """ + """Undoes hex encoded censored redacted event JSON.""" def _event_fix_redactions_bytes_txn(txn): # This update is quite fast due to new index. @@ -642,7 +641,13 @@ def get_rejected_events( LIMIT ? """ - txn.execute(sql, (last_event_id, batch_size,)) + txn.execute( + sql, + ( + last_event_id, + batch_size, + ), + ) return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore @@ -910,7 +915,11 @@ def _calculate_chain_cover_txn( # Annoyingly we need to gut wrench into the persit event store so that # we can reuse the function to calculate the chain cover for rooms. PersistEventsStore._add_chain_cover_index( - txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + txn, + self.db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, ) return _CalculateChainCover( diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 0ac1da9c3522..b3703ae161bd 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -71,7 +71,9 @@ def delete_forward_extremities_for_room_txn(txn): if txn.rowcount > 0: # Invalidate the cache self._invalidate_cache_and_stream( - txn, self.get_latest_event_ids_in_room, (room_id,), + txn, + self.get_latest_event_ids_in_room, + (room_id,), ) return txn.rowcount @@ -97,5 +99,6 @@ def get_forward_extremities_for_room_txn(txn): return self.db_pool.cursor_to_dict(txn) return await self.db_pool.runInteraction( - "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, + "get_forward_extremities_for_room", + get_forward_extremities_for_room_txn, ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 71d823be72f6..c8850a4707f6 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -120,7 +120,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): # SQLite). if hs.get_instance_name() in hs.config.worker.writers.events: self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + db_conn, + "events", + "stream_ordering", ) self._backfill_id_gen = StreamIdGenerator( db_conn, @@ -140,7 +142,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): if hs.config.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( - self._cleanup_old_transaction_ids, 5 * 60 * 1000, + self._cleanup_old_transaction_ids, + 5 * 60 * 1000, ) self._get_event_cache = LruCache( @@ -1325,8 +1328,7 @@ def get_deltas_for_stream_id_txn(txn, stream_id): return rows, to_token, True async def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ + """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @@ -1428,8 +1430,7 @@ async def get_already_persisted_events( @wrap_as_background_process("_cleanup_old_transaction_ids") async def _cleanup_old_transaction_ids(self): - """Cleans out transaction id mappings older than 24hrs. - """ + """Cleans out transaction id mappings older than 24hrs.""" def _cleanup_old_transaction_ids_txn(txn): sql = """ @@ -1440,5 +1441,6 @@ def _cleanup_old_transaction_ids_txn(txn): txn.execute(sql, (one_day_ago,)) return await self.db_pool.runInteraction( - "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, + "_cleanup_old_transaction_ids", + _cleanup_old_transaction_ids_txn, ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 721819196530..abc19f71acc4 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -123,7 +123,9 @@ def _get_rooms_in_group_txn(txn): ) async def get_rooms_for_summary_by_category( - self, group_id: str, include_private: bool = False, + self, + group_id: str, + include_private: bool = False, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request @@ -368,8 +370,7 @@ async def is_user_admin_in_group( async def is_user_invited_to_local_group( self, group_id: str, user_id: str ) -> Optional[bool]: - """Has the group server invited a user? - """ + """Has the group server invited a user?""" return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -427,8 +428,7 @@ def _get_users_membership_in_group_txn(txn): ) async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: - """Get all groups a user is publicising - """ + """Get all groups a user is publicising""" return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, @@ -437,8 +437,7 @@ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: ) async def get_attestations_need_renewals(self, valid_until_ms): - """Get all attestations that need to be renewed until givent time - """ + """Get all attestations that need to be renewed until givent time""" def _get_attestations_need_renewals_txn(txn): sql = """ @@ -781,8 +780,7 @@ async def upsert_group_category( profile: Optional[JsonDict], is_public: Optional[bool], ) -> None: - """Add/update room category for group - """ + """Add/update room category for group""" insertion_values = {} update_values = {"category_id": category_id} # This cannot be empty @@ -818,8 +816,7 @@ async def upsert_group_role( profile: Optional[JsonDict], is_public: Optional[bool], ) -> None: - """Add/remove user role - """ + """Add/remove user role""" insertion_values = {} update_values = {"role_id": role_id} # This cannot be empty @@ -1012,8 +1009,7 @@ async def remove_user_from_summary( ) async def add_group_invite(self, group_id: str, user_id: str) -> None: - """Record that the group server has invited a user - """ + """Record that the group server has invited a user""" await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, @@ -1156,8 +1152,7 @@ def _remove_room_from_group_txn(txn): async def update_group_publicity( self, group_id: str, user_id: str, publicise: bool ) -> None: - """Update whether the user is publicising their membership of the group - """ + """Update whether the user is publicising their membership of the group""" await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -1300,8 +1295,7 @@ async def update_group_profile(self, group_id, profile): async def update_attestation_renewal( self, group_id: str, user_id: str, attestation: dict ) -> None: - """Update an attestation that we have renewed - """ + """Update an attestation that we have renewed""" await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -1312,8 +1306,7 @@ async def update_attestation_renewal( async def update_remote_attestion( self, group_id: str, user_id: str, attestation: dict ) -> None: - """Update an attestation that a remote has renewed - """ + """Update an attestation that a remote has renewed""" await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index e97026dc2e30..d504323b0330 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -33,8 +33,7 @@ class KeyStore(SQLBaseStore): - """Persistence for signature verification keys - """ + """Persistence for signature verification keys""" @cached() def _get_server_verify_key(self, server_name_and_key_id): diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index e017177655fe..a0313c3ccff4 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -169,7 +169,10 @@ def get_local_media_by_user_paginate_txn(txn): ) async def get_local_media_before( - self, before_ts: int, size_gt: int, keep_profiles: bool, + self, + before_ts: int, + size_gt: int, + keep_profiles: bool, ) -> List[str]: # to find files that have never been accessed (last_access_ts IS NULL) @@ -454,10 +457,14 @@ async def get_remote_media_thumbnails( ) async def get_remote_media_thumbnail( - self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + self, + origin: str, + media_id: str, + t_width: int, + t_height: int, + t_type: str, ) -> Optional[Dict[str, Any]]: - """Fetch the thumbnail info of given width, height and type. - """ + """Fetch the thumbnail info of given width, height and type.""" return await self.db_pool.simple_select_one( table="remote_media_cache_thumbnails", diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index dbbb99cb95fb..29edab34d47d 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -130,7 +130,9 @@ def _get_presence_for_user(self, user_id): raise NotImplementedError() @cachedList( - cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, ) async def get_presence_for_users(self, user_ids): rows = await self.db_pool.simple_select_many_batch( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 54ef0f1f5499..ba01d3108a96 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -118,8 +118,7 @@ async def maybe_delete_remote_profile_cache(self, user_id): ) async def is_subscribed_remote_profile_for_user(self, user_id): - """Check whether we are interested in a remote user's profile. - """ + """Check whether we are interested in a remote user's profile.""" res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, @@ -145,8 +144,7 @@ async def is_subscribed_remote_profile_for_user(self, user_id): async def get_remote_profile_cache_entries_that_expire( self, last_checked: int ) -> List[Dict[str, str]]: - """Get all users who haven't been checked since `last_checked` - """ + """Get all users who haven't been checked since `last_checked`""" def _get_remote_profile_cache_entries_that_expire_txn(txn): sql = """ diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 711d5aa23d6a..9e58dc0e6ae4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -168,7 +168,9 @@ def have_push_rules_changed_txn(txn): ) @cachedList( - cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, ) async def bulk_get_push_rules(self, user_ids): if not user_ids: @@ -195,7 +197,9 @@ async def bulk_get_push_rules(self, user_ids): use_new_defaults = user_id in self._users_new_default_push_rules results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), use_new_defaults, + rules, + enabled_map_by_user.get(user_id, {}), + use_new_defaults, ) return results diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 2687ef3e431d..7cb69dd6bd71 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -179,7 +179,9 @@ async def get_if_user_has_pusher(self, user_id: str): raise NotImplementedError() @cachedList( - cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, + cached_method_name="get_if_user_has_pusher", + list_name="user_ids", + num_args=1, ) async def get_if_users_have_pushers( self, user_ids: Iterable[str] @@ -263,7 +265,8 @@ async def get_throttle_params_by_room( params_by_room = {} for row in res: params_by_room[row["room_id"]] = ThrottleParams( - row["last_sent_ts"], row["throttle_ms"], + row["last_sent_ts"], + row["throttle_ms"], ) return params_by_room diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ae9283f52d81..43c852c96c00 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -208,8 +208,7 @@ async def get_linearized_receipts_for_room( async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None ) -> List[dict]: - """See get_linearized_receipts_for_room - """ + """See get_linearized_receipts_for_room""" def f(txn): if from_key: @@ -304,7 +303,9 @@ def f(txn): } return results - @cached(num_args=2,) + @cached( + num_args=2, + ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None ) -> Dict[str, JsonDict]: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8405dd460fb2..07e219aaed58 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -79,13 +79,16 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer" # call `find_max_generated_user_id_localpart` each time, which is # expensive if there are many entries. self._user_id_seq = build_sequence_generator( - database.engine, find_max_generated_user_id_localpart, "user_id_seq", + database.engine, + find_max_generated_user_id_localpart, + "user_id_seq", ) self._account_validity = hs.config.account_validity if hs.config.run_background_tasks and self._account_validity.enabled: self._clock.call_later( - 0.0, self._set_expiration_date_when_missing, + 0.0, + self._set_expiration_date_when_missing, ) # Create a background job for culling expired 3PID validity tokens diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cba343aa6873..9cbcd53026d8 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -193,8 +193,7 @@ def _count_public_rooms_txn(txn): ) async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ + """Retrieve the total number of rooms.""" def f(txn): sql = "SELECT count(*) FROM rooms" @@ -517,7 +516,8 @@ def _get_rooms_paginate_txn(txn): return rooms, room_count[0] return await self.db_pool.runInteraction( - "get_rooms_paginate", _get_rooms_paginate_txn, + "get_rooms_paginate", + _get_rooms_paginate_txn, ) @cached(max_entries=10000) @@ -578,7 +578,8 @@ def get_retention_policy_for_room_txn(txn): return self.db_pool.cursor_to_dict(txn) ret = await self.db_pool.runInteraction( - "get_retention_policy_for_room", get_retention_policy_for_room_txn, + "get_retention_policy_for_room", + get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default @@ -707,7 +708,10 @@ def _get_media_mxcs_in_room_txn(self, txn, room_id): return local_media_mxcs, remote_media_mxcs async def quarantine_media_by_id( - self, server_name: str, media_id: str, quarantined_by: str, + self, + server_name: str, + media_id: str, + quarantined_by: str, ) -> int: """quarantines a single local or remote media id @@ -961,7 +965,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.config = hs.config self.db_pool.updates.register_background_update_handler( - "insert_room_retention", self._background_insert_retention, + "insert_room_retention", + self._background_insert_retention, ) self.db_pool.updates.register_background_update_handler( @@ -1033,7 +1038,8 @@ def _background_insert_retention_txn(txn): return False end = await self.db_pool.runInteraction( - "insert_room_retention", _background_insert_retention_txn, + "insert_room_retention", + _background_insert_retention_txn, ) if end: @@ -1588,7 +1594,8 @@ def _get_event_reports_paginate_txn(txn): LIMIT ? OFFSET ? """.format( - where_clause=where_clause, order=order, + where_clause=where_clause, + order=order, ) args += [limit, start] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 92382bed28ee..a9216ca9ae52 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -70,10 +70,12 @@ def __init__(self, database: DatabasePool, db_conn, hs): ): self._known_servers_count = 1 self.hs.get_clock().looping_call( - self._count_known_servers, 60 * 1000, + self._count_known_servers, + 60 * 1000, ) self.hs.get_clock().call_later( - 1000, self._count_known_servers, + 1000, + self._count_known_servers, ) LaterGauge( "synapse_federation_known_servers", @@ -174,7 +176,7 @@ def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: @cached(max_entries=100000) async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: - """ Get the details of a room roughly suitable for use by the room + """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: room_id: The room ID to query @@ -488,8 +490,7 @@ async def get_rooms_for_user(self, user_id: str, on_invalidate=None): async def get_users_who_share_room_with_user( self, user_id: str, cache_context: _CacheContext ) -> Set[str]: - """Returns the set of users who share a room with `user_id` - """ + """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate ) @@ -618,7 +619,8 @@ def _get_joined_profile_from_event_id(self, event_id): raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", ) async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join @@ -802,8 +804,7 @@ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: - """Get user_id and membership of a set of event IDs. - """ + """Get user_id and membership of a set of event IDs.""" return await self.db_pool.simple_select_many_batch( table="room_memberships", diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index ad875c733a9c..3907189e29fc 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", + (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 3c1e33819b88..a7f371732fd7 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -52,8 +52,7 @@ def __len__(self): # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): - """The parts of StateGroupStore that can be called from workers. - """ + """The parts of StateGroupStore that can be called from workers.""" def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -276,8 +275,7 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: num_args=1, ) async def _get_state_group_for_events(self, event_ids): - """Returns mapping event_id -> state_group - """ + """Returns mapping event_id -> state_group""" rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", @@ -338,7 +336,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): columns=["state_group"], ) self.db_pool.updates.register_background_update_handler( - self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, + self.DELETE_CURRENT_STATE_UPDATE_NAME, + self._background_remove_left_rooms, ) async def _background_remove_left_rooms(self, progress, batch_size): @@ -487,7 +486,7 @@ def _background_remove_left_rooms_txn(txn): class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): - """ Keeps track of the state at a given event. + """Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned a state group (identified by an arbitrary string), which references a diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d421d18f8d1a..1c99393c657c 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -1001,7 +1001,9 @@ def get_users_media_usage_paginate_txn(txn): ORDER BY {order_by_column} {order} LIMIT ? OFFSET ? """.format( - sql_base=sql_base, order_by_column=order_by_column, order=order, + sql_base=sql_base, + order_by_column=order_by_column, + order=order, ) args += [limit, start] diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index e3b9ff5ca6b3..91f8abb67d59 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -565,7 +565,14 @@ def f(txn): AND e.stream_ordering > ? AND e.stream_ordering <= ? ORDER BY e.stream_ordering ASC """ - txn.execute(sql, (user_id, min_from_id, max_to_id,)) + txn.execute( + sql, + ( + user_id, + min_from_id, + max_to_id, + ), + ) rows = [ _EventDictReturn(event_id, None, stream_ordering) @@ -695,7 +702,10 @@ async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: return "t%d-%d" % (topo, token) def get_stream_id_for_event_txn( - self, txn: LoggingTransaction, event_id: str, allow_none=False, + self, + txn: LoggingTransaction, + event_id: str, + allow_none=False, ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn=txn, @@ -706,8 +716,7 @@ def get_stream_id_for_event_txn( ) async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: - """Get the persisted position for an event - """ + """Get the persisted position for an event""" row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, @@ -897,19 +906,19 @@ async def get_all_new_events_stream( ) -> Tuple[int, List[EventBase]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all events with from_id < stream_ordering <= current_id. - Args: - from_id: the stream_ordering of the last event we processed - current_id: the stream_ordering of the most recently processed event - limit: the maximum number of events to return + Args: + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return - Returns: - A tuple of (next_id, events), where `next_id` is the next value to - pass as `from_id` (it will either be the stream_ordering of the - last returned event, or, if fewer than `limit` events were found, - the `current_id`). - """ + Returns: + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). + """ def get_all_new_events_stream_txn(txn): sql = ( @@ -1238,8 +1247,7 @@ async def paginate_room_events( @cached() async def get_id_for_instance(self, instance_name: str) -> int: - """Get a unique, immutable ID that corresponds to the given Synapse worker instance. - """ + """Get a unique, immutable ID that corresponds to the given Synapse worker instance.""" def _get_id_for_instance_txn(txn): instance_id = self.db_pool.simple_select_one_onecol_txn( diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 248a6c3f25ff..b921d63d3051 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -64,8 +64,7 @@ def _cleanup_transactions_txn(txn): class TransactionStore(TransactionWorkerStore): - """A collection of queries for handling PDUs. - """ + """A collection of queries for handling PDUs.""" def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -299,7 +298,10 @@ def _set_destination_retry_timings_emulated( ) async def store_destination_rooms_entries( - self, destinations: Iterable[str], room_id: str, stream_ordering: int, + self, + destinations: Iterable[str], + room_id: str, + stream_ordering: int, ) -> None: """ Updates or creates `destination_rooms` entries in batch for a single event. @@ -394,7 +396,9 @@ async def set_destination_last_successful_stream_ordering( ) async def get_catch_up_room_event_ids( - self, destination: str, last_successful_stream_ordering: int, + self, + destination: str, + last_successful_stream_ordering: int, ) -> List[str]: """ Returns at most 50 event IDs and their corresponding stream_orderings @@ -418,7 +422,9 @@ async def get_catch_up_room_event_ids( @staticmethod def _get_catch_up_room_event_ids_txn( - txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int, + txn: LoggingTransaction, + destination: str, + last_successful_stream_ordering: int, ) -> List[str]: q = """ SELECT event_id FROM destination_rooms @@ -429,7 +435,8 @@ def _get_catch_up_room_event_ids_txn( LIMIT 50 """ txn.execute( - q, (destination, last_successful_stream_ordering), + q, + (destination, last_successful_stream_ordering), ) event_ids = [row[0] for row in txn] return event_ids diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 79b7ece3302a..5473ec1485d0 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore): """ async def create_ui_auth_session( - self, clientdict: JsonDict, uri: str, method: str, description: str, + self, + clientdict: JsonDict, + uri: str, + method: str, + description: str, ) -> UIAuthSessionData: """ Creates a new user interactive authentication session. @@ -123,7 +127,10 @@ async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: return UIAuthSessionData(session_id, **result) async def mark_ui_auth_stage_complete( - self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + self, + session_id: str, + stage_type: str, + result: Union[str, bool, JsonDict], ): """ Mark a session stage as completed. @@ -261,10 +268,12 @@ async def get_ui_auth_session_data( return serverdict.get(key, default) async def add_user_agent_ip_to_ui_auth_session( - self, session_id: str, user_agent: str, ip: str, + self, + session_id: str, + user_agent: str, + ip: str, ): - """Add the given user agent / IP to the tracking table - """ + """Add the given user agent / IP to the tracking table""" await self.db_pool.simple_upsert( table="ui_auth_sessions_ips", keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, @@ -273,7 +282,8 @@ async def add_user_agent_ip_to_ui_auth_session( ) async def get_user_agents_ips_to_ui_auth_session( - self, session_id: str, + self, + session_id: str, ) -> List[Tuple[str, str]]: """Get the given user agents / IPs used during the ui auth process diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 7b9729da0958..3a1fe3ed52b1 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -336,8 +336,7 @@ def _get_next_batch(txn): return len(users_to_work_on) async def is_room_world_readable_or_publicly_joinable(self, room_id): - """Check if the room is either world_readable or publically joinable - """ + """Check if the room is either world_readable or publically joinable""" # Create a state filter that only queries join and history state event types_to_filter = ( @@ -516,8 +515,7 @@ async def add_users_in_public_rooms( ) async def delete_all_from_user_dir(self) -> None: - """Delete the entire user directory - """ + """Delete the entire user directory""" def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM user_directory") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 89cdc84a9cda..b16b9905d887 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -48,8 +48,7 @@ def __len__(self): class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): - """A data store for fetching/storing state groups. - """ + """A data store for fetching/storing state groups.""" def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -89,7 +88,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", 500000, + "*stateGroupMembersCache*", + 500000, ) def get_max_state_group_txn(txn: Cursor): diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index d6d632dc10f6..cca839c70f59 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -94,14 +94,12 @@ def lock_table(self, txn, table: str) -> None: @property @abc.abstractmethod def server_version(self) -> str: - """Gets a string giving the server version. For example: '3.22.0' - """ + """Gets a string giving the server version. For example: '3.22.0'""" ... @abc.abstractmethod def in_transaction(self, conn: Connection) -> bool: - """Whether the connection is currently in a transaction. - """ + """Whether the connection is currently in a transaction.""" ... @abc.abstractmethod diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 7719ac32f764..80a3558aec3e 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -138,8 +138,7 @@ def supports_tuple_comparison(self): @property def supports_using_any_list(self): - """Do we support using `a = ANY(?)` and passing a list - """ + """Do we support using `a = ANY(?)` and passing a list""" return True def is_deadlock(self, error): diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index b3d1834efbec..b87e7798daab 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -29,7 +29,10 @@ def __init__(self, database_module, database_config): super().__init__(database_module, database_config) database = database_config.get("args", {}).get("database") - self._is_in_memory = database in (None, ":memory:",) + self._is_in_memory = database in ( + None, + ":memory:", + ) if platform.python_implementation() == "PyPy": # pypy's sqlite3 module doesn't handle bytearrays, convert them @@ -63,8 +66,7 @@ def supports_tuple_comparison(self): @property def supports_using_any_list(self): - """Do we support using `a = ANY(?)` and passing a list - """ + """Do we support using `a = ANY(?)` and passing a list""" return False def check_database(self, db_conn, allow_outdated_version: bool = False): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 61fc49c69c8d..3a0d6fb32e84 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -411,8 +411,8 @@ async def _persist_events( ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = await self.main_store.get_latest_event_ids_in_room( - room_id + latest_event_ids = ( + await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids @@ -889,7 +889,8 @@ async def _prune_extremities( continue logger.debug( - "Not dropping as too new and not in new_senders: %s", new_senders, + "Not dropping as too new and not in new_senders: %s", + new_senders, ) return new_latest_event_ids @@ -1004,7 +1005,10 @@ async def _is_server_still_joined( remote_event_ids = [ event_id - for (typ, state_key,), event_id in current_state.items() + for ( + typ, + state_key, + ), event_id in current_state.items() if typ == EventTypes.Member and not self.is_mine_id(state_key) ] rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index cd30e6b80a8a..6c3c2da5201f 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -425,7 +425,10 @@ def _upgrade_existing_database( # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( "Found multiple delta files with the same name in v%d: %s" - % (v, duplicates,) + % ( + v, + duplicates, + ) ) # We sort to ensure that we apply the delta files in a consistent @@ -532,7 +535,8 @@ def _apply_module_schema_files( names_and_streams: the names and streams of schemas to be applied """ cur.execute( - "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", + (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index 6c359c1aaedb..3c4908865f88 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -26,15 +26,13 @@ class PurgeEventsStorage: - """High level interface for purging rooms and event history. - """ + """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores async def purge_room(self, room_id: str) -> None: - """Deletes all record of a room - """ + """Deletes all record of a room""" state_groups_to_delete = await self.stores.main.purge_room(room_id) await self.stores.state.purge_room_state(room_id, state_groups_to_delete) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 31ccbf23dc20..d179a4188449 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -340,8 +340,7 @@ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: class StateGroupStorage: - """High level interface to fetching state for event. - """ + """High level interface to fetching state for event.""" def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores @@ -400,7 +399,7 @@ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: async def get_state_groups( self, room_id: str, event_ids: Iterable[str] ) -> Dict[int, List[EventBase]]: - """ Get the state groups for the given list of event_ids + """Get the state groups for the given list of event_ids Args: room_id: ID of the room for these events. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9dd537bf6674..d4643c4fdf30 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -277,7 +277,9 @@ def __init__( self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, tables: List[Tuple[str, str, str]], + self, + db_conn, + tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -364,7 +366,10 @@ def _load_current_ids( rows.sort() with self._lock: - for (instance, stream_id,) in rows: + for ( + instance, + stream_id, + ) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) @@ -481,8 +486,7 @@ def get_current_token(self) -> int: return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - """ + """Returns the position of the given writer.""" # If we don't have an entry for the given instance name, we assume it's a # new writer. @@ -581,8 +585,7 @@ def _add_persisted_position(self, new_id: int): break def _update_stream_positions_table_txn(self, txn: Cursor): - """Update the `stream_positions` table with newly persisted position. - """ + """Update the `stream_positions` table with newly persisted position.""" if not self._writers: return @@ -622,8 +625,7 @@ async def __aexit__(self, exc_type, exc, tb): @attr.s(slots=True) class _MultiWriterCtxManager: - """Async context manager returned by MultiWriterIdGenerator - """ + """Async context manager returned by MultiWriterIdGenerator""" id_gen = attr.ib(type=MultiWriterIdGenerator) multiple_ids = attr.ib(type=Optional[int], default=None) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index e2b316a2182c..3ea637b28128 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -124,8 +124,7 @@ def check_consistency( stream_name: Optional[str] = None, positive: bool = True, ): - """See SequenceGenerator.check_consistency for docstring. - """ + """See SequenceGenerator.check_consistency for docstring.""" txn = db_conn.cursor(txn_name="sequence.check_consistency") diff --git a/synapse/types.py b/synapse/types.py index c695558a8611..721343f0b5ea 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -469,8 +469,7 @@ class RoomStreamToken: ) def __attrs_post_init__(self): - """Validates that both `topological` and `instance_map` aren't set. - """ + """Validates that both `topological` and `instance_map` aren't set.""" if self.instance_map and self.topological: raise ValueError( @@ -498,7 +497,11 @@ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": instance_name = await store.get_name_from_instance_id(instance_id) instance_map[instance_name] = pos - return cls(topological=None, stream=stream, instance_map=instance_map,) + return cls( + topological=None, + stream=stream, + instance_map=instance_map, + ) except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 691dde9a014f..719e35b78d72 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -252,8 +252,7 @@ def __init__( self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] def is_queued(self, key: Hashable) -> bool: - """Checks whether there is a process queued up waiting - """ + """Checks whether there is a process queued up waiting""" entry = self.key_to_defer.get(key) if not entry: # No entry so nothing is waiting. @@ -452,7 +451,9 @@ def _ctx_manager(): def timeout_deferred( - deferred: defer.Deferred, timeout: float, reactor: IReactorTime, + deferred: defer.Deferred, + timeout: float, + reactor: IReactorTime, ) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -529,8 +530,7 @@ def failure_cb(val): @attr.s(slots=True, frozen=True) class DoneAwaitable: - """Simple awaitable that returns the provided value. - """ + """Simple awaitable that returns the provided value.""" value = attr.ib() @@ -545,8 +545,7 @@ def __next__(self): def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: - """Convert a value to an awaitable if not already an awaitable. - """ + """Convert a value to an awaitable if not already an awaitable.""" if inspect.isawaitable(value): assert isinstance(value, Awaitable) return value diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 89f0b385357d..e676c2cac46b 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -149,8 +149,7 @@ def register_cache( def intern_string(string): - """Takes a (potentially) unicode string and interns it if it's ascii - """ + """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -161,8 +160,7 @@ def intern_string(string): def intern_dict(dictionary): - """Takes a dictionary and interns well known keys and their values - """ + """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) for key, value in dictionary.items() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index a924140cdf5e..4e843799147d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, cache_context: bool = False, + max_entries: int = 1000, + cache_context: bool = False, ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -156,7 +157,9 @@ def foo(self, key, cache_context): def func(orig: F) -> _LruCachedFunction[F]: desc = LruCacheDescriptor( - orig, max_entries=max_entries, cache_context=cache_context, + orig, + max_entries=max_entries, + cache_context=cache_context, ) return cast(_LruCachedFunction[F], desc) @@ -170,14 +173,18 @@ class _Sentinel(enum.Enum): sentinel = object() def __init__( - self, orig, max_entries: int = 1000, cache_context: bool = False, + self, + orig, + max_entries: int = 1000, + cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries def __get__(self, obj, owner): cache = LruCache( - cache_name=self.orig.__name__, max_size=self.max_entries, + cache_name=self.orig.__name__, + max_size=self.max_entries, ) # type: LruCache[CacheKey, Any] get_cache_key = self.cache_key_builder @@ -212,7 +219,7 @@ def _wrapped(*args, **kwargs): class DeferredCacheDescriptor(_CacheDescriptorBase): - """ A method decorator that applies a memoizing cache around the function. + """A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that fail are removed from the cache. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c541bf45797d..644e9e778a29 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -84,8 +84,7 @@ def set_cache_factor(self, factor: float) -> bool: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos - """ + """Returns True if the entity may have been updated since stream_pos""" assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: @@ -133,8 +132,7 @@ def get_entities_changed( return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed - """ + """Returns if any entity has changed""" assert type(stream_pos) is int if not self._cache: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index a6ee9edaec90..3c47285d05ca 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -108,7 +108,10 @@ async def do(observer): return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( - "%s signal observer %s failed: %r", self.name, observer, e, + "%s signal observer %s failed: %r", + self.name, + observer, + e, ) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 733f5e26e63b..68dc632491dc 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -83,15 +83,13 @@ def registerProducer(self, producer, streaming): self._producer.resumeProducing() def unregisterProducer(self): - """Part of IProducer interface - """ + """Part of IProducer interface""" self._producer = None if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) def write(self, bytes): - """Part of IProducer interface - """ + """Part of IProducer interface""" if self._write_exception: raise self._write_exception @@ -107,8 +105,7 @@ def write(self, bytes): self._producer.pauseProducing() def _writer(self): - """This is run in a background thread to write to the file. - """ + """This is run in a background thread to write to the file.""" try: while self._producer or not self._bytes_queue.empty(): # If we've paused the producer check if we should resume the @@ -135,13 +132,11 @@ def _writer(self): self._file_obj.close() def wait(self): - """Returns a deferred that resolves when finished writing to file - """ + """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) def _resume_paused_producer(self): - """Gets called if we should resume producing after being paused - """ + """Gets called if we should resume producing after being paused""" if self._paused_producer and self._producer: self._paused_producer = False self._producer.resumeProducing() diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 8d2411513fd7..98707c119deb 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -62,7 +62,8 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: def sorted_topologically( - nodes: Iterable[T], graph: Mapping[T, Collection[T]], + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], ) -> Generator[T, None, None]: """Given a set of nodes and a graph, yield the nodes in toplogical order. diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 50516926f3ef..e3a8ed5b2f27 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -15,7 +15,7 @@ class JsonEncodedObject: - """ A common base class for defining protocol units that are represented + """A common base class for defining protocol units that are represented as JSON. Attributes: @@ -39,7 +39,7 @@ class JsonEncodedObject: """ def __init__(self, **kwargs): - """ Takes the dict of `kwargs` and loads all keys that are *valid* + """Takes the dict of `kwargs` and loads all keys that are *valid* (i.e., are included in the `valid_keys` list) into the dictionary` instance variable. @@ -61,7 +61,7 @@ def __init__(self, **kwargs): self.unrecognized_keys[k] = v def get_dict(self): - """ Converts this protocol unit into a :py:class:`dict`, ready to be + """Converts this protocol unit into a :py:class:`dict`, ready to be encoded as JSON. The keys it encodes are: `valid_keys` - `internal_keys` diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index f4de6b9f5407..1023c856d143 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -161,8 +161,7 @@ def get_resource_usage(self) -> ContextResourceUsage: return self._logging_context.get_resource_usage() def _update_in_flight(self, metrics): - """Gets called when processing in flight metrics - """ + """Gets called when processing in flight metrics""" duration = self.clock.time() - self.start metrics.real_time_max = max(metrics.real_time_max, duration) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 09b094ded7d9..d184e2a90cb6 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -25,7 +25,7 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: - """ Loads a synapse module with its config + """Loads a synapse module with its config Args: provider: a dict with keys 'module' (the module name) and 'config' diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 72574d3af257..d9f9ae99d639 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -204,16 +204,13 @@ def check_yield_points_inner(*args, **kwargs): # We don't raise here as its perfectly valid for contexts to # change in a function, as long as it sets the correct context # on resolving (which is checked separately). - err = ( - "%s changed context from %s to %s, happened between lines %d and %d in %s" - % ( - frame.f_code.co_name, - expected_context, - current_context(), - last_yield_line_no, - frame.f_lineno, - frame.f_code.co_filename, - ) + err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % ( + frame.f_code.co_name, + expected_context, + current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, ) changes.append(err) diff --git a/synmark/__main__.py b/synmark/__main__.py index de13c1a9094c..f55968a5a420 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -96,5 +96,6 @@ def add_cmdline_args(cmd, args): runner.args.loops = orig_loops loops = "auto" runner.bench_time_func( - suite.__name__ + "_" + str(loops), make_test(suite.main), + suite.__name__ + "_" + str(loops), + make_test(suite.main), ) diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index c9d9cf761ef5..c306891b27f4 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -98,7 +98,9 @@ class Config: logger = logging.getLogger("synapse.logging.test_terse_json") _setup_stdlib_logging( - hs_config, log_config, logBeginner=beginner, + hs_config, + log_config, + logBeginner=beginner, ) # Wait for it to connect... diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index b1a8c58e1cf9..34f72ae795c7 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -276,7 +276,10 @@ async def get_user(tok): if token != tok: return None return TokenLookupResult( - user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + user_id=USER_ID, + is_guest=False, + token_id=1234, + device_id="DEVICE", ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index fe504d0869c1..483418192c4b 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -43,7 +43,11 @@ def test_allowed_user_via_can_requester_do_action(self): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", + None, + "example.com", + id="foo", + rate_limited=True, + sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +72,11 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", + None, + "example.com", + id="foo", + rate_limited=False, + sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -113,12 +121,18 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,) + allowed, time_allowed = limiter.can_do_action( + ("test_id",), + _time_now_s=0, + ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,) + allowed, time_allowed = limiter.can_do_action( + ("test_id",), + _time_now_s=1, + ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index d3ec24c975df..2b7f09c14b27 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -127,8 +127,7 @@ def test_global_instantiated_after_config_load(self): self.assertEqual(cache.max_size, 150) def test_cache_with_asterisk_in_name(self): - """Some caches have asterisks in their name, test that they are set correctly. - """ + """Some caches have asterisks in their name, test that they are set correctly.""" config = { "caches": { @@ -164,7 +163,8 @@ def test_apply_cache_factor_from_config(self): t.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache( - max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False, + max_size=t.caches.event_cache_size, + apply_cache_factor_from_config=False, ) add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 1d65ea2f9c13..30fcc4c1bfcc 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -400,7 +400,10 @@ def make_homeserver(self, reactor, clock): ) def build_perspectives_response( - self, server_name: str, signing_key: SigningKey, valid_until_ts: int, + self, + server_name: str, + signing_key: SigningKey, + valid_until_ts: int, ) -> dict: """ Build a valid perspectives server response to a request for the given key @@ -455,7 +458,9 @@ def test_get_keys_from_perspectives(self): VALID_UNTIL_TS = 200 * 1000 response = self.build_perspectives_response( - SERVER_NAME, testkey, VALID_UNTIL_TS, + SERVER_NAME, + testkey, + VALID_UNTIL_TS, ) self.expect_outgoing_key_query(SERVER_NAME, "key1", response) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 3a8062622496..ec85324c0c62 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -43,7 +43,10 @@ def test_serialize_deserialize_msg(self): event, context = self.get_success( create_event( - self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, ) ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9ccd2d76b8ed..8186b8ca013c 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -150,8 +150,8 @@ def test_join_too_large_once_joined(self): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable( - 600 + self.hs.get_datastore().get_current_state_event_counts = ( + lambda x: make_awaitable(600) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 917762e6b658..ecc3faa57218 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -279,7 +279,8 @@ def test_upload_signatures(self): ret = self.get_success( e2e_handler.upload_signatures_for_device_keys( - u1, {u1: {"D1": d1_json, "D2": d2_json}}, + u1, + {u1: {"D1": d1_json, "D2": d2_json}}, ) ) self.assertEqual(ret["failures"], {}) @@ -486,9 +487,11 @@ def check_device_update_edu( self.assertGreaterEqual(content["stream_id"], prev_stream_id) return content["stream_id"] - def check_signing_key_update_txn(self, txn: JsonDict,) -> None: - """Check that the txn has an EDU with a signing key update. - """ + def check_signing_key_update_txn( + self, + txn: JsonDict, + ) -> None: + """Check that the txn has an EDU with a signing key update.""" edus = txn["edus"] self.assertEqual(len(edus), 1) @@ -502,7 +505,9 @@ def generate_and_upload_device_signing_key( self.get_success( self.hs.get_e2e_keys_handler().upload_keys_for_user( - user_id, device_id, {"device_keys": device_dict}, + user_id, + device_id, + {"device_keys": device_dict}, ) ) return sk diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 5c2b4de1a665..a01fdd083981 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -44,8 +44,7 @@ def prepare(self, reactor, clock, hs): self.token2 = self.login("user2", "password") def test_single_public_joined_room(self): - """Test that we write *all* events for a public room - """ + """Test that we write *all* events for a public room""" room_id = self.helper.create_room_as( self.user1, tok=self.token1, is_public=True ) @@ -116,8 +115,7 @@ def test_single_private_joined_room(self): self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) def test_single_left_room(self): - """Tests that we don't see events in the room after we leave. - """ + """Tests that we don't see events in the room after we leave.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) self.helper.join(room_id, self.user2, tok=self.token2) @@ -190,8 +188,7 @@ def test_single_left_rejoined_private_room(self): self.assertEqual(counter[(EventTypes.Member, self.user2)], 3) def test_invite(self): - """Tests that pending invites get handled correctly. - """ + """Tests that pending invites get handled correctly.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) self.helper.invite(room_id, self.user1, self.user2, tok=self.token1) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 5dfeccfeb6e9..821629bc38a3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -260,7 +260,9 @@ def test_dehydrate_and_rehydrate_device(self): # Create a new login for the user and dehydrated the device device_id, access_token = self.get_success( self.registration.register_device( - user_id=user_id, device_id=None, initial_display_name="new device", + user_id=user_id, + device_id=None, + initial_display_name="new device", ) ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index a39f89860817..863d8737b2f8 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -131,7 +131,9 @@ def test_create_alias_joined_room(self): """A user can create an alias for a room they're in.""" self.get_success( self.handler.create_association( - create_requester(self.test_user), self.room_alias, self.room_id, + create_requester(self.test_user), + self.room_alias, + self.room_id, ) ) @@ -143,7 +145,9 @@ def test_create_alias_other_room(self): self.get_failure( self.handler.create_association( - create_requester(self.test_user), self.room_alias, other_room_id, + create_requester(self.test_user), + self.room_alias, + other_room_id, ), synapse.api.errors.SynapseError, ) @@ -156,7 +160,9 @@ def test_create_alias_admin(self): self.get_success( self.handler.create_association( - create_requester(self.admin_user), self.room_alias, other_room_id, + create_requester(self.admin_user), + self.room_alias, + other_room_id, ) ) @@ -275,8 +281,7 @@ def test_delete_alias_sufficient_power(self): class CanonicalAliasTestCase(unittest.HomeserverTestCase): - """Test modifications of the canonical alias when delete aliases. - """ + """Test modifications of the canonical alias when delete aliases.""" servlets = [ synapse.rest.admin.register_servlets, @@ -317,7 +322,10 @@ def _add_alias(self, alias: str) -> RoomAlias: def _set_canonical_alias(self, content): """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, + self.room_id, + "m.room.canonical_alias", + content, + tok=self.admin_user_tok, ) def _get_canonical_alias(self): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index c1a13aeb7176..5e86c5e56bf4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -33,8 +33,7 @@ def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() def test_query_local_devices_no_devices(self): - """If the user has no devices, we expect an empty list. - """ + """If the user has no devices, we expect an empty list.""" local_user = "@boris:" + self.hs.hostname res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) @@ -102,7 +101,9 @@ def test_change_one_time_keys(self): # Error when replacing string key with dict self.get_failure( self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, ), SynapseError, ) @@ -215,7 +216,8 @@ def test_fallback_key(self): ) ) self.assertEqual( - res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, ) res = self.get_success( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 58773a0c3812..d7498aa51a80 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -70,8 +70,7 @@ def test_get_missing_version_info(self): self.assertEqual(res, 404) def test_create_version(self): - """Check that we can create and then retrieve versions. - """ + """Check that we can create and then retrieve versions.""" res = self.get_success( self.handler.create_version( self.local_user, @@ -138,8 +137,7 @@ def test_create_version(self): ) def test_update_version(self): - """Check that we can update versions. - """ + """Check that we can update versions.""" version = self.get_success( self.handler.create_version( self.local_user, @@ -178,8 +176,7 @@ def test_update_version(self): ) def test_update_missing_version(self): - """Check that we get a 404 on updating nonexistent versions - """ + """Check that we get a 404 on updating nonexistent versions""" e = self.get_failure( self.handler.update_version( self.local_user, @@ -196,8 +193,7 @@ def test_update_missing_version(self): self.assertEqual(res, 404) def test_update_omitted_version(self): - """Check that the update succeeds if the version is missing from the body - """ + """Check that the update succeeds if the version is missing from the body""" version = self.get_success( self.handler.create_version( self.local_user, @@ -234,8 +230,7 @@ def test_update_omitted_version(self): ) def test_update_bad_version(self): - """Check that we get a 400 if the version in the body doesn't match - """ + """Check that we get a 400 if the version in the body doesn't match""" version = self.get_success( self.handler.create_version( self.local_user, @@ -263,8 +258,7 @@ def test_update_bad_version(self): self.assertEqual(res, 400) def test_delete_missing_version(self): - """Check that we get a 404 on deleting nonexistent versions - """ + """Check that we get a 404 on deleting nonexistent versions""" e = self.get_failure( self.handler.delete_version(self.local_user, "1"), SynapseError ) @@ -272,15 +266,13 @@ def test_delete_missing_version(self): self.assertEqual(res, 404) def test_delete_missing_current_version(self): - """Check that we get a 404 on deleting nonexistent current version - """ + """Check that we get a 404 on deleting nonexistent current version""" e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) res = e.value.code self.assertEqual(res, 404) def test_delete_version(self): - """Check that we can create and then delete versions. - """ + """Check that we can create and then delete versions.""" res = self.get_success( self.handler.create_version( self.local_user, @@ -303,8 +295,7 @@ def test_delete_version(self): self.assertEqual(res, 404) def test_get_missing_backup(self): - """Check that we get a 404 on querying missing backup - """ + """Check that we get a 404 on querying missing backup""" e = self.get_failure( self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError ) @@ -312,8 +303,7 @@ def test_get_missing_backup(self): self.assertEqual(res, 404) def test_get_missing_room_keys(self): - """Check we get an empty response from an empty backup - """ + """Check we get an empty response from an empty backup""" version = self.get_success( self.handler.create_version( self.local_user, @@ -332,8 +322,7 @@ def test_get_missing_room_keys(self): # although this is probably best done in sytest def test_upload_room_keys_no_versions(self): - """Check that we get a 404 on uploading keys when no versions are defined - """ + """Check that we get a 404 on uploading keys when no versions are defined""" e = self.get_failure( self.handler.upload_room_keys(self.local_user, "no_version", room_keys), SynapseError, @@ -364,8 +353,7 @@ def test_upload_room_keys_bogus_version(self): self.assertEqual(res, 404) def test_upload_room_keys_wrong_version(self): - """Check that we get a 403 on uploading keys for an old version - """ + """Check that we get a 403 on uploading keys for an old version""" version = self.get_success( self.handler.create_version( self.local_user, @@ -395,8 +383,7 @@ def test_upload_room_keys_wrong_version(self): self.assertEqual(res, 403) def test_upload_room_keys_insert(self): - """Check that we can insert and retrieve keys for a session - """ + """Check that we can insert and retrieve keys for a session""" version = self.get_success( self.handler.create_version( self.local_user, @@ -510,8 +497,7 @@ def test_upload_room_keys_merge(self): # TODO: check edge cases as well as the common variations here def test_delete_room_keys(self): - """Check that we can insert and delete keys for a session - """ + """Check that we can insert and delete keys for a session""" version = self.get_success( self.handler.create_version( self.local_user, diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 983e36859297..3af361195b57 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -226,12 +226,20 @@ def create_invite(): for i in range(3): event = create_invite() self.get_success( - self.handler.on_invite_request(other_server, event, event.room_version,) + self.handler.on_invite_request( + other_server, + event, + event.room_version, + ) ) event = create_invite() self.get_failure( - self.handler.on_invite_request(other_server, event, event.room_version,), + self.handler.on_invite_request( + other_server, + event, + event.room_version, + ), exc=LimitExceededError, ) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f955dfa490ae..a0d1ebdbe3c1 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,9 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token,) + self.hs.get_datastore().get_user_by_access_token( + self.access_token, + ) ) self.token_id = self.info.token_id @@ -169,8 +171,7 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) def test_allow_server_acl(self): - """Test that sending an ACL that blocks everyone but ourselves works. - """ + """Test that sending an ACL that blocks everyone but ourselves works.""" self.helper.send_state( self.room_id, @@ -181,8 +182,7 @@ def test_allow_server_acl(self): ) def test_deny_server_acl_block_outselves(self): - """Test that sending an ACL that blocks ourselves does not work. - """ + """Test that sending an ACL that blocks ourselves does not work.""" self.helper.send_state( self.room_id, EventTypes.ServerACL, @@ -192,8 +192,7 @@ def test_deny_server_acl_block_outselves(self): ) def test_deny_redact_server_acl(self): - """Test that attempting to redact an ACL is blocked. - """ + """Test that attempting to redact an ACL is blocked.""" body = self.helper.send_state( self.room_id, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 27bb50e3b5d4..bdd2e02eae12 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -512,7 +512,9 @@ def test_callback_session(self): # Mismatching session session = self._generate_oidc_session_token( - state="state", nonce="nonce", client_redirect_url="http://client/redirect", + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", ) request.args = {} request.args[b"state"] = [b"mismatching state"] @@ -567,7 +569,9 @@ def test_exchange_code(self): # Internal server error with no JSON body self.http_client.request = simple_async_mock( return_value=FakeResponse( - code=500, phrase=b"Internal Server Error", body=b"Not JSON", + code=500, + phrase=b"Internal Server Error", + body=b"Not JSON", ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -587,7 +591,11 @@ def test_exchange_code(self): # 4xx error without "error" field self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + return_value=FakeResponse( + code=400, + phrase=b"Bad request", + body=b"{}", + ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") @@ -595,7 +603,9 @@ def test_exchange_code(self): # 2xx error with "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse( - code=200, phrase=b"OK", body=b'{"error": "some_error"}', + code=200, + phrase=b"OK", + body=b'{"error": "some_error"}', ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -632,7 +642,9 @@ def test_extra_attributes(self): state = "state" client_redirect_url = "http://client/redirect" session = self._generate_oidc_session_token( - state=state, nonce="nonce", client_redirect_url=client_redirect_url, + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, ) request = _build_callback_request("code", state, session) @@ -895,7 +907,9 @@ async def _make_callback_with_userinfo( session = handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, + idp_id="oidc", + nonce="nonce", + client_redirect_url=client_redirect_url, ), ) request = _build_callback_request("code", state, session) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index f816594ee4d5..a98a65ae67e4 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -231,8 +231,7 @@ def test_local_user_fallback_ui_auth(self): } ) def test_no_local_user_fallback_login(self): - """localdb_enabled can block login with the local password - """ + """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") # check_password must return an awaitable @@ -251,8 +250,7 @@ def test_no_local_user_fallback_login(self): } ) def test_no_local_user_fallback_ui_auth(self): - """localdb_enabled can block ui auth with the local password - """ + """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") # allow login via the auth provider @@ -594,7 +592,10 @@ def _authed_delete_device( ) def _delete_device( - self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", + self, + access_token: str, + device: str, + body: Union[JsonDict, bytes] = b"", ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 0794b32c9c69..be2ee26f07cf 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -589,8 +589,7 @@ def test_remote_gets_presence_when_local_user_joins(self): ) def _add_new_user(self, room_id, user_id): - """Add new user to the room by creating an event and poking the federation API. - """ + """Add new user to the room by creating an event and poking the federation API.""" hostname = get_domain_from_id(user_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 787fab78754a..18ca8b84f5e1 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -208,7 +208,9 @@ def test_set_my_avatar(self): # Set avatar to an empty string self.get_success( self.handler.set_avatar_url( - self.frank, synapse.types.create_requester(self.frank), "", + self.frank, + synapse.types.create_requester(self.frank), + "", ) ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 96e5bdac4a80..24e71381965a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -143,14 +143,14 @@ async def get_users_in_room(room_id): self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( - ([], 0) + self.datastore.get_new_device_msgs_for_remote = ( + lambda *args, **kargs: make_awaitable(([], 0)) ) - self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( - None + self.datastore.delete_device_msgs_for_remote = ( + lambda *args, **kargs: make_awaitable(None) ) - self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( - None + self.datastore.set_received_txn_response = ( + lambda *args, **kwargs: make_awaitable(None) ) def test_started_typing_local(self): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9c886d671a1b..3572e54c5d26 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -200,7 +200,9 @@ def test_encrypted_by_default_config_option_all(self): # Check that the room has an encryption state event event_content = self.helper.get_state( - room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, ) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) @@ -209,7 +211,9 @@ def test_encrypted_by_default_config_option_all(self): # Check that the room has an encryption state event event_content = self.helper.get_state( - room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, ) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) @@ -227,7 +231,9 @@ def test_encrypted_by_default_config_option_invite(self): # Check that the room has an encryption state event event_content = self.helper.get_state( - room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, ) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 686012dd25e0..4c56253da549 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -518,8 +518,7 @@ def test_get_no_srv_no_well_known(self): self.successResultOf(test_d) def test_get_well_known(self): - """Test the behaviour when the .well-known delegates elsewhere - """ + """Test the behaviour when the .well-known delegates elsewhere""" self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -1135,8 +1134,7 @@ def test_well_known_too_large(self): self.assertIsNone(r.delegated_server) def test_srv_fallbacks(self): - """Test that other SRV results are tried if the first one fails. - """ + """Test that other SRV results are tried if the first one fails.""" self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [ Server(host=b"target.com", port=8443), diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 27206ca3dbe5..edacd1b566ba 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -100,7 +100,10 @@ def test_sending_events_into_room(self): # Check that the event was sent self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( - expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True, + expected_requester, + event_dict, + ratelimit=False, + ignore_shadow_ban=True, ) # Create and send a state event diff --git a/tests/replication/_base.py b/tests/replication/_base.py index d5dce1f83fc0..f6a6aed35e2b 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -79,7 +79,11 @@ def prepare(self, reactor, clock, hs): repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( - self.worker_hs, "client", "test", clock, repl_handler, + self.worker_hs, + "client", + "test", + clock, + repl_handler, ) self._client_transport = None @@ -228,7 +232,9 @@ def setUp(self): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - "localhost", 6379, self.connect_any_redis_attempts, + "localhost", + 6379, + self.connect_any_redis_attempts, ) self.hs.get_tcp_replication().start_replication(self.hs) @@ -246,8 +252,7 @@ def setUp(self): ) def create_test_resource(self): - """Overrides `HomeserverTestCase.create_test_resource`. - """ + """Overrides `HomeserverTestCase.create_test_resource`.""" # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all # subclassses. @@ -296,7 +301,10 @@ def make_worker_hs( if instance_loc.host not in self.reactor.lookups: raise Exception( "Host does not have an IP for instance_map[%r].host = %r" - % (instance_name, instance_loc.host,) + % ( + instance_name, + instance_loc.host, + ) ) self.reactor.add_tcp_client_callback( @@ -315,7 +323,11 @@ def make_worker_hs( if not worker_hs.config.redis_enabled: repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, + worker_hs, + "client", + "test", + self.clock, + repl_handler, ) server = self.server_factory.buildProtocol(None) @@ -485,8 +497,7 @@ def unregisterProducer(self): self._pull_to_push_producer.stop() def checkPersistence(self, request, version): - """Check whether the connection can be re-used - """ + """Check whether the connection can be re-used""" # We hijack this to always say no for ease of wiring stuff up in # `handle_http_replication_attempt`. request.responseHeaders.setRawHeaders(b"connection", [b"close"]) @@ -494,8 +505,7 @@ def checkPersistence(self, request, version): class _PullToPushProducer: - """A push producer that wraps a pull producer. - """ + """A push producer that wraps a pull producer.""" def __init__( self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer @@ -512,39 +522,33 @@ def __init__( self._start_loop() def _start_loop(self): - """Start the looping call to - """ + """Start the looping call to""" if not self._looping_call: # Start a looping call which runs every tick. self._looping_call = self._clock.looping_call(self._run_once, 0) def stop(self): - """Stops calling resumeProducing. - """ + """Stops calling resumeProducing.""" if self._looping_call: self._looping_call.stop() self._looping_call = None def pauseProducing(self): - """Implements IPushProducer - """ + """Implements IPushProducer""" self.stop() def resumeProducing(self): - """Implements IPushProducer - """ + """Implements IPushProducer""" self._start_loop() def stopProducing(self): - """Implements IPushProducer - """ + """Implements IPushProducer""" self.stop() self._producer.stopProducing() def _run_once(self): - """Calls resumeProducing on producer once. - """ + """Calls resumeProducing on producer once.""" try: self._producer.resumeProducing() @@ -559,25 +563,21 @@ def _run_once(self): class FakeRedisPubSubServer: - """A fake Redis server for pub/sub. - """ + """A fake Redis server for pub/sub.""" def __init__(self): self._subscribers = set() def add_subscriber(self, conn): - """A connection has called SUBSCRIBE - """ + """A connection has called SUBSCRIBE""" self._subscribers.add(conn) def remove_subscriber(self, conn): - """A connection has called UNSUBSCRIBE - """ + """A connection has called UNSUBSCRIBE""" self._subscribers.discard(conn) def publish(self, conn, channel, msg) -> int: - """A connection want to publish a message to subscribers. - """ + """A connection want to publish a message to subscribers.""" for sub in self._subscribers: sub.send(["message", channel, msg]) @@ -588,8 +588,7 @@ def buildProtocol(self, addr): class FakeRedisPubSubProtocol(Protocol): - """A connection from a client talking to the fake Redis server. - """ + """A connection from a client talking to the fake Redis server.""" def __init__(self, server: FakeRedisPubSubServer): self._server = server @@ -613,8 +612,7 @@ def dataReceived(self, data): self.handle_command(msg[0], *msg[1:]) def handle_command(self, command, *args): - """Received a Redis command from the client. - """ + """Received a Redis command from the client.""" # We currently only support pub/sub. if command == b"PUBLISH": @@ -635,8 +633,7 @@ def handle_command(self, command, *args): raise Exception("Unknown command") def send(self, msg): - """Send a message back to the client. - """ + """Send a message back to the client.""" raw = self.encode(msg).encode("utf-8") self.transport.write(raw) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index c0ee1cfbd6f2..0ceb0f935cd4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -66,7 +66,10 @@ def prepare(self, *args, **kwargs): self.get_success( self.master_store.store_room( - ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1, + ROOM_ID, + USER_ID, + is_public=False, + room_version=RoomVersions.V1, ) ) diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index 6a5116dd2a05..153634d4eeaf 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -23,8 +23,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_room_account_data_limit(self): - """Test replication with many room account data updates - """ + """Test replication with many room account data updates""" store = self.hs.get_datastore() # generate lots of account data updates @@ -70,8 +69,7 @@ def test_update_function_room_account_data_limit(self): self.assertEqual([], received_rows) def test_update_function_global_account_data_limit(self): - """Test replication with many global account data updates - """ + """Test replication with many global account data updates""" store = self.hs.get_datastore() # generate lots of account data updates diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index bad0df08cf02..77856fc30445 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -129,7 +129,10 @@ def test_update_function_huge_state_change(self): ) pls["users"][OTHER_USER] = 50 self.helper.send_state( - self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + self.room_id, + EventTypes.PowerLevels, + pls, + tok=self.user_tok, ) # this is the point in the DAG where we make a fork @@ -255,8 +258,7 @@ def test_update_function_huge_state_change(self): self.assertIsNone(sr.event_id) def test_update_function_state_row_limit(self): - """Test replication with many state events over several stream ids. - """ + """Test replication with many state events over several stream ids.""" # we want to generate lots of state changes, but for this test, we want to # spread out the state changes over a few stream IDs. @@ -282,7 +284,10 @@ def test_update_function_state_row_limit(self): ) pls["users"].update({u: 50 for u in user_ids}) self.helper.send_state( - self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + self.room_id, + EventTypes.PowerLevels, + pls, + tok=self.user_tok, ) # this is the point in the DAG where we make a fork diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py index d1c15caeb001..1fe9d5b4d076 100644 --- a/tests/replication/tcp/test_remote_server_up.py +++ b/tests/replication/tcp/test_remote_server_up.py @@ -28,8 +28,7 @@ def prepare(self, reactor, clock, hs): self.factory = ReplicationStreamProtocolFactory(hs) def _make_client(self) -> Tuple[IProtocol, StringTransport]: - """Create a new direct TCP replication connection - """ + """Create a new direct TCP replication connection""" proto = self.factory.buildProtocol(("127.0.0.1", 0)) transport = StringTransport() diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index f35a5235e1fd..f8fd8a843c82 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -79,8 +79,7 @@ def _test_register(self) -> FakeChannel: ) def test_no_auth(self): - """With no authentication the request should finish. - """ + """With no authentication the request should finish.""" channel = self._test_register() self.assertEqual(channel.code, 200) @@ -89,8 +88,7 @@ def test_no_auth(self): @override_config({"main_replication_secret": "my-secret"}) def test_missing_auth(self): - """If the main process expects a secret that is not provided, an error results. - """ + """If the main process expects a secret that is not provided, an error results.""" channel = self._test_register() self.assertEqual(channel.code, 500) @@ -101,15 +99,13 @@ def test_missing_auth(self): } ) def test_unauthorized(self): - """If the main process receives the wrong secret, an error results. - """ + """If the main process receives the wrong secret, an error results.""" channel = self._test_register() self.assertEqual(channel.code, 500) @override_config({"worker_replication_secret": "my-secret"}) def test_authorized(self): - """The request should finish when the worker provides the authentication header. - """ + """The request should finish when the worker provides the authentication header.""" channel = self._test_register() self.assertEqual(channel.code, 200) diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 4608b65a0cbc..5da1d5dc4d8b 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -35,8 +35,7 @@ def _get_worker_hs_config(self) -> dict: return config def test_register_single_worker(self): - """Test that registration works when using a single client reader worker. - """ + """Test that registration works when using a single client reader worker.""" worker_hs = self.make_worker_hs("synapse.app.client_reader") site = self._hs_to_site[worker_hs] @@ -66,8 +65,7 @@ def test_register_single_worker(self): self.assertEqual(channel_2.json_body["user_id"], "@user:test") def test_register_multi_worker(self): - """Test that registration works when using multiple client reader workers. - """ + """Test that registration works when using multiple client reader workers.""" worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index d1feca961fa4..7ff11cde10b2 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -36,8 +36,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): - """Checks running multiple media repos work correctly. - """ + """Checks running multiple media repos work correctly.""" servlets = [ admin.register_servlets_for_client_rest_resource, @@ -124,8 +123,7 @@ def _get_media_req( return channel, request def test_basic(self): - """Test basic fetching of remote media from a single worker. - """ + """Test basic fetching of remote media from a single worker.""" hs1 = self.make_worker_hs("synapse.app.generic_worker") channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") @@ -223,16 +221,14 @@ def test_download_image_race(self): self.assertEqual(start_count + 3, self._count_remote_thumbnails()) def _count_remote_media(self) -> int: - """Count the number of files in our remote media directory. - """ + """Count the number of files in our remote media directory.""" path = os.path.join( self.hs.get_media_repository().primary_base_path, "remote_content" ) return sum(len(files) for _, _, files in os.walk(path)) def _count_remote_thumbnails(self) -> int: - """Count the number of files in our remote thumbnails directory. - """ + """Count the number of files in our remote thumbnails directory.""" path = os.path.join( self.hs.get_media_repository().primary_base_path, "remote_thumbnail" ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 800ad94a04a5..f118fe32af60 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -27,8 +27,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): - """Checks pusher sharding works - """ + """Checks pusher sharding works""" servlets = [ admin.register_servlets_for_client_rest_resource, @@ -88,11 +87,10 @@ def _create_pusher_and_send_msg(self, localpart): return event_id def test_send_push_single_worker(self): - """Test that registration works when using a pusher worker. - """ + """Test that registration works when using a pusher worker.""" http_client_mock = Mock(spec_set=["post_json_get_json"]) - http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( - {} + http_client_mock.post_json_get_json.side_effect = ( + lambda *_, **__: defer.succeed({}) ) self.make_worker_hs( @@ -119,11 +117,10 @@ def test_send_push_single_worker(self): ) def test_send_push_multiple_workers(self): - """Test that registration works when using sharded pusher workers. - """ + """Test that registration works when using sharded pusher workers.""" http_client_mock1 = Mock(spec_set=["post_json_get_json"]) - http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( - {} + http_client_mock1.post_json_get_json.side_effect = ( + lambda *_, **__: defer.succeed({}) ) self.make_worker_hs( @@ -137,8 +134,8 @@ def test_send_push_multiple_workers(self): ) http_client_mock2 = Mock(spec_set=["post_json_get_json"]) - http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( - {} + http_client_mock2.post_json_get_json.side_effect = ( + lambda *_, **__: defer.succeed({}) ) self.make_worker_hs( diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 8d494ebc038d..c9b773fbd215 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -29,8 +29,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): - """Checks event persisting sharding works - """ + """Checks event persisting sharding works""" # Event persister sharding requires postgres (due to needing # `MutliWriterIdGenerator`). @@ -63,8 +62,7 @@ def default_config(self): return conf def _create_room(self, room_id: str, user_id: str, tok: str): - """Create a room with given room_id - """ + """Create a room with given room_id""" # We control the room ID generation by patching out the # `_generate_room_id` method @@ -91,11 +89,13 @@ def test_basic(self): """ self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "worker1"}, + "synapse.app.generic_worker", + {"worker_name": "worker1"}, ) self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "worker2"}, + "synapse.app.generic_worker", + {"worker_name": "worker2"}, ) persisted_on_1 = False @@ -139,15 +139,18 @@ def test_vector_clock_token(self): """ self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "worker1"}, + "synapse.app.generic_worker", + {"worker_name": "worker1"}, ) worker_hs2 = self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "worker2"}, + "synapse.app.generic_worker", + {"worker_name": "worker2"}, ) sync_hs = self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "sync"}, + "synapse.app.generic_worker", + {"worker_name": "sync"}, ) sync_hs_site = self._hs_to_site[sync_hs] @@ -323,7 +326,9 @@ def test_vector_clock_token(self): sync_hs_site, "GET", "/rooms/{}/messages?from={}&to={}&dir=f".format( - room_id2, vector_clock_token, prev_batch2, + room_id2, + vector_clock_token, + prev_batch2, ), access_token=access_token, ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9d22c04073cb..057e27372e1e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -130,8 +130,7 @@ def _check_group(self, group_id, expect_code): ) def _get_groups_user_is_in(self, access_token): - """Returns the list of groups the user is in (given their access token) - """ + """Returns the list of groups the user is in (given their access token)""" channel = self.make_request( "GET", "/joined_groups".encode("ascii"), access_token=access_token ) @@ -142,8 +141,7 @@ def _get_groups_user_is_in(self, access_token): class QuarantineMediaTestCase(unittest.HomeserverTestCase): - """Test /quarantine_media admin API. - """ + """Test /quarantine_media admin API.""" servlets = [ synapse.rest.admin.register_servlets, @@ -237,7 +235,9 @@ def test_quarantine_media_requires_admin(self): # Attempt quarantine media APIs as non-admin url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" channel = self.make_request( - "POST", url.encode("ascii"), access_token=non_admin_user_tok, + "POST", + url.encode("ascii"), + access_token=non_admin_user_tok, ) # Expect a forbidden error @@ -250,7 +250,9 @@ def test_quarantine_media_requires_admin(self): # And the roomID/userID endpoint url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" channel = self.make_request( - "POST", url.encode("ascii"), access_token=non_admin_user_tok, + "POST", + url.encode("ascii"), + access_token=non_admin_user_tok, ) # Expect a forbidden error @@ -294,7 +296,11 @@ def test_quarantine_media_by_id(self): urllib.parse.quote(server_name), urllib.parse.quote(media_id), ) - channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=admin_user_tok, + ) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -346,7 +352,11 @@ def test_quarantine_all_media_in_room(self, override_url_template=None): url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( room_id ) - channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=admin_user_tok, + ) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( @@ -391,7 +401,9 @@ def test_quarantine_all_media_by_user(self): non_admin_user ) channel = self.make_request( - "POST", url.encode("ascii"), access_token=admin_user_tok, + "POST", + url.encode("ascii"), + access_token=admin_user_tok, ) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -437,7 +449,9 @@ def test_cannot_quarantine_safe_media(self): non_admin_user ) channel = self.make_request( - "POST", url.encode("ascii"), access_token=admin_user_tok, + "POST", + url.encode("ascii"), + access_token=admin_user_tok, ) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 248c4442c3a2..2a1bcf1760ed 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -70,21 +70,27 @@ def test_requester_is_no_admin(self): If the user is not a server admin, an error is returned. """ channel = self.make_request( - "GET", self.url, access_token=self.other_user_token, + "GET", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) channel = self.make_request( - "PUT", self.url, access_token=self.other_user_token, + "PUT", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) channel = self.make_request( - "DELETE", self.url, access_token=self.other_user_token, + "DELETE", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -99,17 +105,29 @@ def test_user_does_not_exist(self): % self.other_user_device_id ) - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -123,17 +141,29 @@ def test_user_is_not_local(self): % self.other_user_device_id ) - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -146,16 +176,28 @@ def test_unknown_device(self): self.other_user ) - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) # Delete unknown device returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -190,7 +232,11 @@ def test_update_device_too_long_display_name(self): self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -207,12 +253,20 @@ def test_update_no_display_name(self): ) ) - channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -233,7 +287,11 @@ def test_update_display_name(self): self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) @@ -242,7 +300,11 @@ def test_get_device(self): """ Tests that a normal lookup for a device is successfully """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -264,7 +326,9 @@ def test_delete_device(self): # Delete device channel = self.make_request( - "DELETE", self.url, access_token=self.admin_user_tok, + "DELETE", + self.url, + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -306,7 +370,11 @@ def test_requester_is_no_admin(self): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -316,7 +384,11 @@ def test_user_does_not_exist(self): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -327,7 +399,11 @@ def test_user_is_not_local(self): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -339,7 +415,11 @@ def test_user_has_no_devices(self): """ # Get devices - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -355,7 +435,11 @@ def test_get_devices(self): self.login("user", "pass") # Get devices - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) @@ -404,7 +488,11 @@ def test_requester_is_no_admin(self): """ other_user_token = self.login("user", "pass") - channel = self.make_request("POST", self.url, access_token=other_user_token,) + channel = self.make_request( + "POST", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -414,7 +502,11 @@ def test_user_does_not_exist(self): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" - channel = self.make_request("POST", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -425,7 +517,11 @@ def test_user_is_not_local(self): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" - channel = self.make_request("POST", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index d0090faa4fa2..e30ffe4fa0c1 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -51,19 +51,23 @@ def prepare(self, reactor, clock, hs): # Two rooms and two users. Every user sends and reports every room event for i in range(5): self._create_event_and_report( - room_id=self.room_id1, user_tok=self.other_user_tok, + room_id=self.room_id1, + user_tok=self.other_user_tok, ) for i in range(5): self._create_event_and_report( - room_id=self.room_id2, user_tok=self.other_user_tok, + room_id=self.room_id2, + user_tok=self.other_user_tok, ) for i in range(5): self._create_event_and_report( - room_id=self.room_id1, user_tok=self.admin_user_tok, + room_id=self.room_id1, + user_tok=self.admin_user_tok, ) for i in range(5): self._create_event_and_report( - room_id=self.room_id2, user_tok=self.admin_user_tok, + room_id=self.room_id2, + user_tok=self.admin_user_tok, ) self.url = "/_synapse/admin/v1/event_reports" @@ -82,7 +86,11 @@ def test_requester_is_no_admin(self): If the user is not a server admin, an error 403 is returned. """ - channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.other_user_tok, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -92,7 +100,11 @@ def test_default_success(self): Testing list of reported events """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -106,7 +118,9 @@ def test_limit(self): """ channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -121,7 +135,9 @@ def test_from(self): """ channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -136,7 +152,9 @@ def test_limit_and_from(self): """ channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -213,7 +231,9 @@ def test_valid_search_order(self): # fetch the most recent first, largest timestamp channel = self.make_request( - "GET", self.url + "?dir=b", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=b", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -229,7 +249,9 @@ def test_valid_search_order(self): # fetch the oldest first, smallest timestamp channel = self.make_request( - "GET", self.url + "?dir=f", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=f", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -249,7 +271,9 @@ def test_invalid_search_order(self): """ channel = self.make_request( - "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -262,7 +286,9 @@ def test_limit_is_negative(self): """ channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -274,7 +300,9 @@ def test_from_is_negative(self): """ channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -288,7 +316,9 @@ def test_next_token(self): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -299,7 +329,9 @@ def test_next_token(self): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -310,7 +342,9 @@ def test_next_token(self): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -322,7 +356,9 @@ def test_next_token(self): # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -331,8 +367,7 @@ def test_next_token(self): self.assertNotIn("next_token", channel.json_body) def _create_event_and_report(self, room_id, user_tok): - """Create and report events - """ + """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -345,8 +380,7 @@ def _create_event_and_report(self, room_id, user_tok): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in an event report - """ + """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) self.assertIn("received_ts", c) @@ -381,7 +415,8 @@ def prepare(self, reactor, clock, hs): self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) self._create_event_and_report( - room_id=self.room_id1, user_tok=self.other_user_tok, + room_id=self.room_id1, + user_tok=self.other_user_tok, ) # first created event report gets `id`=2 @@ -401,7 +436,11 @@ def test_requester_is_no_admin(self): If the user is not a server admin, an error 403 is returned. """ - channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.other_user_tok, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -411,7 +450,11 @@ def test_default_success(self): Testing get a reported event """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) @@ -479,8 +522,7 @@ def test_report_id_not_found(self): self.assertEqual("Event report not found", channel.json_body["error"]) def _create_event_and_report(self, room_id, user_tok): - """Create and report events - """ + """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -493,8 +535,7 @@ def _create_event_and_report(self, room_id, user_tok): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in a event report - """ + """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) self.assertIn("room_id", content) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 51a7731693cd..31db472cd32b 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -63,7 +63,11 @@ def test_requester_is_no_admin(self): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - channel = self.make_request("DELETE", url, access_token=self.other_user_token,) + channel = self.make_request( + "DELETE", + url, + access_token=self.other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -74,7 +78,11 @@ def test_media_does_not_exist(self): """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -85,7 +93,11 @@ def test_media_is_not_local(self): """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) @@ -139,12 +151,17 @@ def test_delete_media(self): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) # Delete media - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - media_id, channel.json_body["deleted_media"][0], + media_id, + channel.json_body["deleted_media"][0], ) # Attempt to access media @@ -207,7 +224,9 @@ def test_requester_is_no_admin(self): self.other_user_token = self.login("user", "pass") channel = self.make_request( - "POST", self.url, access_token=self.other_user_token, + "POST", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -220,7 +239,9 @@ def test_media_is_not_local(self): url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" channel = self.make_request( - "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, + "POST", + url + "?before_ts=1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -230,7 +251,11 @@ def test_missing_parameter(self): """ If the parameter `before_ts` is missing, an error is returned. """ - channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -243,7 +268,9 @@ def test_invalid_parameter(self): If parameters are invalid, an error is returned. """ channel = self.make_request( - "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, + "POST", + self.url + "?before_ts=-1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -304,7 +331,8 @@ def test_delete_media_never_accessed(self): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - media_id, channel.json_body["deleted_media"][0], + media_id, + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -340,7 +368,8 @@ def test_keep_media_by_date(self): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -374,7 +403,8 @@ def test_keep_media_by_size(self): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -417,7 +447,8 @@ def test_keep_media_by_user_avatar(self): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -461,7 +492,8 @@ def test_keep_media_by_room_avatar(self): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 2a217b1ce05f..b55160b70afa 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -127,8 +127,7 @@ def test_shutdown_room_block_peek(self): self._assert_peek(room_id, expect_code=403) def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ + """Assert that the admin user can (or cannot) peek into the room.""" url = "rooms/%s/initialSync" % (room_id,) channel = self.make_request( @@ -186,7 +185,10 @@ def test_requester_is_no_admin(self): """ channel = self.make_request( - "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + "POST", + self.url, + json.dumps({}), + access_token=self.other_user_tok, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -199,7 +201,10 @@ def test_room_does_not_exist(self): url = "/_synapse/admin/v1/rooms/!unknown:test/delete" channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, + "POST", + url, + json.dumps({}), + access_token=self.admin_user_tok, ) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) @@ -212,12 +217,16 @@ def test_room_is_not_valid(self): url = "/_synapse/admin/v1/rooms/invalidroom/delete" channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, + "POST", + url, + json.dumps({}), + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - "invalidroom is not a legal room ID", channel.json_body["error"], + "invalidroom is not a legal room ID", + channel.json_body["error"], ) def test_new_room_user_does_not_exist(self): @@ -254,7 +263,8 @@ def test_new_room_user_is_not_local(self): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - "User must be our own: @not:exist.bla", channel.json_body["error"], + "User must be our own: @not:exist.bla", + channel.json_body["error"], ) def test_block_is_not_bool(self): @@ -491,8 +501,7 @@ def test_shutdown_room_block_peek(self): self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id, expect=True): - """Assert that the room is blocked or not - """ + """Assert that the room is blocked or not""" d = self.store.is_room_blocked(room_id) if expect: self.assertTrue(self.get_success(d)) @@ -500,20 +509,17 @@ def _is_blocked(self, room_id, expect=True): self.assertIsNone(self.get_success(d)) def _has_no_members(self, room_id): - """Assert there is now no longer anyone in the room - """ + """Assert there is now no longer anyone in the room""" users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) def _is_member(self, room_id, user_id): - """Test that user is member of the room - """ + """Test that user is member of the room""" users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertIn(user_id, users_in_room) def _is_purged(self, room_id): - """Test that the following tables have been purged of all rows related to the room. - """ + """Test that the following tables have been purged of all rows related to the room.""" for table in PURGE_TABLES: count = self.get_success( self.store.db_pool.simple_select_one_onecol( @@ -527,8 +533,7 @@ def _is_purged(self, room_id): self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ + """Assert that the admin user can (or cannot) peek into the room.""" url = "rooms/%s/initialSync" % (room_id,) channel = self.make_request( @@ -548,8 +553,7 @@ def _assert_peek(self, room_id, expect_code): class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ + """Test /purge_room admin API.""" servlets = [ synapse.rest.admin.register_servlets, @@ -594,8 +598,7 @@ def test_purge_room(self): class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ + """Test /room admin API.""" servlets = [ synapse.rest.admin.register_servlets, @@ -623,7 +626,9 @@ def test_list_rooms(self): # Request the list of rooms url = "/_synapse/admin/v1/rooms" channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) # Check request completed successfully @@ -685,7 +690,10 @@ def test_list_rooms_pagination(self): # Set the name of the rooms so we get a consistent returned ordering for idx, room_id in enumerate(room_ids): self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + room_id, + "m.room.name", + {"name": str(idx)}, + tok=self.admin_user_tok, ) # Request the list of rooms @@ -704,7 +712,9 @@ def test_list_rooms_pagination(self): "name", ) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual( 200, int(channel.result["code"]), msg=channel.result["body"] @@ -744,7 +754,9 @@ def test_list_rooms_pagination(self): url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -788,13 +800,18 @@ def test_correct_room_attributes(self): # Set a name for the room self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + room_id, + "m.room.name", + {"name": test_room_name}, + tok=self.admin_user_tok, ) # Request the list of rooms url = "/_synapse/admin/v1/rooms" channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -860,7 +877,9 @@ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): ) def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, + order_type: str, + expected_room_list: List[str], + reverse: bool = False, ): """Request the list of rooms in a certain order. Assert that order is what we expect @@ -875,7 +894,9 @@ def _order_test( if reverse: url += "&dir=b" channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -907,13 +928,22 @@ def _order_test( # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + room_id_1, + "m.room.name", + {"name": "A"}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + room_id_2, + "m.room.name", + {"name": "B"}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + room_id_3, + "m.room.name", + {"name": "C"}, + tok=self.admin_user_tok, ) # Set room canonical room aliases @@ -990,10 +1020,16 @@ def test_search_term(self): # Set the name for each room self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + room_id_1, + "m.room.name", + {"name": room_name_1}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + room_id_2, + "m.room.name", + {"name": room_name_2}, + tok=self.admin_user_tok, ) def _search_test( @@ -1011,7 +1047,9 @@ def _search_test( """ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -1071,15 +1109,23 @@ def test_single_room(self): # Set the name for each room self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + room_id_1, + "m.room.name", + {"name": room_name_1}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + room_id_2, + "m.room.name", + {"name": room_name_2}, + tok=self.admin_user_tok, ) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1109,7 +1155,9 @@ def test_single_room_devices(self): url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) @@ -1121,7 +1169,9 @@ def test_single_room_devices(self): url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) @@ -1131,7 +1181,9 @@ def test_single_room_devices(self): self.helper.leave(room_id_1, user_1, tok=user_tok_1) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) @@ -1160,7 +1212,9 @@ def test_room_members(self): url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1171,7 +1225,9 @@ def test_room_members(self): url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1187,7 +1243,9 @@ def test_room_state(self): url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) @@ -1342,7 +1400,9 @@ def test_join_public_room(self): # Validate if user is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) @@ -1389,7 +1449,9 @@ def test_join_private_room_if_member(self): # Validate if server admin is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.admin_user_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1411,7 +1473,9 @@ def test_join_private_room_if_member(self): # Validate if user is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1440,7 +1504,9 @@ def test_join_private_room_if_owner(self): # Validate if user is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1555,8 +1621,7 @@ def prepare(self, reactor, clock, homeserver): ) def test_public_room(self): - """Test that getting admin in a public room works. - """ + """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) @@ -1581,10 +1646,11 @@ def test_public_room(self): ) def test_private_room(self): - """Test that getting admin in a private room works and we get invited. - """ + """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False, + self.creator, + tok=self.creator_tok, + is_public=False, ) channel = self.make_request( @@ -1608,8 +1674,7 @@ def test_private_room(self): ) def test_other_user(self): - """Test that giving admin in a public room works to a non-admin user works. - """ + """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) @@ -1634,8 +1699,7 @@ def test_other_user(self): ) def test_not_enough_power(self): - """Test that we get a sensible error if there are no local room admins. - """ + """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index f48be3d65aa0..1f1d11f527d5 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -55,7 +55,10 @@ def test_requester_is_no_admin(self): If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( - "GET", self.url, json.dumps({}), access_token=self.other_user_tok, + "GET", + self.url, + json.dumps({}), + access_token=self.other_user_tok, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -67,7 +70,9 @@ def test_invalid_parameter(self): """ # unkown order_by channel = self.make_request( - "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok, + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -75,7 +80,9 @@ def test_invalid_parameter(self): # negative from channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -83,7 +90,9 @@ def test_invalid_parameter(self): # negative limit channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -91,7 +100,9 @@ def test_invalid_parameter(self): # negative from_ts channel = self.make_request( - "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok, + "GET", + self.url + "?from_ts=-1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -99,7 +110,9 @@ def test_invalid_parameter(self): # negative until_ts channel = self.make_request( - "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok, + "GET", + self.url + "?until_ts=-1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -117,7 +130,9 @@ def test_invalid_parameter(self): # empty search term channel = self.make_request( - "GET", self.url + "?search_term=", access_token=self.admin_user_tok, + "GET", + self.url + "?search_term=", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -125,7 +140,9 @@ def test_invalid_parameter(self): # invalid search order channel = self.make_request( - "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -138,7 +155,9 @@ def test_limit(self): self._create_users_with_media(10, 2) channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -154,7 +173,9 @@ def test_from(self): self._create_users_with_media(20, 2) channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -170,7 +191,9 @@ def test_limit_and_from(self): self._create_users_with_media(20, 2) channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -190,7 +213,9 @@ def test_next_token(self): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -201,7 +226,9 @@ def test_next_token(self): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -212,7 +239,9 @@ def test_next_token(self): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -223,7 +252,9 @@ def test_next_token(self): # Set `from` to value of `next_token` for request remaining entries # Check `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -237,7 +268,11 @@ def test_no_media(self): if users have no media created """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -264,10 +299,14 @@ def test_order_by(self): # order by user_id self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"]) self._order_test( - "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f", + "user_id", + ["@user_a:test", "@user_b:test", "@user_c:test"], + "f", ) self._order_test( - "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b", + "user_id", + ["@user_c:test", "@user_b:test", "@user_a:test"], + "b", ) # order by displayname @@ -275,32 +314,46 @@ def test_order_by(self): "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"] ) self._order_test( - "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f", + "displayname", + ["@user_c:test", "@user_b:test", "@user_a:test"], + "f", ) self._order_test( - "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b", + "displayname", + ["@user_a:test", "@user_b:test", "@user_c:test"], + "b", ) # order by media_length self._order_test( - "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], + "media_length", + ["@user_a:test", "@user_c:test", "@user_b:test"], ) self._order_test( - "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f", + "media_length", + ["@user_a:test", "@user_c:test", "@user_b:test"], + "f", ) self._order_test( - "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b", + "media_length", + ["@user_b:test", "@user_c:test", "@user_a:test"], + "b", ) # order by media_count self._order_test( - "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], + "media_count", + ["@user_a:test", "@user_c:test", "@user_b:test"], ) self._order_test( - "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f", + "media_count", + ["@user_a:test", "@user_c:test", "@user_b:test"], + "f", ) self._order_test( - "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b", + "media_count", + ["@user_b:test", "@user_c:test", "@user_a:test"], + "b", ) def test_from_until_ts(self): @@ -313,14 +366,20 @@ def test_from_until_ts(self): ts1 = self.clock.time_msec() # list all media when filter is not set - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media # result is 0 channel = self.make_request( - "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, + "GET", + self.url + "?from_ts=%s" % (ts1,), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) @@ -342,7 +401,9 @@ def test_from_until_ts(self): # filter media until `ts2` and earlier channel = self.make_request( - "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, + "GET", + self.url + "?until_ts=%s" % (ts2,), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) @@ -351,7 +412,11 @@ def test_search_term(self): self._create_users_with_media(20, 1) # check without filter get all users - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -376,7 +441,9 @@ def test_search_term(self): # filter and get empty result channel = self.make_request( - "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok, + "GET", + self.url + "?search_term=foobar", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) @@ -441,7 +508,9 @@ def _order_test( if dir is not None and dir in ("b", "f"): url += "&dir=%s" % (dir,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ee05ee60bc6c..ff75199c8e0e 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -528,9 +528,14 @@ def _search_test( search_field: Field which is to request: `name` or `user_id` expected_http_code: The expected http code for the request """ - url = self.url + "?%s=%s" % (search_field, search_term,) + url = self.url + "?%s=%s" % ( + search_field, + search_term, + ) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -590,7 +595,9 @@ def test_invalid_parameter(self): # negative limit channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -598,7 +605,9 @@ def test_invalid_parameter(self): # negative from channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -606,7 +615,9 @@ def test_invalid_parameter(self): # invalid guests channel = self.make_request( - "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok, + "GET", + self.url + "?guests=not_bool", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -614,7 +625,9 @@ def test_invalid_parameter(self): # invalid deactivated channel = self.make_request( - "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok, + "GET", + self.url + "?deactivated=not_bool", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -630,7 +643,9 @@ def test_limit(self): self._create_users(number_users - 1) channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -649,7 +664,9 @@ def test_from(self): self._create_users(number_users - 1) channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -668,7 +685,9 @@ def test_limit_and_from(self): self._create_users(number_users - 1) channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -689,7 +708,9 @@ def test_next_token(self): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -700,7 +721,9 @@ def test_next_token(self): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -711,7 +734,9 @@ def test_next_token(self): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -723,7 +748,9 @@ def test_next_token(self): # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -753,7 +780,10 @@ def _create_users(self, number_users: int): """ for i in range(1, number_users + 1): self.register_user( - "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i, + "user%d" % i, + "pass%d" % i, + admin=False, + displayname="Name %d" % i, ) @@ -808,7 +838,10 @@ def test_requester_is_not_admin(self): self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( - "POST", url, access_token=self.other_user_token, content=b"{}", + "POST", + url, + access_token=self.other_user_token, + content=b"{}", ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -862,7 +895,9 @@ def test_deactivate_user_erase_true(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -886,7 +921,9 @@ def test_deactivate_user_erase_true(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -905,7 +942,9 @@ def test_deactivate_user_erase_false(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -929,7 +968,9 @@ def test_deactivate_user_erase_false(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -942,8 +983,7 @@ def test_deactivate_user_erase_false(self): self._is_erased("@user:test", False) def _is_erased(self, user_id: str, expect: bool) -> None: - """Assert that the user is erased or not - """ + """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: self.assertTrue(self.get_success(d)) @@ -977,13 +1017,20 @@ def test_requester_is_no_admin(self): """ url = "/_synapse/admin/v2/users/@bob:test" - channel = self.make_request("GET", url, access_token=self.other_user_token,) + channel = self.make_request( + "GET", + url, + access_token=self.other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( - "PUT", url, access_token=self.other_user_token, content=b"{}", + "PUT", + url, + access_token=self.other_user_token, + content=b"{}", ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -1036,7 +1083,11 @@ def test_create_server_admin(self): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1081,7 +1132,11 @@ def test_create_user(self): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1306,7 +1361,9 @@ def test_set_displayname(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1337,7 +1394,9 @@ def test_set_threepid(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1360,7 +1419,9 @@ def test_deactivate_user(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1390,7 +1451,9 @@ def test_deactivate_user(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1488,7 +1551,9 @@ def test_reactivate_user(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1517,7 +1582,9 @@ def test_set_user_as_admin(self): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1546,7 +1613,11 @@ def test_accidental_deactivation_prevention(self): self.assertEqual("bob", channel.json_body["displayname"]) # Get user - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1566,7 +1637,11 @@ def test_accidental_deactivation_prevention(self): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Check user is not deactivated - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1576,8 +1651,7 @@ def test_accidental_deactivation_prevention(self): self.assertEqual(0, channel.json_body["deactivated"]) def _is_erased(self, user_id, expect): - """Assert that the user is erased or not - """ + """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: self.assertTrue(self.get_success(d)) @@ -1617,7 +1691,11 @@ def test_requester_is_no_admin(self): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1627,7 +1705,11 @@ def test_user_does_not_exist(self): Tests that a lookup for a user that does not exist returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1639,7 +1721,11 @@ def test_user_is_not_local(self): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1651,7 +1737,11 @@ def test_no_memberships(self): if user has no memberships """ # Get rooms - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1668,7 +1758,11 @@ def test_get_rooms(self): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) @@ -1711,7 +1805,11 @@ def test_get_rooms_with_nonlocal_user(self): # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1751,7 +1849,11 @@ def test_requester_is_no_admin(self): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1761,7 +1863,11 @@ def test_user_does_not_exist(self): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1772,7 +1878,11 @@ def test_user_is_not_local(self): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1783,7 +1893,11 @@ def test_get_pushers(self): """ # Get pushers - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1810,7 +1924,11 @@ def test_get_pushers(self): ) # Get pushers - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1859,7 +1977,11 @@ def test_requester_is_no_admin(self): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1869,7 +1991,11 @@ def test_user_does_not_exist(self): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/media" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1880,7 +2006,11 @@ def test_user_is_not_local(self): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1895,7 +2025,9 @@ def test_limit(self): self._create_media(other_user_tok, number_media) channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1914,7 +2046,9 @@ def test_from(self): self._create_media(other_user_tok, number_media) channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1933,7 +2067,9 @@ def test_limit_and_from(self): self._create_media(other_user_tok, number_media) channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1948,7 +2084,9 @@ def test_limit_is_negative(self): """ channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1960,7 +2098,9 @@ def test_from_is_negative(self): """ channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1978,7 +2118,9 @@ def test_next_token(self): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1989,7 +2131,9 @@ def test_next_token(self): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -2000,7 +2144,9 @@ def test_next_token(self): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -2012,7 +2158,9 @@ def test_next_token(self): # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -2026,7 +2174,11 @@ def test_user_has_no_media(self): if user has no media created """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -2041,7 +2193,11 @@ def test_get_media(self): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) @@ -2068,8 +2224,7 @@ def _create_media(self, user_token, number_media): ) def _check_fields(self, content): - """Checks that all attributes are present in content - """ + """Checks that all attributes are present in content""" for m in content: self.assertIn("media_id", m) self.assertIn("media_type", m) @@ -2082,8 +2237,7 @@ def _check_fields(self, content): class UserTokenRestTestCase(unittest.HomeserverTestCase): - """Test for /_synapse/admin/v1/users//login - """ + """Test for /_synapse/admin/v1/users//login""" servlets = [ synapse.rest.admin.register_servlets, @@ -2114,16 +2268,14 @@ def _get_token(self) -> str: return channel.json_body["access_token"] def test_no_auth(self): - """Try to login as a user without authentication. - """ + """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): - """Try to login as a user as a non-admin user. - """ + """Try to login as a user as a non-admin user.""" channel = self.make_request( "POST", self.url, b"{}", access_token=self.other_user_tok ) @@ -2131,8 +2283,7 @@ def test_not_admin(self): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) def test_send_event(self): - """Test that sending event as a user works. - """ + """Test that sending event as a user works.""" # Create a room. room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) @@ -2146,8 +2297,7 @@ def test_send_event(self): self.assertEqual(event.sender, self.other_user) def test_devices(self): - """Tests that logging in as a user doesn't create a new device for them. - """ + """Tests that logging in as a user doesn't create a new device for them.""" # Login in as the user self._get_token() @@ -2161,8 +2311,7 @@ def test_devices(self): self.assertEqual(len(channel.json_body["devices"]), 1) def test_logout(self): - """Test that calling `/logout` with the token works. - """ + """Test that calling `/logout` with the token works.""" # Login in as the user puppet_token = self._get_token() @@ -2252,8 +2401,7 @@ def test_admin_logout_all(self): } ) def test_consent(self): - """Test that sending a message is not subject to the privacy policies. - """ + """Test that sending a message is not subject to the privacy policies.""" # Have the admin user accept the terms. self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0")) @@ -2328,11 +2476,19 @@ def test_requester_is_not_admin(self): self.register_user("user2", "pass") other_user2_token = self.login("user2", "pass") - channel = self.make_request("GET", self.url1, access_token=other_user2_token,) + channel = self.make_request( + "GET", + self.url1, + access_token=other_user2_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - channel = self.make_request("GET", self.url2, access_token=other_user2_token,) + channel = self.make_request( + "GET", + self.url2, + access_token=other_user2_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -2343,11 +2499,19 @@ def test_user_is_not_local(self): url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - channel = self.make_request("GET", url1, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url1, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) - channel = self.make_request("GET", url2, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url2, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) @@ -2355,12 +2519,20 @@ def test_get_whois_admin(self): """ The lookup should succeed for an admin. """ - channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url1, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url2, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -2371,12 +2543,20 @@ def test_get_whois_user(self): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url1, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url1, + access_token=other_user_token, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - channel = self.make_request("GET", self.url2, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url2, + access_token=other_user_token, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 913ea3c98e19..5256c11fe672 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -73,7 +73,9 @@ def prepare(self, reactor, clock, hs): # Mod the mod room_power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.admin_access_token, + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, ) # Update existing power levels with mod at PL50 diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index f0707646bb33..e0c74591b643 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -181,8 +181,7 @@ def test_redact_create_event(self): ) def test_redact_event_as_moderator_ratelimit(self): - """Tests that the correct ratelimiting is applied to redactions - """ + """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] # as a regular user, send messages to redact diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 31dc832fd550..aee99bb6a0aa 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -250,7 +250,8 @@ def make_homeserver(self, reactor, clock): mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, federation_client=mock_federation_client, + config=config, + federation_client=mock_federation_client, ) return self.hs diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 0ebdf1415b2d..d2cce44032fa 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -260,7 +260,10 @@ def test_displayname(self): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, ) ) self.assertEqual( @@ -292,7 +295,10 @@ def test_room_displayname(self): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, ) ) self.assertEqual( diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0a5ca317ea8b..2ae896db1ec9 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -150,6 +150,8 @@ def test_get_event_via_events(self): event_id = resp["event_id"] channel = self.make_request( - "GET", "/events/" + event_id, access_token=self.token, + "GET", + "/events/" + event_id, + access_token=self.token, ) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 49543d9acb78..fb29eaed6f08 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -611,7 +611,9 @@ def test_login_via_oidc(self): # matrix access token, mxid, and device id. login_token = params[2][1] chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") @@ -619,7 +621,8 @@ def test_login_via_oidc(self): def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( - "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + "GET", + "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) self.assertEqual(channel.code, 400, channel.result) @@ -719,7 +722,8 @@ async def get_raw(uri, args): mocked_http_client.get_raw.side_effect = get_raw self.hs = self.setup_test_homeserver( - config=config, proxied_http_client=mocked_http_client, + config=config, + proxied_http_client=mocked_http_client, ) return self.hs @@ -1244,7 +1248,9 @@ def test_username_picker(self): # looks ok. username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions self.assertIn( - session_id, username_mapping_sessions, "session id not found in map", + session_id, + username_mapping_sessions, + "session id not found in map", ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") @@ -1299,7 +1305,9 @@ def test_username_picker(self): # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2548b3a80c97..ed65f645fc2c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", federation_http_client=None, federation_client=Mock(), + "red", + federation_http_client=None, + federation_client=Mock(), ) self.hs.get_federation_handler = Mock() @@ -1480,7 +1482,9 @@ def test_search_filter_labels(self): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 2, [result["result"]["content"] for result in results], + len(results), + 2, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1515,7 +1519,9 @@ def test_search_filter_not_labels(self): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 4, [result["result"]["content"] for result in results], + len(results), + 4, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1562,7 +1568,9 @@ def test_search_filter_labels_not_labels(self): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 1, [result["result"]["content"] for result in results], + len(results), + 1, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index f6f3b9a356c5..329dbd06def2 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -37,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", federation_http_client=None, federation_client=Mock(), + "red", + federation_http_client=None, + federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index b1333df82daa..8231a423f336 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -166,9 +166,12 @@ def change_membership( json.dumps(data).encode("utf8"), ) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) self.auth_user_id = temp_id @@ -201,9 +204,12 @@ def send_event( json.dumps(content).encode("utf8"), ) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) return channel.json_body @@ -251,9 +257,12 @@ def _read_write_state( channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) return channel.json_body @@ -447,7 +456,10 @@ def auth_via_oidc( return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) def complete_oidc_auth( - self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + self, + oauth_uri: str, + cookies: Mapping[str, str], + user_info_dict: JsonDict, ) -> FakeChannel: """Mock out an OIDC authentication flow @@ -491,7 +503,9 @@ async def mock_req(method: str, uri: str, data=None, headers=None): (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), + code=200, + phrase=b"OK", + body=json.dumps(resp_obj).encode("utf-8"), ) return resp diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 177dc476da78..e72b61963d11 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -75,8 +75,7 @@ def prepare(self, reactor, clock, hs): self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): - """Test basic password reset flow - """ + """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -114,8 +113,7 @@ def test_basic_password_reset(self): @override_config({"rc_3pid_validation": {"burst_count": 3}}) def test_ratelimit_by_email(self): - """Test that we ratelimit /requestToken for the same email. - """ + """Test that we ratelimit /requestToken for the same email.""" old_password = "monkey" new_password = "kangeroo" @@ -203,8 +201,7 @@ def test_basic_password_reset_canonicalise_email(self): self.attempt_wrong_password_login("kermit", old_password) def test_cant_reset_password_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ + """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -299,7 +296,9 @@ def _request_token(self, email, client_secret, ip="127.0.0.1"): if channel.code != 200: raise HttpResponseException( - channel.code, channel.result["reason"], channel.result["body"], + channel.code, + channel.result["reason"], + channel.result["body"], ) return channel.json_body["sid"] @@ -566,8 +565,7 @@ def test_address_trim(self): @override_config({"rc_3pid_validation": {"burst_count": 3}}) def test_ratelimit_by_ip(self): - """Tests that adding emails is ratelimited by IP - """ + """Tests that adding emails is ratelimited by IP""" # We expect to be able to set three emails before getting ratelimited. self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) @@ -580,8 +578,7 @@ def test_ratelimit_by_ip(self): self.assertEqual(cm.exception.code, 429) def test_add_email_if_disabled(self): - """Test adding email to profile when doing so is disallowed - """ + """Test adding email to profile when doing so is disallowed""" self.hs.config.enable_3pid_changes = False client_secret = "foobar" @@ -611,15 +608,16 @@ def test_add_email_if_disabled(self): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self): - """Test deleting an email from profile - """ + """Test deleting an email from profile""" # Add a threepid self.get_success( self.store.user_add_threepid( @@ -641,15 +639,16 @@ def test_delete_email(self): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self): - """Test deleting an email from profile when disallowed - """ + """Test deleting an email from profile when disallowed""" self.hs.config.enable_3pid_changes = False # Add a threepid @@ -675,7 +674,9 @@ def test_delete_email_if_disabled(self): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -683,8 +684,7 @@ def test_delete_email_if_disabled(self): self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ + """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -710,7 +710,9 @@ def test_cant_add_email_without_clicking_link(self): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -743,7 +745,9 @@ def test_no_valid_token(self): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -788,7 +792,10 @@ def test_next_link_domain_whitelist(self): # Ensure not providing a next_link parameter still works self._request_token( - "something@example.com", "some_secret", next_link=None, expect_code=200, + "something@example.com", + "some_secret", + next_link=None, + expect_code=200, ) self._request_token( @@ -846,17 +853,27 @@ def _request_token( if next_link: body["next_link"] = next_link - channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + body, + ) if channel.code != expect_code: raise HttpResponseException( - channel.code, channel.result["reason"], channel.result["body"], + channel.code, + channel.result["reason"], + channel.result["body"], ) return channel.json_body.get("sid") def _request_token_invalid_email( - self, email, expected_errcode, expected_error, client_secret="foobar", + self, + email, + expected_errcode, + expected_error, + client_secret="foobar", ): channel = self.make_request( "POST", @@ -895,8 +912,7 @@ def _get_link_from_email(self): return match.group(0) def _add_email(self, request_email, expected_email): - """Test adding an email to profile - """ + """Test adding an email to profile""" previous_email_attempts = len(self.email_attempts) client_secret = "foobar" @@ -926,7 +942,9 @@ def _add_email(self, request_email, expected_email): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 3f50c567455b..501f09203fe1 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -102,7 +102,8 @@ def test_fallback_captcha(self): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"}, + 401, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -191,7 +192,10 @@ def delete_device( ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( - "DELETE", "devices/" + device, body, access_token=access_token, + "DELETE", + "devices/" + device, + body, + access_token=access_token, ) # Ensure the response is sane. @@ -204,7 +208,10 @@ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: # Note that this uses the delete_devices endpoint so that we can modify # the payload half-way through some tests. channel = self.make_request( - "POST", "delete_devices", body, access_token=self.user_tok, + "POST", + "delete_devices", + body, + access_token=self.user_tok, ) # Ensure the response is sane. @@ -417,7 +424,10 @@ def test_ui_auth_via_sso(self): # and now the delete request should succeed. self.delete_device( - self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, + self.user_tok, + self.device_id, + 200, + body={"auth": {"session": session_id}}, ) @skip_unless(HAS_OIDC, "requires OIDC") @@ -443,8 +453,7 @@ def test_does_not_offer_sso_for_password_user(self): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self): - """A user that had a password and then logged in with SSO should get both flows - """ + """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -459,8 +468,7 @@ def test_offers_both_flows_for_upgraded_user(self): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self): - """If the user tries to authenticate with the wrong SSO user, they get an error - """ + """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index fba34def30f7..5ebc5707a5f1 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -91,7 +91,9 @@ def test_password_too_short(self): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, ) def test_password_no_digit(self): @@ -100,7 +102,9 @@ def test_password_no_digit(self): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, ) def test_password_no_symbol(self): @@ -109,7 +113,9 @@ def test_password_no_symbol(self): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, ) def test_password_no_uppercase(self): @@ -118,7 +124,9 @@ def test_password_no_uppercase(self): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, ) def test_password_no_lowercase(self): @@ -127,7 +135,9 @@ def test_password_no_lowercase(self): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, ) def test_password_compliant(self): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index bd574077e7bb..7c457754f1ff 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -83,14 +83,12 @@ def test_send_relation(self): ) def test_deny_membership(self): - """Test that we deny relations on membership events - """ + """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) def test_deny_double_react(self): - """Test that we deny relations on membership events - """ + """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) @@ -98,8 +96,7 @@ def test_deny_double_react(self): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API correctly the latest relations. - """ + """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) @@ -174,8 +171,7 @@ def test_repeated_paginate_relations(self): self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation_pagination_groups(self): - """Test that we can paginate annotation groups correctly. - """ + """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] @@ -240,8 +236,7 @@ def test_aggregation_pagination_groups(self): self.assertEquals(sent_groups, found_groups) def test_aggregation_pagination_within_group(self): - """Test that we can paginate within an annotation group. - """ + """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] @@ -311,8 +306,7 @@ def test_aggregation_pagination_within_group(self): self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation(self): - """Test that annotations get correctly aggregated. - """ + """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -344,8 +338,7 @@ def test_aggregation(self): ) def test_aggregation_redactions(self): - """Test that annotations get correctly aggregated after a redaction. - """ + """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -379,8 +372,7 @@ def test_aggregation_redactions(self): ) def test_aggregation_must_be_annotation(self): - """Test that aggregations must be annotations. - """ + """Test that aggregations must be annotations.""" channel = self.make_request( "GET", @@ -437,8 +429,7 @@ def test_aggregation_get_event(self): ) def test_edit(self): - """Test that a simple edit works. - """ + """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 512e36c2362d..2dbf42397a6b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -388,13 +388,19 @@ def test_unread_counts(self): # Check that room name changes increase the unread counter. self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + self.room_id, + "m.room.name", + {"name": "my super room"}, + tok=self.tok2, ) self._check_unread_count(1) # Check that room topic changes increase the unread counter. self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + self.room_id, + "m.room.topic", + {"topic": "welcome!!!"}, + tok=self.tok2, ) self._check_unread_count(2) @@ -404,7 +410,10 @@ def test_unread_counts(self): # Check that custom events with a body increase the unread counter. self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + self.room_id, + "org.matrix.custom_type", + {"body": "hello"}, + tok=self.tok2, ) self._check_unread_count(4) @@ -443,14 +452,18 @@ def _check_unread_count(self, expected_count: True): """Syncs and compares the unread count with the expected value.""" channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, + "GET", + self.url % self.next_batch, + access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) room_entry = channel.json_body["rooms"]["join"][self.room_id] self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + room_entry["org.matrix.msc2654.unread_count"], + expected_count, + room_entry, ) # Store the next batch for the next request. diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py index 7c22293d6d58..d890d11863a5 100644 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -85,7 +85,9 @@ def test_power_levels(self): # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.creator_token, + self.room_id, + "m.room.power_levels", + tok=self.creator_token, ) power_levels["users"][self.other] = 100 self.helper.send_state( @@ -109,7 +111,9 @@ def test_power_levels_user_default(self): # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.creator_token, + self.room_id, + "m.room.power_levels", + tok=self.creator_token, ) power_levels["users_default"] = 100 self.helper.send_state( @@ -133,7 +137,9 @@ def test_power_levels_tombstone(self): # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.creator_token, + self.room_id, + "m.room.power_levels", + tok=self.creator_token, ) power_levels["events"]["m.room.tombstone"] = 0 self.helper.send_state( @@ -148,6 +154,8 @@ def test_power_levels_tombstone(self): self.assertEquals(200, channel.code, channel.result) power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.creator_token, + self.room_id, + "m.room.power_levels", + tok=self.creator_token, ) self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 5e90d656f7c5..9d0d0ef41466 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -180,7 +180,8 @@ def prepare(self, reactor, clock, homeserver): async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( - path, "/_matrix/key/v2/query", + path, + "/_matrix/key/v2/query", ) channel = FakeChannel(self.site, self.reactor) @@ -188,7 +189,9 @@ async def post_json(destination, path, data): req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( - b"POST", path.encode("utf-8"), b"1.1", + b"POST", + path.encode("utf-8"), + b"1.1", ) channel.await_result() self.assertEqual(channel.code, 200) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index c279eb49e366..0789b12392f5 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -167,7 +167,16 @@ class _TestImage: ), ), # an empty file - (_TestImage(b"", b"image/gif", b".gif", None, None, False,),), + ( + _TestImage( + b"", + b"image/gif", + b".gif", + None, + None, + False, + ), + ), ], ) class MediaRepoTests(unittest.HomeserverTestCase): @@ -469,8 +478,7 @@ def default_config(self): return config def test_upload_innocent(self): - """Attempt to upload some innocent data that should be allowed. - """ + """Attempt to upload some innocent data that should be allowed.""" image_data = unhexlify( b"89504e470d0a1a0a0000000d4948445200000001000000010806" diff --git a/tests/server.py b/tests/server.py index 6419c445ec20..d4ece5c448ac 100644 --- a/tests/server.py +++ b/tests/server.py @@ -347,8 +347,7 @@ def add_tcp_client_callback(self, host, port, callback): self._tcp_callbacks[(host, port)] = callback def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): - """Fake L{IReactorTCP.connectTCP}. - """ + """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( host, port, factory, timeout=timeout, bindAddress=None diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index fea54464af79..d40d65b06a8b 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -353,7 +353,11 @@ def _trigger_notice_and_join(self): tok = self.login(localpart, "password") # Sync with the user's token to mark the user as active. - channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,) + channel = self.make_request( + "GET", + "/sync?timeout=0", + access_token=tok, + ) # Also retrieves the list of invites for this user. We don't care about that # one except if we're processing the last user, which should have received an diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 77c72834f2f1..66e3cafe8e9f 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -382,8 +382,7 @@ def test_topic(self): self.do_check(events, edges, expected_state_ids) def test_mainline_sort(self): - """Tests that the mainline ordering works correctly. - """ + """Tests that the mainline ordering works correctly.""" events = [ FakeEvent( @@ -660,15 +659,27 @@ def test_simple(self): # C -|-> B -> A a = FakeEvent( - id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="A", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([], []) b = FakeEvent( - id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="B", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([a.event_id], []) c = FakeEvent( - id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="C", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([b.event_id], []) persisted_events = {a.event_id: a, b.event_id: b} @@ -694,19 +705,35 @@ def test_multiple_unpersisted_chain(self): # D -> C -|-> B -> A a = FakeEvent( - id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="A", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([], []) b = FakeEvent( - id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="B", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([a.event_id], []) c = FakeEvent( - id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="C", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([b.event_id], []) d = FakeEvent( - id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="D", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([c.event_id], []) persisted_events = {a.event_id: a, b.event_id: b} @@ -737,23 +764,43 @@ def test_unpersisted_events_different_sets(self): # | a = FakeEvent( - id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="A", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([], []) b = FakeEvent( - id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="B", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([a.event_id], []) c = FakeEvent( - id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="C", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([b.event_id], []) d = FakeEvent( - id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="D", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([c.event_id], []) e = FakeEvent( - id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + id="E", + sender=ALICE, + type=EventTypes.Member, + state_key="", + content={}, ).to_event([c.event_id, b.event_id], []) persisted_events = {a.event_id: a, b.event_id: b} diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 673e1fe3e339..38444e48e295 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -96,7 +96,9 @@ def test_invalid_data(self): # No ignored_users key. self.get_success( self.store.add_account_data_for_user( - self.user, AccountDataTypes.IGNORED_USER_LIST, {}, + self.user, + AccountDataTypes.IGNORED_USER_LIST, + {}, ) ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 02aae1c13d21..1b4fae0bb555 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -67,7 +67,9 @@ async def update(progress, count): async def update(progress, count): self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( - count, target_background_update_duration_ms / duration_ms, places=0, + count, + target_background_update_duration_ms / duration_ms, + places=0, ) await self.updates._end_background_update("test_update") return count diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index c13a57dad185..779113868880 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -43,8 +43,7 @@ def prepare(self, reactor, clock, homeserver): self.room_id = info["room_id"] def run_background_update(self): - """Re run the background update to clean up the extremities. - """ + """Re run the background update to clean up the extremities.""" # Make sure we don't clash with in progress updates. self.assertTrue( self.store.db_pool.updates._all_done, "Background updates are still ongoing" diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index a69117c5a9fa..34e65260970e 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -41,7 +41,13 @@ def test_insert_new_client_ip(self): device_id = "MY_DEVICE" # Insert a user IP - self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.store_device( + user_id, + device_id, + "display name", + ) + ) self.get_success( self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", device_id @@ -214,7 +220,13 @@ def test_devices_last_seen_bg_update(self): device_id = "MY_DEVICE" # Insert a user IP - self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.store_device( + user_id, + device_id, + "display name", + ) + ) self.get_success( self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", device_id @@ -303,7 +315,13 @@ def test_old_user_ips_pruned(self): device_id = "MY_DEVICE" # Insert a user IP - self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.store_device( + user_id, + device_id, + "display name", + ) + ) self.get_success( self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", device_id diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 0c46ad595bbe..16daa66cc919 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -90,7 +90,8 @@ def test_simple(self): "content": {"tag": "power"}, }, ).build( - prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id], ) ) @@ -226,7 +227,8 @@ def test_simple(self): self.assertFalse( link_map.exists_path_from( - chain_map[create.event_id], chain_map[event.event_id], + chain_map[create.event_id], + chain_map[event.event_id], ), ) @@ -287,7 +289,8 @@ def test_out_of_order_events(self): "content": {"tag": "power"}, }, ).build( - prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id], ) ) @@ -373,7 +376,8 @@ def test_out_of_order_events(self): ) def persist( - self, events: List[EventBase], + self, + events: List[EventBase], ): """Persist the given events and check that the links generated match those given. @@ -394,7 +398,10 @@ def _persist(txn): persist_events_store._persist_event_auth_chain_txn(txn, events) self.get_success( - persist_events_store.db_pool.runInteraction("_persist", _persist,) + persist_events_store.db_pool.runInteraction( + "_persist", + _persist, + ) ) def fetch_chains( @@ -447,8 +454,7 @@ def fetch_chains( class LinkMapTestCase(unittest.TestCase): def test_simple(self): - """Basic tests for the LinkMap. - """ + """Basic tests for the LinkMap.""" link_map = _LinkMap() link_map.add_link((1, 1), (2, 1), new=False) @@ -490,8 +496,7 @@ def prepare(self, reactor, clock, hs): self.requester = create_requester(self.user_id) def _generate_room(self) -> Tuple[str, List[Set[str]]]: - """Insert a room without a chain cover index. - """ + """Insert a room without a chain cover index.""" room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Mark the room as not having a chain cover index diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 9d04a066d838..06000f81a63d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -215,7 +215,12 @@ def insert_event(txn): ], ) - self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + self.get_success( + self.store.db_pool.runInteraction( + "insert", + insert_event, + ) + ) # Now actually test that various combinations give the right result: @@ -370,7 +375,8 @@ def insert_event(txn): ) self.hs.datastores.persist_events._persist_event_auth_chain_txn( - txn, [FakeEvent("b", room_id, auth_graph["b"])], + txn, + [FakeEvent("b", room_id, auth_graph["b"])], ) self.store.db_pool.simple_update_txn( @@ -380,7 +386,12 @@ def insert_event(txn): updatevalues={"has_auth_chain_index": True}, ) - self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + self.get_success( + self.store.db_pool.runInteraction( + "insert", + insert_event, + ) + ) # Now actually test that various combinations give the right result: diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index c0595963dd91..485f1ee033c4 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -84,7 +84,9 @@ def _inject_actions(stream, action): yield defer.ensureDeferred( self.store.add_push_actions_to_staging( - event.event_id, {user_id: action}, False, + event.event_id, + {user_id: action}, + False, ) ) yield defer.ensureDeferred( diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 71210ce6065c..ed898b8dbb70 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -68,16 +68,14 @@ def prepare(self, reactor, clock, homeserver): self.assert_extremities([self.remote_event_1.event_id]) def persist_event(self, event, state=None): - """Persist the event, with optional state - """ + """Persist the event, with optional state""" context = self.get_success( self.state.compute_event_context(event, old_state=state) ) self.get_success(self.persistence.persist_event(event, context)) def assert_extremities(self, expected_extremities): - """Assert the current extremities for the room - """ + """Assert the current extremities for the room""" extremities = self.get_success( self.store.get_prev_events_for_room(self.room_id) ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 3e2fd4da0189..aad6bc907e43 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -86,7 +86,11 @@ def _insert_row_with_id(self, instance_name: str, stream_id: int): def _insert(txn): txn.execute( - "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + "INSERT INTO foobar VALUES (?, ?)", + ( + stream_id, + instance_name, + ), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) txn.execute( @@ -138,8 +142,7 @@ async def _get_next_async(): self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_out_of_order_finish(self): - """Test that IDs persisted out of order are correctly handled - """ + """Test that IDs persisted out of order are correctly handled""" # Prefill table with 7 rows written by 'master' self._insert_rows("master", 7) @@ -246,8 +249,7 @@ async def _get_next_async(): self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) def test_get_next_txn(self): - """Test that the `get_next_txn` function works correctly. - """ + """Test that the `get_next_txn` function works correctly.""" # Prefill table with 7 rows written by 'master' self._insert_rows("master", 7) @@ -386,8 +388,7 @@ def test_restart_during_out_of_order_persistence(self): self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) def test_writer_config_change(self): - """Test that changing the writer config correctly works. - """ + """Test that changing the writer config correctly works.""" self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) @@ -434,8 +435,7 @@ async def _get_next_async(): self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) def test_sequence_consistency(self): - """Test that we error out if the table and sequence diverges. - """ + """Test that we error out if the table and sequence diverges.""" # Prefill with some rows self._insert_row_with_id("master", 3) @@ -452,8 +452,7 @@ def _insert(txn): class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): - """Tests MultiWriterIdGenerator that produce *negative* stream IDs. - """ + """Tests MultiWriterIdGenerator that produce *negative* stream IDs.""" if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -494,12 +493,15 @@ def _create(conn): return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_row(self, instance_name: str, stream_id: int): - """Insert one row as the given instance with given stream_id. - """ + """Insert one row as the given instance with given stream_id.""" def _insert(txn): txn.execute( - "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + "INSERT INTO foobar VALUES (?, ?)", + ( + stream_id, + instance_name, + ), ) txn.execute( """ diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 8d97b6d4cdf4..5858c7fcc4d2 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -198,7 +198,7 @@ def test_reap_monthly_active_users(self): # value, although it gets stored on the config object as mau_limits. @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) def test_reap_monthly_active_users_reserved_users(self): - """ Tests that reaping correctly handles reaping where reserved users are + """Tests that reaping correctly handles reaping where reserved users are present""" threepids = self.hs.config.mau_limits_reserved_threepids initial_users = len(threepids) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index a6303bf0ee2e..b2a0e6085678 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -299,8 +299,7 @@ def type(self): ) def test_redact_censor(self): - """Test that a redacted event gets censored in the DB after a month - """ + """Test that a redacted event gets censored in the DB after a month""" self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -370,8 +369,7 @@ def test_redact_censor(self): self.assert_dict({"content": {}}, json.loads(event_json)) def test_redact_redaction(self): - """Tests that we can redact a redaction and can fetch it again. - """ + """Tests that we can redact a redaction and can fetch it again.""" self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -404,8 +402,7 @@ def test_redact_redaction(self): ) def test_store_redacted_redaction(self): - """Tests that we can store a redacted redaction. - """ + """Tests that we can store a redacted redaction.""" self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index c8c7a90e5dd7..abbaed7cdc03 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -145,7 +145,10 @@ def test_3pid_inhibit_invalid_validation_session_error(self): try: yield defer.ensureDeferred( self.store.validate_threepid_session( - "fake_sid", "fake_client_secret", "fake_token", 0, + "fake_sid", + "fake_client_secret", + "fake_token", + 0, ) ) except ThreepidValidationError as e: @@ -158,7 +161,10 @@ def test_3pid_inhibit_invalid_validation_session_error(self): try: yield defer.ensureDeferred( self.store.validate_threepid_session( - "fake_sid", "fake_client_secret", "fake_token", 0, + "fake_sid", + "fake_client_secret", + "fake_token", + 0, ) ) except ThreepidValidationError as e: diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 69b4c5d6c2f9..3f2691ee6bec 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -85,7 +85,10 @@ def test_state_default_level(self): # king should be able to send state event_auth.check( - RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, + RoomVersions.V1, + _random_state_event(king), + auth_events, + do_sig_check=False, ) def test_alias_event(self): @@ -99,7 +102,10 @@ def test_alias_event(self): # creator should be able to send aliases event_auth.check( - RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False, + RoomVersions.V1, + _alias_event(creator), + auth_events, + do_sig_check=False, ) # Reject an event with no state key. @@ -122,7 +128,10 @@ def test_alias_event(self): # Note that the member does *not* need to be in the room. event_auth.check( - RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False, + RoomVersions.V1, + _alias_event(other), + auth_events, + do_sig_check=False, ) def test_msc2432_alias_event(self): @@ -136,7 +145,10 @@ def test_msc2432_alias_event(self): # creator should be able to send aliases event_auth.check( - RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False, + RoomVersions.V6, + _alias_event(creator), + auth_events, + do_sig_check=False, ) # No particular checks are done on the state key. @@ -156,7 +168,10 @@ def test_msc2432_alias_event(self): # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): event_auth.check( - RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False, + RoomVersions.V6, + _alias_event(other), + auth_events, + do_sig_check=False, ) def test_msc2209(self): diff --git a/tests/test_mau.py b/tests/test_mau.py index 51660b51d5fb..75d28a42dfe5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -242,7 +242,10 @@ def create_user(self, localpart, token=None): ) channel = self.make_request( - "POST", "/register", request_data, access_token=token, + "POST", + "/register", + request_data, + access_token=token, ) if channel.code != 200: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 759e4cd0480f..f696fcf89ef9 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -21,7 +21,7 @@ def get_sample_labels_value(sample): - """ Extract the labels and values of a sample. + """Extract the labels and values of a sample. prometheus_client 0.5 changed the sample type to a named tuple with more members than the plain tuple had in 0.4 and earlier. This function can diff --git a/tests/test_server.py b/tests/test_server.py index 815da18e6575..55cde7f62f48 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -166,7 +166,10 @@ def _callback(request, **kwargs): res = JsonResource(self.homeserver) res.register_paths( - "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet", + "GET", + [re.compile("^/_matrix/foo$")], + _callback, + "test_servlet", ) # The path was registered as GET, but this is a HEAD request. diff --git a/tests/unittest.py b/tests/unittest.py index 767d5d607738..ca7031c724b3 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -255,7 +255,10 @@ def setUp(self): # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( self.hs.get_datastore().add_access_token_to_user( - self.helper.auth_user_id, "some_fake_token", None, None, + self.helper.auth_user_id, + "some_fake_token", + None, + None, ) ) diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index ecd9efc4dfc3..c24c33ee9132 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -232,7 +232,10 @@ def test_eviction_lru(self): def test_eviction_iterable(self): cache = DeferredCache( - "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True, + "test", + max_entries=3, + apply_cache_factor_from_config=False, + iterable=True, ) cache.prefill(1, ["one", "two"]) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index cf1e3203a4b3..afb11b9caf2d 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -143,8 +143,7 @@ def fn(self, arg1, arg2): obj.mock.assert_not_called() def test_cache_with_sync_exception(self): - """If the wrapped function throws synchronously, things should continue to work - """ + """If the wrapped function throws synchronously, things should continue to work""" class Cls: @cached() @@ -165,8 +164,7 @@ def fn(self, arg1): self.failureResultOf(d, SynapseError) def test_cache_with_async_exception(self): - """The wrapped function returns a failure - """ + """The wrapped function returns a failure""" class Cls: result = None @@ -282,7 +280,8 @@ def do_lookup(): try: d = obj.fn(1) self.assertEqual( - current_context(), SENTINEL_CONTEXT, + current_context(), + SENTINEL_CONTEXT, ) yield d self.fail("No exception thrown") @@ -374,8 +373,7 @@ def fn(self, arg1, arg2): obj.mock.assert_not_called() def test_cache_iterable_with_sync_exception(self): - """If the wrapped function throws synchronously, things should continue to work - """ + """If the wrapped function throws synchronously, things should continue to work""" class Cls: @descriptors.cached(iterable=True) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 1ef0af8e8f03..e931a7ec1852 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -24,28 +24,32 @@ def test_short_seq(self): parts = chunk_seq("123", 8) self.assertEqual( - list(parts), ["123"], + list(parts), + ["123"], ) def test_long_seq(self): parts = chunk_seq("abcdefghijklmnop", 8) self.assertEqual( - list(parts), ["abcdefgh", "ijklmnop"], + list(parts), + ["abcdefgh", "ijklmnop"], ) def test_uneven_parts(self): parts = chunk_seq("abcdefghijklmnop", 5) self.assertEqual( - list(parts), ["abcde", "fghij", "klmno", "p"], + list(parts), + ["abcde", "fghij", "klmno", "p"], ) def test_empty_input(self): parts = chunk_seq([], 5) self.assertEqual( - list(parts), [], + list(parts), + [], ) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 13b753e367fa..9ed01f7e0c97 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -70,7 +70,8 @@ def test_entity_has_changed_pops_off_start(self): self.assertTrue("user@foo.com" not in cache._entity_to_key) self.assertEqual( - cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"], + cache.get_all_entities_changed(2), + ["bar@baz.net", "user@elsewhere.org"], ) self.assertIsNone(cache.get_all_entities_changed(1)) @@ -80,7 +81,8 @@ def test_entity_has_changed_pops_off_start(self): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"], + cache.get_all_entities_changed(2), + ["user@elsewhere.org", "bar@baz.net"], ) self.assertIsNone(cache.get_all_entities_changed(1)) @@ -222,7 +224,8 @@ def test_get_entities_changed(self): # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"}, + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), + {"bar@baz.net"}, ) def test_max_pos(self): diff --git a/tests/utils.py b/tests/utils.py index 840b657f825d..4fb5098550a3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -263,7 +263,10 @@ def setup_test_homeserver( db_conn.close() hs = homeserver_to_use( - name, config=config, version_string="Synapse/tests", reactor=reactor, + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, ) # Install @cache_in_self attributes @@ -365,7 +368,7 @@ def trigger_get(self, path): def trigger( self, http_method, path, content, mock_request, federation_auth_origin=None ): - """ Fire an HTTP event. + """Fire an HTTP event. Args: http_method : The HTTP method @@ -528,8 +531,7 @@ def time_bound_deferred(self, d, *args, **kwargs): async def create_room(hs, room_id: str, creator_id: str): - """Creates and persist a creation event for the given room - """ + """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage().persistence store = hs.get_datastore()