diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15522ce5c..b08ce9bfa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -76,7 +76,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10"] database: ["sqlite"] toxenv: ["py"] include: @@ -85,9 +85,9 @@ jobs: toxenv: "py-noextras" # Oldest Python with PostgreSQL - - python-version: "3.6" + - python-version: "3.7" database: "postgres" - postgres-version: "9.6" + postgres-version: "10" toxenv: "py" # Newest Python with newest PostgreSQL @@ -167,7 +167,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["pypy-3.6"] + python-version: ["pypy-3.7"] steps: - uses: actions/checkout@v2 @@ -284,8 +284,8 @@ jobs: strategy: matrix: include: - - python-version: "3.6" - postgres-version: "9.6" + - python-version: "3.7" + postgres-version: "10" - python-version: "3.10" postgres-version: "14" diff --git a/CHANGES.md b/CHANGES.md index 9f6e29631..8029a9d21 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,122 @@ +Synapse 1.50.0 (2022-01-18) +=========================== + +Please note that we now only support Python 3.7+ and PostgreSQL 10+ (if applicable), because Python 3.6 and PostgreSQL 9.6 have reached end-of-life. + +No significant changes since 1.50.0rc2. + + +Synapse 1.50.0rc2 (2022-01-14) +============================== + +This release candidate fixes a federation-breaking regression introduced in Synapse 1.50.0rc1. + +Bugfixes +-------- + +- Fix a bug introduced in Synapse v1.0.0 whereby some device list updates would not be sent to remote homeservers if there were too many to send at once. ([\#11729](https://github.com/matrix-org/synapse/issues/11729)) +- Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. ([\#11730](https://github.com/matrix-org/synapse/issues/11730)) + + +Improved Documentation +---------------------- + +- Document that now the minimum supported PostgreSQL version is 10. ([\#11725](https://github.com/matrix-org/synapse/issues/11725)) + + +Internal Changes +---------------- + +- Fix a typechecker problem related to our (ab)use of `nacl.signing.SigningKey`s. ([\#11714](https://github.com/matrix-org/synapse/issues/11714)) + + +Synapse 1.50.0rc1 (2022-01-05) +============================== + + +Features +-------- + +- Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419). ([\#11378](https://github.com/matrix-org/synapse/issues/11378)) +- Add experimental support for part of [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): allowing application services to masquerade as specific devices. ([\#11538](https://github.com/matrix-org/synapse/issues/11538)) +- Add admin API to get users' account data. ([\#11664](https://github.com/matrix-org/synapse/issues/11664)) +- Include the room topic in the stripped state included with invites and knocking. ([\#11666](https://github.com/matrix-org/synapse/issues/11666)) +- Send and handle cross-signing messages using the stable prefix. ([\#10520](https://github.com/matrix-org/synapse/issues/10520)) +- Support unprefixed versions of fallback key property names. ([\#11541](https://github.com/matrix-org/synapse/issues/11541)) + + +Bugfixes +-------- + +- Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. ([\#11516](https://github.com/matrix-org/synapse/issues/11516)) +- Fix a long-standing bug which could cause `AssertionError`s to be written to the log when Synapse was restarted after purging events from the database. ([\#11536](https://github.com/matrix-org/synapse/issues/11536), [\#11642](https://github.com/matrix-org/synapse/issues/11642)) +- Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. ([\#11547](https://github.com/matrix-org/synapse/issues/11547)) +- Fix a long-standing bug where responses included bundled aggregations when they should not, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11592](https://github.com/matrix-org/synapse/issues/11592), [\#11623](https://github.com/matrix-org/synapse/issues/11623)) +- Fix a long-standing bug that some unknown endpoints would return HTML error pages instead of JSON `M_UNRECOGNIZED` errors. ([\#11602](https://github.com/matrix-org/synapse/issues/11602)) +- Fix a bug introduced in Synapse 1.19.3 which could sometimes cause `AssertionError`s when backfilling rooms over federation. ([\#11632](https://github.com/matrix-org/synapse/issues/11632)) + + +Improved Documentation +---------------------- + +- Update Synapse install command for FreeBSD as the package is now prefixed with `py38`. Contributed by @itchychips. ([\#11267](https://github.com/matrix-org/synapse/issues/11267)) +- Document the usage of refresh tokens. ([\#11427](https://github.com/matrix-org/synapse/issues/11427)) +- Add details for how to configure a TURN server when behind a NAT. Contibuted by @AndrewFerr. ([\#11553](https://github.com/matrix-org/synapse/issues/11553)) +- Add references for using Postgres to the Docker documentation. ([\#11640](https://github.com/matrix-org/synapse/issues/11640)) +- Fix the documentation link in newly-generated configuration files. ([\#11678](https://github.com/matrix-org/synapse/issues/11678)) +- Correct the documentation for `nginx` to use a case-sensitive url pattern. Fixes an error introduced in v1.21.0. ([\#11680](https://github.com/matrix-org/synapse/issues/11680)) +- Clarify SSO mapping provider documentation by writing `def` or `async def` before the names of methods, as appropriate. ([\#11681](https://github.com/matrix-org/synapse/issues/11681)) + + +Deprecations and Removals +------------------------- + +- Replace `mock` package by its standard library version. ([\#11588](https://github.com/matrix-org/synapse/issues/11588)) +- Drop support for Python 3.6 and Ubuntu 18.04. ([\#11633](https://github.com/matrix-org/synapse/issues/11633)) + + +Internal Changes +---------------- + +- Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#11243](https://github.com/matrix-org/synapse/issues/11243)) +- A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. ([\#11331](https://github.com/matrix-org/synapse/issues/11331)) +- Add type hints to `synapse.appservice`. ([\#11360](https://github.com/matrix-org/synapse/issues/11360)) +- Add missing type hints to `synapse.config` module. ([\#11480](https://github.com/matrix-org/synapse/issues/11480)) +- Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. ([\#11487](https://github.com/matrix-org/synapse/issues/11487)) +- Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. ([\#11503](https://github.com/matrix-org/synapse/issues/11503)) +- Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. ([\#11505](https://github.com/matrix-org/synapse/issues/11505), [\#11687](https://github.com/matrix-org/synapse/issues/11687)) +- Use `HTTPStatus` constants in place of literals in `tests.rest.client.test_auth`. ([\#11520](https://github.com/matrix-org/synapse/issues/11520)) +- Add a receipt types constant for `m.read`. ([\#11531](https://github.com/matrix-org/synapse/issues/11531)) +- Clean up `synapse.rest.admin`. ([\#11535](https://github.com/matrix-org/synapse/issues/11535)) +- Add missing `errcode` to `parse_string` and `parse_boolean`. ([\#11542](https://github.com/matrix-org/synapse/issues/11542)) +- Use `HTTPStatus` constants in place of literals in `synapse.http`. ([\#11543](https://github.com/matrix-org/synapse/issues/11543)) +- Add missing type hints to storage classes. ([\#11546](https://github.com/matrix-org/synapse/issues/11546), [\#11549](https://github.com/matrix-org/synapse/issues/11549), [\#11551](https://github.com/matrix-org/synapse/issues/11551), [\#11555](https://github.com/matrix-org/synapse/issues/11555), [\#11575](https://github.com/matrix-org/synapse/issues/11575), [\#11589](https://github.com/matrix-org/synapse/issues/11589), [\#11594](https://github.com/matrix-org/synapse/issues/11594), [\#11652](https://github.com/matrix-org/synapse/issues/11652), [\#11653](https://github.com/matrix-org/synapse/issues/11653), [\#11654](https://github.com/matrix-org/synapse/issues/11654), [\#11657](https://github.com/matrix-org/synapse/issues/11657)) +- Fix an inaccurate and misleading comment in the `/sync` code. ([\#11550](https://github.com/matrix-org/synapse/issues/11550)) +- Add missing type hints to `synapse.logging.context`. ([\#11556](https://github.com/matrix-org/synapse/issues/11556)) +- Stop populating unused database column `state_events.prev_state`. ([\#11558](https://github.com/matrix-org/synapse/issues/11558)) +- Minor efficiency improvements in event persistence. ([\#11560](https://github.com/matrix-org/synapse/issues/11560)) +- Add some safety checks that storage functions are used correctly. ([\#11564](https://github.com/matrix-org/synapse/issues/11564), [\#11580](https://github.com/matrix-org/synapse/issues/11580)) +- Make `get_device` return `None` if the device doesn't exist rather than raising an exception. ([\#11565](https://github.com/matrix-org/synapse/issues/11565)) +- Split the HTML parsing code from the URL preview resource code. ([\#11566](https://github.com/matrix-org/synapse/issues/11566)) +- Remove redundant `COALESCE()`s around `COUNT()`s in database queries. ([\#11570](https://github.com/matrix-org/synapse/issues/11570)) +- Add missing type hints to `synapse.http`. ([\#11571](https://github.com/matrix-org/synapse/issues/11571)) +- Add [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) and [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) to `/versions` -> `unstable_features` to detect server support. ([\#11582](https://github.com/matrix-org/synapse/issues/11582)) +- Add type hints to `synapse/tests/rest/admin`. ([\#11590](https://github.com/matrix-org/synapse/issues/11590)) +- Drop end-of-life Python 3.6 and Postgres 9.6 from CI. ([\#11595](https://github.com/matrix-org/synapse/issues/11595)) +- Update black version and run it on all the files. ([\#11596](https://github.com/matrix-org/synapse/issues/11596)) +- Add opentracing type stubs and fix associated mypy errors. ([\#11603](https://github.com/matrix-org/synapse/issues/11603), [\#11622](https://github.com/matrix-org/synapse/issues/11622)) +- Improve OpenTracing support for requests which use a `ResponseCache`. ([\#11607](https://github.com/matrix-org/synapse/issues/11607)) +- Improve OpenTracing support for incoming HTTP requests. ([\#11618](https://github.com/matrix-org/synapse/issues/11618)) +- A number of improvements to opentracing support. ([\#11619](https://github.com/matrix-org/synapse/issues/11619)) +- Refactor the way that the `outlier` flag is set on events received over federation. ([\#11634](https://github.com/matrix-org/synapse/issues/11634)) +- Improve the error messages from `get_create_event_for_room`. ([\#11638](https://github.com/matrix-org/synapse/issues/11638)) +- Remove redundant `get_current_events_token` method. ([\#11643](https://github.com/matrix-org/synapse/issues/11643)) +- Convert `namedtuples` to `attrs`. ([\#11665](https://github.com/matrix-org/synapse/issues/11665), [\#11574](https://github.com/matrix-org/synapse/issues/11574)) +- Update the `/capabilities` response to include whether support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) is available. ([\#11690](https://github.com/matrix-org/synapse/issues/11690)) +- Send the `Accept` header in HTTP requests made using `SimpleHttpClient.get_json`. ([\#11677](https://github.com/matrix-org/synapse/issues/11677)) +- Work around Mjolnir compatibility issue by adding an import for `glob_to_regex` in `synapse.util`, where it moved from. ([\#11696](https://github.com/matrix-org/synapse/issues/11696)) + + Synapse 1.49.2 (2021-12-21) =========================== diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml index 26d640c44..5ac41139e 100644 --- a/contrib/docker/docker-compose.yml +++ b/contrib/docker/docker-compose.yml @@ -14,6 +14,7 @@ services: # failure restart: unless-stopped # See the readme for a full documentation of the environment settings + # NOTE: You must edit homeserver.yaml to use postgres, it defaults to sqlite environment: - SYNAPSE_CONFIG_PATH=/data/homeserver.yaml volumes: diff --git a/debian/changelog b/debian/changelog index ebe3e0cbf..f1245cd3a 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,21 @@ +matrix-synapse-py3 (1.50.0) stable; urgency=medium + + * New synapse release 1.50.0. + + -- Synapse Packaging team Tue, 18 Jan 2022 10:40:38 +0000 + +matrix-synapse-py3 (1.50.0~rc2) stable; urgency=medium + + * New synapse release 1.50.0~rc2. + + -- Synapse Packaging team Fri, 14 Jan 2022 11:18:06 +0000 + +matrix-synapse-py3 (1.50.0~rc1) stable; urgency=medium + + * New synapse release 1.50.0~rc1. + + -- Synapse Packaging team Wed, 05 Jan 2022 12:36:17 +0000 + matrix-synapse-py3 (1.49.2) stable; urgency=medium * New synapse release 1.49.2. diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index 1dd88140c..fbc1d2346 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -16,7 +16,7 @@ ARG distro="" ### Stage 0: build a dh-virtualenv ### -# This is only really needed on bionic and focal, since other distributions we +# This is only really needed on focal, since other distributions we # care about have a recent version of dh-virtualenv by default. Unfortunately, # it looks like focal is going to be with us for a while. # @@ -36,9 +36,8 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \ wget # fetch and unpack the package -# TODO: Upgrade to 1.2.2 once bionic is dropped (1.2.2 requires debhelper 12; bionic has only 11) RUN mkdir /dh-virtualenv -RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz +RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/refs/tags/1.2.2.tar.gz RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz # install its build deps. We do another apt-cache-update here, because we might @@ -86,12 +85,12 @@ RUN apt-get update -qq -o Acquire::Languages=none \ libpq-dev \ xmlsec1 -COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / +COPY --from=builder /dh-virtualenv_1.2.2-1_all.deb / # install dhvirtualenv. Update the apt cache again first, in case we got a # cached cache from docker the first time. RUN apt-get update -qq -o Acquire::Languages=none \ - && apt-get install -yq /dh-virtualenv_1.2~dev-1_all.deb + && apt-get install -yq /dh-virtualenv_1.2.2-1_all.deb WORKDIR /synapse/source ENTRYPOINT ["bash","/synapse/source/docker/build_debian.sh"] diff --git a/docker/README.md b/docker/README.md index 4349e71f8..67c3bc65f 100644 --- a/docker/README.md +++ b/docker/README.md @@ -68,6 +68,10 @@ The following environment variables are supported in `generate` mode: directories. If unset, and no user is set via `docker run --user`, defaults to `991`, `991`. +## Postgres + +By default the config will use SQLite. See the [docs on using Postgres](https://github.com/matrix-org/synapse/blob/develop/docs/postgres.md) for more info on how to use Postgres. Until this section is improved [this issue](https://github.com/matrix-org/synapse/issues/8304) may provide useful information. + ## Running synapse Once you have a valid configuration file, you can start synapse as follows: diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index b05af6d69..11f597b3e 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -30,6 +30,7 @@ - [SSO Mapping Providers](sso_mapping_providers.md) - [Password Auth Providers](password_auth_providers.md) - [JSON Web Tokens](jwt.md) + - [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md) - [Registration Captcha](CAPTCHA_SETUP.md) - [Application Services](application_services.md) - [Server Notices](server_notices.md) diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index ba574d795..74933d2fc 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -480,6 +480,81 @@ The following fields are returned in the JSON response body: - `joined_rooms` - An array of `room_id`. - `total` - Number of rooms. +## Account Data +Gets information about account data for a specific `user_id`. + +The API is: + +``` +GET /_synapse/admin/v1/users//accountdata +``` + +A response body like the following is returned: + +```json +{ + "account_data": { + "global": { + "m.secret_storage.key.LmIGHTg5W": { + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "fwjNZatxg==", + "mac": "eWh9kNnLWZUNOgnc=" + }, + "im.vector.hide_profile": { + "hide_profile": true + }, + "org.matrix.preview_urls": { + "disable": false + }, + "im.vector.riot.breadcrumb_rooms": { + "rooms": [ + "!LxcBDAsDUVAfJDEo:matrix.org", + "!MAhRxqasbItjOqxu:matrix.org" + ] + }, + "m.accepted_terms": { + "accepted": [ + "https://example.org/somewhere/privacy-1.2-en.html", + "https://example.org/somewhere/terms-2.0-en.html" + ] + }, + "im.vector.setting.breadcrumbs": { + "recent_rooms": [ + "!MAhRxqasbItqxuEt:matrix.org", + "!ZtSaPCawyWtxiImy:matrix.org" + ] + } + }, + "rooms": { + "!GUdfZSHUJibpiVqHYd:matrix.org": { + "m.fully_read": { + "event_id": "$156334540fYIhZ:matrix.org" + } + }, + "!tOZwOOiqwCYQkLhV:matrix.org": { + "m.fully_read": { + "event_id": "$xjsIyp4_NaVl2yPvIZs_k1Jl8tsC_Sp23wjqXPno" + } + } + } + } +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `user_id` - fully qualified: for example, `@user:server.com`. + +**Response** + +The following fields are returned in the JSON response body: + +- `account_data` - A map containing the account data for the user + - `global` - A map containing the global account data for the user + - `rooms` - A map containing the account data per room for the user + ## User media ### List media uploaded by a user diff --git a/docs/postgres.md b/docs/postgres.md index e4861c1f1..0562021da 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -1,6 +1,6 @@ # Using Postgres -Synapse supports PostgreSQL versions 9.6 or later. +Synapse supports PostgreSQL versions 10 or later. ## Install postgres client libraries diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index f3b3aea73..1a89da50f 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -63,7 +63,7 @@ server { server_name matrix.example.com; - location ~* ^(\/_matrix|\/_synapse\/client) { + location ~ ^(/_matrix|/_synapse/client) { # note: do not add a path (even a single /) after the port in `proxy_pass`, # otherwise nginx will canonicalise the URI and cause signature verification # errors. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2f167e7ca..160cc8ca5 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -37,7 +37,7 @@ # Server admins can expand Synapse's functionality with external modules. # -# See https://matrix-org.github.io/synapse/latest/modules.html for more +# See https://matrix-org.github.io/synapse/latest/modules/index.html for more # documentation on how to configure or create custom modules for Synapse. # modules: @@ -1591,6 +1591,7 @@ room_prejoin_state: # - m.room.encryption # - m.room.name # - m.room.create + # - m.room.topic # # Uncomment the following to disable these defaults (so that only the event # types listed in 'additional_event_types' are shared). Defaults to 'false'. diff --git a/docs/setup/installation.md b/docs/setup/installation.md index 16562be95..210c80dac 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -164,7 +164,7 @@ xbps-install -S synapse Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` -- Packages: `pkg install py37-matrix-synapse` +- Packages: `pkg install py38-matrix-synapse` #### OpenBSD diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 7a407012e..7b4ddc5b7 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -49,12 +49,12 @@ comment these options out and use those specified by the module instead. A custom mapping provider must specify the following methods: -* `__init__(self, parsed_config)` +* `def __init__(self, parsed_config)` - Arguments: - `parsed_config` - A configuration object that is the return value of the `parse_config` method. You should set any configuration options needed by the module here. -* `parse_config(config)` +* `def parse_config(config)` - This method should have the `@staticmethod` decoration. - Arguments: - `config` - A `dict` representing the parsed content of the @@ -63,13 +63,13 @@ A custom mapping provider must specify the following methods: any option values they need here. - Whatever is returned will be passed back to the user mapping provider module's `__init__` method during construction. -* `get_remote_user_id(self, userinfo)` +* `def get_remote_user_id(self, userinfo)` - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user information from. - This method must return a string, which is the unique, immutable identifier for the user. Commonly the `sub` claim of the response. -* `map_user_attributes(self, userinfo, token, failures)` +* `async def map_user_attributes(self, userinfo, token, failures)` - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user @@ -91,7 +91,7 @@ A custom mapping provider must specify the following methods: during a user's first login. Once a localpart has been associated with a remote user ID (see `get_remote_user_id`) it cannot be updated. - `displayname`: An optional string, the display name for the user. -* `get_extra_attributes(self, userinfo, token)` +* `async def get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user @@ -125,15 +125,15 @@ comment these options out and use those specified by the module instead. A custom mapping provider must specify the following methods: -* `__init__(self, parsed_config, module_api)` +* `def __init__(self, parsed_config, module_api)` - Arguments: - `parsed_config` - A configuration object that is the return value of the `parse_config` method. You should set any configuration options needed by the module here. - `module_api` - a `synapse.module_api.ModuleApi` object which provides the stable API available for extension modules. -* `parse_config(config)` - - This method should have the `@staticmethod` decoration. +* `def parse_config(config)` + - **This method should have the `@staticmethod` decoration.** - Arguments: - `config` - A `dict` representing the parsed content of the `saml_config.user_mapping_provider.config` homeserver config option. @@ -141,15 +141,15 @@ A custom mapping provider must specify the following methods: any option values they need here. - Whatever is returned will be passed back to the user mapping provider module's `__init__` method during construction. -* `get_saml_attributes(config)` - - This method should have the `@staticmethod` decoration. +* `def get_saml_attributes(config)` + - **This method should have the `@staticmethod` decoration.** - Arguments: - `config` - A object resulting from a call to `parse_config`. - Returns a tuple of two sets. The first set equates to the SAML auth response attributes that are required for the module to function, whereas the second set consists of those attributes which can be used if available, but are not necessary. -* `get_remote_user_id(self, saml_response, client_redirect_url)` +* `def get_remote_user_id(self, saml_response, client_redirect_url)` - Arguments: - `saml_response` - A `saml2.response.AuthnResponse` object to extract user information from. @@ -157,7 +157,7 @@ A custom mapping provider must specify the following methods: redirected to. - This method must return a string, which is the unique, immutable identifier for the user. Commonly the `uid` claim of the response. -* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)` +* `def saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)` - Arguments: - `saml_response` - A `saml2.response.AuthnResponse` object to extract user information from. diff --git a/docs/turn-howto.md b/docs/turn-howto.md index e6812de69..e32aaa185 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -15,8 +15,8 @@ The following sections describe how to install [coturn](TURN->TURN->client flows work + # this should be one of the turn server's listening IPs allowed-peer-ip=10.0.0.1 # consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS. @@ -123,7 +139,7 @@ This will install and start a systemd service called `coturn`. pkey=/path/to/privkey.pem ``` - In this case, replace the `turn:` schemes in the `turn_uri` settings below + In this case, replace the `turn:` schemes in the `turn_uris` settings below with `turns:`. We recommend that you only try to set up TLS/DTLS once you have set up a @@ -134,21 +150,33 @@ This will install and start a systemd service called `coturn`. traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535 for the UDP relay.) -1. We do not recommend running a TURN server behind NAT, and are not aware of - anyone doing so successfully. +1. If your TURN server is behind NAT, the NAT gateway must have an external, + publicly-reachable IP address. You must configure coturn to advertise that + address to connecting clients: + + ``` + external-ip=EXTERNAL_NAT_IPv4_ADDRESS + ``` - If you want to try it anyway, you will at least need to tell coturn its - external IP address: + You may optionally limit the TURN server to listen only on the local + address that is mapped by NAT to the external address: ``` - external-ip=192.88.99.1 + listening-ip=INTERNAL_TURNSERVER_IPv4_ADDRESS ``` - ... and your NAT gateway must forward all of the relayed ports directly - (eg, port 56789 on the external IP must be always be forwarded to port - 56789 on the internal IP). + If your NAT gateway is reachable over both IPv4 and IPv6, you may + configure coturn to advertise each available address: - If you get this working, let us know! + ``` + external-ip=EXTERNAL_NAT_IPv4_ADDRESS + external-ip=EXTERNAL_NAT_IPv6_ADDRESS + ``` + + When advertising an external IPv6 address, ensure that the firewall and + network settings of the system running your TURN server are configured to + accept IPv6 traffic, and that the TURN server is listening on the local + IPv6 address that is mapped by NAT to the external IPv6 address. 1. (Re)start the turn server: @@ -216,9 +244,6 @@ connecting". Unfortunately, troubleshooting this can be tricky. Here are a few things to try: - * Check that your TURN server is not behind NAT. As above, we're not aware of - anyone who has successfully set this up. - * Check that you have opened your firewall to allow TCP and UDP traffic to the TURN ports (normally 3478 and 5349). @@ -234,6 +259,18 @@ Here are a few things to try: Try removing any AAAA records for your TURN server, so that it is only reachable over IPv4. + * If your TURN server is behind NAT: + + * double-check that your NAT gateway is correctly forwarding all TURN + ports (normally 3478 & 5349 for TCP & UDP TURN traffic, and 49152-65535 for the UDP + relay) to the NAT-internal address of your TURN server. If advertising + both IPv4 and IPv6 external addresses via the `external-ip` option, ensure + that the NAT is forwarding both IPv4 and IPv6 traffic to the IPv4 and IPv6 + internal addresses of your TURN server. When in doubt, remove AAAA records + for your TURN server and specify only an IPv4 address as your `external-ip`. + + * ensure that your TURN server uses the NAT gateway as its default route. + * Enable more verbose logging in coturn via the `verbose` setting: ``` diff --git a/docs/upgrade.md b/docs/upgrade.md index 136c806c4..30bb0dcd9 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,17 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.50.0 + +## Dropping support for old Python and Postgres versions + +In line with our [deprecation policy](deprecation_policy.md), +we've dropped support for Python 3.6 and PostgreSQL 9.6, as they are no +longer supported upstream. + +This release of Synapse requires Python 3.7+ and PostgreSQL 10+. + + # Upgrading to v1.47.0 ## Removal of old Room Admin API diff --git a/docs/usage/configuration/user_authentication/refresh_tokens.md b/docs/usage/configuration/user_authentication/refresh_tokens.md new file mode 100644 index 000000000..23b3cddae --- /dev/null +++ b/docs/usage/configuration/user_authentication/refresh_tokens.md @@ -0,0 +1,139 @@ +# Refresh Tokens + +Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible). + + +[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens + + +## Background and motivation + +Synapse users' sessions are identified by **access tokens**; access tokens are +issued to users on login. Each session gets a unique access token which identifies +it; the access token must be kept secret as it grants access to the user's account. + +Traditionally, these access tokens were eternally valid (at least until the user +explicitly chose to log out). + +In some cases, it may be desirable for these access tokens to expire so that the +potential damage caused by leaking an access token is reduced. +On the other hand, forcing a user to re-authenticate (log in again) often might +be too much of an inconvenience. + +**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst +still getting most of the benefits of short access token lifetimes. +Refresh tokens are also a concept present in OAuth 2 — further reading is available +[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5). + +When refresh tokens are in use, both an access token and a refresh token will be +issued to users on login. The access token will expire after a predetermined amount +of time, but otherwise works in the same way as before. When the access token is +close to expiring (or has expired), the user's client should present the homeserver +(Synapse) with the refresh token. + +The homeserver will then generate a new access token and refresh token for the user +and return them. The old refresh token is invalidated and can not be used again*. + +Finally, refresh tokens also make it possible for sessions to be logged out if they +are inactive for too long, before the session naturally ends; see the configuration +guide below. + + +*To prevent issues if clients lose connection half-way through refreshing a token, +the refresh token is only invalidated once the new access token has been used at +least once. For all intents and purposes, the above simplification is sufficient. + + +## Caveats + +There are some caveats: + +* If a third party gets both your access token and refresh token, they will be able to + continue to enjoy access to your session. + * This is still an improvement because you (the user) will notice when *your* + session expires and you're not able to use your refresh token. + That would be a giveaway that someone else has compromised your session. + You would be able to log in again and terminate that session. + Previously (with long-lived access tokens), a third party that has your access + token could go undetected for a very long time. +* Clients need to implement support for refresh tokens in order for them to be a + useful mechanism. + * It is up to homeserver administrators if they want to issue long-lived access + tokens to clients not implementing refresh tokens. + * For compatibility, it is likely that they should, at least until client support + is widespread. + * Users with clients that support refresh tokens will still benefit from the + added security; it's not possible to downgrade a session to using long-lived + access tokens so this effectively gives users the choice. + * In a closed environment where all users use known clients, this may not be + an issue as the homeserver administrator can know if the clients have refresh + token support. In that case, the non-refreshable access token lifetime + may be set to a short duration so that a similar level of security is provided. + + +## Configuration Guide + +The following configuration options, in the `registration` section, are related: + +* `session_lifetime`: maximum length of a session, even if it's refreshed. + In other words, the client must log in again after this time period. + In most cases, this can be unset (infinite) or set to a long time (years or months). +* `refreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients supporting refresh tokens. + This should be short; a good value might be 5 minutes (`5m`). +* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients which don't support refresh tokens. + Make this short if you want to effectively force use of refresh tokens. + Make this long if you don't want to inconvenience users of clients which don't + support refresh tokens (by forcing them to frequently re-authenticate using + login credentials). +* `refresh_token_lifetime`: lifetime of refresh tokens. + In other words, the client must refresh within this time period to maintain its session. + Unless you want to log inactive sessions out, it is often fine to use a long + value here or even leave it unset (infinite). + Beware that making it too short will inconvenience clients that do not connect + very often, including mobile clients and clients of infrequent users (by making + it more difficult for them to refresh in time, which may force them to need to + re-authenticate using login credentials). + +**Note:** All four options above only apply when tokens are created (by logging in or refreshing). +Changes to these settings do not apply retroactively. + + +### Using refresh token expiry to log out inactive sessions + +If you'd like to force sessions to be logged out upon inactivity, you can enable +refreshable access token expiry and refresh token expiry. + +This works because a client must refresh at least once within a period of +`refresh_token_lifetime` in order to maintain valid credentials to access the +account. + +(It's suggested that `refresh_token_lifetime` should be longer than +`refreshable_access_token_lifetime` and this section assumes that to be the case +for simplicity.) + +Note: this will only affect sessions using refresh tokens. You may wish to +set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed +by clients that do not support refresh tokens. + + +#### Choosing values that guarantee permitting some inactivity + +It may be desirable to permit some short periods of inactivity, for example to +accommodate brief outages in client connectivity. + +The following model aims to provide guidance for choosing `refresh_token_lifetime` +and `refreshable_access_token_lifetime` to satisfy requirements of the form: + +1. inactivity longer than `L` **MUST** cause the session to be logged out; and +2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out. + +This model makes the weakest assumption that all active clients will refresh as +needed to maintain an active access token, but no sooner. +*In reality, clients may refresh more often than this model assumes, but the +above requirements will still hold.* + +To satisfy the above model, +* `refresh_token_lifetime` should be set to `L`; and +* `refreshable_access_token_lifetime` should be set to `L - S`. diff --git a/mypy.ini b/mypy.ini index 1caf807e8..85fa22d28 100644 --- a/mypy.ini +++ b/mypy.ini @@ -25,14 +25,9 @@ exclude = (?x) ^( |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py - |synapse/storage/databases/main/account_data.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py - |synapse/storage/databases/main/e2e_room_keys.py - |synapse/storage/databases/main/end_to_end_keys.py |synapse/storage/databases/main/event_federation.py - |synapse/storage/databases/main/event_push_actions.py - |synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py @@ -40,12 +35,9 @@ exclude = (?x) |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py - |synapse/storage/databases/main/room.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/stats.py - |synapse/storage/databases/main/transactions.py |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ @@ -107,7 +99,6 @@ exclude = (?x) |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py |tests/state/test_v2.py - |tests/storage/test_account_data.py |tests/storage/test_background_update.py |tests/storage/test_base.py |tests/storage/test_client_ips.py @@ -145,6 +136,9 @@ disallow_untyped_defs = True [mypy-synapse.app.*] disallow_untyped_defs = True +[mypy-synapse.appservice.*] +disallow_untyped_defs = True + [mypy-synapse.config._base] disallow_untyped_defs = True @@ -163,6 +157,12 @@ disallow_untyped_defs = False [mypy-synapse.handlers.*] disallow_untyped_defs = True +[mypy-synapse.http.server] +disallow_untyped_defs = True + +[mypy-synapse.logging.context] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True @@ -181,24 +181,48 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.account_data] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.client_ips] disallow_untyped_defs = True [mypy-synapse.storage.databases.main.directory] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.e2e_room_keys] +disallow_untyped_defs = True + +[mypy-synapse.storage.databases.main.end_to_end_keys] +disallow_untyped_defs = True + +[mypy-synapse.storage.databases.main.event_push_actions] +disallow_untyped_defs = True + +[mypy-synapse.storage.databases.main.events_bg_updates] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.events_worker] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.room] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.room_batch] disallow_untyped_defs = True [mypy-synapse.storage.databases.main.profile] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.stats] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.state_deltas] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.transactions] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.user_erasure_store] disallow_untyped_defs = True @@ -223,6 +247,9 @@ disallow_untyped_defs = True [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True +[mypy-tests.rest.admin.*] +disallow_untyped_defs = True + [mypy-tests.rest.client.test_directory] disallow_untyped_defs = True @@ -286,9 +313,6 @@ ignore_missing_imports = True [mypy-netaddr] ignore_missing_imports = True -[mypy-opentracing] -ignore_missing_imports = True - [mypy-parameterized.*] ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 8bca1fa4e..963f149c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py36'] +target-version = ['py37', 'py38', 'py39', 'py310'] exclude = ''' ( diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 3a9a2d257..4d34e9070 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -24,7 +24,6 @@ DISTS = ( "debian:bullseye", "debian:bookworm", "debian:sid", - "ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23) "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) "ubuntu:hirsute", # 21.04 (EOL 2022-01-05) "ubuntu:impish", # 21.10 (EOL 2022-07) diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index 25bd75ad1..83ad6a7ad 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -42,8 +42,8 @@ echo "--------------------------" echo matched=0 -for f in $(git diff --name-only FETCH_HEAD... -- changelog.d); do - # check that any modified newsfiles on this branch end with a full stop. +for f in $(git diff --diff-filter=d --name-only FETCH_HEAD... -- changelog.d); do + # check that any added newsfiles on this branch end with a full stop. lastchar=$(tr -d '\n' < "$f" | tail -c 1) if [ "$lastchar" != '.' ] && [ "$lastchar" != '!' ]; then echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 diff --git a/setup.py b/setup.py index 4ac92bd73..e618ff898 100755 --- a/setup.py +++ b/setup.py @@ -107,6 +107,7 @@ def exec_file(path_segments): "mypy-zope==0.3.2", "types-bleach>=4.1.0", "types-jsonschema>=3.2.0", + "types-opentracing>=2.4.2", "types-Pillow>=8.3.4", "types-pyOpenSSL>=20.0.7", "types-PyYAML>=5.4.10", @@ -119,9 +120,7 @@ def exec_file(path_segments): # Tests assume that all optional dependencies are installed. # # parameterized_class decorator was introduced in parameterized 0.7.0 -# -# We use `mock` library as that backports `AsyncMock` to Python 3.6 -CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"] +CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"] CONDITIONAL_REQUIREMENTS["dev"] = ( CONDITIONAL_REQUIREMENTS["lint"] @@ -163,7 +162,6 @@ def exec_file(path_segments): "Topic :: Communications :: Chat", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 4ff3c6de5..429234d7a 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -17,11 +17,12 @@ from typing import Any, List, Optional, Type, Union from twisted.internet import protocol +from twisted.internet.defer import Deferred class RedisProtocol(protocol.Protocol): def publish(self, channel: str, message: bytes): ... - async def ping(self) -> None: ... - async def set( + def ping(self) -> "Deferred[None]": ... + def set( self, key: str, value: Any, @@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol): pexpire: Optional[int] = None, only_if_not_exists: bool = False, only_if_exists: bool = False, - ) -> None: ... - async def get(self, key: str) -> Any: ... + ) -> "Deferred[None]": ... + def get(self, key: str) -> "Deferred[Any]": ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/__init__.py b/synapse/__init__.py index 95a49c20b..201925e91 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ except ImportError: pass -__version__ = "1.49.2" +__version__ = "1.50.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 44883c666..4a32d430b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -32,7 +32,7 @@ from synapse.events import EventBase from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest -from synapse.logging import opentracing as opentracing +from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache @@ -149,13 +149,53 @@ async def get_user_by_req( is invalid. AuthError if access is denied for the user in the access token """ + parent_span = active_span() + with start_active_span("get_user_by_req"): + requester = await self._wrapped_get_user_by_req( + request, allow_guest, rights, allow_expired + ) + + if parent_span: + if requester.authenticated_entity in self._force_tracing_for_users: + # request tracing is enabled for this user, so we need to force it + # tracing on for the parent span (which will be the servlet span). + # + # It's too late for the get_user_by_req span to inherit the setting, + # so we also force it on for that. + force_tracing() + force_tracing(parent_span) + parent_span.set_tag( + "authenticated_entity", requester.authenticated_entity + ) + parent_span.set_tag("user_id", requester.user.to_string()) + if requester.device_id is not None: + parent_span.set_tag("device_id", requester.device_id) + if requester.app_service is not None: + parent_span.set_tag("appservice_id", requester.app_service.id) + return requester + + async def _wrapped_get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool, + rights: str, + allow_expired: bool, + ) -> Requester: + """Helper for get_user_by_req + + Once get_user_by_req has set up the opentracing span, this does the actual work. + """ try: ip_addr = request.getClientIP() user_agent = get_request_user_agent(request) access_token = self.get_access_token_from_request(request) - user_id, app_service = await self._get_appservice_user_id(request) + ( + user_id, + device_id, + app_service, + ) = await self._get_appservice_user_id_and_device_id(request) if user_id and app_service: if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( @@ -163,18 +203,16 @@ async def get_user_by_req( access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id="dummy-device", # stubbed + device_id="dummy-device" + if device_id is None + else device_id, # stubbed ) - requester = create_requester(user_id, app_service=app_service) + requester = create_requester( + user_id, app_service=app_service, device_id=device_id + ) request.requester = user_id - if user_id in self._force_tracing_for_users: - opentracing.force_tracing() - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("user_id", user_id) - opentracing.set_tag("appservice_id", app_service.id) - return requester user_info = await self.get_user_by_access_token( @@ -232,13 +270,6 @@ async def get_user_by_req( ) request.requester = requester - if user_info.token_owner in self._force_tracing_for_users: - opentracing.force_tracing() - opentracing.set_tag("authenticated_entity", user_info.token_owner) - opentracing.set_tag("user_id", user_info.user_id) - if device_id: - opentracing.set_tag("device_id", device_id) - return requester except KeyError: raise MissingClientTokenError() @@ -274,33 +305,81 @@ async def validate_appservice_can_control_user_id( 403, "Application service has not registered this user (%s)" % user_id ) - async def _get_appservice_user_id( + async def _get_appservice_user_id_and_device_id( self, request: Request - ) -> Tuple[Optional[str], Optional[ApplicationService]]: + ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]: + """ + Given a request, reads the request parameters to determine: + - whether it's an application service that's making this request + - what user the application service should be treated as controlling + (the user_id URI parameter allows an application service to masquerade + any applicable user in its namespace) + - what device the application service should be treated as controlling + (the device_id[^1] URI parameter allows an application service to masquerade + as any device that exists for the relevant user) + + [^1] Unstable and provided by MSC3202. + Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. + + Returns: + 3-tuple of + (user ID?, device ID?, application service?) + + Postconditions: + - If an application service is returned, so is a user ID + - A user ID is never returned without an application service + - A device ID is never returned without a user ID or an application service + - The returned application service, if present, is permitted to control the + returned user ID. + - The returned device ID, if present, has been checked to be a valid device ID + for the returned user ID. + """ + DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id" + app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) if app_service is None: - return None, None + return None, None, None if app_service.ip_range_whitelist: ip_address = IPAddress(request.getClientIP()) if ip_address not in app_service.ip_range_whitelist: - return None, None + return None, None, None # This will always be set by the time Twisted calls us. assert request.args is not None - if b"user_id" not in request.args: - return app_service.sender, app_service + if b"user_id" in request.args: + effective_user_id = request.args[b"user_id"][0].decode("utf8") + await self.validate_appservice_can_control_user_id( + app_service, effective_user_id + ) + else: + effective_user_id = app_service.sender - user_id = request.args[b"user_id"][0].decode("utf8") - await self.validate_appservice_can_control_user_id(app_service, user_id) + effective_device_id: Optional[str] = None - if app_service.sender == user_id: - return app_service.sender, app_service + if ( + self.hs.config.experimental.msc3202_device_masquerading_enabled + and DEVICE_ID_ARG_NAME in request.args + ): + effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8") + # We only just set this so it can't be None! + assert effective_device_id is not None + device_opt = await self.store.get_device( + effective_user_id, effective_device_id + ) + if device_opt is None: + # For now, use 400 M_EXCLUSIVE if the device doesn't exist. + # This is an open thread of discussion on MSC3202 as of 2021-12-09. + raise AuthError( + 400, + f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})", + Codes.EXCLUSIVE, + ) - return user_id, app_service + return effective_user_id, effective_device_id, app_service async def get_user_by_access_token( self, diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b431..52c083a20 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -253,5 +253,9 @@ class GuestAccess: FORBIDDEN: Final = "forbidden" +class ReceiptTypes: + READ: Final = "m.read" + + class ReadReceiptEventFields: MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 13dd6ce24..d087c816d 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -351,8 +351,7 @@ def _check(self, event: FilterEvent) -> bool: True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, - # except for presence which actually gets passed around as its own - # namedtuple type. + # except for presence which actually gets passed around as its own type. if isinstance(event, UserPresenceState): user_id = event.user_id field_matchers = { diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index dd76e0732..177ce040e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -27,6 +27,7 @@ import synapse.config.logger from synapse import events from synapse.api.urls import ( + CLIENT_API_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, @@ -192,13 +193,7 @@ def _configure_named_resource( resources.update( { - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/v1": client_resource, - "/_matrix/client/v3": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, + CLIENT_API_PREFIX: client_resource, "/.well-known": well_known_resource(self), "/_synapse/admin": AdminRestResource(self), **build_synapse_client_resource_tree(self), diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f9d3bd337..8c9ff93b2 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import re from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern + +import attr +from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -33,6 +37,13 @@ class ApplicationServiceState(Enum): UP = "up" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Namespace: + exclusive: bool + group_id: Optional[str] + regex: Pattern[str] + + class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -50,17 +61,17 @@ class ApplicationService: def __init__( self, - token, - hostname, - id, - sender, - url=None, - namespaces=None, - hs_token=None, - protocols=None, - rate_limited=True, - ip_range_whitelist=None, - supports_ephemeral=False, + token: str, + hostname: str, + id: str, + sender: str, + url: Optional[str] = None, + namespaces: Optional[JsonDict] = None, + hs_token: Optional[str] = None, + protocols: Optional[Iterable[str]] = None, + rate_limited: bool = True, + ip_range_whitelist: Optional[IPSet] = None, + supports_ephemeral: bool = False, ): self.token = token self.url = ( @@ -85,27 +96,33 @@ def __init__( self.rate_limited = rate_limited - def _check_namespaces(self, namespaces): + def _check_namespaces( + self, namespaces: Optional[JsonDict] + ) -> Dict[str, List[Namespace]]: # Sanity check that it is of the form: # { # users: [ {regex: "[A-z]+.*", exclusive: true}, ...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # } - if not namespaces: + if namespaces is None: namespaces = {} + result: Dict[str, List[Namespace]] = {} + for ns in ApplicationService.NS_LIST: + result[ns] = [] + if ns not in namespaces: - namespaces[ns] = [] continue - if type(namespaces[ns]) != list: + if not isinstance(namespaces[ns], list): raise ValueError("Bad namespace value for '%s'" % ns) for regex_obj in namespaces[ns]: if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) - if not isinstance(regex_obj.get("exclusive"), bool): + exclusive = regex_obj.get("exclusive") + if not isinstance(exclusive, bool): raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: @@ -126,22 +143,26 @@ def _check_namespaces(self, namespaces): ) regex = regex_obj.get("regex") - if isinstance(regex, str): - regex_obj["regex"] = re.compile(regex) # Pre-compile regex - else: + if not isinstance(regex, str): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) - return namespaces - def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: - for regex_obj in self.namespaces[namespace_key]: - if regex_obj["regex"].match(test_string): - return regex_obj + # Pre-compile regex. + result[ns].append(Namespace(exclusive, group_id, re.compile(regex))) + + return result + + def _matches_regex( + self, namespace_key: str, test_string: str + ) -> Optional[Namespace]: + for namespace in self.namespaces[namespace_key]: + if namespace.regex.match(test_string): + return namespace return None - def _is_exclusive(self, ns_key: str, test_string: str) -> bool: - regex_obj = self._matches_regex(test_string, ns_key) - if regex_obj: - return regex_obj["exclusive"] + def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: + namespace = self._matches_regex(namespace_key, test_string) + if namespace: + return namespace.exclusive return False async def _matches_user( @@ -260,15 +281,15 @@ async def is_interested_in_presence( def is_interested_in_user(self, user_id: str) -> bool: return ( - bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) + bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) or user_id == self.sender ) def is_interested_in_alias(self, alias: str) -> bool: - return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) + return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias)) def is_interested_in_room(self, room_id: str) -> bool: - return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) + return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id)) def is_exclusive_user(self, user_id: str) -> bool: return ( @@ -285,14 +306,14 @@ def is_exclusive_alias(self, alias: str) -> bool: def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exclusive_user_regexes(self): + def get_exclusive_user_regexes(self) -> List[Pattern[str]]: """Get the list of regexes used to determine if a user is exclusively registered by the AS """ return [ - regex_obj["regex"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if regex_obj["exclusive"] + namespace.regex + for namespace in self.namespaces[ApplicationService.NS_USERS] + if namespace.exclusive ] def get_groups_for_user(self, user_id: str) -> Iterable[str]: @@ -305,15 +326,15 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]: An iterable that yields group_id strings. """ return ( - regex_obj["group_id"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + namespace.group_id + for namespace in self.namespaces[ApplicationService.NS_USERS] + if namespace.group_id and namespace.regex.match(user_id) ) def is_rate_limited(self) -> bool: return self.rate_limited - def __str__(self): + def __str__(self) -> str: # copy dictionary and redact token fields so they don't get logged dict_copy = self.__dict__.copy() dict_copy["token"] = "" diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index f51b63641..def4424af 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib -from typing import TYPE_CHECKING, List, Optional, Tuple +import urllib.parse +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from prometheus_client import Counter @@ -53,7 +53,7 @@ APP_SERVICE_PREFIX = "/_matrix/app/unstable" -def _is_valid_3pe_metadata(info): +def _is_valid_3pe_metadata(info: JsonDict) -> bool: if "instances" not in info: return False if not isinstance(info["instances"], list): @@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info): return True -def _is_valid_3pe_result(r, field): +def _is_valid_3pe_result(r: JsonDict, field: str) -> bool: if not isinstance(r, dict): return False @@ -93,9 +93,13 @@ def __init__(self, hs: "HomeServer"): hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - async def query_user(self, service, user_id): + async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: return False + + # This is required by the configuration. + assert service.hs_token is not None + uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) try: response = await self.get_json(uri, {"access_token": service.hs_token}) @@ -109,9 +113,13 @@ async def query_user(self, service, user_id): logger.warning("query_user to %s threw exception %s", uri, ex) return False - async def query_alias(self, service, alias): + async def query_alias(self, service: "ApplicationService", alias: str) -> bool: if service.url is None: return False + + # This is required by the configuration. + assert service.hs_token is not None + uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) try: response = await self.get_json(uri, {"access_token": service.hs_token}) @@ -125,7 +133,13 @@ async def query_alias(self, service, alias): logger.warning("query_alias to %s threw exception %s", uri, ex) return False - async def query_3pe(self, service, kind, protocol, fields): + async def query_3pe( + self, + service: "ApplicationService", + kind: str, + protocol: str, + fields: Dict[bytes, List[bytes]], + ) -> List[JsonDict]: if kind == ThirdPartyEntityKind.USER: required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: @@ -205,11 +219,14 @@ async def push_bulk( events: List[EventBase], ephemeral: List[JsonDict], txn_id: Optional[int] = None, - ): + ) -> bool: if service.url is None: return True - events = self._serialize(service, events) + # This is required by the configuration. + assert service.hs_token is not None + + serialized_events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -221,9 +238,12 @@ async def push_bulk( # Never send ephemeral events to appservices that do not support it if service.supports_ephemeral: - body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} + body = { + "events": serialized_events, + "de.sorunome.msc2409.ephemeral": ephemeral, + } else: - body = {"events": events} + body = {"events": serialized_events} try: await self.put_json( @@ -238,7 +258,7 @@ async def push_bulk( [event.get("event_id") for event in events], ) sent_transactions_counter.labels(service.id).inc() - sent_events_counter.labels(service.id).inc(len(events)) + sent_events_counter.labels(service.id).inc(len(serialized_events)) return True except CodeMessageException as e: logger.warning( @@ -260,7 +280,9 @@ async def push_bulk( failed_transactions_counter.labels(service.id).inc() return False - def _serialize(self, service, events): + def _serialize( + self, service: "ApplicationService", events: Iterable[EventBase] + ) -> List[JsonDict]: time_now = self.clock.time_msec() return [ serialize_event( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6a2ce99b5..185e3a527 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,13 +48,19 @@ components. """ import logging -from typing import List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main import DataStore from synapse.types import JsonDict +from synapse.util import Clock + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -72,7 +78,7 @@ class ApplicationServiceScheduler: case is a simple array. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.as_api = hs.get_application_service_api() @@ -80,7 +86,7 @@ def __init__(self, hs): self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) - async def start(self): + async def start(self) -> None: logger.info("Starting appservice scheduler") # check for any DOWN ASes and start recoverers for them. @@ -91,12 +97,14 @@ async def start(self): for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service: ApplicationService, event: EventBase): + def submit_event_for_as( + self, service: ApplicationService, event: EventBase + ) -> None: self.queuer.enqueue_event(service, event) def submit_ephemeral_events_for_as( self, service: ApplicationService, events: List[JsonDict] - ): + ) -> None: self.queuer.enqueue_ephemeral(service, events) @@ -108,16 +116,18 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl, clock): - self.queued_events = {} # dict of {service_id: [events]} - self.queued_ephemeral = {} # dict of {service_id: [events]} + def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + # dict of {service_id: [events]} + self.queued_events: Dict[str, List[EventBase]] = {} + # dict of {service_id: [events]} + self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # the appservices which currently have a transaction in flight - self.requests_in_flight = set() + self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock - def _start_background_request(self, service): + def _start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one if service.id in self.requests_in_flight: return @@ -126,15 +136,17 @@ def _start_background_request(self, service): "as-sender-%s" % (service.id,), self._send_request, service ) - def enqueue_event(self, service: ApplicationService, event: EventBase): + def enqueue_event(self, service: ApplicationService, event: EventBase) -> None: self.queued_events.setdefault(service.id, []).append(event) self._start_background_request(service) - def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): + def enqueue_ephemeral( + self, service: ApplicationService, events: List[JsonDict] + ) -> None: self.queued_ephemeral.setdefault(service.id, []).extend(events) self._start_background_request(service) - async def _send_request(self, service: ApplicationService): + async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -168,20 +180,15 @@ class _TransactionController: if a transaction fails. (Note we have only have one of these in the homeserver.) - - Args: - clock (synapse.util.Clock): - store (synapse.storage.DataStore): - as_api (synapse.appservice.api.ApplicationServiceApi): """ - def __init__(self, clock, store, as_api): + def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi): self.clock = clock self.store = store self.as_api = as_api # map from service id to recoverer instance - self.recoverers = {} + self.recoverers: Dict[str, "_Recoverer"] = {} # for UTs self.RECOVERER_CLASS = _Recoverer @@ -191,7 +198,7 @@ async def send( service: ApplicationService, events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, - ): + ) -> None: try: txn = await self.store.create_appservice_txn( service=service, events=events, ephemeral=ephemeral or [] @@ -207,7 +214,7 @@ async def send( logger.exception("Error creating appservice transaction") run_in_background(self._on_txn_fail, service) - async def on_recovered(self, recoverer): + async def on_recovered(self, recoverer: "_Recoverer") -> None: logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) @@ -217,18 +224,18 @@ async def on_recovered(self, recoverer): recoverer.service, ApplicationServiceState.UP ) - async def _on_txn_fail(self, service): + async def _on_txn_fail(self, service: ApplicationService) -> None: try: await self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") - def start_recoverer(self, service): + def start_recoverer(self, service: ApplicationService) -> None: """Start a Recoverer for the given service Args: - service (synapse.appservice.ApplicationService): + service: """ logger.info("Starting recoverer for AS ID %s", service.id) assert service.id not in self.recoverers @@ -257,7 +264,14 @@ class _Recoverer: callback (callable[_Recoverer]): called once the service recovers. """ - def __init__(self, clock, store, as_api, service, callback): + def __init__( + self, + clock: Clock, + store: DataStore, + as_api: ApplicationServiceApi, + service: ApplicationService, + callback: Callable[["_Recoverer"], Awaitable[None]], + ): self.clock = clock self.store = store self.as_api = as_api @@ -265,8 +279,8 @@ def __init__(self, clock, store, as_api, service, callback): self.callback = callback self.backoff_counter = 1 - def recover(self): - def _retry(): + def recover(self) -> None: + def _retry() -> None: run_as_background_process( "as-recoverer-%s" % (self.service.id,), self.retry ) @@ -275,13 +289,13 @@ def _retry(): logger.info("Scheduling retries on %s in %fs", self.service.id, delay) self.clock.call_later(delay, _retry) - def _backoff(self): + def _backoff(self) -> None: # cap the backoff to be around 8.5min => (2^9) = 512 secs if self.backoff_counter < 9: self.backoff_counter += 1 self.recover() - async def retry(self): + async def retry(self) -> None: logger.info("Starting retries on %s", self.service.id) try: while True: diff --git a/synapse/config/api.py b/synapse/config/api.py index 972f1ffc7..eeb87cc73 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -107,6 +107,8 @@ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: EventTypes.Name, # Per MSC1772. EventTypes.Create, + # Per MSC3173. + EventTypes.Topic, ] diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e4bb7224a..7fad2e042 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -147,8 +147,7 @@ def _load_appservice( # protocols check protocols = as_info.get("protocols") if protocols: - # Because strings are lists in python - if isinstance(protocols, str) or not isinstance(protocols, list): + if not isinstance(protocols, list): raise KeyError("Optional 'protocols' must be a list if present.") for p in protocols: if not isinstance(p, str): diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d78a15097..dbaeb1091 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,7 +32,7 @@ def read_config(self, config: JsonDict, **kwargs): # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) - # MSC2716 (backfill existing history) + # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) # MSC2285 (hidden read receipts) @@ -49,3 +49,8 @@ def read_config(self, config: JsonDict, **kwargs): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) + + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) diff --git a/synapse/config/key.py b/synapse/config/key.py index 035ee2416..ee83c6c06 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,12 +16,14 @@ import hashlib import logging import os -from typing import Any, Dict +from typing import Any, Dict, Iterator, List, Optional import attr import jsonschema from signedjson.key import ( NACL_ED25519, + SigningKey, + VerifyKey, decode_signing_key_base64, decode_verify_key_bytes, generate_signing_key, @@ -31,6 +33,7 @@ ) from unpaddedbase64 import decode_base64 +from synapse.types import JsonDict from synapse.util.stringutils import random_string, random_string_with_symbols from ._base import Config, ConfigError @@ -81,14 +84,13 @@ logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True, auto_attribs=True) class TrustedKeyServer: - # string: name of the server. - server_name = attr.ib() + # name of the server. + server_name: str - # dict[str,VerifyKey]|None: map from key id to key object, or None to disable - # signature verification. - verify_keys = attr.ib(default=None) + # map from key id to key object, or None to disable signature verification. + verify_keys: Optional[Dict[str, VerifyKey]] = None class KeyConfig(Config): @@ -279,15 +281,15 @@ def generate_config_section( % locals() ) - def read_signing_keys(self, signing_key_path, name): + def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: """Read the signing keys in the given path. Args: - signing_key_path (str) - name (str): Associated config key name + signing_key_path + name: Associated config key name Returns: - list[SigningKey] + The signing keys read from the given path. """ signing_keys = self.read_file(signing_key_path, name) @@ -296,7 +298,9 @@ def read_signing_keys(self, signing_key_path, name): except Exception as e: raise ConfigError("Error reading %s: %s" % (name, str(e))) - def read_old_signing_keys(self, old_signing_keys): + def read_old_signing_keys( + self, old_signing_keys: Optional[JsonDict] + ) -> Dict[str, VerifyKey]: if old_signing_keys is None: return {} keys = {} @@ -340,7 +344,7 @@ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: write_signing_keys(signing_key_file, (key,)) -def _perspectives_to_key_servers(config): +def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]: """Convert old-style 'perspectives' configs into new-style 'trusted_key_servers' Returns an iterable of entries to add to trusted_key_servers. @@ -402,7 +406,9 @@ def _perspectives_to_key_servers(config): } -def _parse_key_servers(key_servers, federation_verify_certificates): +def _parse_key_servers( + key_servers: List[Any], federation_verify_certificates: bool +) -> Iterator[TrustedKeyServer]: try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: @@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates): yield result -def _assert_keyserver_has_verify_keys(trusted_key_server): +def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None: if not trusted_key_server.verify_keys: raise ConfigError(INSECURE_NOTARY_ERROR) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 7ac82edb0..1cc26e757 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -22,10 +22,12 @@ @attr.s class MetricsFlags: - known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + known_servers: bool = attr.ib( + default=False, validator=attr.validators.instance_of(bool) + ) @classmethod - def all_off(cls): + def all_off(cls) -> "MetricsFlags": """ Instantiate the flags with all options set to off. """ diff --git a/synapse/config/modules.py b/synapse/config/modules.py index ae0821e5a..85fb05890 100644 --- a/synapse/config/modules.py +++ b/synapse/config/modules.py @@ -37,7 +37,7 @@ def generate_config_section(self, **kwargs): # Server admins can expand Synapse's functionality with external modules. # - # See https://matrix-org.github.io/synapse/latest/modules.html for more + # See https://matrix-org.github.io/synapse/latest/modules/index.html for more # documentation on how to configure or create custom modules for Synapse. # modules: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b129b9dd6..1980351e7 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -14,10 +14,11 @@ import logging import os -from collections import namedtuple from typing import Dict, List, Tuple from urllib.request import getproxies_environment # type: ignore +import attr + from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict @@ -44,18 +45,20 @@ HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" -ThumbnailRequirement = namedtuple( - "ThumbnailRequirement", ["width", "height", "method", "media_type"] -) -MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", - ( - "store_local", # Whether to store newly uploaded local files - "store_remote", # Whether to store newly downloaded remote files - "store_synchronous", # Whether to wait for successful storage for local uploads - ), -) +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThumbnailRequirement: + width: int + height: int + method: str + media_type: str + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class MediaStorageProviderConfig: + store_local: bool # Whether to store newly uploaded local files + store_remote: bool # Whether to store newly downloaded remote files + store_synchronous: bool # Whether to wait for successful storage for local uploads def parse_thumbnail_requirements( @@ -66,11 +69,10 @@ def parse_thumbnail_requirements( method, and thumbnail media type to precalculate Args: - thumbnail_sizes(list): List of dicts with "width", "height", and - "method" keys + thumbnail_sizes: List of dicts with "width", "height", and "method" keys + Returns: - Dictionary mapping from media type string to list of - ThumbnailRequirement tuples. + Dictionary mapping from media type string to list of ThumbnailRequirement. """ requirements: Dict[str, List[ThumbnailRequirement]] = {} for size in thumbnail_sizes: diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 57316c59b..3c5e0f7ce 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -15,8 +15,9 @@ from typing import List +from matrix_common.regex import glob_to_regex + from synapse.types import JsonDict -from synapse.util import glob_to_regex from ._base import Config, ConfigError diff --git a/synapse/config/server.py b/synapse/config/server.py index b89a0c364..0171eff41 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1295,7 +1295,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: help="Turn on the twisted telnet manhole service on the given port.", ) - def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]: + def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]: """Reads the three durations for the GC min interval option, returning seconds.""" if durations is None: return None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4ca111618..6e673d65a 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -16,11 +16,12 @@ import os from typing import List, Optional, Pattern +from matrix_common.regex import glob_to_regex + from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError -from synapse.util import glob_to_regex logger = logging.getLogger(__name__) @@ -132,7 +133,7 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs): self.tls_certificate: Optional[crypto.X509] = None self.tls_private_key: Optional[crypto.PKey] = None - def read_certificate_from_disk(self): + def read_certificate_from_disk(self) -> None: """ Read the certificates and private key from disk. """ diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df6..2038e7292 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -395,7 +395,7 @@ async def serialize_event( event: Union[JsonDict, EventBase], time_now: int, *, - bundle_aggregations: bool = True, + bundle_aggregations: bool = False, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -454,23 +454,26 @@ async def _injected_bundled_aggregations( return event_id = event.event_id + room_id = event.room_id # The bundled aggregations to include. aggregations = {} - annotations = await self.store.get_aggregation_groups_for_event(event_id) + annotations = await self.store.get_aggregation_groups_for_event( + event_id, room_id + ) if annotations.chunk: aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" + event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id, room_id) if edit: # If there is an edit replace the content, preserving existing @@ -503,7 +506,7 @@ async def _injected_bundled_aggregations( ( thread_count, latest_thread_event, - ) = await self.store.get_thread_summary(event_id) + ) = await self.store.get_thread_summary(event_id, room_id) if latest_thread_event: aggregations[RelationTypes.THREAD] = { # Don't bundle aggregations as this could recurse forever. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f56344a3b..addc0bf00 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership @@ -104,10 +103,6 @@ async def _check_sigs_and_hash( return pdu -class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): - pass - - async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase ) -> None: @@ -220,15 +215,12 @@ def _is_invite_via_3pid(event: EventBase) -> bool: ) -def event_from_pdu_json( - pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False -) -> EventBase: +def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase: """Construct an EventBase from an event json received over federation Args: pdu_json: pdu as received over federation room_version: The version of the room this event belongs to - outlier: True to mark this event as an outlier Raises: SynapseError: if the pdu is missing required fields or is otherwise @@ -252,6 +244,4 @@ def event_from_pdu_json( validate_canonicaljson(pdu_json) event = make_event_from_dict(pdu_json, room_version) - event.internal_metadata.outlier = outlier - return event diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fee1477ab..6ea4edfc7 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -265,14 +265,11 @@ async def backfill( room_version = await self.store.get_room_version(room_id) - pdus = [ - event_from_pdu_json(p, room_version, outlier=False) - for p in transaction_data_pdus - ] + pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus] # Check signatures and hash of pdus, removing any from the list that fail checks pdus[:] = await self._check_sigs_and_hash_and_fetch( - dest, pdus, outlier=True, room_version=room_version + dest, pdus, room_version=room_version ) return pdus @@ -282,7 +279,6 @@ async def get_pdu_from_destination_raw( destination: str, event_id: str, room_version: RoomVersion, - outlier: bool = False, timeout: Optional[int] = None, ) -> Optional[EventBase]: """Requests the PDU with given origin and ID from the remote home @@ -292,9 +288,6 @@ async def get_pdu_from_destination_raw( destination: Which homeserver to query event_id: event to fetch room_version: version of the room - outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitrary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -316,8 +309,7 @@ async def get_pdu_from_destination_raw( ) pdu_list: List[EventBase] = [ - event_from_pdu_json(p, room_version, outlier=outlier) - for p in transaction_data["pdus"] + event_from_pdu_json(p, room_version) for p in transaction_data["pdus"] ] if pdu_list and pdu_list[0]: @@ -334,7 +326,6 @@ async def get_pdu( destinations: Iterable[str], event_id: str, room_version: RoomVersion, - outlier: bool = False, timeout: Optional[int] = None, ) -> Optional[EventBase]: """Requests the PDU with given origin and ID from the remote home @@ -347,9 +338,6 @@ async def get_pdu( destinations: Which homeservers to query event_id: event to fetch room_version: version of the room - outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitrary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -377,7 +365,6 @@ async def get_pdu( destination=destination, event_id=event_id, room_version=room_version, - outlier=outlier, timeout=timeout, ) @@ -435,7 +422,6 @@ async def _check_sigs_and_hash_and_fetch( origin: str, pdus: Collection[EventBase], room_version: RoomVersion, - outlier: bool = False, ) -> List[EventBase]: """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in @@ -451,7 +437,6 @@ async def _check_sigs_and_hash_and_fetch( origin pdu room_version - outlier: Whether the events are outliers or not Returns: A list of PDUs that have valid signatures and hashes. @@ -466,7 +451,6 @@ async def _execute(pdu: EventBase) -> None: valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=pdu, origin=origin, - outlier=outlier, room_version=room_version, ) @@ -482,7 +466,6 @@ async def _check_sigs_and_hash_and_fetch_one( pdu: EventBase, origin: str, room_version: RoomVersion, - outlier: bool = False, ) -> Optional[EventBase]: """Takes a PDU and checks its signatures and hashes. If the PDU fails its signature check then we check if we have it in the database and if @@ -494,9 +477,6 @@ async def _check_sigs_and_hash_and_fetch_one( origin pdu room_version - outlier: Whether the events are outliers or not - include_none: Whether to include None in the returned list - for events that have failed their checks Returns: The PDU (possibly redacted) if it has valid signatures and hashes. @@ -521,7 +501,6 @@ async def _check_sigs_and_hash_and_fetch_one( destinations=[pdu_origin], event_id=pdu.event_id, room_version=room_version, - outlier=outlier, timeout=10000, ) except SynapseError: @@ -541,13 +520,10 @@ async def get_event_auth( room_version = await self.store.get_room_version(room_id) - auth_chain = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in res["auth_chain"] - ] + auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]] signed_auth = await self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version + destination, auth_chain, room_version=room_version ) return signed_auth @@ -816,7 +792,6 @@ async def send_request(destination: str) -> SendJoinResult: valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=event, origin=destination, - outlier=True, room_version=room_version, ) @@ -864,7 +839,6 @@ async def _execute(pdu: EventBase) -> None: valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=pdu, origin=destination, - outlier=True, room_version=room_version, ) @@ -1235,7 +1209,7 @@ async def get_missing_events( ] signed_events = await self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version + destination, events, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e37e7620..ee71f289c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,9 +28,9 @@ Union, ) +from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -66,8 +66,8 @@ ) from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util import json_decoder, unwrapFirstError +from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -360,13 +360,13 @@ async def _handle_incoming_transaction( # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. pdu_results, _ = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self._handle_pdus_in_txn, origin, transaction, request_time ), run_in_background(self._handle_edus_in_txn, origin, transaction), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 63289a5a3..0d7c4f506 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -30,7 +30,6 @@ """ import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Dict, @@ -43,6 +42,7 @@ Type, ) +import attr from sortedcontainers import SortedDict from synapse.api.presence import UserPresenceState @@ -382,13 +382,11 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: raise NotImplementedError() -class PresenceDestinationsRow( - BaseFederationRow, - namedtuple( - "PresenceDestinationsRow", - ("state", "destinations"), # UserPresenceState # list[str] - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PresenceDestinationsRow(BaseFederationRow): + state: UserPresenceState + destinations: List[str] + TypeId = "pd" @staticmethod @@ -404,17 +402,15 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow( - BaseFederationRow, - namedtuple( - "KeyedEduRow", - ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class KeyedEduRow(BaseFederationRow): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ + key: Tuple[str, ...] # the edu key passed to send_edu + edu: Edu + TypeId = "k" @staticmethod @@ -428,9 +424,12 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EduRow(BaseFederationRow): """Streams EDUs that don't have keys. See KeyedEduRow""" + edu: Edu + TypeId = "e" @staticmethod @@ -453,14 +452,14 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: TypeToRow = {Row.TypeId: Row for Row in _rowtypes} -ParsedFederationStreamData = namedtuple( - "ParsedFederationStreamData", - ( - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - ), -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ParsedFederationStreamData: + # list of tuples of UserPresenceState and destinations + presence_destinations: List[Tuple[UserPresenceState, List[str]]] + # dict of destination -> { key -> Edu } + keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]] + # dict of destination -> [Edu] + edus: Dict[str, List[Edu]] def process_rows_for_federation( diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dc39e3537..da1fbf8b6 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -22,13 +22,11 @@ from synapse.http.server import HttpServer, ServletCallback from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( - SynapseTags, - start_active_span, - start_active_span_from_request, - tags, + set_tag, + span_context_from_request, + start_active_span_follows_from, whitelisted_homeserver, ) from synapse.server import HomeServer @@ -279,30 +277,19 @@ async def new_func( logger.warning("authenticate_request failed: %s", e) raise - request_tags = { - SynapseTags.REQUEST_ID: request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - "authenticated_entity": origin, - "servlet_name": request.request_metrics.name, - } - - # Only accept the span context if the origin is authenticated - # and whitelisted + # update the active opentracing span with the authenticated entity + set_tag("authenticated_entity", origin) + + # if the origin is authenticated and whitelisted, link to its span context + context = None if origin and whitelisted_homeserver(origin): - scope = start_active_span_from_request( - request, "incoming-federation-request", tags=request_tags - ) - else: - scope = start_active_span( - "incoming-federation-request", tags=request_tags - ) + context = span_context_from_request(request) - with scope: - opentracing.inject_response_headers(request.responseHeaders) + scope = start_active_span_follows_from( + "incoming-federation-request", contexts=(context,) if context else () + ) + with scope: if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9abdad262..7833e77e2 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -462,9 +462,9 @@ async def query_room_alias_exists( Args: room_alias: The room alias to query. + Returns: - namedtuple: with keys "room_id" and "servers" or None if no - association can be found. + RoomAliasMapping or None if no association can be found. """ room_alias_str = room_alias.to_string() services = self.store.get_app_services() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6acab8ca5..75edd7d4d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -997,9 +997,7 @@ async def create_access_token_for_user_id( # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - try: - await self.store.get_device(user_id, device_id) - except StoreError: + if await self.store.get_device(user_id, device_id) is None: await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 82ee11e92..766542523 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -106,10 +106,10 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict: Raises: errors.NotFoundError: if the device was not found """ - try: - device = await self.store.get_device(user_id, device_id) - except errors.StoreError: - raise errors.NotFoundError + device = await self.store.get_device(user_id, device_id) + if device is None: + raise errors.NotFoundError() + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) @@ -602,6 +602,8 @@ async def rehydrate_device( access_token, device_id ) old_device = await self.store.get_device(user_id, old_device_id) + if old_device is None: + raise errors.NotFoundError() await self.store.update_device(user_id, device_id, old_device["display_name"]) # can't call self.delete_device because that will clobber the # access token so call the storage layer directly diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7ee5c47fd..082f52179 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -278,13 +278,15 @@ async def get_association(self, room_alias: RoomAlias) -> JsonDict: users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} - servers = set(extra_servers) | set(servers) + servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. - if self.server_name in servers: - servers = [self.server_name] + [s for s in servers if s != self.server_name] + if self.server_name in servers_set: + servers = [self.server_name] + [ + s for s in servers_set if s != self.server_name + ] else: - servers = list(servers) + servers = list(servers_set) return {"room_id": room_id, "servers": servers} diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 60c11e3d2..14360b4e4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -65,8 +65,12 @@ def __init__(self, hs: "HomeServer"): else: # Only register this edu handler on master as it requires writing # device updates to the db - # - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "m.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # also handle the unstable version + # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update, @@ -576,7 +580,9 @@ async def upload_keys_for_user( log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) - fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) + fallback_keys = keys.get("fallback_keys") or keys.get( + "org.matrix.msc2732.fallback_keys" + ) if fallback_keys and isinstance(fallback_keys, dict): log_kv( { diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 31742236a..12614b2c5 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -14,7 +14,9 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Dict, Optional + +from typing_extensions import Literal from synapse.api.errors import ( Codes, @@ -24,6 +26,7 @@ SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.storage.databases.main.e2e_room_keys import RoomKey from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer @@ -58,7 +61,9 @@ async def get_room_keys( version: str, room_id: Optional[str] = None, session_id: Optional[str] = None, - ) -> List[JsonDict]: + ) -> Dict[ + Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] + ]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -72,8 +77,8 @@ async def get_room_keys( Raises: NotFoundError: if the backup version does not exist Returns: - A list of dicts giving the session_data and message metadata for - these room keys. + A dict giving the session_data and message metadata for these room keys. + `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}` """ # we deliberately take the lock to get keys so that changing the version @@ -273,7 +278,7 @@ async def upload_room_keys( @staticmethod def _should_replace_room_key( - current_room_key: Optional[JsonDict], room_key: JsonDict + current_room_key: Optional[RoomKey], room_key: RoomKey ) -> bool: """ Determine whether to replace a given current_room_key (if any) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 32b0254c5..1b996c420 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -79,13 +79,14 @@ async def get_stream( # thundering herds on restart. timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) - events, tokens = await self.notifier.get_events_for( + stream_result = await self.notifier.get_events_for( auth_user, pagin_config, timeout, is_guest=is_guest, explicit_room_id=room_id, ) + events = stream_result.events time_now = self.clock.time_msec() @@ -122,14 +123,12 @@ async def get_stream( events, time_now, as_client_event=as_client_event, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, ) chunk = { "chunk": chunks, - "start": await tokens[0].to_string(self.store), - "end": await tokens[1].to_string(self.store), + "start": await stream_result.start_token.to_string(self.store), + "end": await stream_result.end_token.to_string(self.store), } return chunk diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1ea837d08..26b8e3f43 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -360,31 +360,34 @@ async def try_backfill(domains: List[str]) -> bool: logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = await make_deferred_yieldable( + states_list = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) ) - # dict[str, dict[tuple, str]], a map from event_id to state map of - # event_ids. - states = dict(zip(event_ids, [s.state for s in states])) + # A map from event_id to state map of event_ids. + state_ids: Dict[str, StateMap[str]] = dict( + zip(event_ids, [s.state for s in states_list]) + ) state_map = await self.store.get_events( - [e_id for ids in states.values() for e_id in ids.values()], + [e_id for ids in state_ids.values() for e_id in ids.values()], get_prev_content=False, ) - states = { + + # A map from event_id to state map of events. + state_events: Dict[str, StateMap[EventBase]] = { key: { k: state_map[e_id] for k, e_id in state_dict.items() if e_id in state_map } - for key, state_dict in states.items() + for key, state_dict in state_ids.items() } for e_id in event_ids: - likely_extremeties_domains = get_domains_from_state(states[e_id]) + likely_extremeties_domains = get_domains_from_state(state_events[e_id]) success = await try_backfill( [ diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 991761329..11771f3c9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -421,9 +421,6 @@ async def process_remote_join( Raises: SynapseError if the response is in some way invalid. """ - for e in itertools.chain(auth_events, state): - e.internal_metadata.outlier = True - event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} create_event = None @@ -666,7 +663,9 @@ async def _process_pulled_event( logger.info("Processing pulled event %s", event) # these should not be outliers. - assert not event.internal_metadata.is_outlier() + assert ( + not event.internal_metadata.is_outlier() + ), "pulled event unexpectedly flagged as outlier" event_id = event.event_id @@ -1192,7 +1191,6 @@ async def get_event(event_id: str) -> None: [destination], event_id, room_version, - outlier=True, ) if event is None: logger.warning( @@ -1221,9 +1219,10 @@ async def _auth_and_persist_outliers( """Persist a batch of outlier events fetched from remote servers. We first sort the events to make sure that we process each event's auth_events - before the event itself, and then auth and persist them. + before the event itself. - Notifies about the events where appropriate. + We then mark the events as outliers, persist them to the database, and, where + appropriate (eg, an invite), awake the notifier. Params: room_id: the room that the events are meant to be in (though this has @@ -1274,7 +1273,8 @@ async def _auth_and_persist_outliers_inner( Persists a batch of events where we have (theoretically) already persisted all of their auth events. - Notifies about the events where appropriate. + Marks the events as outliers, auths them, persists them to the database, and, + where appropriate (eg, an invite), awakes the notifier. Params: origin: where the events came from @@ -1312,6 +1312,9 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: return None auth.append(ae) + # we're not bothering about room state, so flag the event as an outlier. + event.internal_metadata.outlier = True + context = EventContext.for_outlier() try: validate_event_for_room_version(room_version_obj, event) @@ -1838,7 +1841,7 @@ async def persist_events_and_notify( The stream ID after which all events have been persisted. """ if not event_and_contexts: - return self._store.get_current_events_token() + return self._store.get_room_max_stream_ordering() instance = self._config.worker.events_shard_config.get_instance(room_id) if instance != self._instance_name: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9cd21e7f2..601bab67f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -13,21 +13,27 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple - -from twisted.internet import defer +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import unwrapFirstError -from synapse.util.async_helpers import concurrently_execute +from synapse.util.async_helpers import concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -167,8 +173,6 @@ async def handle_room(event: RoomsForUser) -> None: d["invite"] = await self._event_serializer.serialize_event( invite_event, time_now, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, as_client_event=as_client_event, ) @@ -190,14 +194,13 @@ async def handle_room(event: RoomsForUser) -> None: ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] + ).addCallback( + lambda states: cast(StateMap[EventBase], states[event.event_id]) ) (messages, token), current_state = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self.store.get_recent_events_for_room, event.room_id, @@ -205,7 +208,7 @@ async def handle_room(event: RoomsForUser) -> None: end_token=room_end_token, ), deferred_room_state, - ] + ) ) ).addErrback(unwrapFirstError) @@ -222,8 +225,6 @@ async def handle_room(event: RoomsForUser) -> None: await self._event_serializer.serialize_events( messages, time_now=time_now, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, as_client_event=as_client_event, ) ), @@ -234,8 +235,6 @@ async def handle_room(event: RoomsForUser) -> None: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, as_client_event=as_client_event, ) @@ -377,9 +376,7 @@ async def _room_initial_sync_parted( "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - messages, time_now, bundle_aggregations=False - ) + await self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), @@ -387,7 +384,7 @@ async def _room_initial_sync_parted( "state": ( # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now, bundle_aggregations=False + room_state.values(), time_now ) ), "presence": [], @@ -408,7 +405,7 @@ async def _room_initial_sync_joined( time_now = self.clock.time_msec() # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now, bundle_aggregations=False + current_state.values(), time_now ) now_token = self.hs.get_event_sources().get_current_token() @@ -454,8 +451,8 @@ async def get_receipts() -> List[JsonDict]: return receipts presence, receipts, (messages, token) = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background(get_presence), run_in_background(get_receipts), run_in_background( @@ -464,7 +461,7 @@ async def get_receipts() -> List[JsonDict]: limit=limit, end_token=now_token.room_key, ), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -483,9 +480,7 @@ async def get_receipts() -> List[JsonDict]: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - messages, time_now, bundle_aggregations=False - ) + await self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 87f671708..5e3d3886e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,7 +21,6 @@ from canonicaljson import encode_canonical_json -from twisted.internet import defer from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -57,7 +56,7 @@ from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -496,6 +495,7 @@ async def create_event( require_consent: bool = True, outlier: bool = False, historical: bool = False, + allow_no_prev_events: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ @@ -607,6 +607,7 @@ async def create_event( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + allow_no_prev_events=allow_no_prev_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -882,6 +883,7 @@ async def create_new_client_event( prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + allow_no_prev_events: bool = False, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -912,6 +914,7 @@ async def create_new_client_event( full_state_ids_at_event = None if auth_event_ids is not None: # If auth events are provided, prev events must be also. + # prev_event_ids could be an empty array though. assert prev_event_ids is not None # Copy the full auth state before it stripped down @@ -943,14 +946,22 @@ async def create_new_client_event( else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - # we now ought to have some prev_events (unless it's a create event). - # - # do a quick sanity check here, rather than waiting until we've created the + # Do a quick sanity check here, rather than waiting until we've created the # event and then try to auth it (which fails with a somewhat confusing "No # create event in auth events") - assert ( - builder.type == EventTypes.Create or len(prev_event_ids) > 0 - ), "Attempting to create an event with no prev_events" + if allow_no_prev_events: + # We allow events with no `prev_events` but it better have some `auth_events` + assert ( + builder.type == EventTypes.Create + # Allow an event to have empty list of prev_event_ids + # only if it has auth_event_ids. + or auth_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids" + else: + # we now ought to have some prev_events (unless it's a create event). + assert ( + builder.type == EventTypes.Create or prev_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events" event = await builder.build( prev_event_ids=prev_event_ids, @@ -1156,9 +1167,9 @@ async def handle_new_client_event( # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - result = await make_deferred_yieldable( - defer.gatherResults( - [ + result, _ = await make_deferred_yieldable( + gather_results( + ( run_in_background( self._persist_event, requester=requester, @@ -1170,12 +1181,12 @@ async def handle_new_client_event( run_in_background( self.cache_joined_hosts_for_event, event, context ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ], + ), consumeErrors=True, ) ).addErrback(unwrapFirstError) - return result[0] + return result async def _persist_event( self, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 4f4243805..7469cc55a 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -542,7 +542,10 @@ async def get_messages( chunk = { "chunk": ( await self._event_serializer.serialize_events( - events, time_now, as_client_event=as_client_event + events, + time_now, + bundle_aggregations=True, + as_client_event=as_client_event, ) ), "start": await from_token.to_string(self.store), diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 454d06c97..c781fefb1 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -729,7 +729,7 @@ def run_persister() -> Awaitable[None]: # Presence is best effort and quickly heals itself, so lets just always # stream from the current state when we restart. - self._event_pos = self.store.get_current_events_token() + self._event_pos = self.store.get_room_max_stream_ordering() self._event_processing = False async def _on_shutdown(self) -> None: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4911a1153..5cb1ff749 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: for event_id in content.keys(): event_content = content.get(event_id, {}) - m_read = event_content.get("m.read", {}) + m_read = event_content.get(ReceiptTypes.READ, {}) # If m_read is missing copy over the original event_content as there is nothing to process here if not m_read: @@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: # Set new users unless empty if len(new_users.keys()) > 0: - new_event["content"][event_id] = {"m.read": new_users} + new_event["content"][event_id] = {ReceiptTypes.READ: new_users} # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ead2198e1..b9c1cbffa 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -172,7 +172,7 @@ async def upgrade_room( user_id = requester.user.to_string() # Check if this room is already being upgraded by another person - for key in self._upgrade_response_cache.pending_result_cache: + for key in self._upgrade_response_cache.keys(): if key[0] == old_room_id and key[1] != user_id: # Two different people are trying to upgrade the same room. # Send the second an error. diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index ba7a14d65..1a33211a1 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,9 +13,9 @@ # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING, Any, Optional, Tuple +import attr import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -474,16 +474,12 @@ async def _get_remote_list_cached( ) -class RoomListNextBatch( - namedtuple( - "RoomListNextBatch", - ( - "last_joined_members", # The count to get rooms after/before - "last_room_id", # The room_id to get rooms after/before - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomListNextBatch: + last_joined_members: int # The count to get rooms after/before + last_room_id: str # The room_id to get rooms after/before + direction_is_forward: bool # True if this is a next_batch, false if prev_batch + KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", @@ -502,12 +498,12 @@ def from_token(cls, token: str) -> "RoomListNextBatch": def to_token(self) -> str: return encode_base64( msgpack.dumps( - {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} ) ) def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch": - return self._replace(**kwds) + return attr.evolve(self, **kwds) def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b520b2c3a..18f9e6386 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -671,7 +671,8 @@ async def update_membership_locked( if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - if prev_event_ids: + # An empty prev_events list is allowed as long as the auth_event_ids are present + if prev_event_ids is not None: return await self._local_membership_update( requester=requester, target=target, @@ -1032,7 +1033,7 @@ async def transfer_room_state_on_room_upgrade( # Add new room to the room directory if the old room was there # Remove old room from the room directory old_room = await self.store.get_room(old_room_id) - if old_room and old_room["is_public"]: + if old_room is not None and old_room["is_public"]: await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(room_id, True) @@ -1043,7 +1044,9 @@ async def transfer_room_state_on_room_upgrade( local_group_ids = await self.store.get_local_groups_for_room(old_room_id) for group_id in local_group_ids: # Add new the new room to those groups - await self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) + await self.store.add_room_to_group( + group_id, room_id, old_room is not None and old_room["is_public"] + ) # Remove the old room from those groups await self.store.remove_room_from_group(group_id, old_room_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index bd3e6f2ec..29e41a4c7 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -80,6 +80,17 @@ async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_stats_positions() + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self.pos > room_max_stream_ordering: + # apparently, we've processed more events than exist in the database! + # this can happen if events are removed with history purge or similar. + logger.warning( + "Event stream ordering appears to have gone backwards (%i -> %i): " + "rewinding stats processor", + self.pos, + room_max_stream_ordering, + ) + self.pos = room_max_stream_ordering # Loop round handling deltas until we're up to date diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f3039c3c3..7baf3f199 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -36,6 +36,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user +from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -421,7 +422,7 @@ async def current_sync_for_user( span to track the sync. See `generate_sync_result` for the next part of your indoctrination. """ - with start_active_span("current_sync_for_user"): + with start_active_span("sync.current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( sync_config, since_token, full_state @@ -1041,18 +1042,17 @@ async def compute_state_delta( async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> Dict[str, int]: + ) -> NotifCounts: with Measure(self.clock, "unread_notifs_for_room_id"): last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read", + receipt_type=ReceiptTypes.READ, ) - notifs = await self.store.get_unread_event_push_actions_by_room_for_user( + return await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) - return notifs async def generate_sync_result( self, @@ -1585,7 +1585,8 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: ) logger.debug("Generated room entry for %s", room_entry.room_id) - await concurrently_execute(handle_room_entries, room_entries, 10) + with start_active_span("sync.generate_room_entries"): + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) @@ -1662,20 +1663,20 @@ async def _get_rooms_changed( ) -> _RoomChanges: """Determine the changes in rooms to report to the user. - Ideally, we want to report all events whose stream ordering `s` lies in the - range `since_token < s <= now_token`, where the two tokens are read from the - sync_result_builder. + This function is a first pass at generating the rooms part of the sync response. + It determines which rooms have changed during the sync period, and categorises + them into four buckets: "knock", "invite", "join" and "leave". - If there are too many events in that range to report, things get complicated. - In this situation we return a truncated list of the most recent events, and - indicate in the response that there is a "gap" of omitted events. Additionally: + 1. Finds all membership changes for the user in the sync period (from + `since_token` up to `now_token`). + 2. Uses those to place the room in one of the four categories above. + 3. Builds a `_RoomChanges` struct to record this, and return that struct. - - we include a "state_delta", to describe the changes in state over the gap, - - we include all membership events applying to the user making the request, - even those in the gap. - - See the spec for the rationale: - https://spec.matrix.org/v1.1/client-server-api/#syncing + For rooms classified as "knock", "invite" or "leave", we just need to report + a single membership event in the eventual /sync response. For "join" we need + to fetch additional non-membership events, e.g. messages in the room. That is + more complicated, so instead we report an intermediary `RoomSyncResultBuilder` + struct, and leave the additional work to `_generate_room_entry`. The sync_result_builder is not modified by this function. """ @@ -1686,16 +1687,6 @@ async def _get_rooms_changed( assert since_token - # The spec - # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync - # notes that membership events need special consideration: - # - # > When a sync is limited, the server MUST return membership events for events - # > in the gap (between since and the start of the returned timeline), regardless - # > as to whether or not they are redundant. - # - # We fetch such events here, but we only seem to use them for categorising rooms - # as newly joined, newly left, invited or knocked. # TODO: we've already called this function and ran this query in # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. @@ -2009,6 +2000,23 @@ async def _generate_room_entry( """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. + Ideally, we want to report all events whose stream ordering `s` lies in the + range `since_token < s <= now_token`, where the two tokens are read from the + sync_result_builder. + + If there are too many events in that range to report, things get complicated. + In this situation we return a truncated list of the most recent events, and + indicate in the response that there is a "gap" of omitted events. Lots of this + is handled in `_load_filtered_recents`, but some of is handled in this method. + + Additionally: + - we include a "state_delta", to describe the changes in state over the gap, + - we include all membership events applying to the user making the request, + even those in the gap. + + See the spec for the rationale: + https://spec.matrix.org/v1.1/client-server-api/#syncing + Args: sync_result_builder ignored_users: Set of users ignored by user. @@ -2038,7 +2046,7 @@ async def _generate_room_entry( since_token = room_builder.since_token upto_token = room_builder.upto_token - with start_active_span("generate_room_entry"): + with start_active_span("sync.generate_room_entry"): set_tag("room_id", room_id) log_kv({"events": len(events or ())}) @@ -2166,10 +2174,10 @@ async def _generate_room_entry( if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] + unread_notifications["notification_count"] = notifs.notify_count + unread_notifications["highlight_count"] = notifs.highlight_count - room_sync.unread_count = notifs["unread_count"] + room_sync.unread_count = notifs.unread_count sync_result_builder.joined.append(room_sync) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 1676ebd05..e43c22832 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -13,9 +13,10 @@ # limitations under the License. import logging import random -from collections import namedtuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +import attr + from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -37,7 +38,10 @@ # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomMember: + room_id: str + user_id: str # How often we expect remote servers to resend us presence. @@ -119,7 +123,7 @@ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member: RoomMember) -> bool: - return member.user_id in self._room_typing.get(member.room_id, []) + return member.user_id in self._room_typing.get(member.room_id, set()) async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: @@ -166,9 +170,9 @@ def process_replication_rows( for row in rows: self._room_serials[row.room_id] = token - prev_typing = set(self._room_typing.get(row.room_id, [])) + prev_typing = self._room_typing.get(row.room_id, set()) now_typing = set(row.user_ids) - self._room_typing[row.room_id] = row.user_ids + self._room_typing[row.room_id] = now_typing if self.federation: run_as_background_process( diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index a0eb45446..1565e034c 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -148,9 +148,21 @@ async def _unsafe_process(self) -> None: if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() - # If still None then the initial background update hasn't happened yet. - if self.pos is None: - return None + # If still None then the initial background update hasn't happened yet. + if self.pos is None: + return None + + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self.pos > room_max_stream_ordering: + # apparently, we've processed more events than exist in the database! + # this can happen if events are removed with history purge or similar. + logger.warning( + "Event stream ordering appears to have gone backwards (%i -> %i): " + "rewinding user directory processor", + self.pos, + room_max_stream_ordering, + ) + self.pos = room_max_stream_ordering # Loop round handling deltas until we're up to date while True: diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 578fc48ef..efecb089c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,7 +25,7 @@ class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(504, msg) @@ -33,7 +33,7 @@ def __init__(self, msg): CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$") -def redact_uri(uri): +def redact_uri(uri: str) -> str: """Strips sensitive information from the uri replaces with """ uri = ACCESS_TOKEN_RE.sub(r"\1\3", uri) return CLIENT_SECRET_RE.sub(r"\1\3", uri) @@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer): https://twistedmatrix.com/trac/ticket/6528 """ - def stopProducing(self): + def stopProducing(self) -> None: try: FileBodyProducer.stopProducing(self) except task.TaskStopped: diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 9a2684aca..6a9f6635d 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource): and exception handling. """ - def __init__(self, hs: "HomeServer", handler): + def __init__( + self, + hs: "HomeServer", + handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]], + ): """Initialise AdditionalResource The ``handler`` should return a deferred which completes when it has @@ -47,7 +51,7 @@ def __init__(self, hs: "HomeServer", handler): super().__init__() self._handler = handler - def _async_render(self, request: Request): + async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]: # Cheekily pass the result straight through, so we don't need to worry # if its an awaitable or not. - return self._handler(request) + return await self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index b5a2d333a..ca33b45cb 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. import logging import urllib.parse +from http import HTTPStatus from io import BytesIO from typing import ( TYPE_CHECKING, @@ -280,7 +281,9 @@ def request( ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) - e = SynapseError(403, "IP address blocked by IP blacklist entry") + e = SynapseError( + HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry" + ) return defer.fail(Failure(e)) return self._agent.request( @@ -585,7 +588,7 @@ async def get_json( if headers: actual_headers.update(headers) # type: ignore - body = await self.get_raw(uri, args, headers=headers) + body = await self.get_raw(uri, args, headers=actual_headers) return json_decoder.decode(body.decode("utf-8")) async def put_json( @@ -719,7 +722,9 @@ async def get_file( if response.code > 299: logger.warning("Got %d when downloading %s" % (response.code, url)) - raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) + raise SynapseError( + HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN + ) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it @@ -731,12 +736,14 @@ async def get_file( ) except BodyExceededMaxSize: raise SynapseError( - 502, + HTTPStatus.BAD_GATEWAY, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) except Exception as e: - raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e + raise SynapseError( + HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e) + ) from e return ( length, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 1238bfd28..a8a520f80 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -25,6 +25,7 @@ from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import ( + IProtocol, IProtocolFactory, IReactorCore, IStreamClientEndpoint, @@ -309,12 +310,14 @@ def __init__( self._srv_resolver = srv_resolver - def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: + def connect( + self, protocol_factory: IProtocolFactory + ) -> "defer.Deferred[IProtocol]": """Implements IStreamClientEndpoint interface""" return run_in_background(self._do_connect, protocol_factory) - async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: + async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol: first_exception = None server_list = await self._resolve_server() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 203d723d4..deedde0b5 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,6 +19,7 @@ import sys import typing import urllib.parse +from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, @@ -1154,7 +1155,7 @@ async def get_file( request.destination, msg, ) - raise SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", diff --git a/synapse/http/server.py b/synapse/http/server.py index 91badb0b0..09b412548 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,7 +14,6 @@ # limitations under the License. import abc -import collections import html import logging import types @@ -30,12 +29,14 @@ Iterable, Iterator, List, + NoReturn, Optional, Pattern, Tuple, Union, ) +import attr import jinja2 from canonicaljson import encode_canonical_json from typing_extensions import Protocol @@ -57,12 +58,14 @@ ) from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background -from synapse.logging.opentracing import trace_servlet +from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: + import opentracing + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -170,7 +173,9 @@ def return_html_error( respond_with_html(request, code, body) -def wrap_async_request_handler(h): +def wrap_async_request_handler( + h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] +) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. This helps ensure that work done by the request handler after the request is completed @@ -183,7 +188,9 @@ def wrap_async_request_handler(h): logged until the deferred completes. """ - async def wrapped_async_request_handler(self, request): + async def wrapped_async_request_handler( + self: "_AsyncResource", request: SynapseRequest + ) -> None: with request.processing(): await h(self, request) @@ -240,18 +247,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): context from the request the servlet is handling. """ - def __init__(self, extract_context=False): + def __init__(self, extract_context: bool = False): super().__init__() self._extract_context = extract_context - def render(self, request): + def render(self, request: SynapseRequest) -> int: """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request: SynapseRequest): + async def _async_render_wrapper(self, request: SynapseRequest) -> None: """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -271,7 +278,7 @@ async def _async_render_wrapper(self, request: SynapseRequest): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request: Request): + async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: """Delegates to `_async_render_` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -318,7 +325,7 @@ class DirectServeJsonResource(_AsyncResource): formatting responses and errors as JSON. """ - def __init__(self, canonical_json=False, extract_context=False): + def __init__(self, canonical_json: bool = False, extract_context: bool = False): super().__init__(extract_context) self.canonical_json = canonical_json @@ -327,7 +334,7 @@ def _send_response( request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( @@ -347,9 +354,11 @@ def _send_error_response( return_json_error(f, request) -_PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PathEntry: + pattern: Pattern + callback: ServletCallback + servlet_classname: str class JsonResource(DirectServeJsonResource): @@ -368,34 +377,45 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): + def __init__( + self, + hs: "HomeServer", + canonical_json: bool = True, + extract_context: bool = False, + ): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs - def register_paths(self, method, path_patterns, callback, servlet_classname): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate handler for that request. Args: - method (str): GET, POST etc + method: GET, POST etc - path_patterns (Iterable[str]): A list of regular expressions to which - the request URLs are compared. + path_patterns: A list of regular expressions to which the request + URLs are compared. - callback (function): The handler for the request. Usually a Servlet + callback: The handler for the request. Usually a Servlet - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ - method = method.encode("utf-8") # method is bytes on py3 + method_bytes = method.encode("utf-8") for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) - self.path_regexs.setdefault(method, []).append( + self.path_regexs.setdefault(method_bytes, []).append( _PathEntry(path_pattern, callback, servlet_classname) ) @@ -427,7 +447,7 @@ def _get_handler_for_request( # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, "unrecognised_request_handler", {} - async def _async_render(self, request): + async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) # Make sure we have an appropriate name for this handler in prometheus @@ -468,7 +488,7 @@ def _send_response( request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) @@ -492,12 +512,12 @@ class StaticResource(File): Differs from the File resource by adding clickjacking protection. """ - def render_GET(self, request: Request): + def render_GET(self, request: Request) -> bytes: set_clickjacking_protection_headers(request) return super().render_GET(request) -def _unrecognised_request_handler(request): +def _unrecognised_request_handler(request: Request) -> NoReturn: """Request handler for unrecognised requests This is a request handler suitable for return from @@ -505,7 +525,7 @@ def _unrecognised_request_handler(request): UnrecognizedRequestError. Args: - request (twisted.web.http.Request): + request: Unused, but passed in to match the signature of ServletCallback. """ raise UnrecognizedRequestError() @@ -513,23 +533,23 @@ def _unrecognised_request_handler(request): class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" - def __init__(self, path): - resource.Resource.__init__(self) + def __init__(self, path: str): + super().__init__() self.url = path - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: return redirectTo(self.url.encode("ascii"), request) - def getChild(self, name, request): + def getChild(self, name: str, request: Request) -> resource.Resource: if len(name) == 0: return self # select ourselves as the child to render - return resource.Resource.getChild(self, name, request) + return super().getChild(name, request) class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request): + def render_OPTIONS(self, request: Request) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -537,10 +557,10 @@ def render_OPTIONS(self, request): return b"" - def getChildWithDefault(self, path, request): + def getChildWithDefault(self, path: str, request: Request) -> resource.Resource: if request.method == b"OPTIONS": return self # select ourselves as the child to render - return resource.Resource.getChildWithDefault(self, path, request) + return super().getChildWithDefault(path, request) class RootOptionsRedirectResource(OptionsResource, RootRedirect): @@ -649,7 +669,7 @@ def respond_with_json( json_object: Any, send_cors: bool = False, canonical_json: bool = True, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -696,7 +716,7 @@ def respond_with_json_bytes( code: int, json_bytes: bytes, send_cors: bool = False, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -713,7 +733,7 @@ def respond_with_json_bytes( logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -731,7 +751,7 @@ async def _async_write_json_to_request_in_thread( request: SynapseRequest, json_encoder: Callable[[Any], bytes], json_object: Any, -): +) -> None: """Encodes the given JSON object on a thread and then writes it to the request. @@ -743,7 +763,20 @@ async def _async_write_json_to_request_in_thread( expensive. """ - json_str = await defer_to_thread(request.reactor, json_encoder, json_object) + def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes: + # it might take a while for the threadpool to schedule us, so we write + # opentracing logs once we actually get scheduled, so that we can see how + # much that contributed. + if opentracing_span: + opentracing_span.log_kv({"event": "scheduled"}) + res = json_encoder(json_object) + if opentracing_span: + opentracing_span.log_kv({"event": "encoded"}) + return res + + with start_active_span("encode_json_response"): + span = active_span() + json_str = await defer_to_thread(request.reactor, encode, span) _write_bytes_to_request(request, json_str) @@ -773,7 +806,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request): +def set_cors_headers(request: Request) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -790,14 +823,14 @@ def set_cors_headers(request: Request): ) -def respond_with_html(request: Request, code: int, html: str): +def respond_with_html(request: Request, code: int, html: str) -> None: """ Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. """ respond_with_html_bytes(request, code, html.encode("utf-8")) -def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): +def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None: """ Sends HTML (encoded as UTF-8 bytes) as the response to the given request. @@ -815,7 +848,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") @@ -828,7 +861,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): finish_request(request) -def set_clickjacking_protection_headers(request: Request): +def set_clickjacking_protection_headers(request: Request) -> None: """ Set headers to guard against clickjacking of embedded content. @@ -850,7 +883,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: finish_request(request) -def finish_request(request: Request): +def finish_request(request: Request) -> None: """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 47303f5fd..a1ab44173 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,6 +14,7 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Iterable, @@ -31,6 +32,7 @@ from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder @@ -138,11 +140,15 @@ def parse_integer_from_args( return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -247,11 +253,15 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -314,7 +324,7 @@ def parse_bytes_from_args( return args[name_bytes][0] elif required: message = "Missing string query parameter %s" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM) return default @@ -408,14 +418,16 @@ def _parse_string_value( try: value_str = value.decode(encoding) except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding) + ) if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM) else: return value_str @@ -511,7 +523,9 @@ def parse_strings_from_args( else: if required: message = "Missing string query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) return default @@ -639,7 +653,7 @@ def parse_json_value_from_request( try: content_bytes = request.content.read() # type: ignore except Exception: - raise SynapseError(400, "Error reading JSON content.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.") if not content_bytes and allow_empty_body: return None @@ -648,7 +662,9 @@ def parse_json_value_from_request( content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON + ) return content @@ -674,7 +690,7 @@ def parse_json_object_from_request( if not isinstance(content, dict): message = "Content must be a JSON object." - raise SynapseError(400, message, errcode=Codes.BAD_JSON) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) return content @@ -688,7 +704,9 @@ def assert_params_in_dict( absent.append(k) if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM + ) class RestServlet: @@ -712,7 +730,7 @@ class attribute containing a pre-compiled regular expression. The automatic into the appropriate HTTP response. """ - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: """Register this servlet with the given HTTP server.""" patterns = getattr(self, "PATTERNS", None) if patterns: @@ -761,10 +779,12 @@ async def resolve_room_id( resolved_room_id = room_id.to_string() else: raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) + HTTPStatus.BAD_REQUEST, + "%s was not legal room ID or room alias" % (room_identifier,), ) if not resolved_room_id: raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier + HTTPStatus.BAD_REQUEST, + "Unknown room ID or room alias %s" % room_identifier, ) return resolved_room_id, remote_room_hosts diff --git a/synapse/http/site.py b/synapse/http/site.py index 755ad5663..80f7a2ff5 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer @@ -35,6 +35,9 @@ ) from synapse.types import Requester +if TYPE_CHECKING: + import opentracing + logger = logging.getLogger(__name__) _next_request_seq = 0 @@ -66,9 +69,9 @@ def __init__( self, channel: HTTPChannel, site: "SynapseSite", - *args, + *args: Any, max_request_body_size: int = 1024, - **kw, + **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size @@ -81,6 +84,10 @@ def __init__( # server name, for client requests this is the Requester object. self._requester: Optional[Union[Requester, str]] = None + # An opentracing span for this request. Will be closed when the request is + # completely processed. + self._opentracing_span: "Optional[opentracing.Span]" = None + # we can't yet create the logcontext, as we don't know the method. self.logcontext: Optional[LoggingContext] = None @@ -148,6 +155,13 @@ def requester(self, value: Union[Requester, str]) -> None: # If there's no authenticated entity, it was the requester. self.logcontext.request.authenticated_entity = authenticated_entity or requester + def set_opentracing_span(self, span: "opentracing.Span") -> None: + """attach an opentracing span to this request + + Doing so will cause the span to be closed when we finish processing the request + """ + self._opentracing_span = span + def get_request_id(self) -> str: return "%s-%i" % (self.get_method(), self.request_seq) @@ -286,6 +300,9 @@ async def handle_request(request): self._processing_finished_time = time.time() self._is_processing = False + if self._opentracing_span: + self._opentracing_span.log_kv({"event": "finished processing"}) + # if we've already sent the response, log it now; otherwise, we wait for the # response to be sent. if self.finish_time is not None: @@ -299,6 +316,8 @@ def finish(self) -> None: """ self.finish_time = time.time() Request.finish(self) + if self._opentracing_span: + self._opentracing_span.log_kv({"event": "response sent"}) if not self._is_processing: assert self.logcontext is not None with PreserveLoggingContext(self.logcontext): @@ -333,6 +352,11 @@ def connectionLost(self, reason: Union[Failure, Exception]) -> None: with PreserveLoggingContext(self.logcontext): logger.info("Connection from client lost before response was sent") + if self._opentracing_span: + self._opentracing_span.log_kv( + {"event": "client connection lost", "reason": str(reason.value)} + ) + if not self._is_processing: self._finished_processing() @@ -421,6 +445,10 @@ def _finished_processing(self) -> None: usage.evt_db_fetch_count, ) + # complete the opentracing span, if any. + if self._opentracing_span: + self._opentracing_span.finish() + try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: @@ -557,7 +585,7 @@ def __init__( proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest - def request_factory(channel, queued: bool) -> Request: + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 1cc07eca0..d4ee89337 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -22,20 +22,33 @@ See doc/log_contexts.rst for details on how this works. """ -import inspect import logging import threading import typing import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) import attr from typing_extensions import Literal from twisted.internet import defer, threads +from twisted.python.threadpool import ThreadPool if TYPE_CHECKING: from synapse.logging.scopecontextmanager import _LogContextScope + from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -65,7 +78,7 @@ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]": # a hook which can be set during testing to assert that we aren't abusing logcontexts. -def logcontext_error(msg: str): +def logcontext_error(msg: str) -> None: logger.warning(msg) @@ -222,22 +235,19 @@ def __init__(self) -> None: def __str__(self) -> str: return "sentinel" - def copy_to(self, record): - pass - - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def stop(self, rusage: "Optional[resource.struct_rusage]"): + def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def add_database_transaction(self, duration_sec): + def add_database_transaction(self, duration_sec: float) -> None: pass - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: pass - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: pass def __bool__(self) -> Literal[False]: @@ -378,7 +388,12 @@ def __enter__(self) -> "LoggingContext": ) return self - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: @@ -398,17 +413,6 @@ def __exit__(self, type, value, traceback) -> None: # recorded against the correct metrics. self.finished = True - def copy_to(self, record) -> None: - """Copy logging fields from this context to a log record or - another LoggingContext - """ - - # we track the current request - record.request = self.request - - # we also track the current scope: - record.scope = self.scope - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """ Record that this logcontext is currently running. @@ -625,7 +629,12 @@ def __init__( def __enter__(self) -> None: self._old_context = set_current_context(self._new_context) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: context = set_current_context(self._old_context) if context != self._new_context: @@ -710,16 +719,61 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) -def preserve_fn(f): +R = TypeVar("R") + + +@overload +def preserve_fn( # type: ignore[misc] + f: Callable[..., Awaitable[R]], +) -> Callable[..., "defer.Deferred[R]"]: + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: + ... + + +def preserve_fn( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ] +) -> Callable[..., "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" - def g(*args, **kwargs): + def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": return run_in_background(f, *args, **kwargs) return g -def run_in_background(f, *args, **kwargs) -> defer.Deferred: +@overload +def run_in_background( # type: ignore[misc] + f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def run_in_background( + f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + ... + + +def run_in_background( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -750,6 +804,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: # At this point we should have a Deferred, if not then f was a synchronous # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): + # `res` is not a `Deferred` and not a `Coroutine`. + # There are no other types of `Awaitable`s we expect to encounter in Synapse. + assert not isinstance(res, Awaitable) + return defer.succeed(res) if res.called and not res.paused: @@ -777,13 +835,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: return res -def make_deferred_yieldable(deferred): - """Given a deferred (or coroutine), make it follow the Synapse logcontext - rules: +T = TypeVar("T") + - If the deferred has completed (or is not actually a Deferred), essentially - does nothing (just returns another completed deferred with the - result/failure). +def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed, essentially does nothing (just returns another + completed deferred with the result/failure). If the deferred has not yet completed, resets the logcontext before returning a deferred. Then, when the deferred completes, restores the @@ -791,16 +850,6 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ - if inspect.isawaitable(deferred): - # If we're given a coroutine we convert it to a deferred so that we - # run it and find out if it immediately finishes, it it does then we - # don't need to fiddle with log contexts at all and can return - # immediately. - deferred = defer.ensureDeferred(deferred) - - if not isinstance(deferred, defer.Deferred): - return deferred - if deferred.called and not deferred.paused: # it looks like this deferred is ready to run any callbacks we give it # immediately. We may as well optimise out the logcontext faffery. @@ -822,7 +871,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: return result -def defer_to_thread(reactor, f, *args, **kwargs): +def defer_to_thread( + reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": """ Calls the function `f` using a thread from the reactor's default threadpool and returns the result as a Deferred. @@ -854,7 +905,13 @@ def defer_to_thread(reactor, f, *args, **kwargs): return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): +def defer_to_threadpool( + reactor: "ISynapseReactor", + threadpool: ThreadPool, + f: Callable[..., R], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles logcontexts correctly. @@ -896,7 +953,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): assert isinstance(curr_context, LoggingContext) parent_context = curr_context - def g(): + def g() -> R: with LoggingContext(str(curr_context), parent_context=parent_context): return f(*args, **kwargs) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 20d23a426..622445e9f 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -173,6 +173,7 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"): import attr from twisted.internet import defer +from twisted.web.http import Request from twisted.web.http_headers import Headers from synapse.config import ConfigError @@ -219,11 +220,12 @@ class _DummyTagNames: try: import opentracing + import opentracing.tags tags = opentracing.tags except ImportError: - opentracing = None - tags = _DummyTagNames + opentracing = None # type: ignore[assignment] + tags = _DummyTagNames # type: ignore[assignment] try: from jaeger_client import Config as JaegerConfig @@ -366,7 +368,7 @@ def init_tracer(hs: "HomeServer"): global opentracing if not hs.config.tracing.opentracer_enabled: # We don't have a tracer - opentracing = None + opentracing = None # type: ignore[assignment] return if not opentracing or not JaegerConfig: @@ -452,7 +454,7 @@ def start_active_span( """ if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] return opentracing.tracer.start_active_span( operation_name, @@ -477,7 +479,7 @@ def start_active_span_follows_from( forced, the new span will also have tracing forced. """ if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span(operation_name, references=references) @@ -490,48 +492,6 @@ def start_active_span_follows_from( return scope -def start_active_span_from_request( - request, - operation_name, - references=None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -): - """ - Extracts a span context from a Twisted Request. - args: - headers (twisted.web.http.Request) - - For the other args see opentracing.tracer - - returns: - span_context (opentracing.span.SpanContext) - """ - # Twisted encodes the values as lists whereas opentracing doesn't. - # So, we take the first item in the list. - # Also, twisted uses byte arrays while opentracing expects strings. - - if opentracing is None: - return noop_context_manager() - - header_dict = { - k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() - } - context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) - - return opentracing.tracer.start_active_span( - operation_name, - child_of=context, - references=references, - tags=tags, - start_time=start_time, - ignore_active_span=ignore_active_span, - finish_on_close=finish_on_close, - ) - - def start_active_span_from_edu( edu_content, operation_name, @@ -553,7 +513,7 @@ def start_active_span_from_edu( references = references or [] if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} @@ -594,18 +554,21 @@ def active_span(): @ensure_active_span("set a tag") def set_tag(key, value): """Sets a tag on the active span""" + assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") def log_kv(key_values, timestamp=None): """Log to the active span""" + assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @ensure_active_span("set the traces operation name") def set_operation_name(operation_name): """Sets the operation name of the active span""" + assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_operation_name(operation_name) @@ -674,6 +637,7 @@ def inject_header_dict( span = opentracing.tracer.active_span carrier: Dict[str, str] = {} + assert span is not None opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -716,6 +680,7 @@ def get_active_span_text_map(destination=None): return {} carrier: Dict[str, str] = {} + assert opentracing.tracer.active_span is not None opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier ) @@ -731,12 +696,27 @@ def active_span_context_as_string(): """ carrier: Dict[str, str] = {} if opentracing: + assert opentracing.tracer.active_span is not None opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier ) return json_encoder.encode(carrier) +def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]": + """Extract an opentracing context from the headers on an HTTP request + + This is useful when we have received an HTTP request from another part of our + system, and want to link our spans to those of the remote system. + """ + if not opentracing: + return None + header_dict = { + k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() + } + return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) + + @only_if_tracing def span_context_from_string(carrier): """ @@ -773,7 +753,7 @@ def trace(func=None, opname=None): def decorator(func): if opentracing is None: - return func + return func # type: ignore[unreachable] _opname = opname if opname else func.__name__ @@ -864,7 +844,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False): """ if opentracing is None: - yield + yield # type: ignore[unreachable] return request_tags = { @@ -876,10 +856,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False): } request_name = request.request_metrics.name - if extract_context: - scope = start_active_span_from_request(request, request_name) - else: - scope = start_active_span(request_name) + context = span_context_from_request(request) if extract_context else None + + # we configure the scope not to finish the span immediately on exit, and instead + # pass the span into the SynapseRequest, which will finish it once we've finished + # sending the response to the client. + scope = start_active_span(request_name, child_of=context, finish_on_close=False) + request.set_opentracing_span(scope.span) with scope: inject_response_headers(request.responseHeaders) diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index b1e8e08fe..db8ca2c04 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -71,7 +71,7 @@ def activate(self, span, finish_on_close): if not ctx: # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") - return Scope(None, span) + return Scope(None, span) # type: ignore[arg-type] elif ctx.scope is not None: # We want the logging scope to look exactly the same so we give it # a blank suffix diff --git a/synapse/notifier.py b/synapse/notifier.py index 60e540989..bbabdb058 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from collections import namedtuple from typing import ( Awaitable, Callable, @@ -44,7 +43,13 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig -from synapse.types import PersistedEventPosition, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + JsonDict, + PersistedEventPosition, + RoomStreamToken, + StreamToken, + UserID, +) from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -178,7 +183,12 @@ def new_listener(self, token: StreamToken) -> _NotificationListener: return _NotificationListener(self.notify_deferred.observe()) -class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventStreamResult: + events: List[Union[JsonDict, EventBase]] + start_token: StreamToken + end_token: StreamToken + def __bool__(self): return bool(self.events) @@ -582,9 +592,12 @@ async def check_for_updates( before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: if after_token == before_token: - return EventStreamResult([], (from_token, from_token)) + return EventStreamResult([], from_token, from_token) - events: List[EventBase] = [] + # The events fetched from each source are a JsonDict, EventBase, or + # UserPresenceState, but see below for UserPresenceState being + # converted to JsonDict. + events: List[Union[JsonDict, EventBase]] = [] end_token = from_token for name, source in self.event_sources.sources.get_sources(): @@ -623,7 +636,7 @@ async def check_for_updates( events.extend(new_events) end_token = end_token.copy_and_replace(keyname, new_key) - return EventStreamResult(events, (from_token, end_token)) + return EventStreamResult(events, from_token, end_token) user_id_for_stream = user.to_string() if is_peeking: diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 4f13c0418..39bb2acae 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -177,12 +177,12 @@ async def _unsafe_process(self) -> None: return for push_action in unprocessed: - received_at = push_action["received_ts"] + received_at = push_action.received_ts if received_at is None: received_at = 0 notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS - room_ready_at = self.room_ready_to_notify_at(push_action["room_id"]) + room_ready_at = self.room_ready_to_notify_at(push_action.room_id) should_notify_at = max(notif_ready_at, room_ready_at) @@ -193,23 +193,23 @@ async def _unsafe_process(self) -> None: # to be delivered. reason: EmailReason = { - "room_id": push_action["room_id"], + "room_id": push_action.room_id, "now": self.clock.time_msec(), "received_at": received_at, "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS, - "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]), - "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]), + "last_sent_ts": self.get_room_last_sent_ts(push_action.room_id), + "throttle_ms": self.get_room_throttle_ms(push_action.room_id), } await self.send_notification(unprocessed, reason) await self.save_last_stream_ordering_and_success( - max(ea["stream_ordering"] for ea in unprocessed) + max(ea.stream_ordering for ea in unprocessed) ) # we update the throttle on all the possible unprocessed push actions for ea in unprocessed: - await self.sent_notif_update_throttle(ea["room_id"], ea) + await self.sent_notif_update_throttle(ea.room_id, ea) break else: if soonest_due_at is None or should_notify_at < soonest_due_at: @@ -284,10 +284,10 @@ async def sent_notif_update_throttle( # THROTTLE_RESET_AFTER_MS after the previous one that triggered a # notif, we release the throttle. Otherwise, the throttle is increased. time_of_previous_notifs = await self.store.get_time_of_last_push_action_before( - notified_push_action["stream_ordering"] + notified_push_action.stream_ordering ) - time_of_this_notifs = notified_push_action["received_ts"] + time_of_this_notifs = notified_push_action.received_ts if time_of_previous_notifs is not None and time_of_this_notifs is not None: gap = time_of_this_notifs - time_of_previous_notifs diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 3fa603ccb..96559081d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -199,7 +199,7 @@ async def _unsafe_process(self) -> None: "http-push", tags={ "authenticated_entity": self.user_id, - "event_id": push_action["event_id"], + "event_id": push_action.event_id, "app_id": self.app_id, "app_display_name": self.app_display_name, }, @@ -209,7 +209,7 @@ async def _unsafe_process(self) -> None: if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action["stream_ordering"] + self.last_stream_ordering = push_action.stream_ordering pusher_still_exists = ( await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, @@ -252,7 +252,7 @@ async def _unsafe_process(self) -> None: self.pushkey, ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action["stream_ordering"] + self.last_stream_ordering = push_action.stream_ordering await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, @@ -275,17 +275,17 @@ async def _unsafe_process(self) -> None: break async def _process_one(self, push_action: HttpPushAction) -> bool: - if "notify" not in push_action["actions"]: + if "notify" not in push_action.actions: return True - tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) + tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( self.hs.get_datastore(), self.user_id, group_by_room=self._group_unread_count_by_room, ) - event = await self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action.event_id, allow_none=True) if event is None: return True # It's been redacted rejected = await self.dispatch_push(event, tweaks, badge) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ba4f86648..ff904c2b4 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -232,15 +232,13 @@ async def send_notification_mail( reason: The notification that was ready and is the cause of an email being sent. """ - rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) + rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions]) - notif_events = await self.store.get_events( - [pa["event_id"] for pa in push_actions] - ) + notif_events = await self.store.get_events([pa.event_id for pa in push_actions]) notifs_by_room: Dict[str, List[EmailPushAction]] = {} for pa in push_actions: - notifs_by_room.setdefault(pa["room_id"], []).append(pa) + notifs_by_room.setdefault(pa.room_id, []).append(pa) # collect the current state for all the rooms in which we have # notifications @@ -264,7 +262,7 @@ async def _fetch_room_state(room_id: str) -> None: await concurrently_execute(_fetch_room_state, rooms_in_order, 3) # actually sort our so-called rooms_in_order list, most recent room first - rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) + rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0)) rooms: List[RoomVars] = [] @@ -356,7 +354,7 @@ async def _get_room_vars( # Check if one of the notifs is an invite event for the user. is_invite = False for n in notifs: - ev = notif_events[n["event_id"]] + ev = notif_events[n.event_id] if ev.type == EventTypes.Member and ev.state_key == user_id: if ev.content.get("membership") == Membership.INVITE: is_invite = True @@ -376,7 +374,7 @@ async def _get_room_vars( if not is_invite: for n in notifs: notifvars = await self._get_notif_vars( - n, user_id, notif_events[n["event_id"]], room_state_ids + n, user_id, notif_events[n.event_id], room_state_ids ) # merge overlapping notifs together. @@ -444,15 +442,15 @@ async def _get_notif_vars( """ results = await self.store.get_events_around( - notif["room_id"], - notif["event_id"], + notif.room_id, + notif.event_id, before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER, ) ret: NotifVars = { "link": self._make_notif_link(notif), - "ts": notif["received_ts"], + "ts": notif.received_ts, "messages": [], } @@ -516,7 +514,7 @@ async def _get_message_vars( ret: MessageVars = { "event_type": event.type, - "is_historical": event.event_id != notif["event_id"], + "is_historical": event.event_id != notif.event_id, "id": event.event_id, "ts": event.origin_server_ts, "sender_name": sender_name, @@ -610,7 +608,7 @@ async def _make_summary_text_single_room( # See if one of the notifs is an invite event for the user invite_event = None for n in notifs: - ev = notif_events[n["event_id"]] + ev = notif_events[n.event_id] if ev.type == EventTypes.Member and ev.state_key == user_id: if ev.content.get("membership") == Membership.INVITE: invite_event = ev @@ -659,7 +657,7 @@ async def _make_summary_text_single_room( if len(notifs) == 1: # There is just the one notification, so give some detail sender_name = None - event = notif_events[notifs[0]["event_id"]] + event = notif_events[notifs[0].event_id] if ("m.room.member", event.sender) in room_state_ids: state_event_id = room_state_ids[("m.room.member", event.sender)] state_event = await self.store.get_event(state_event_id) @@ -753,9 +751,9 @@ async def _make_summary_text_from_member_events( # are already in descending received_ts. sender_ids = {} for n in notifs: - sender = notif_events[n["event_id"]].sender + sender = notif_events[n.event_id].sender if sender not in sender_ids: - sender_ids[sender] = n["event_id"] + sender_ids[sender] = n.event_id # Get the actual member events (in order to calculate a pretty name for # the room). @@ -830,17 +828,17 @@ def _make_notif_link(self, notif: EmailPushAction) -> str: if self.hs.config.email.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email.email_riot_base_url, - notif["room_id"], - notif["event_id"], + notif.room_id, + notif.event_id, ) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( - notif["room_id"], - notif["event_id"], + notif.room_id, + notif.event_id, ) else: - return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) + return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id) def _make_unsubscribe_link( self, user_id: str, app_id: str, email_address: str diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 7f68092ec..659a53805 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -17,9 +17,10 @@ import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from matrix_common.regex import glob_to_regex, to_word_pattern + from synapse.events import EventBase from synapse.types import JsonDict, UserID -from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -184,7 +185,7 @@ def _contains_display_name(self, display_name: Optional[str]) -> bool: r = regex_cache.get((display_name, False, True), None) if not r: r1 = re.escape(display_name) - r1 = re_word_boundary(r1) + r1 = to_word_pattern(r1) r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r @@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: try: r = regex_cache.get((glob, True, word_boundary), None) if not r: - r = glob_to_regex(glob, word_boundary) + r = glob_to_regex(glob, word_boundary=word_boundary) regex_cache[(glob, True, word_boundary)] = r return bool(r.search(value)) except re.error: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9c85200c0..957c9b780 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) badge = len(invites) @@ -36,7 +37,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - room_id, user_id, last_unread_event_id ) ) - if notifs["notify_count"] == 0: + if notifs.notify_count == 0: continue if group_by_room: @@ -44,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge += 1 else: # increment the badge count by the number of unread messages in the room - badge += notifs["notify_count"] + badge += notifs.notify_count return badge diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 26735447a..7912311d2 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -27,6 +27,7 @@ from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -113,7 +114,9 @@ async def add_pusher( """ if kind == "email": - email_owner = await self.store.get_user_id_by_threepid("email", pushkey) + email_owner = await self.store.get_user_id_by_threepid( + "email", canonicalise_email(pushkey) + ) if email_owner != user_id: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index d07440191..8f1189489 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -89,6 +89,7 @@ # with the latest security patches. "cryptography>=3.4.7", "ijson>=3.1", + "matrix-common==1.0.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 7ecb446e7..7644146db 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -27,7 +27,12 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen: Optional[ diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 61cd7e522..bc888ce1a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.lrucache import LruCache @@ -25,7 +25,12 @@ class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.client_ip_last_seen: LruCache[tuple, int] = LruCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0a5829608..a2aff75b7 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -17,7 +17,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -27,7 +27,12 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 63ed50caa..0f0837269 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, @@ -58,7 +58,12 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() @@ -75,12 +80,3 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) - - # Cached functions can't be accessed through a class instance so we need - # to reach inside the __dict__ to extract them. - - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 90284c202..4d185e2b5 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore @@ -24,7 +24,12 @@ class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 497e16c69..9d90e2637 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -17,7 +17,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -26,7 +26,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 743a01da0..5a2d90c53 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -15,7 +15,6 @@ import heapq import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -30,6 +29,7 @@ import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -226,17 +226,14 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ - BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class BackfillStreamRow: + event_id: str + room_id: str + type: str + state_key: Optional[str] + redacts: Optional[str] + relates_to: Optional[str] NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -256,18 +253,15 @@ def _current_token(self, instance_name: str) -> int: class PresenceStream(Stream): - PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PresenceStreamRow: + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: str + currently_active: bool NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -302,7 +296,7 @@ class PresenceFederationStream(Stream): send. """ - @attr.s(slots=True, auto_attribs=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceFederationStreamRow: destination: str user_id: str @@ -320,9 +314,10 @@ def __init__(self, hs: "HomeServer"): class TypingStream(Stream): - TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TypingStreamRow: + room_id: str + user_ids: List[str] NAME = "typing" ROW_TYPE = TypingStreamRow @@ -348,16 +343,13 @@ def __init__(self, hs: "HomeServer"): class ReceiptsStream(Stream): - ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ReceiptsStreamRow: + room_id: str + receipt_type: str + user_id: str + event_id: str + data: dict NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -374,7 +366,9 @@ def __init__(self, hs: "HomeServer"): class PushRulesStream(Stream): """A user has changed their push rules""" - PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushRulesStreamRow: + user_id: str NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -396,10 +390,12 @@ def _current_token(self, instance_name: str) -> int: class PushersStream(Stream): """A user has added/changed/removed a pusher""" - PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushersStreamRow: + user_id: str + app_id: str + pushkey: str + deleted: bool NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -419,7 +415,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -430,9 +426,9 @@ class CachesStreamRow: invalidation_ts: Timestamp of when the invalidation took place. """ - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) + cache_func: str + keys: Optional[List[Any]] + invalidation_ts: int NAME = "caches" ROW_TYPE = CachesStreamRow @@ -451,9 +447,9 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - entity = attr.ib(type=str) + entity: str NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -470,7 +466,9 @@ def __init__(self, hs: "HomeServer"): class ToDeviceStream(Stream): """New to_device messages for a client""" - ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceStreamRow: + entity: str NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -487,9 +485,11 @@ def __init__(self, hs: "HomeServer"): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room""" - TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TagAccountDataStreamRow: + user_id: str + room_id: str + data: JsonDict NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -506,10 +506,11 @@ def __init__(self, hs: "HomeServer"): class AccountDataStream(Stream): """Global or per room account data was changed""" - AccountDataStreamRow = namedtuple( - "AccountDataStreamRow", - ("user_id", "room_id", "data_type"), # str # Optional[str] # str - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class AccountDataStreamRow: + user_id: str + room_id: Optional[str] + data_type: str NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -573,10 +574,12 @@ async def _update_function( class GroupServerStream(Stream): - GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class GroupsStreamRow: + group_id: str + user_id: str + type: str + content: JsonDict NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -593,7 +596,9 @@ def __init__(self, hs: "HomeServer"): class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key""" - UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class UserSignatureStreamRow: + user_id: str NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 0600cdbf3..4046bdec6 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -12,14 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple +import attr + from synapse.replication.tcp.streams._base import ( Stream, current_token_without_instance, make_http_update_function, ) +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,13 +32,10 @@ class FederationStream(Stream): sending disabled. """ - FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class FederationStreamRow: + type: str # the type of data as defined in the BaseFederationRows + data: JsonDict # serialization of a federation.send_queue.BaseFederationRow NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c499afd4b..465e06772 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -69,6 +69,7 @@ from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( + AccountDataRestServlet, AccountValidityRenewServlet, DeactivateAccountRestServlet, PushersRestServlet, @@ -108,7 +109,7 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P[^/]*)(/(?P[^/]+))?" + "/purge_history/(?P[^/]*)(/(?P[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -195,7 +196,7 @@ async def on_POST( class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() @@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) + AccountDataRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) ShadowBanRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 479672d4d..6ec00ce0b 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -22,7 +22,7 @@ parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict if TYPE_CHECKING: @@ -41,8 +41,7 @@ def __init__(self, hs: "HomeServer"): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -51,8 +50,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) @@ -84,8 +82,7 @@ def __init__(self, hs: "HomeServer"): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -111,15 +108,14 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class BackgroundUpdateStartJobRestServlet(RestServlet): """Allows to start specific background updates""" - PATTERNS = admin_patterns("/background_updates/start_job") + PATTERNS = admin_patterns("/background_updates/start_job$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastore() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["job_name"]) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 2e5a6600d..d9905ff56 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -53,7 +53,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -63,6 +63,8 @@ async def on_GET( device = await self.device_handler.get_device( target_user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return HTTPStatus.OK, device async def on_DELETE( @@ -71,7 +73,7 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -87,7 +89,7 @@ async def on_PUT( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -124,7 +122,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -155,7 +153,7 @@ async def on_POST( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5ee8b1111..38477f8ea 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 744687be3..50d88c910 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet): 200 OK with details of a destination if success otherwise an error. """ - PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index a27110388..cd697e180 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -30,7 +30,7 @@ class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 9e23e2d8f..7236e4027 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,7 +17,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P[^/]+)"), + *admin_patterns("/quarantine_media/(?P[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P[^/]+)/(?P[^/]+)" + "/media/quarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P[^/]+)/(?P[^/]+)" + "/media/unquarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -138,8 +138,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -154,7 +153,7 @@ async def on_POST( class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -163,8 +162,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) @@ -177,7 +175,7 @@ async def on_POST( class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -186,8 +184,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) @@ -200,7 +197,7 @@ async def on_POST( class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -209,10 +206,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) @@ -254,7 +248,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") + PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -403,16 +397,7 @@ async def on_GET( request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -470,16 +455,7 @@ async def on_DELETE( request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 891b98c08..04948b640 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 829e86675..6030373eb 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet): If 'purge' is true, it will remove all traces of a room from the database. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -123,7 +123,7 @@ async def on_DELETE( class DeleteRoomStatusByRoomIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/delete_status$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -160,7 +160,7 @@ async def on_GET( class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -193,35 +193,17 @@ def __init__(self, hs: "HomeServer"): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": @@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -436,8 +415,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -454,14 +432,14 @@ async def on_GET( class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P[^/]*)") + PATTERNS = admin_patterns("/join/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -477,7 +455,7 @@ async def on_POST( assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "This endpoint can only be used with local users", @@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms//forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -710,8 +685,7 @@ async def on_DELETE( async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -771,13 +745,19 @@ async def on_GET( time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now + results["events_before"], + time_now, + bundle_aggregations=True, ) results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now + results["event"], + time_now, + bundle_aggregations=True, ) results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now + results["events_after"], + time_now, + bundle_aggregations=True, ) results["state"] = await self._event_serializer.serialize_events( results["state"], time_now @@ -793,7 +773,7 @@ class BlockRoomRestServlet(RestServlet): On GET: Get blocking status of room and user who has blocked this room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/block$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/block$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b295fb078..15da9cd88 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -88,7 +88,7 @@ async def on_POST( ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" ) diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index ca41fd45f..7a6546372 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,19 +44,16 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf147296..5353dc368 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2a60b602b..78e795c34 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -126,7 +125,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -575,7 +574,7 @@ async def on_GET( if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) @@ -584,7 +583,7 @@ async def on_GET( class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -740,7 +737,7 @@ async def on_GET( # if not is_admin and target_user != auth_user: # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) @@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -790,7 +787,7 @@ async def on_GET( target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -813,7 +810,7 @@ async def on_PUT( assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -921,7 +918,7 @@ async def on_POST( await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" ) @@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet): {} """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1001,7 +998,7 @@ async def on_DELETE( ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): @@ -1068,7 +1065,7 @@ async def on_POST( ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) @@ -1113,7 +1110,7 @@ async def on_DELETE( ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) @@ -1124,3 +1121,33 @@ async def on_DELETE( await self.store.delete_ratelimit_for_user(user_id) return HTTPStatus.OK, {} + + +class AccountDataRestServlet(RestServlet): + """Retrieve the given user's account data""" + + PATTERNS = admin_patterns("/users/(?P[^/]*)/accountdata") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._is_mine_id = hs.is_mine_id + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not self._is_mine_id(user_id): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") + + if not await self._store.get_user_by_id(user_id): + raise NotFoundError("User not found") + + global_data, by_room_data = await self._store.get_account_data_for_user(user_id) + return HTTPStatus.OK, { + "account_data": { + "global": global_data, + "rooms": by_room_data, + }, + } diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 2a3e24ae7..5c0e3a568 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -73,6 +73,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "enabled": self.config.registration.enable_3pid_changes } + if self.config.experimental.msc3440_enabled: + response["capabilities"]["io.element.thread"] = {"enabled": True} + return 200, response diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8566dc5cb..ad6fd6492 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.api.errors import NotFoundError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -24,10 +25,9 @@ parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict -from ._base import client_patterns, interactive_auth_handler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -116,6 +116,8 @@ async def on_GET( device = await self.device_handler.get_device( requester.user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return 200, device @interactive_auth_handler diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c..acd0c9e13 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,10 +55,10 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) - notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_event_ids = [pa.event_id for pa in push_actions] notif_events = await self.store.get_events(notif_event_ids) returned_push_actions = [] @@ -66,30 +67,30 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: for pa in push_actions: returned_pa = { - "room_id": pa["room_id"], - "profile_tag": pa["profile_tag"], - "actions": pa["actions"], - "ts": pa["received_ts"], + "room_id": pa.room_id, + "profile_tag": pa.profile_tag, + "actions": pa.actions, + "ts": pa.received_ts, "event": ( await self._event_serializer.serialize_event( - notif_events[pa["event_id"]], + notif_events[pa.event_id], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, ) ), } - if pa["room_id"] not in receipts_by_room: + if pa.room_id not in receipts_by_room: returned_pa["read"] = False else: - receipt = receipts_by_room[pa["room_id"]] + receipt = receipts_by_room[pa.room_id] returned_pa["read"] = ( receipt["topological_ordering"], receipt["stream_ordering"], - ) >= (pa["topological_ordering"], pa["stream_ordering"]) + ) >= (pa.topological_ordering, pa.stream_ordering) returned_push_actions.append(returned_pa) - next_token = str(pa["stream_ordering"]) + next_token = str(pa.stream_ordering) return 200, {"notifications": returned_push_actions, "next_token": next_token} diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6..f51be511d 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,7 +48,7 @@ async def on_POST( await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): @@ -62,7 +62,7 @@ async def on_POST( if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 2b25b9aad..b24ad2d1b 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer @@ -53,7 +53,7 @@ async def on_POST( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c..5815650ee 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ async def on_GET( pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -231,7 +232,9 @@ async def on_GET( ) # The relations returned for the requested event do include their # bundled aggregations. - serialized_events = await self._event_serializer.serialize_events(events, now) + serialized_events = await self._event_serializer.serialize_events( + events, now, bundle_aggregations=True + ) return_value = pagination_chunk.to_dict() return_value["chunk"] = serialized_events @@ -317,6 +320,7 @@ async def on_GET( pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -383,7 +387,9 @@ async def on_GET( # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -402,6 +408,7 @@ async def on_GET( result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8de7d2220..5db643fa3 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -187,7 +187,7 @@ async def on_PUT( state_key: str, txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if txn_id: set_tag("txn_id", txn_id) @@ -662,7 +662,9 @@ async def on_GET( time_now = self.clock.time_msec() if event: - event_dict = await self._event_serializer.serialize_event(event, time_now) + event_dict = await self._event_serializer.serialize_event( + event, time_now, bundle_aggregations=True + ) return 200, event_dict raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -707,13 +709,13 @@ async def on_GET( time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now + results["events_before"], time_now, bundle_aggregations=True ) results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now + results["event"], time_now, bundle_aggregations=True ) results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now + results["events_after"], time_now, bundle_aggregations=True ) results["state"] = await self._event_serializer.serialize_events( results["state"], time_now diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 7f5846d38..e99a943d0 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import trace from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -222,6 +223,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: logger.debug("Event formatting complete") return 200, response_content + @trace(opname="sync.encode_response") async def encode_response( self, time_now: int, @@ -293,6 +295,9 @@ async def encode_response( response[ "org.matrix.msc2732.device_unused_fallback_key_types" ] = sync_result.device_unused_fallback_key_types + response[ + "device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined @@ -329,6 +334,7 @@ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict: ] } + @trace(opname="sync.encode_joined") async def encode_joined( self, rooms: List[JoinedSyncResult], @@ -365,6 +371,7 @@ async def encode_joined( return joined + @trace(opname="sync.encode_invited") async def encode_invited( self, rooms: List[InvitedSyncResult], @@ -403,6 +410,7 @@ async def encode_invited( return invited + @trace(opname="sync.encode_knocked") async def encode_knocked( self, rooms: List[KnockedSyncResult], @@ -457,6 +465,7 @@ async def encode_knocked( return knocked + @trace(opname="sync.encode_archived") async def encode_archived( self, rooms: List[ArchivedSyncResult], diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 07f605fa6..4e5b95327 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -96,6 +96,10 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]: "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving hidden read receipts as per MSC2285 "org.matrix.msc2285": self.config.experimental.msc2285_enabled, + # Adds support for importing historical messages as per MSC2716 + "org.matrix.msc2716": self.config.experimental.msc2716_enabled, + # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 + "org.matrix.msc3030": self.config.experimental.msc3030_enabled, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 12b3ae120..b9bfbea21 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -99,7 +99,7 @@ def response_json_object(self) -> JsonDict: json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> int: + def render_GET(self, request: Request) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 244ba261b..71b9a34b1 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -739,14 +739,21 @@ async def _generate_thumbnails( # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails: Dict[Tuple[int, int, str], str] = {} - for r_width, r_height, r_method, r_type in requirements: - if r_method == "crop": - thumbnails.setdefault((r_width, r_height, r_type), r_method) - elif r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, r_type)] = r_method + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.items(): diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2a59552c2..cce1527ed 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -17,6 +17,7 @@ import attr +from synapse.rest.media.v1.preview_html import parse_html_description from synapse.types import JsonDict from synapse.util import json_decoder @@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> if video_urls: open_graph_response["og:video"] = video_urls[0] - from synapse.rest.media.v1.preview_url_resource import _calc_description - - description = _calc_description(tree) + description = parse_html_description(tree) if description: open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py new file mode 100644 index 000000000..30b067dd4 --- /dev/null +++ b/synapse/rest/media/v1/preview_html.py @@ -0,0 +1,397 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import itertools +import logging +import re +from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union +from urllib import parse as urlparse + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def parse_html_to_open_graph( + tree: "etree.Element", media_uri: str +) -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + media_url: The URI used to download the body. + + Returns: + The Open Graph response as a dictionary. + """ + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og: Dict[str, Optional[str]] = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + og[tag.attrib["property"]] = tag.attrib["content"] + + # TODO: grab article: meta tags too, e.g.: + + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + if "og:title" not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + if title and title[0].text is not None: + og["og:title"] = title[0].text.strip() + else: + og["og:title"] = None + + if "og:image" not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) + if meta_image: + og["og:image"] = rebase_url(meta_image[0], media_uri) + else: + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + if not images: + images = tree.xpath("//img[@src]") + if images: + og["og:image"] = images[0].attrib["src"] + + if "og:description" not in og: + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content" + ) + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the tag, unless they are within + an HTML5 semantic markup tag (
,