diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 6c3a99849925..0dfab4e087cf 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -8,6 +8,7 @@
- Use markdown where necessary, mostly for `code blocks`.
- End with either a period (.) or an exclamation mark (!).
- Start with a capital letter.
+ - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry.
* [ ] Pull request includes a [sign off](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#sign-off)
* [ ] [Code style](https://matrix-org.github.io/synapse/latest/code_style.html) is correct
(run the [linters](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 21c9ee7823c7..4f58069702dc 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -76,7 +76,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
database: ["sqlite"]
toxenv: ["py"]
include:
@@ -85,9 +85,9 @@ jobs:
toxenv: "py-noextras"
# Oldest Python with PostgreSQL
- - python-version: "3.6"
+ - python-version: "3.7"
database: "postgres"
- postgres-version: "9.6"
+ postgres-version: "10"
toxenv: "py"
# Newest Python with newest PostgreSQL
@@ -167,7 +167,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["pypy-3.6"]
+ python-version: ["pypy-3.7"]
steps:
- uses: actions/checkout@v2
@@ -291,8 +291,8 @@ jobs:
strategy:
matrix:
include:
- - python-version: "3.6"
- postgres-version: "9.6"
+ - python-version: "3.7"
+ postgres-version: "10"
- python-version: "3.10"
postgres-version: "14"
@@ -366,6 +366,8 @@ jobs:
# Build initial Synapse image
- run: docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
working-directory: synapse
+ env:
+ DOCKER_BUILDKIT: 1
# Build a ready-to-run Synapse image based on the initial image above.
# This new image includes a config file, keys for signing and TLS, and
@@ -374,7 +376,8 @@ jobs:
working-directory: complement/dockerfiles
# Run Complement
- - run: go test -v -tags synapse_blacklist,msc2403 ./tests/...
+ - run: set -o pipefail && go test -v -json -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt
+ shell: bash
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement
diff --git a/.gitignore b/.gitignore
index fe137f337019..3bd6b1a08c57 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,3 +50,7 @@ __pycache__/
# docs
book/
+
+# complement
+/complement-*
+/master.tar.gz
diff --git a/CHANGES.md b/CHANGES.md
index 9f6e29631df6..d3dcd0ed2ef4 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,228 @@
+Synapse 1.51.0 (2022-01-25)
+===========================
+
+No significant changes since 1.51.0rc2.
+
+Synapse 1.51.0 deprecates `webclient` listeners and non-HTTP(S) `web_client_location`s. Support for these will be removed in Synapse 1.53.0, at which point Synapse will not be capable of directly serving a web client for Matrix.
+
+Synapse 1.51.0rc2 (2022-01-24)
+==============================
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse 1.40.0 that caused Synapse to fail to process incoming federation traffic after handling a large amount of events in a v1 room. ([\#11806](https://github.com/matrix-org/synapse/issues/11806))
+
+
+Synapse 1.51.0rc1 (2022-01-21)
+==============================
+
+Features
+--------
+
+- Add `track_puppeted_user_ips` config flag to record client IP addresses against puppeted users, and include the puppeted users in monthly active user counts. ([\#11561](https://github.com/matrix-org/synapse/issues/11561), [\#11749](https://github.com/matrix-org/synapse/issues/11749), [\#11757](https://github.com/matrix-org/synapse/issues/11757))
+- Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). ([\#11577](https://github.com/matrix-org/synapse/issues/11577))
+- Return an `M_FORBIDDEN` error code instead of `M_UNKNOWN` when a spam checker module prevents a user from creating a room. ([\#11672](https://github.com/matrix-org/synapse/issues/11672))
+- Add a flag to the `synapse_review_recent_signups` script to ignore and filter appservice users. ([\#11675](https://github.com/matrix-org/synapse/issues/11675), [\#11770](https://github.com/matrix-org/synapse/issues/11770))
+
+
+Bugfixes
+--------
+
+- Fix a long-standing issue which could cause Synapse to incorrectly accept data in the unsigned field of events
+ received over federation. ([\#11530](https://github.com/matrix-org/synapse/issues/11530))
+- Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices. ([\#11587](https://github.com/matrix-org/synapse/issues/11587))
+- Fix an error that occurs whilst trying to get the federation status of a destination server that was working normally. This admin API was newly introduced in Synapse v1.49.0. ([\#11593](https://github.com/matrix-org/synapse/issues/11593))
+- Fix bundled aggregations not being included in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11612](https://github.com/matrix-org/synapse/issues/11612), [\#11659](https://github.com/matrix-org/synapse/issues/11659), [\#11791](https://github.com/matrix-org/synapse/issues/11791))
+- Fix the `/_matrix/client/v1/room/{roomId}/hierarchy` endpoint returning incorrect fields which have been present since Synapse 1.49.0. ([\#11667](https://github.com/matrix-org/synapse/issues/11667))
+- Fix preview of some GIF URLs (like tenor.com). Contributed by Philippe Daouadi. ([\#11669](https://github.com/matrix-org/synapse/issues/11669))
+- Fix a bug where only the first 50 rooms from a space were returned from the `/hierarchy` API. This has existed since the introduction of the API in Synapse v1.41.0. ([\#11695](https://github.com/matrix-org/synapse/issues/11695))
+- Fix a bug introduced in Synapse v1.18.0 where password reset and address validation emails would not be sent if their subject was configured to use the 'app' template variable. Contributed by @br4nnigan. ([\#11710](https://github.com/matrix-org/synapse/issues/11710), [\#11745](https://github.com/matrix-org/synapse/issues/11745))
+- Make the 'List Rooms' Admin API sort stable. Contributed by Daniël Sonck. ([\#11737](https://github.com/matrix-org/synapse/issues/11737))
+- Fix a long-standing bug where space hierarchy over federation would only work correctly some of the time. ([\#11775](https://github.com/matrix-org/synapse/issues/11775))
+- Fix a bug introduced in Synapse v1.46.0 that prevented `on_logged_out` module callbacks from being correctly awaited by Synapse. ([\#11786](https://github.com/matrix-org/synapse/issues/11786))
+
+
+Improved Documentation
+----------------------
+
+- Warn against using a Let's Encrypt certificate for TLS/DTLS TURN server client connections, and suggest using ZeroSSL certificate instead. This works around client-side connectivity errors caused by WebRTC libraries that reject Let's Encrypt certificates. Contibuted by @AndrewFerr. ([\#11686](https://github.com/matrix-org/synapse/issues/11686))
+- Document the new `SYNAPSE_TEST_PERSIST_SQLITE_DB` environment variable in the contributing guide. ([\#11715](https://github.com/matrix-org/synapse/issues/11715))
+- Document that the minimum supported PostgreSQL version is now 10. ([\#11725](https://github.com/matrix-org/synapse/issues/11725))
+- Fix typo in demo docs: differnt. ([\#11735](https://github.com/matrix-org/synapse/issues/11735))
+- Update room spec URL in config files. ([\#11739](https://github.com/matrix-org/synapse/issues/11739))
+- Mention `python3-venv` and `libpq-dev` dependencies in the contribution guide. ([\#11740](https://github.com/matrix-org/synapse/issues/11740))
+- Update documentation for configuring login with Facebook. ([\#11755](https://github.com/matrix-org/synapse/issues/11755))
+- Update installation instructions to note that Python 3.6 is no longer supported. ([\#11781](https://github.com/matrix-org/synapse/issues/11781))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove the unstable `/send_relation` endpoint. ([\#11682](https://github.com/matrix-org/synapse/issues/11682))
+- Remove `python_twisted_reactor_pending_calls` Prometheus metric. ([\#11724](https://github.com/matrix-org/synapse/issues/11724))
+- Remove the `password_hash` field from the response dictionaries of the [Users Admin API](https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html). ([\#11576](https://github.com/matrix-org/synapse/issues/11576))
+- **Deprecate support for `webclient` listeners and non-HTTP(S) `web_client_location` configuration. ([\#11774](https://github.com/matrix-org/synapse/issues/11774), [\#11783](https://github.com/matrix-org/synapse/issues/11783))**
+
+
+Internal Changes
+----------------
+
+- Run `pyupgrade --py37-plus --keep-percent-format` on Synapse. ([\#11685](https://github.com/matrix-org/synapse/issues/11685))
+- Use buildkit's cache feature to speed up docker builds. ([\#11691](https://github.com/matrix-org/synapse/issues/11691))
+- Use `auto_attribs` and native type hints for attrs classes. ([\#11692](https://github.com/matrix-org/synapse/issues/11692), [\#11768](https://github.com/matrix-org/synapse/issues/11768))
+- Remove debug logging for #4422, which has been closed since Synapse 0.99. ([\#11693](https://github.com/matrix-org/synapse/issues/11693))
+- Remove fallback code for Python 2. ([\#11699](https://github.com/matrix-org/synapse/issues/11699))
+- Add a test for [an edge case](https://github.com/matrix-org/synapse/pull/11532#discussion_r769104461) in the `/sync` logic. ([\#11701](https://github.com/matrix-org/synapse/issues/11701))
+- Add the option to write SQLite test dbs to disk when running tests. ([\#11702](https://github.com/matrix-org/synapse/issues/11702))
+- Improve Complement test output for Gitub Actions. ([\#11707](https://github.com/matrix-org/synapse/issues/11707))
+- Fix docstring on `add_account_data_for_user`. ([\#11716](https://github.com/matrix-org/synapse/issues/11716))
+- Complement environment variable name change and update `.gitignore`. ([\#11718](https://github.com/matrix-org/synapse/issues/11718))
+- Simplify calculation of Prometheus metrics for garbage collection. ([\#11723](https://github.com/matrix-org/synapse/issues/11723))
+- Improve accuracy of `python_twisted_reactor_tick_time` Prometheus metric. ([\#11724](https://github.com/matrix-org/synapse/issues/11724), [\#11771](https://github.com/matrix-org/synapse/issues/11771))
+- Minor efficiency improvements when inserting many values into the database. ([\#11742](https://github.com/matrix-org/synapse/issues/11742))
+- Invite PR authors to give themselves credit in the changelog. ([\#11744](https://github.com/matrix-org/synapse/issues/11744))
+- Add optional debugging to investigate [issue 8631](https://github.com/matrix-org/synapse/issues/8631). ([\#11760](https://github.com/matrix-org/synapse/issues/11760))
+- Remove `log_function` utility function and its uses. ([\#11761](https://github.com/matrix-org/synapse/issues/11761))
+- Add a unit test that checks both `client` and `webclient` resources will function when simultaneously enabled. ([\#11765](https://github.com/matrix-org/synapse/issues/11765))
+- Allow overriding complement commit using `COMPLEMENT_REF`. ([\#11766](https://github.com/matrix-org/synapse/issues/11766))
+- Add some comments and type annotations for `_update_outliers_txn`. ([\#11776](https://github.com/matrix-org/synapse/issues/11776))
+
+
+Synapse 1.50.1 (2022-01-18)
+===========================
+
+This release fixes a bug in Synapse 1.50.0 that could prevent clients from being able to connect to Synapse if the `webclient` resource was enabled. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse 1.50.0rc1 that could cause Matrix clients to be unable to connect to Synapse instances with the `webclient` resource enabled. ([\#11764](https://github.com/matrix-org/synapse/issues/11764))
+
+
+Synapse 1.50.0 (2022-01-18)
+===========================
+
+**This release contains a critical bug that may prevent clients from being able to connect.
+As such, it is not recommended to upgrade to 1.50.0. Instead, please upgrade straight to
+to 1.50.1. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).**
+
+Please note that we now only support Python 3.7+ and PostgreSQL 10+ (if applicable), because Python 3.6 and PostgreSQL 9.6 have reached end-of-life.
+
+No significant changes since 1.50.0rc2.
+
+
+Synapse 1.50.0rc2 (2022-01-14)
+==============================
+
+This release candidate fixes a federation-breaking regression introduced in Synapse 1.50.0rc1.
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse v1.0.0 whereby some device list updates would not be sent to remote homeservers if there were too many to send at once. ([\#11729](https://github.com/matrix-org/synapse/issues/11729))
+- Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. ([\#11730](https://github.com/matrix-org/synapse/issues/11730))
+
+
+Improved Documentation
+----------------------
+
+- Document that now the minimum supported PostgreSQL version is 10. ([\#11725](https://github.com/matrix-org/synapse/issues/11725))
+
+
+Internal Changes
+----------------
+
+- Fix a typechecker problem related to our (ab)use of `nacl.signing.SigningKey`s. ([\#11714](https://github.com/matrix-org/synapse/issues/11714))
+
+
+Synapse 1.50.0rc1 (2022-01-05)
+==============================
+
+
+Features
+--------
+
+- Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419). ([\#11378](https://github.com/matrix-org/synapse/issues/11378))
+- Add experimental support for part of [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): allowing application services to masquerade as specific devices. ([\#11538](https://github.com/matrix-org/synapse/issues/11538))
+- Add admin API to get users' account data. ([\#11664](https://github.com/matrix-org/synapse/issues/11664))
+- Include the room topic in the stripped state included with invites and knocking. ([\#11666](https://github.com/matrix-org/synapse/issues/11666))
+- Send and handle cross-signing messages using the stable prefix. ([\#10520](https://github.com/matrix-org/synapse/issues/10520))
+- Support unprefixed versions of fallback key property names. ([\#11541](https://github.com/matrix-org/synapse/issues/11541))
+
+
+Bugfixes
+--------
+
+- Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. ([\#11516](https://github.com/matrix-org/synapse/issues/11516))
+- Fix a long-standing bug which could cause `AssertionError`s to be written to the log when Synapse was restarted after purging events from the database. ([\#11536](https://github.com/matrix-org/synapse/issues/11536), [\#11642](https://github.com/matrix-org/synapse/issues/11642))
+- Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. ([\#11547](https://github.com/matrix-org/synapse/issues/11547))
+- Fix a long-standing bug where responses included bundled aggregations when they should not, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11592](https://github.com/matrix-org/synapse/issues/11592), [\#11623](https://github.com/matrix-org/synapse/issues/11623))
+- Fix a long-standing bug that some unknown endpoints would return HTML error pages instead of JSON `M_UNRECOGNIZED` errors. ([\#11602](https://github.com/matrix-org/synapse/issues/11602))
+- Fix a bug introduced in Synapse 1.19.3 which could sometimes cause `AssertionError`s when backfilling rooms over federation. ([\#11632](https://github.com/matrix-org/synapse/issues/11632))
+
+
+Improved Documentation
+----------------------
+
+- Update Synapse install command for FreeBSD as the package is now prefixed with `py38`. Contributed by @itchychips. ([\#11267](https://github.com/matrix-org/synapse/issues/11267))
+- Document the usage of refresh tokens. ([\#11427](https://github.com/matrix-org/synapse/issues/11427))
+- Add details for how to configure a TURN server when behind a NAT. Contibuted by @AndrewFerr. ([\#11553](https://github.com/matrix-org/synapse/issues/11553))
+- Add references for using Postgres to the Docker documentation. ([\#11640](https://github.com/matrix-org/synapse/issues/11640))
+- Fix the documentation link in newly-generated configuration files. ([\#11678](https://github.com/matrix-org/synapse/issues/11678))
+- Correct the documentation for `nginx` to use a case-sensitive url pattern. Fixes an error introduced in v1.21.0. ([\#11680](https://github.com/matrix-org/synapse/issues/11680))
+- Clarify SSO mapping provider documentation by writing `def` or `async def` before the names of methods, as appropriate. ([\#11681](https://github.com/matrix-org/synapse/issues/11681))
+
+
+Deprecations and Removals
+-------------------------
+
+- Replace `mock` package by its standard library version. ([\#11588](https://github.com/matrix-org/synapse/issues/11588))
+- Drop support for Python 3.6 and Ubuntu 18.04. ([\#11633](https://github.com/matrix-org/synapse/issues/11633))
+
+
+Internal Changes
+----------------
+
+- Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#11243](https://github.com/matrix-org/synapse/issues/11243))
+- A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. ([\#11331](https://github.com/matrix-org/synapse/issues/11331))
+- Add type hints to `synapse.appservice`. ([\#11360](https://github.com/matrix-org/synapse/issues/11360))
+- Add missing type hints to `synapse.config` module. ([\#11480](https://github.com/matrix-org/synapse/issues/11480))
+- Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. ([\#11487](https://github.com/matrix-org/synapse/issues/11487))
+- Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. ([\#11503](https://github.com/matrix-org/synapse/issues/11503))
+- Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. ([\#11505](https://github.com/matrix-org/synapse/issues/11505), [\#11687](https://github.com/matrix-org/synapse/issues/11687))
+- Use `HTTPStatus` constants in place of literals in `tests.rest.client.test_auth`. ([\#11520](https://github.com/matrix-org/synapse/issues/11520))
+- Add a receipt types constant for `m.read`. ([\#11531](https://github.com/matrix-org/synapse/issues/11531))
+- Clean up `synapse.rest.admin`. ([\#11535](https://github.com/matrix-org/synapse/issues/11535))
+- Add missing `errcode` to `parse_string` and `parse_boolean`. ([\#11542](https://github.com/matrix-org/synapse/issues/11542))
+- Use `HTTPStatus` constants in place of literals in `synapse.http`. ([\#11543](https://github.com/matrix-org/synapse/issues/11543))
+- Add missing type hints to storage classes. ([\#11546](https://github.com/matrix-org/synapse/issues/11546), [\#11549](https://github.com/matrix-org/synapse/issues/11549), [\#11551](https://github.com/matrix-org/synapse/issues/11551), [\#11555](https://github.com/matrix-org/synapse/issues/11555), [\#11575](https://github.com/matrix-org/synapse/issues/11575), [\#11589](https://github.com/matrix-org/synapse/issues/11589), [\#11594](https://github.com/matrix-org/synapse/issues/11594), [\#11652](https://github.com/matrix-org/synapse/issues/11652), [\#11653](https://github.com/matrix-org/synapse/issues/11653), [\#11654](https://github.com/matrix-org/synapse/issues/11654), [\#11657](https://github.com/matrix-org/synapse/issues/11657))
+- Fix an inaccurate and misleading comment in the `/sync` code. ([\#11550](https://github.com/matrix-org/synapse/issues/11550))
+- Add missing type hints to `synapse.logging.context`. ([\#11556](https://github.com/matrix-org/synapse/issues/11556))
+- Stop populating unused database column `state_events.prev_state`. ([\#11558](https://github.com/matrix-org/synapse/issues/11558))
+- Minor efficiency improvements in event persistence. ([\#11560](https://github.com/matrix-org/synapse/issues/11560))
+- Add some safety checks that storage functions are used correctly. ([\#11564](https://github.com/matrix-org/synapse/issues/11564), [\#11580](https://github.com/matrix-org/synapse/issues/11580))
+- Make `get_device` return `None` if the device doesn't exist rather than raising an exception. ([\#11565](https://github.com/matrix-org/synapse/issues/11565))
+- Split the HTML parsing code from the URL preview resource code. ([\#11566](https://github.com/matrix-org/synapse/issues/11566))
+- Remove redundant `COALESCE()`s around `COUNT()`s in database queries. ([\#11570](https://github.com/matrix-org/synapse/issues/11570))
+- Add missing type hints to `synapse.http`. ([\#11571](https://github.com/matrix-org/synapse/issues/11571))
+- Add [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) and [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) to `/versions` -> `unstable_features` to detect server support. ([\#11582](https://github.com/matrix-org/synapse/issues/11582))
+- Add type hints to `synapse/tests/rest/admin`. ([\#11590](https://github.com/matrix-org/synapse/issues/11590))
+- Drop end-of-life Python 3.6 and Postgres 9.6 from CI. ([\#11595](https://github.com/matrix-org/synapse/issues/11595))
+- Update black version and run it on all the files. ([\#11596](https://github.com/matrix-org/synapse/issues/11596))
+- Add opentracing type stubs and fix associated mypy errors. ([\#11603](https://github.com/matrix-org/synapse/issues/11603), [\#11622](https://github.com/matrix-org/synapse/issues/11622))
+- Improve OpenTracing support for requests which use a `ResponseCache`. ([\#11607](https://github.com/matrix-org/synapse/issues/11607))
+- Improve OpenTracing support for incoming HTTP requests. ([\#11618](https://github.com/matrix-org/synapse/issues/11618))
+- A number of improvements to opentracing support. ([\#11619](https://github.com/matrix-org/synapse/issues/11619))
+- Refactor the way that the `outlier` flag is set on events received over federation. ([\#11634](https://github.com/matrix-org/synapse/issues/11634))
+- Improve the error messages from `get_create_event_for_room`. ([\#11638](https://github.com/matrix-org/synapse/issues/11638))
+- Remove redundant `get_current_events_token` method. ([\#11643](https://github.com/matrix-org/synapse/issues/11643))
+- Convert `namedtuples` to `attrs`. ([\#11665](https://github.com/matrix-org/synapse/issues/11665), [\#11574](https://github.com/matrix-org/synapse/issues/11574))
+- Update the `/capabilities` response to include whether support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) is available. ([\#11690](https://github.com/matrix-org/synapse/issues/11690))
+- Send the `Accept` header in HTTP requests made using `SimpleHttpClient.get_json`. ([\#11677](https://github.com/matrix-org/synapse/issues/11677))
+- Work around Mjolnir compatibility issue by adding an import for `glob_to_regex` in `synapse.util`, where it moved from. ([\#11696](https://github.com/matrix-org/synapse/issues/11696))
+
+
Synapse 1.49.2 (2021-12-21)
===========================
diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml
index 26d640c44887..5ac41139e31d 100644
--- a/contrib/docker/docker-compose.yml
+++ b/contrib/docker/docker-compose.yml
@@ -14,6 +14,7 @@ services:
# failure
restart: unless-stopped
# See the readme for a full documentation of the environment settings
+ # NOTE: You must edit homeserver.yaml to use postgres, it defaults to sqlite
environment:
- SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
volumes:
diff --git a/contrib/prometheus/consoles/synapse.html b/contrib/prometheus/consoles/synapse.html
index cd9ad15231fc..d17c8a08d9e3 100644
--- a/contrib/prometheus/consoles/synapse.html
+++ b/contrib/prometheus/consoles/synapse.html
@@ -92,22 +92,6 @@
Average reactor tick time
})
-
Pending calls per tick
-
-
-
Storage
Queries
diff --git a/debian/changelog b/debian/changelog
index ebe3e0cbf958..7b51a0b2d6fa 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,45 @@
+matrix-synapse-py3 (1.51.0) stable; urgency=medium
+
+ * New synapse release 1.51.0.
+
+ -- Synapse Packaging team Tue, 25 Jan 2022 11:28:51 +0000
+
+matrix-synapse-py3 (1.51.0~rc2) stable; urgency=medium
+
+ * New synapse release 1.51.0~rc2.
+
+ -- Synapse Packaging team Mon, 24 Jan 2022 12:25:00 +0000
+
+matrix-synapse-py3 (1.51.0~rc1) stable; urgency=medium
+
+ * New synapse release 1.51.0~rc1.
+
+ -- Synapse Packaging team Fri, 21 Jan 2022 10:46:02 +0000
+
+matrix-synapse-py3 (1.50.1) stable; urgency=medium
+
+ * New synapse release 1.50.1.
+
+ -- Synapse Packaging team Tue, 18 Jan 2022 16:06:26 +0000
+
+matrix-synapse-py3 (1.50.0) stable; urgency=medium
+
+ * New synapse release 1.50.0.
+
+ -- Synapse Packaging team Tue, 18 Jan 2022 10:40:38 +0000
+
+matrix-synapse-py3 (1.50.0~rc2) stable; urgency=medium
+
+ * New synapse release 1.50.0~rc2.
+
+ -- Synapse Packaging team Fri, 14 Jan 2022 11:18:06 +0000
+
+matrix-synapse-py3 (1.50.0~rc1) stable; urgency=medium
+
+ * New synapse release 1.50.0~rc1.
+
+ -- Synapse Packaging team Wed, 05 Jan 2022 12:36:17 +0000
+
matrix-synapse-py3 (1.49.2) stable; urgency=medium
* New synapse release 1.49.2.
diff --git a/demo/README b/demo/README
index 0bec820ad657..a5a95bd19666 100644
--- a/demo/README
+++ b/demo/README
@@ -22,5 +22,5 @@ Logs and sqlitedb will be stored in demo/808{0,1,2}.{log,db}
-Also note that when joining a public room on a differnt HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name.
+Also note that when joining a public room on a different HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 1665d2be7d50..546c87da42cd 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,14 +1,17 @@
# Dockerfile to build the matrixdotorg/synapse docker images.
#
+# Note that it uses features which are only available in BuildKit - see
+# https://docs.docker.com/go/buildkit/ for more information.
+#
# To build the image, run `docker build` command from the root of the
# synapse repository:
#
-# docker build -f docker/Dockerfile .
+# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile .
#
# There is an optional PYTHON_VERSION build argument which sets the
# version of python to build against: for example:
#
-# docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.6 .
+# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 .
#
ARG PYTHON_VERSION=3.8
@@ -19,7 +22,16 @@ ARG PYTHON_VERSION=3.8
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
-RUN apt-get update && apt-get install -y \
+#
+# RUN --mount is specific to buildkit and is documented at
+# https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md#build-mounts-run---mount.
+# Here we use it to set up a cache for apt, to improve rebuild speeds on
+# slow connections.
+#
+RUN \
+ --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ apt-get update && apt-get install -y \
build-essential \
libffi-dev \
libjpeg-dev \
@@ -44,7 +56,8 @@ COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py
# used while you develop on the source
#
# This is aiming at installing the `install_requires` and `extras_require` from `setup.py`
-RUN pip install --prefix="/install" --no-warn-script-location \
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install --prefix="/install" --no-warn-script-location \
/synapse[all]
# Copy over the rest of the project
@@ -66,7 +79,10 @@ LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/syna
LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
LABEL org.opencontainers.image.licenses='Apache-2.0'
-RUN apt-get update && apt-get install -y \
+RUN \
+ --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ apt-get update && apt-get install -y \
curl \
gosu \
libjpeg62-turbo \
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 1dd88140c7a4..fbc1d2346fb8 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -16,7 +16,7 @@ ARG distro=""
### Stage 0: build a dh-virtualenv
###
-# This is only really needed on bionic and focal, since other distributions we
+# This is only really needed on focal, since other distributions we
# care about have a recent version of dh-virtualenv by default. Unfortunately,
# it looks like focal is going to be with us for a while.
#
@@ -36,9 +36,8 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
wget
# fetch and unpack the package
-# TODO: Upgrade to 1.2.2 once bionic is dropped (1.2.2 requires debhelper 12; bionic has only 11)
RUN mkdir /dh-virtualenv
-RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
+RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/refs/tags/1.2.2.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
# install its build deps. We do another apt-cache-update here, because we might
@@ -86,12 +85,12 @@ RUN apt-get update -qq -o Acquire::Languages=none \
libpq-dev \
xmlsec1
-COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /
+COPY --from=builder /dh-virtualenv_1.2.2-1_all.deb /
# install dhvirtualenv. Update the apt cache again first, in case we got a
# cached cache from docker the first time.
RUN apt-get update -qq -o Acquire::Languages=none \
- && apt-get install -yq /dh-virtualenv_1.2~dev-1_all.deb
+ && apt-get install -yq /dh-virtualenv_1.2.2-1_all.deb
WORKDIR /synapse/source
ENTRYPOINT ["bash","/synapse/source/docker/build_debian.sh"]
diff --git a/docker/README.md b/docker/README.md
index 4349e71f87bb..67c3bc65f095 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -68,6 +68,10 @@ The following environment variables are supported in `generate` mode:
directories. If unset, and no user is set via `docker run --user`, defaults
to `991`, `991`.
+## Postgres
+
+By default the config will use SQLite. See the [docs on using Postgres](https://github.com/matrix-org/synapse/blob/develop/docs/postgres.md) for more info on how to use Postgres. Until this section is improved [this issue](https://github.com/matrix-org/synapse/issues/8304) may provide useful information.
+
## Running synapse
Once you have a valid configuration file, you can start synapse as follows:
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index b05af6d69051..11f597b3edb8 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -30,6 +30,7 @@
- [SSO Mapping Providers](sso_mapping_providers.md)
- [Password Auth Providers](password_auth_providers.md)
- [JSON Web Tokens](jwt.md)
+ - [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md)
- [Registration Captcha](CAPTCHA_SETUP.md)
- [Application Services](application_services.md)
- [Server Notices](server_notices.md)
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index ba574d795fcc..c514cadb9dae 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -15,9 +15,10 @@ server admin: [Admin API](../usage/administration/admin_api)
It returns a JSON body like the following:
-```json
+```jsonc
{
- "displayname": "User",
+ "name": "@user:example.com",
+ "displayname": "User", // can be null if not set
"threepids": [
{
"medium": "email",
@@ -32,11 +33,11 @@ It returns a JSON body like the following:
"validated_at": 1586458409743
}
],
- "avatar_url": "",
+ "avatar_url": "", // can be null if not set
+ "is_guest": 0,
"admin": 0,
"deactivated": 0,
"shadow_banned": 0,
- "password_hash": "$2b$12$p9B4GkqYdRTPGD",
"creation_ts": 1560432506,
"appservice_id": null,
"consent_server_notice_sent": null,
@@ -480,6 +481,81 @@ The following fields are returned in the JSON response body:
- `joined_rooms` - An array of `room_id`.
- `total` - Number of rooms.
+## Account Data
+Gets information about account data for a specific `user_id`.
+
+The API is:
+
+```
+GET /_synapse/admin/v1/users//accountdata
+```
+
+A response body like the following is returned:
+
+```json
+{
+ "account_data": {
+ "global": {
+ "m.secret_storage.key.LmIGHTg5W": {
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "fwjNZatxg==",
+ "mac": "eWh9kNnLWZUNOgnc="
+ },
+ "im.vector.hide_profile": {
+ "hide_profile": true
+ },
+ "org.matrix.preview_urls": {
+ "disable": false
+ },
+ "im.vector.riot.breadcrumb_rooms": {
+ "rooms": [
+ "!LxcBDAsDUVAfJDEo:matrix.org",
+ "!MAhRxqasbItjOqxu:matrix.org"
+ ]
+ },
+ "m.accepted_terms": {
+ "accepted": [
+ "https://example.org/somewhere/privacy-1.2-en.html",
+ "https://example.org/somewhere/terms-2.0-en.html"
+ ]
+ },
+ "im.vector.setting.breadcrumbs": {
+ "recent_rooms": [
+ "!MAhRxqasbItqxuEt:matrix.org",
+ "!ZtSaPCawyWtxiImy:matrix.org"
+ ]
+ }
+ },
+ "rooms": {
+ "!GUdfZSHUJibpiVqHYd:matrix.org": {
+ "m.fully_read": {
+ "event_id": "$156334540fYIhZ:matrix.org"
+ }
+ },
+ "!tOZwOOiqwCYQkLhV:matrix.org": {
+ "m.fully_read": {
+ "event_id": "$xjsIyp4_NaVl2yPvIZs_k1Jl8tsC_Sp23wjqXPno"
+ }
+ }
+ }
+ }
+}
+```
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- `user_id` - fully qualified: for example, `@user:server.com`.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- `account_data` - A map containing the account data for the user
+ - `global` - A map containing the global account data for the user
+ - `rooms` - A map containing the account data per room for the user
+
## User media
### List media uploaded by a user
diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md
index abdb8084382b..c14298169351 100644
--- a/docs/development/contributing_guide.md
+++ b/docs/development/contributing_guide.md
@@ -20,7 +20,9 @@ recommended for development. More information about WSL can be found at
. Running Synapse natively
on Windows is not officially supported.
-The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download).
+The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://www.python.org/downloads/). Your Python also needs support for [virtual environments](https://docs.python.org/3/library/venv.html). This is usually built-in, but some Linux distributions like Debian and Ubuntu split it out into its own package. Running `sudo apt install python3-venv` should be enough.
+
+Synapse can connect to PostgreSQL via the [psycopg2](https://pypi.org/project/psycopg2/) Python library. Building this library from source requires access to PostgreSQL's C header files. On Debian or Ubuntu Linux, these can be installed with `sudo apt install libpq-dev`.
The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git).
@@ -169,6 +171,27 @@ To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`:
SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests
```
+By default, tests will use an in-memory SQLite database for test data. For additional
+help with debugging, one can use an on-disk SQLite database file instead, in order to
+review database state during and after running tests. This can be done by setting
+the `SYNAPSE_TEST_PERSIST_SQLITE_DB` environment variable. Doing so will cause the
+database state to be stored in a file named `test.db` under the trial process'
+working directory. Typically, this ends up being `_trial_temp/test.db`. For example:
+
+```sh
+SYNAPSE_TEST_PERSIST_SQLITE_DB=1 trial tests
+```
+
+The database file can then be inspected with:
+
+```sh
+sqlite3 _trial_temp/test.db
+```
+
+Note that the database file is cleared at the beginning of each test run. Thus it
+will always only contain the data generated by the *last run test*. Though generally
+when debugging, one is only running a single test anyway.
+
### Running tests under PostgreSQL
Invoking `trial` as above will use an in-memory SQLite database. This is great for
diff --git a/docs/development/url_previews.md b/docs/development/url_previews.md
index aff38136091d..154b9a5e12f4 100644
--- a/docs/development/url_previews.md
+++ b/docs/development/url_previews.md
@@ -35,7 +35,12 @@ When Synapse is asked to preview a URL it does the following:
5. If the media is HTML:
1. Decodes the HTML via the stored file.
2. Generates an Open Graph response from the HTML.
- 3. If an image exists in the Open Graph response:
+ 3. If a JSON oEmbed URL was found in the HTML via autodiscovery:
+ 1. Downloads the URL and stores it into a file via the media storage provider
+ and saves the local media metadata.
+ 2. Convert the oEmbed response to an Open Graph response.
+ 3. Override any Open Graph data from the HTML with data from oEmbed.
+ 4. If an image exists in the Open Graph response:
1. Downloads the URL and stores it into a file via the media storage
provider and saves the local media metadata.
2. Generates thumbnails.
diff --git a/docs/openid.md b/docs/openid.md
index ff9de9d5b8bf..171ea3b7128b 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -390,9 +390,6 @@ oidc_providers:
### Facebook
-Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
-one so requires a little more configuration.
-
0. You will need a Facebook developer account. You can register for one
[here](https://developers.facebook.com/async/registration/).
1. On the [apps](https://developers.facebook.com/apps/) page of the developer
@@ -412,24 +409,28 @@ Synapse config:
idp_name: Facebook
idp_brand: "facebook" # optional: styling hint for clients
discover: false
- issuer: "https://facebook.com"
+ issuer: "https://www.facebook.com"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "email"]
- authorization_endpoint: https://facebook.com/dialog/oauth
- token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
- user_profile_method: "userinfo_endpoint"
- userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
+ authorization_endpoint: "https://facebook.com/dialog/oauth"
+ token_endpoint: "https://graph.facebook.com/v9.0/oauth/access_token"
+ jwks_uri: "https://www.facebook.com/.well-known/oauth/openid/jwks/"
user_mapping_provider:
config:
- subject_claim: "id"
display_name_template: "{{ user.name }}"
+ email_template: "{{ '{{ user.email }}' }}"
```
Relevant documents:
- * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
- * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
- * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
+ * [Manually Build a Login Flow](https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow)
+ * [Using Facebook's Graph API](https://developers.facebook.com/docs/graph-api/using-graph-api/)
+ * [Reference to the User endpoint](https://developers.facebook.com/docs/graph-api/reference/user)
+
+Facebook do have an [OIDC discovery endpoint](https://www.facebook.com/.well-known/openid-configuration),
+but it has a `response_types_supported` which excludes "code" (which we rely on, and
+is even mentioned in their [documentation](https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow#login)),
+so we have to disable discovery and configure the URIs manually.
### Gitea
diff --git a/docs/postgres.md b/docs/postgres.md
index e4861c1f127f..0562021da526 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -1,6 +1,6 @@
# Using Postgres
-Synapse supports PostgreSQL versions 9.6 or later.
+Synapse supports PostgreSQL versions 10 or later.
## Install postgres client libraries
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index f3b3aea732c7..1a89da50fd97 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -63,7 +63,7 @@ server {
server_name matrix.example.com;
- location ~* ^(\/_matrix|\/_synapse\/client) {
+ location ~ ^(/_matrix|/_synapse/client) {
# note: do not add a path (even a single /) after the port in `proxy_pass`,
# otherwise nginx will canonicalise the URI and cause signature verification
# errors.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 6696ed5d1ef9..1b86d0295d73 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -37,7 +37,7 @@
# Server admins can expand Synapse's functionality with external modules.
#
-# See https://matrix-org.github.io/synapse/latest/modules.html for more
+# See https://matrix-org.github.io/synapse/latest/modules/index.html for more
# documentation on how to configure or create custom modules for Synapse.
#
modules:
@@ -74,13 +74,7 @@ server_name: "SERVERNAME"
#
pid_file: DATADIR/homeserver.pid
-# The absolute URL to the web client which /_matrix/client will redirect
-# to if 'webclient' is configured under the 'listeners' configuration.
-#
-# This option can be also set to the filesystem path to the web client
-# which will be served at /_matrix/client/ if 'webclient' is configured
-# under the 'listeners' configuration, however this is a security risk:
-# https://github.com/matrix-org/synapse#security-note
+# The absolute URL to the web client which / will redirect to.
#
#web_client_location: https://riot.example.com/
@@ -164,7 +158,7 @@ presence:
# The default room version for newly created rooms.
#
# Known room versions are listed here:
-# https://matrix.org/docs/spec/#complete-list-of-room-versions
+# https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
@@ -310,8 +304,6 @@ presence:
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
#
-# webclient: A web client. Requires web_client_location to be set.
-#
listeners:
# TLS-enabled listener: for when matrix traffic is sent directly to synapse.
#
@@ -1488,6 +1480,7 @@ room_prejoin_state:
# - m.room.encryption
# - m.room.name
# - m.room.create
+ # - m.room.topic
#
# Uncomment the following to disable these defaults (so that only the event
# types listed in 'additional_event_types' are shared). Defaults to 'false'.
@@ -1502,6 +1495,21 @@ room_prejoin_state:
#additional_event_types:
# - org.example.custom.event.type
+# We record the IP address of clients used to access the API for various
+# reasons, including displaying it to the user in the "Where you're signed in"
+# dialog.
+#
+# By default, when puppeting another user via the admin API, the client IP
+# address is recorded against the user who created the access token (ie, the
+# admin user), and *not* the puppeted user.
+#
+# Uncomment the following to also record the IP address against the puppeted
+# user. (This also means that the puppeted user will count as an "active" user
+# for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc
+# above.)
+#
+#track_puppeted_user_ips: true
+
# A list of application service config files to use
#
@@ -1869,10 +1877,13 @@ saml2_config:
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
-# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+# endpoint, or to rely on the data returned in the id_token from the
+# token_endpoint.
+#
+# Valid values are: 'auto' or 'userinfo_endpoint'.
#
-# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
-# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+# Defaults to 'auto', which uses the userinfo endpoint if 'openid' is
+# not included in 'scopes'. Set to 'userinfo_endpoint' to always use the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
diff --git a/docs/setup/installation.md b/docs/setup/installation.md
index 16562be95388..fe657a15dfa6 100644
--- a/docs/setup/installation.md
+++ b/docs/setup/installation.md
@@ -164,7 +164,7 @@ xbps-install -S synapse
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
-- Packages: `pkg install py37-matrix-synapse`
+- Packages: `pkg install py38-matrix-synapse`
#### OpenBSD
@@ -194,7 +194,7 @@ When following this route please make sure that the [Platform-specific prerequis
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
-- Python 3.6 or later, up to Python 3.9.
+- Python 3.7 or later, up to Python 3.9.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
To install the Synapse homeserver run:
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 7a407012e0b1..7b4ddc5b7423 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -49,12 +49,12 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
-* `__init__(self, parsed_config)`
+* `def __init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
-* `parse_config(config)`
+* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
@@ -63,13 +63,13 @@ A custom mapping provider must specify the following methods:
any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
-* `get_remote_user_id(self, userinfo)`
+* `def get_remote_user_id(self, userinfo)`
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- This method must return a string, which is the unique, immutable identifier
for the user. Commonly the `sub` claim of the response.
-* `map_user_attributes(self, userinfo, token, failures)`
+* `async def map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
@@ -91,7 +91,7 @@ A custom mapping provider must specify the following methods:
during a user's first login. Once a localpart has been associated with a
remote user ID (see `get_remote_user_id`) it cannot be updated.
- `displayname`: An optional string, the display name for the user.
-* `get_extra_attributes(self, userinfo, token)`
+* `async def get_extra_attributes(self, userinfo, token)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
@@ -125,15 +125,15 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
-* `__init__(self, parsed_config, module_api)`
+* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
-* `parse_config(config)`
- - This method should have the `@staticmethod` decoration.
+* `def parse_config(config)`
+ - **This method should have the `@staticmethod` decoration.**
- Arguments:
- `config` - A `dict` representing the parsed content of the
`saml_config.user_mapping_provider.config` homeserver config option.
@@ -141,15 +141,15 @@ A custom mapping provider must specify the following methods:
any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
-* `get_saml_attributes(config)`
- - This method should have the `@staticmethod` decoration.
+* `def get_saml_attributes(config)`
+ - **This method should have the `@staticmethod` decoration.**
- Arguments:
- `config` - A object resulting from a call to `parse_config`.
- Returns a tuple of two sets. The first set equates to the SAML auth
response attributes that are required for the module to function, whereas
the second set consists of those attributes which can be used if available,
but are not necessary.
-* `get_remote_user_id(self, saml_response, client_redirect_url)`
+* `def get_remote_user_id(self, saml_response, client_redirect_url)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
@@ -157,7 +157,7 @@ A custom mapping provider must specify the following methods:
redirected to.
- This method must return a string, which is the unique, immutable identifier
for the user. Commonly the `uid` claim of the response.
-* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
+* `def saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index e6812de69e6b..eba7ca6124a5 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -15,8 +15,8 @@ The following sections describe how to install [coturn](TURN->TURN->client flows work
+ # this should be one of the turn server's listening IPs
allowed-peer-ip=10.0.0.1
# consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS.
@@ -121,34 +137,58 @@ This will install and start a systemd service called `coturn`.
# TLS private key file
pkey=/path/to/privkey.pem
+
+ # Ensure the configuration lines that disable TLS/DTLS are commented-out or removed
+ #no-tls
+ #no-dtls
```
- In this case, replace the `turn:` schemes in the `turn_uri` settings below
+ In this case, replace the `turn:` schemes in the `turn_uris` settings below
with `turns:`.
We recommend that you only try to set up TLS/DTLS once you have set up a
basic installation and got it working.
+ NB: If your TLS certificate was provided by Let's Encrypt, TLS/DTLS will
+ not work with any Matrix client that uses Chromium's WebRTC library. This
+ currently includes Element Android & iOS; for more details, see their
+ [respective](https://github.com/vector-im/element-android/issues/1533)
+ [issues](https://github.com/vector-im/element-ios/issues/2712) as well as the underlying
+ [WebRTC issue](https://bugs.chromium.org/p/webrtc/issues/detail?id=11710).
+ Consider using a ZeroSSL certificate for your TURN server as a working alternative.
+
1. Ensure your firewall allows traffic into the TURN server on the ports
you've configured it to listen on (By default: 3478 and 5349 for TURN
traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
for the UDP relay.)
-1. We do not recommend running a TURN server behind NAT, and are not aware of
- anyone doing so successfully.
+1. If your TURN server is behind NAT, the NAT gateway must have an external,
+ publicly-reachable IP address. You must configure coturn to advertise that
+ address to connecting clients:
- If you want to try it anyway, you will at least need to tell coturn its
- external IP address:
+ ```
+ external-ip=EXTERNAL_NAT_IPv4_ADDRESS
+ ```
+
+ You may optionally limit the TURN server to listen only on the local
+ address that is mapped by NAT to the external address:
```
- external-ip=192.88.99.1
+ listening-ip=INTERNAL_TURNSERVER_IPv4_ADDRESS
```
- ... and your NAT gateway must forward all of the relayed ports directly
- (eg, port 56789 on the external IP must be always be forwarded to port
- 56789 on the internal IP).
+ If your NAT gateway is reachable over both IPv4 and IPv6, you may
+ configure coturn to advertise each available address:
+
+ ```
+ external-ip=EXTERNAL_NAT_IPv4_ADDRESS
+ external-ip=EXTERNAL_NAT_IPv6_ADDRESS
+ ```
- If you get this working, let us know!
+ When advertising an external IPv6 address, ensure that the firewall and
+ network settings of the system running your TURN server are configured to
+ accept IPv6 traffic, and that the TURN server is listening on the local
+ IPv6 address that is mapped by NAT to the external IPv6 address.
1. (Re)start the turn server:
@@ -216,15 +256,16 @@ connecting". Unfortunately, troubleshooting this can be tricky.
Here are a few things to try:
- * Check that your TURN server is not behind NAT. As above, we're not aware of
- anyone who has successfully set this up.
-
* Check that you have opened your firewall to allow TCP and UDP traffic to the
TURN ports (normally 3478 and 5349).
* Check that you have opened your firewall to allow UDP traffic to the UDP
relay ports (49152-65535 by default).
+ * Try disabling `coturn`'s TLS/DTLS listeners and enable only its (unencrypted)
+ TCP/UDP listeners. (This will only leave signaling traffic unencrypted;
+ voice & video WebRTC traffic is always encrypted.)
+
* Some WebRTC implementations (notably, that of Google Chrome) appear to get
confused by TURN servers which are reachable over IPv6 (this appears to be
an unexpected side-effect of its handling of multiple IP addresses as
@@ -234,6 +275,18 @@ Here are a few things to try:
Try removing any AAAA records for your TURN server, so that it is only
reachable over IPv4.
+ * If your TURN server is behind NAT:
+
+ * double-check that your NAT gateway is correctly forwarding all TURN
+ ports (normally 3478 & 5349 for TCP & UDP TURN traffic, and 49152-65535 for the UDP
+ relay) to the NAT-internal address of your TURN server. If advertising
+ both IPv4 and IPv6 external addresses via the `external-ip` option, ensure
+ that the NAT is forwarding both IPv4 and IPv6 traffic to the IPv4 and IPv6
+ internal addresses of your TURN server. When in doubt, remove AAAA records
+ for your TURN server and specify only an IPv4 address as your `external-ip`.
+
+ * ensure that your TURN server uses the NAT gateway as its default route.
+
* Enable more verbose logging in coturn via the `verbose` setting:
```
diff --git a/docs/upgrade.md b/docs/upgrade.md
index 136c806c417a..f455d257babf 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,28 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
+# Upgrading to v1.51.0
+
+## Deprecation of `webclient` listeners and non-HTTP(S) `web_client_location`
+
+Listeners of type `webclient` are deprecated and scheduled to be removed in
+Synapse v1.53.0.
+
+Similarly, a non-HTTP(S) `web_client_location` configuration is deprecated and
+will become a configuration error in Synapse v1.53.0.
+
+
+# Upgrading to v1.50.0
+
+## Dropping support for old Python and Postgres versions
+
+In line with our [deprecation policy](deprecation_policy.md),
+we've dropped support for Python 3.6 and PostgreSQL 9.6, as they are no
+longer supported upstream.
+
+This release of Synapse requires Python 3.7+ and PostgreSQL 10+.
+
+
# Upgrading to v1.47.0
## Removal of old Room Admin API
diff --git a/docs/usage/configuration/user_authentication/refresh_tokens.md b/docs/usage/configuration/user_authentication/refresh_tokens.md
new file mode 100644
index 000000000000..23b3cddae054
--- /dev/null
+++ b/docs/usage/configuration/user_authentication/refresh_tokens.md
@@ -0,0 +1,139 @@
+# Refresh Tokens
+
+Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible).
+
+
+[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens
+
+
+## Background and motivation
+
+Synapse users' sessions are identified by **access tokens**; access tokens are
+issued to users on login. Each session gets a unique access token which identifies
+it; the access token must be kept secret as it grants access to the user's account.
+
+Traditionally, these access tokens were eternally valid (at least until the user
+explicitly chose to log out).
+
+In some cases, it may be desirable for these access tokens to expire so that the
+potential damage caused by leaking an access token is reduced.
+On the other hand, forcing a user to re-authenticate (log in again) often might
+be too much of an inconvenience.
+
+**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst
+still getting most of the benefits of short access token lifetimes.
+Refresh tokens are also a concept present in OAuth 2 — further reading is available
+[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5).
+
+When refresh tokens are in use, both an access token and a refresh token will be
+issued to users on login. The access token will expire after a predetermined amount
+of time, but otherwise works in the same way as before. When the access token is
+close to expiring (or has expired), the user's client should present the homeserver
+(Synapse) with the refresh token.
+
+The homeserver will then generate a new access token and refresh token for the user
+and return them. The old refresh token is invalidated and can not be used again*.
+
+Finally, refresh tokens also make it possible for sessions to be logged out if they
+are inactive for too long, before the session naturally ends; see the configuration
+guide below.
+
+
+*To prevent issues if clients lose connection half-way through refreshing a token,
+the refresh token is only invalidated once the new access token has been used at
+least once. For all intents and purposes, the above simplification is sufficient.
+
+
+## Caveats
+
+There are some caveats:
+
+* If a third party gets both your access token and refresh token, they will be able to
+ continue to enjoy access to your session.
+ * This is still an improvement because you (the user) will notice when *your*
+ session expires and you're not able to use your refresh token.
+ That would be a giveaway that someone else has compromised your session.
+ You would be able to log in again and terminate that session.
+ Previously (with long-lived access tokens), a third party that has your access
+ token could go undetected for a very long time.
+* Clients need to implement support for refresh tokens in order for them to be a
+ useful mechanism.
+ * It is up to homeserver administrators if they want to issue long-lived access
+ tokens to clients not implementing refresh tokens.
+ * For compatibility, it is likely that they should, at least until client support
+ is widespread.
+ * Users with clients that support refresh tokens will still benefit from the
+ added security; it's not possible to downgrade a session to using long-lived
+ access tokens so this effectively gives users the choice.
+ * In a closed environment where all users use known clients, this may not be
+ an issue as the homeserver administrator can know if the clients have refresh
+ token support. In that case, the non-refreshable access token lifetime
+ may be set to a short duration so that a similar level of security is provided.
+
+
+## Configuration Guide
+
+The following configuration options, in the `registration` section, are related:
+
+* `session_lifetime`: maximum length of a session, even if it's refreshed.
+ In other words, the client must log in again after this time period.
+ In most cases, this can be unset (infinite) or set to a long time (years or months).
+* `refreshable_access_token_lifetime`: lifetime of access tokens that are created
+ by clients supporting refresh tokens.
+ This should be short; a good value might be 5 minutes (`5m`).
+* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created
+ by clients which don't support refresh tokens.
+ Make this short if you want to effectively force use of refresh tokens.
+ Make this long if you don't want to inconvenience users of clients which don't
+ support refresh tokens (by forcing them to frequently re-authenticate using
+ login credentials).
+* `refresh_token_lifetime`: lifetime of refresh tokens.
+ In other words, the client must refresh within this time period to maintain its session.
+ Unless you want to log inactive sessions out, it is often fine to use a long
+ value here or even leave it unset (infinite).
+ Beware that making it too short will inconvenience clients that do not connect
+ very often, including mobile clients and clients of infrequent users (by making
+ it more difficult for them to refresh in time, which may force them to need to
+ re-authenticate using login credentials).
+
+**Note:** All four options above only apply when tokens are created (by logging in or refreshing).
+Changes to these settings do not apply retroactively.
+
+
+### Using refresh token expiry to log out inactive sessions
+
+If you'd like to force sessions to be logged out upon inactivity, you can enable
+refreshable access token expiry and refresh token expiry.
+
+This works because a client must refresh at least once within a period of
+`refresh_token_lifetime` in order to maintain valid credentials to access the
+account.
+
+(It's suggested that `refresh_token_lifetime` should be longer than
+`refreshable_access_token_lifetime` and this section assumes that to be the case
+for simplicity.)
+
+Note: this will only affect sessions using refresh tokens. You may wish to
+set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed
+by clients that do not support refresh tokens.
+
+
+#### Choosing values that guarantee permitting some inactivity
+
+It may be desirable to permit some short periods of inactivity, for example to
+accommodate brief outages in client connectivity.
+
+The following model aims to provide guidance for choosing `refresh_token_lifetime`
+and `refreshable_access_token_lifetime` to satisfy requirements of the form:
+
+1. inactivity longer than `L` **MUST** cause the session to be logged out; and
+2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out.
+
+This model makes the weakest assumption that all active clients will refresh as
+needed to maintain an active access token, but no sooner.
+*In reality, clients may refresh more often than this model assumes, but the
+above requirements will still hold.*
+
+To satisfy the above model,
+* `refresh_token_lifetime` should be set to `L`; and
+* `refreshable_access_token_lifetime` should be set to `L - S`.
diff --git a/mypy.ini b/mypy.ini
index 1caf807e8505..85fa22d28f2b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -25,14 +25,9 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
- |synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
- |synapse/storage/databases/main/e2e_room_keys.py
- |synapse/storage/databases/main/end_to_end_keys.py
|synapse/storage/databases/main/event_federation.py
- |synapse/storage/databases/main/event_push_actions.py
- |synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
@@ -40,12 +35,9 @@ exclude = (?x)
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
- |synapse/storage/databases/main/room.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
- |synapse/storage/databases/main/stats.py
- |synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/
@@ -107,7 +99,6 @@ exclude = (?x)
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
- |tests/storage/test_account_data.py
|tests/storage/test_background_update.py
|tests/storage/test_base.py
|tests/storage/test_client_ips.py
@@ -145,6 +136,9 @@ disallow_untyped_defs = True
[mypy-synapse.app.*]
disallow_untyped_defs = True
+[mypy-synapse.appservice.*]
+disallow_untyped_defs = True
+
[mypy-synapse.config._base]
disallow_untyped_defs = True
@@ -163,6 +157,12 @@ disallow_untyped_defs = False
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
+[mypy-synapse.http.server]
+disallow_untyped_defs = True
+
+[mypy-synapse.logging.context]
+disallow_untyped_defs = True
+
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
@@ -181,24 +181,48 @@ disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.account_data]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.client_ips]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.e2e_room_keys]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.databases.main.end_to_end_keys]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.databases.main.event_push_actions]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.databases.main.events_bg_updates]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.room]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.profile]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.stats]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.transactions]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True
@@ -223,6 +247,9 @@ disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True
+[mypy-tests.rest.admin.*]
+disallow_untyped_defs = True
+
[mypy-tests.rest.client.test_directory]
disallow_untyped_defs = True
@@ -286,9 +313,6 @@ ignore_missing_imports = True
[mypy-netaddr]
ignore_missing_imports = True
-[mypy-opentracing]
-ignore_missing_imports = True
-
[mypy-parameterized.*]
ignore_missing_imports = True
diff --git a/pyproject.toml b/pyproject.toml
index 8bca1fa4efd9..963f149c6af6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@
showcontent = true
[tool.black]
-target-version = ['py36']
+target-version = ['py37', 'py38', 'py39', 'py310']
exclude = '''
(
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index 3a9a2d257c6f..4d34e9070363 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -24,7 +24,6 @@ DISTS = (
"debian:bullseye",
"debian:bookworm",
"debian:sid",
- "ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23)
"ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
"ubuntu:hirsute", # 21.04 (EOL 2022-01-05)
"ubuntu:impish", # 21.10 (EOL 2022-07)
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index af4de345df57..c764011d6adb 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -42,8 +42,8 @@ echo "--------------------------"
echo
matched=0
-for f in $(git diff --name-only FETCH_HEAD... -- changelog.d); do
- # check that any modified newsfiles on this branch end with a full stop.
+for f in $(git diff --diff-filter=d --name-only FETCH_HEAD... -- changelog.d); do
+ # check that any added newsfiles on this branch end with a full stop.
lastchar=$(tr -d '\n' < "$f" | tail -c 1)
if [ "$lastchar" != '.' ] && [ "$lastchar" != '!' ]; then
echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 53295b58fca9..e08ffedaf33a 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -8,7 +8,8 @@
# By default the script will fetch the latest Complement master branch and
# run tests with that. This can be overridden to use a custom Complement
# checkout by setting the COMPLEMENT_DIR environment variable to the
-# filepath of a local Complement checkout.
+# filepath of a local Complement checkout or by setting the COMPLEMENT_REF
+# environment variable to pull a different branch or commit.
#
# By default Synapse is run in monolith mode. This can be overridden by
# setting the WORKERS environment variable.
@@ -23,16 +24,20 @@
# Exit if a line returns a non-zero exit code
set -e
+# enable buildkit for the docker builds
+export DOCKER_BUILDKIT=1
+
# Change to the repository root
cd "$(dirname $0)/.."
# Check for a user-specified Complement checkout
if [[ -z "$COMPLEMENT_DIR" ]]; then
- echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
- wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
- tar -xzf master.tar.gz
- COMPLEMENT_DIR=complement-master
- echo "Checkout available at 'complement-master'"
+ COMPLEMENT_REF=${COMPLEMENT_REF:-master}
+ echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..."
+ wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz
+ tar -xzf ${COMPLEMENT_REF}.tar.gz
+ COMPLEMENT_DIR=complement-${COMPLEMENT_REF}
+ echo "Checkout available at 'complement-${COMPLEMENT_REF}'"
fi
# Build the base Synapse image from the local checkout
@@ -47,7 +52,7 @@ if [[ -n "$WORKERS" ]]; then
COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile
# And provide some more configuration to complement.
export COMPLEMENT_CA=true
- export COMPLEMENT_VERSION_CHECK_ITERATIONS=500
+ export COMPLEMENT_SPAWN_HS_TIMEOUT_SECS=25
else
export COMPLEMENT_BASE_IMAGE=complement-synapse
COMPLEMENT_DOCKERFILE=Synapse.Dockerfile
@@ -65,4 +70,5 @@ if [[ -n "$1" ]]; then
fi
# Run the tests!
+echo "Images built; running complement"
go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
diff --git a/setup.py b/setup.py
index 2c6fb9aacb45..e618ff898b46 100755
--- a/setup.py
+++ b/setup.py
@@ -96,7 +96,7 @@ def exec_file(path_segments):
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.7.0",
- "black==21.6b0",
+ "black==21.12b0",
"flake8-comprehensions",
"flake8-bugbear==21.3.2",
"flake8",
@@ -107,6 +107,7 @@ def exec_file(path_segments):
"mypy-zope==0.3.2",
"types-bleach>=4.1.0",
"types-jsonschema>=3.2.0",
+ "types-opentracing>=2.4.2",
"types-Pillow>=8.3.4",
"types-pyOpenSSL>=20.0.7",
"types-PyYAML>=5.4.10",
@@ -119,9 +120,7 @@ def exec_file(path_segments):
# Tests assume that all optional dependencies are installed.
#
# parameterized_class decorator was introduced in parameterized 0.7.0
-#
-# We use `mock` library as that backports `AsyncMock` to Python 3.6
-CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
CONDITIONAL_REQUIREMENTS["dev"] = (
CONDITIONAL_REQUIREMENTS["lint"]
@@ -163,7 +162,6 @@ def exec_file(path_segments):
"Topic :: Communications :: Chat",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 4ff3c6de5feb..429234d7ae7f 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -17,11 +17,12 @@
from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol
+from twisted.internet.defer import Deferred
class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
- async def ping(self) -> None: ...
- async def set(
+ def ping(self) -> "Deferred[None]": ...
+ def set(
self,
key: str,
value: Any,
@@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
- ) -> None: ...
- async def get(self, key: str) -> Any: ...
+ ) -> "Deferred[None]": ...
+ def get(self, key: str) -> "Deferred[Any]": ...
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 95a49c20befc..26bdfec33ae3 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@
except ImportError:
pass
-__version__ = "1.49.2"
+__version__ = "1.51.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
index 093af4327ae0..e207f154f3aa 100644
--- a/synapse/_scripts/review_recent_signups.py
+++ b/synapse/_scripts/review_recent_signups.py
@@ -46,7 +46,9 @@ class UserInfo:
ips: List[str] = attr.Factory(list)
-def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
+def get_recent_users(
+ txn: LoggingTransaction, since_ms: int, exclude_app_service: bool
+) -> List[UserInfo]:
"""Fetches recently registered users and some info on them."""
sql = """
@@ -56,6 +58,9 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
AND deactivated = 0
"""
+ if exclude_app_service:
+ sql += " AND appservice_id IS NULL"
+
txn.execute(sql, (since_ms / 1000,))
user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn]
@@ -113,7 +118,7 @@ def main() -> None:
"-e",
"--exclude-emails",
action="store_true",
- help="Exclude users that have validated email addresses",
+ help="Exclude users that have validated email addresses.",
)
parser.add_argument(
"-u",
@@ -121,6 +126,12 @@ def main() -> None:
action="store_true",
help="Only print user IDs that match.",
)
+ parser.add_argument(
+ "-a",
+ "--exclude-app-service",
+ help="Exclude appservice users.",
+ action="store_true",
+ )
config = ReviewConfig()
@@ -133,6 +144,7 @@ def main() -> None:
since_ms = time.time() * 1000 - Config.parse_duration(config_args.since)
exclude_users_with_email = config_args.exclude_emails
+ exclude_users_with_appservice = config_args.exclude_app_service
include_context = not config_args.only_users
for database_config in config.database.databases:
@@ -143,7 +155,7 @@ def main() -> None:
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
# This generates a type of Cursor, not LoggingTransaction.
- user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type]
+ user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type]
for user_info in user_infos:
if exclude_users_with_email and user_info.emails:
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 44883c6663ff..683241201ca6 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -32,7 +32,7 @@
from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
-from synapse.logging import opentracing as opentracing
+from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
@@ -71,6 +71,7 @@ def __init__(self, hs: "HomeServer"):
self._auth_blocking = AuthBlocking(self.hs)
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
+ self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@@ -149,13 +150,53 @@ async def get_user_by_req(
is invalid.
AuthError if access is denied for the user in the access token
"""
+ parent_span = active_span()
+ with start_active_span("get_user_by_req"):
+ requester = await self._wrapped_get_user_by_req(
+ request, allow_guest, rights, allow_expired
+ )
+
+ if parent_span:
+ if requester.authenticated_entity in self._force_tracing_for_users:
+ # request tracing is enabled for this user, so we need to force it
+ # tracing on for the parent span (which will be the servlet span).
+ #
+ # It's too late for the get_user_by_req span to inherit the setting,
+ # so we also force it on for that.
+ force_tracing()
+ force_tracing(parent_span)
+ parent_span.set_tag(
+ "authenticated_entity", requester.authenticated_entity
+ )
+ parent_span.set_tag("user_id", requester.user.to_string())
+ if requester.device_id is not None:
+ parent_span.set_tag("device_id", requester.device_id)
+ if requester.app_service is not None:
+ parent_span.set_tag("appservice_id", requester.app_service.id)
+ return requester
+
+ async def _wrapped_get_user_by_req(
+ self,
+ request: SynapseRequest,
+ allow_guest: bool,
+ rights: str,
+ allow_expired: bool,
+ ) -> Requester:
+ """Helper for get_user_by_req
+
+ Once get_user_by_req has set up the opentracing span, this does the actual work.
+ """
try:
ip_addr = request.getClientIP()
user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request)
- user_id, app_service = await self._get_appservice_user_id(request)
+ (
+ user_id,
+ device_id,
+ app_service,
+ ) = await self._get_appservice_user_id_and_device_id(request)
if user_id and app_service:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
@@ -163,18 +204,16 @@ async def get_user_by_req(
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
- device_id="dummy-device", # stubbed
+ device_id="dummy-device"
+ if device_id is None
+ else device_id, # stubbed
)
- requester = create_requester(user_id, app_service=app_service)
+ requester = create_requester(
+ user_id, app_service=app_service, device_id=device_id
+ )
request.requester = user_id
- if user_id in self._force_tracing_for_users:
- opentracing.force_tracing()
- opentracing.set_tag("authenticated_entity", user_id)
- opentracing.set_tag("user_id", user_id)
- opentracing.set_tag("appservice_id", app_service.id)
-
return requester
user_info = await self.get_user_by_access_token(
@@ -208,6 +247,18 @@ async def get_user_by_req(
user_agent=user_agent,
device_id=device_id,
)
+ # Track also the puppeted user client IP if enabled and the user is puppeting
+ if (
+ user_info.user_id != user_info.token_owner
+ and self._track_puppeted_user_ips
+ ):
+ await self.store.insert_client_ip(
+ user_id=user_info.user_id,
+ access_token=access_token,
+ ip=ip_addr,
+ user_agent=user_agent,
+ device_id=device_id,
+ )
if is_guest and not allow_guest:
raise AuthError(
@@ -232,13 +283,6 @@ async def get_user_by_req(
)
request.requester = requester
- if user_info.token_owner in self._force_tracing_for_users:
- opentracing.force_tracing()
- opentracing.set_tag("authenticated_entity", user_info.token_owner)
- opentracing.set_tag("user_id", user_info.user_id)
- if device_id:
- opentracing.set_tag("device_id", device_id)
-
return requester
except KeyError:
raise MissingClientTokenError()
@@ -274,33 +318,81 @@ async def validate_appservice_can_control_user_id(
403, "Application service has not registered this user (%s)" % user_id
)
- async def _get_appservice_user_id(
+ async def _get_appservice_user_id_and_device_id(
self, request: Request
- ) -> Tuple[Optional[str], Optional[ApplicationService]]:
+ ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+ """
+ Given a request, reads the request parameters to determine:
+ - whether it's an application service that's making this request
+ - what user the application service should be treated as controlling
+ (the user_id URI parameter allows an application service to masquerade
+ any applicable user in its namespace)
+ - what device the application service should be treated as controlling
+ (the device_id[^1] URI parameter allows an application service to masquerade
+ as any device that exists for the relevant user)
+
+ [^1] Unstable and provided by MSC3202.
+ Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
+
+ Returns:
+ 3-tuple of
+ (user ID?, device ID?, application service?)
+
+ Postconditions:
+ - If an application service is returned, so is a user ID
+ - A user ID is never returned without an application service
+ - A device ID is never returned without a user ID or an application service
+ - The returned application service, if present, is permitted to control the
+ returned user ID.
+ - The returned device ID, if present, has been checked to be a valid device ID
+ for the returned user ID.
+ """
+ DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id"
+
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
if app_service is None:
- return None, None
+ return None, None, None
if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientIP())
if ip_address not in app_service.ip_range_whitelist:
- return None, None
+ return None, None, None
# This will always be set by the time Twisted calls us.
assert request.args is not None
- if b"user_id" not in request.args:
- return app_service.sender, app_service
+ if b"user_id" in request.args:
+ effective_user_id = request.args[b"user_id"][0].decode("utf8")
+ await self.validate_appservice_can_control_user_id(
+ app_service, effective_user_id
+ )
+ else:
+ effective_user_id = app_service.sender
- user_id = request.args[b"user_id"][0].decode("utf8")
- await self.validate_appservice_can_control_user_id(app_service, user_id)
+ effective_device_id: Optional[str] = None
- if app_service.sender == user_id:
- return app_service.sender, app_service
+ if (
+ self.hs.config.experimental.msc3202_device_masquerading_enabled
+ and DEVICE_ID_ARG_NAME in request.args
+ ):
+ effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8")
+ # We only just set this so it can't be None!
+ assert effective_device_id is not None
+ device_opt = await self.store.get_device(
+ effective_user_id, effective_device_id
+ )
+ if device_opt is None:
+ # For now, use 400 M_EXCLUSIVE if the device doesn't exist.
+ # This is an open thread of discussion on MSC3202 as of 2021-12-09.
+ raise AuthError(
+ 400,
+ f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})",
+ Codes.EXCLUSIVE,
+ )
- return user_id, app_service
+ return effective_user_id, effective_device_id, app_service
async def get_user_by_access_token(
self,
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f7d29b431936..52c083a20b9c 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -253,5 +253,9 @@ class GuestAccess:
FORBIDDEN: Final = "forbidden"
+class ReceiptTypes:
+ READ: Final = "m.read"
+
+
class ReadReceiptEventFields:
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 13dd6ce248e1..d087c816db57 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -351,8 +351,7 @@ def _check(self, event: FilterEvent) -> bool:
True if the event matches the filter.
"""
# We usually get the full "events" as dictionaries coming through,
- # except for presence which actually gets passed around as its own
- # namedtuple type.
+ # except for presence which actually gets passed around as its own type.
if isinstance(event, UserPresenceState):
user_id = event.user_id
field_matchers = {
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 0a895bba480a..a747a4081497 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -46,41 +46,41 @@ class RoomDisposition:
UNSTABLE = "unstable"
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomVersion:
"""An object which describes the unique attributes of a room version."""
- identifier = attr.ib(type=str) # the identifier for this version
- disposition = attr.ib(type=str) # one of the RoomDispositions
- event_format = attr.ib(type=int) # one of the EventFormatVersions
- state_res = attr.ib(type=int) # one of the StateResolutionVersions
- enforce_key_validity = attr.ib(type=bool)
+ identifier: str # the identifier for this version
+ disposition: str # one of the RoomDispositions
+ event_format: int # one of the EventFormatVersions
+ state_res: int # one of the StateResolutionVersions
+ enforce_key_validity: bool
# Before MSC2432, m.room.aliases had special auth rules and redaction rules
- special_case_aliases_auth = attr.ib(type=bool)
+ special_case_aliases_auth: bool
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
# * Floats
# * NaN, Infinity, -Infinity
- strict_canonicaljson = attr.ib(type=bool)
+ strict_canonicaljson: bool
# MSC2209: Check 'notifications' key while verifying
# m.room.power_levels auth rules.
- limit_notifications_power_levels = attr.ib(type=bool)
+ limit_notifications_power_levels: bool
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
- msc2176_redaction_rules = attr.ib(type=bool)
+ msc2176_redaction_rules: bool
# MSC3083: Support the 'restricted' join_rule.
- msc3083_join_rules = attr.ib(type=bool)
+ msc3083_join_rules: bool
# MSC3375: Support for the proper redaction rules for MSC3083. This mustn't
# be enabled if MSC3083 is not.
- msc3375_redaction_rules = attr.ib(type=bool)
+ msc3375_redaction_rules: bool
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
- msc2403_knocking = attr.ib(type=bool)
+ msc2403_knocking: bool
# MSC2716: Adds m.room.power_levels -> content.historical field to control
# whether "insertion", "chunk", "marker" events can be sent
- msc2716_historical = attr.ib(type=bool)
+ msc2716_historical: bool
# MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
- msc2716_redactions = attr.ib(type=bool)
+ msc2716_redactions: bool
class RoomVersions:
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 5fc59c1be11d..579adbbca02d 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -60,7 +60,7 @@
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
-from synapse.metrics import register_threadpool
+from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.types import ISynapseReactor
@@ -159,6 +159,7 @@ def run() -> None:
change_resource_limit(soft_file_limit)
if gc_thresholds:
gc.set_threshold(*gc_thresholds)
+ install_gc_manager()
run_command()
# make sure that we run the reactor with the sentinel log context,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index dd76e0732108..efedcc88894b 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -131,9 +131,18 @@ def _listener_http(
resources.update(self._module_web_resources)
self._module_web_resources_consumed = True
- # try to find something useful to redirect '/' to
- if WEB_CLIENT_PREFIX in resources:
- root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
+ # Try to find something useful to serve at '/':
+ #
+ # 1. Redirect to the web client if it is an HTTP(S) URL.
+ # 2. Redirect to the web client served via Synapse.
+ # 3. Redirect to the static "Synapse is running" page.
+ # 4. Do not redirect and use a blank resource.
+ if self.config.server.web_client_location_is_redirect:
+ root_resource: Resource = RootOptionsRedirectResource(
+ self.config.server.web_client_location
+ )
+ elif WEB_CLIENT_PREFIX in resources:
+ root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
else:
@@ -262,15 +271,15 @@ def _configure_named_resource(
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "webclient":
+ # webclient listeners are deprecated as of Synapse v1.51.0, remove it
+ # in > v1.53.0.
webclient_loc = self.config.server.web_client_location
if webclient_loc is None:
logger.warning(
"Not enabling webclient resource, as web_client_location is unset."
)
- elif webclient_loc.startswith("http://") or webclient_loc.startswith(
- "https://"
- ):
+ elif self.config.server.web_client_location_is_redirect:
resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc)
else:
logger.warning(
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f9d3bd337d3b..8c9ff93b2c13 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
import re
from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Match, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
+
+import attr
+from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
@@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
UP = "up"
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Namespace:
+ exclusive: bool
+ group_id: Optional[str]
+ regex: Pattern[str]
+
+
class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -50,17 +61,17 @@ class ApplicationService:
def __init__(
self,
- token,
- hostname,
- id,
- sender,
- url=None,
- namespaces=None,
- hs_token=None,
- protocols=None,
- rate_limited=True,
- ip_range_whitelist=None,
- supports_ephemeral=False,
+ token: str,
+ hostname: str,
+ id: str,
+ sender: str,
+ url: Optional[str] = None,
+ namespaces: Optional[JsonDict] = None,
+ hs_token: Optional[str] = None,
+ protocols: Optional[Iterable[str]] = None,
+ rate_limited: bool = True,
+ ip_range_whitelist: Optional[IPSet] = None,
+ supports_ephemeral: bool = False,
):
self.token = token
self.url = (
@@ -85,27 +96,33 @@ def __init__(
self.rate_limited = rate_limited
- def _check_namespaces(self, namespaces):
+ def _check_namespaces(
+ self, namespaces: Optional[JsonDict]
+ ) -> Dict[str, List[Namespace]]:
# Sanity check that it is of the form:
# {
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
- if not namespaces:
+ if namespaces is None:
namespaces = {}
+ result: Dict[str, List[Namespace]] = {}
+
for ns in ApplicationService.NS_LIST:
+ result[ns] = []
+
if ns not in namespaces:
- namespaces[ns] = []
continue
- if type(namespaces[ns]) != list:
+ if not isinstance(namespaces[ns], list):
raise ValueError("Bad namespace value for '%s'" % ns)
for regex_obj in namespaces[ns]:
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
- if not isinstance(regex_obj.get("exclusive"), bool):
+ exclusive = regex_obj.get("exclusive")
+ if not isinstance(exclusive, bool):
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
group_id = regex_obj.get("group_id")
if group_id:
@@ -126,22 +143,26 @@ def _check_namespaces(self, namespaces):
)
regex = regex_obj.get("regex")
- if isinstance(regex, str):
- regex_obj["regex"] = re.compile(regex) # Pre-compile regex
- else:
+ if not isinstance(regex, str):
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
- return namespaces
- def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
- for regex_obj in self.namespaces[namespace_key]:
- if regex_obj["regex"].match(test_string):
- return regex_obj
+ # Pre-compile regex.
+ result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
+
+ return result
+
+ def _matches_regex(
+ self, namespace_key: str, test_string: str
+ ) -> Optional[Namespace]:
+ for namespace in self.namespaces[namespace_key]:
+ if namespace.regex.match(test_string):
+ return namespace
return None
- def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
- regex_obj = self._matches_regex(test_string, ns_key)
- if regex_obj:
- return regex_obj["exclusive"]
+ def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
+ namespace = self._matches_regex(namespace_key, test_string)
+ if namespace:
+ return namespace.exclusive
return False
async def _matches_user(
@@ -260,15 +281,15 @@ async def is_interested_in_presence(
def is_interested_in_user(self, user_id: str) -> bool:
return (
- bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
+ bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
or user_id == self.sender
)
def is_interested_in_alias(self, alias: str) -> bool:
- return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
+ return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
def is_interested_in_room(self, room_id: str) -> bool:
- return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
+ return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
def is_exclusive_user(self, user_id: str) -> bool:
return (
@@ -285,14 +306,14 @@ def is_exclusive_alias(self, alias: str) -> bool:
def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
- def get_exclusive_user_regexes(self):
+ def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
return [
- regex_obj["regex"]
- for regex_obj in self.namespaces[ApplicationService.NS_USERS]
- if regex_obj["exclusive"]
+ namespace.regex
+ for namespace in self.namespaces[ApplicationService.NS_USERS]
+ if namespace.exclusive
]
def get_groups_for_user(self, user_id: str) -> Iterable[str]:
@@ -305,15 +326,15 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]:
An iterable that yields group_id strings.
"""
return (
- regex_obj["group_id"]
- for regex_obj in self.namespaces[ApplicationService.NS_USERS]
- if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+ namespace.group_id
+ for namespace in self.namespaces[ApplicationService.NS_USERS]
+ if namespace.group_id and namespace.regex.match(user_id)
)
def is_rate_limited(self) -> bool:
return self.rate_limited
- def __str__(self):
+ def __str__(self) -> str:
# copy dictionary and redact token fields so they don't get logged
dict_copy = self.__dict__.copy()
dict_copy["token"] = ""
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f51b636417bb..def4424af0ee 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
-from typing import TYPE_CHECKING, List, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
@@ -53,7 +53,7 @@
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
-def _is_valid_3pe_metadata(info):
+def _is_valid_3pe_metadata(info: JsonDict) -> bool:
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
@@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info):
return True
-def _is_valid_3pe_result(r, field):
+def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
if not isinstance(r, dict):
return False
@@ -93,9 +93,13 @@ def __init__(self, hs: "HomeServer"):
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
- async def query_user(self, service, user_id):
+ async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None:
return False
+
+ # This is required by the configuration.
+ assert service.hs_token is not None
+
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -109,9 +113,13 @@ async def query_user(self, service, user_id):
logger.warning("query_user to %s threw exception %s", uri, ex)
return False
- async def query_alias(self, service, alias):
+ async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
if service.url is None:
return False
+
+ # This is required by the configuration.
+ assert service.hs_token is not None
+
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -125,7 +133,13 @@ async def query_alias(self, service, alias):
logger.warning("query_alias to %s threw exception %s", uri, ex)
return False
- async def query_3pe(self, service, kind, protocol, fields):
+ async def query_3pe(
+ self,
+ service: "ApplicationService",
+ kind: str,
+ protocol: str,
+ fields: Dict[bytes, List[bytes]],
+ ) -> List[JsonDict]:
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
@@ -205,11 +219,14 @@ async def push_bulk(
events: List[EventBase],
ephemeral: List[JsonDict],
txn_id: Optional[int] = None,
- ):
+ ) -> bool:
if service.url is None:
return True
- events = self._serialize(service, events)
+ # This is required by the configuration.
+ assert service.hs_token is not None
+
+ serialized_events = self._serialize(service, events)
if txn_id is None:
logger.warning(
@@ -221,9 +238,12 @@ async def push_bulk(
# Never send ephemeral events to appservices that do not support it
if service.supports_ephemeral:
- body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
+ body = {
+ "events": serialized_events,
+ "de.sorunome.msc2409.ephemeral": ephemeral,
+ }
else:
- body = {"events": events}
+ body = {"events": serialized_events}
try:
await self.put_json(
@@ -238,7 +258,7 @@ async def push_bulk(
[event.get("event_id") for event in events],
)
sent_transactions_counter.labels(service.id).inc()
- sent_events_counter.labels(service.id).inc(len(events))
+ sent_events_counter.labels(service.id).inc(len(serialized_events))
return True
except CodeMessageException as e:
logger.warning(
@@ -260,7 +280,9 @@ async def push_bulk(
failed_transactions_counter.labels(service.id).inc()
return False
- def _serialize(self, service, events):
+ def _serialize(
+ self, service: "ApplicationService", events: Iterable[EventBase]
+ ) -> List[JsonDict]:
time_now = self.clock.time_msec()
return [
serialize_event(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6a2ce99b55dc..185e3a527815 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,13 +48,19 @@
components.
"""
import logging
-from typing import List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.appservice.api import ApplicationServiceApi
from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
+from synapse.util import Clock
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -72,7 +78,7 @@ class ApplicationServiceScheduler:
case is a simple array.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.as_api = hs.get_application_service_api()
@@ -80,7 +86,7 @@ def __init__(self, hs):
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
- async def start(self):
+ async def start(self) -> None:
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
@@ -91,12 +97,14 @@ async def start(self):
for service in services:
self.txn_ctrl.start_recoverer(service)
- def submit_event_for_as(self, service: ApplicationService, event: EventBase):
+ def submit_event_for_as(
+ self, service: ApplicationService, event: EventBase
+ ) -> None:
self.queuer.enqueue_event(service, event)
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[JsonDict]
- ):
+ ) -> None:
self.queuer.enqueue_ephemeral(service, events)
@@ -108,16 +116,18 @@ class _ServiceQueuer:
appservice at a given time.
"""
- def __init__(self, txn_ctrl, clock):
- self.queued_events = {} # dict of {service_id: [events]}
- self.queued_ephemeral = {} # dict of {service_id: [events]}
+ def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
+ # dict of {service_id: [events]}
+ self.queued_events: Dict[str, List[EventBase]] = {}
+ # dict of {service_id: [events]}
+ self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# the appservices which currently have a transaction in flight
- self.requests_in_flight = set()
+ self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
- def _start_background_request(self, service):
+ def _start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight:
return
@@ -126,15 +136,17 @@ def _start_background_request(self, service):
"as-sender-%s" % (service.id,), self._send_request, service
)
- def enqueue_event(self, service: ApplicationService, event: EventBase):
+ def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
self.queued_events.setdefault(service.id, []).append(event)
self._start_background_request(service)
- def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
+ def enqueue_ephemeral(
+ self, service: ApplicationService, events: List[JsonDict]
+ ) -> None:
self.queued_ephemeral.setdefault(service.id, []).extend(events)
self._start_background_request(service)
- async def _send_request(self, service: ApplicationService):
+ async def _send_request(self, service: ApplicationService) -> None:
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@@ -168,20 +180,15 @@ class _TransactionController:
if a transaction fails.
(Note we have only have one of these in the homeserver.)
-
- Args:
- clock (synapse.util.Clock):
- store (synapse.storage.DataStore):
- as_api (synapse.appservice.api.ApplicationServiceApi):
"""
- def __init__(self, clock, store, as_api):
+ def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
self.clock = clock
self.store = store
self.as_api = as_api
# map from service id to recoverer instance
- self.recoverers = {}
+ self.recoverers: Dict[str, "_Recoverer"] = {}
# for UTs
self.RECOVERER_CLASS = _Recoverer
@@ -191,7 +198,7 @@ async def send(
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
- ):
+ ) -> None:
try:
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral or []
@@ -207,7 +214,7 @@ async def send(
logger.exception("Error creating appservice transaction")
run_in_background(self._on_txn_fail, service)
- async def on_recovered(self, recoverer):
+ async def on_recovered(self, recoverer: "_Recoverer") -> None:
logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id
)
@@ -217,18 +224,18 @@ async def on_recovered(self, recoverer):
recoverer.service, ApplicationServiceState.UP
)
- async def _on_txn_fail(self, service):
+ async def _on_txn_fail(self, service: ApplicationService) -> None:
try:
await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
self.start_recoverer(service)
except Exception:
logger.exception("Error starting AS recoverer")
- def start_recoverer(self, service):
+ def start_recoverer(self, service: ApplicationService) -> None:
"""Start a Recoverer for the given service
Args:
- service (synapse.appservice.ApplicationService):
+ service:
"""
logger.info("Starting recoverer for AS ID %s", service.id)
assert service.id not in self.recoverers
@@ -257,7 +264,14 @@ class _Recoverer:
callback (callable[_Recoverer]): called once the service recovers.
"""
- def __init__(self, clock, store, as_api, service, callback):
+ def __init__(
+ self,
+ clock: Clock,
+ store: DataStore,
+ as_api: ApplicationServiceApi,
+ service: ApplicationService,
+ callback: Callable[["_Recoverer"], Awaitable[None]],
+ ):
self.clock = clock
self.store = store
self.as_api = as_api
@@ -265,8 +279,8 @@ def __init__(self, clock, store, as_api, service, callback):
self.callback = callback
self.backoff_counter = 1
- def recover(self):
- def _retry():
+ def recover(self) -> None:
+ def _retry() -> None:
run_as_background_process(
"as-recoverer-%s" % (self.service.id,), self.retry
)
@@ -275,13 +289,13 @@ def _retry():
logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
self.clock.call_later(delay, _retry)
- def _backoff(self):
+ def _backoff(self) -> None:
# cap the backoff to be around 8.5min => (2^9) = 512 secs
if self.backoff_counter < 9:
self.backoff_counter += 1
self.recover()
- async def retry(self):
+ async def retry(self) -> None:
logger.info("Starting retries on %s", self.service.id)
try:
while True:
diff --git a/synapse/config/api.py b/synapse/config/api.py
index b18044f9822a..8133b6b62402 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -29,6 +29,7 @@ class ApiConfig(Config):
def read_config(self, config: JsonDict, **kwargs):
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+ self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
def generate_config_section(cls, **kwargs) -> str:
formatted_default_state_types = "\n".join(
@@ -59,6 +60,21 @@ def generate_config_section(cls, **kwargs) -> str:
#
#additional_event_types:
# - org.example.custom.event.type
+
+ # We record the IP address of clients used to access the API for various
+ # reasons, including displaying it to the user in the "Where you're signed in"
+ # dialog.
+ #
+ # By default, when puppeting another user via the admin API, the client IP
+ # address is recorded against the user who created the access token (ie, the
+ # admin user), and *not* the puppeted user.
+ #
+ # Uncomment the following to also record the IP address against the puppeted
+ # user. (This also means that the puppeted user will count as an "active" user
+ # for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc
+ # above.)
+ #
+ #track_puppeted_user_ips: true
""" % {
"formatted_default_state_types": formatted_default_state_types
}
@@ -107,6 +123,8 @@ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
EventTypes.Name,
# Per MSC1772.
EventTypes.Create,
+ # Per MSC3173.
+ EventTypes.Topic,
]
@@ -136,5 +154,8 @@ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
"properties": {
"room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
"room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
+ "track_puppeted_user_ips": {
+ "type": "boolean",
+ },
},
}
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index e4bb7224a410..7fad2e0422b8 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -147,8 +147,7 @@ def _load_appservice(
# protocols check
protocols = as_info.get("protocols")
if protocols:
- # Because strings are lists in python
- if isinstance(protocols, str) or not isinstance(protocols, list):
+ if not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 7ea3c06af166..028bb3a97027 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -55,19 +55,19 @@
---------------------------------------------------------------------------------------"""
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EmailSubjectConfig:
- message_from_person_in_room = attr.ib(type=str)
- message_from_person = attr.ib(type=str)
- messages_from_person = attr.ib(type=str)
- messages_in_room = attr.ib(type=str)
- messages_in_room_and_others = attr.ib(type=str)
- messages_from_person_and_others = attr.ib(type=str)
- invite_from_person = attr.ib(type=str)
- invite_from_person_to_room = attr.ib(type=str)
- invite_from_person_to_space = attr.ib(type=str)
- password_reset = attr.ib(type=str)
- email_validation = attr.ib(type=str)
+ message_from_person_in_room: str
+ message_from_person: str
+ messages_from_person: str
+ messages_in_room: str
+ messages_in_room_and_others: str
+ messages_from_person_and_others: str
+ invite_from_person: str
+ invite_from_person_to_room: str
+ invite_from_person_to_space: str
+ password_reset: str
+ email_validation: str
class EmailConfig(Config):
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index d78a15097c87..dbaeb1091861 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,7 +32,7 @@ def read_config(self, config: JsonDict, **kwargs):
# MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
- # MSC2716 (backfill existing history)
+ # MSC2716 (importing historical messages)
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
# MSC2285 (hidden read receipts)
@@ -49,3 +49,8 @@ def read_config(self, config: JsonDict, **kwargs):
# MSC3030 (Jump to date API endpoint)
self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
+
+ # The portion of MSC3202 which is related to device masquerading.
+ self.msc3202_device_masquerading_enabled: bool = experimental.get(
+ "msc3202_device_masquerading", False
+ )
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 035ee2416bd6..ee83c6c06b7f 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,12 +16,14 @@
import hashlib
import logging
import os
-from typing import Any, Dict
+from typing import Any, Dict, Iterator, List, Optional
import attr
import jsonschema
from signedjson.key import (
NACL_ED25519,
+ SigningKey,
+ VerifyKey,
decode_signing_key_base64,
decode_verify_key_bytes,
generate_signing_key,
@@ -31,6 +33,7 @@
)
from unpaddedbase64 import decode_base64
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string, random_string_with_symbols
from ._base import Config, ConfigError
@@ -81,14 +84,13 @@
logger = logging.getLogger(__name__)
-@attr.s
+@attr.s(slots=True, auto_attribs=True)
class TrustedKeyServer:
- # string: name of the server.
- server_name = attr.ib()
+ # name of the server.
+ server_name: str
- # dict[str,VerifyKey]|None: map from key id to key object, or None to disable
- # signature verification.
- verify_keys = attr.ib(default=None)
+ # map from key id to key object, or None to disable signature verification.
+ verify_keys: Optional[Dict[str, VerifyKey]] = None
class KeyConfig(Config):
@@ -279,15 +281,15 @@ def generate_config_section(
% locals()
)
- def read_signing_keys(self, signing_key_path, name):
+ def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
"""Read the signing keys in the given path.
Args:
- signing_key_path (str)
- name (str): Associated config key name
+ signing_key_path
+ name: Associated config key name
Returns:
- list[SigningKey]
+ The signing keys read from the given path.
"""
signing_keys = self.read_file(signing_key_path, name)
@@ -296,7 +298,9 @@ def read_signing_keys(self, signing_key_path, name):
except Exception as e:
raise ConfigError("Error reading %s: %s" % (name, str(e)))
- def read_old_signing_keys(self, old_signing_keys):
+ def read_old_signing_keys(
+ self, old_signing_keys: Optional[JsonDict]
+ ) -> Dict[str, VerifyKey]:
if old_signing_keys is None:
return {}
keys = {}
@@ -340,7 +344,7 @@ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
write_signing_keys(signing_key_file, (key,))
-def _perspectives_to_key_servers(config):
+def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]:
"""Convert old-style 'perspectives' configs into new-style 'trusted_key_servers'
Returns an iterable of entries to add to trusted_key_servers.
@@ -402,7 +406,9 @@ def _perspectives_to_key_servers(config):
}
-def _parse_key_servers(key_servers, federation_verify_certificates):
+def _parse_key_servers(
+ key_servers: List[Any], federation_verify_certificates: bool
+) -> Iterator[TrustedKeyServer]:
try:
jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
except jsonschema.ValidationError as e:
@@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
yield result
-def _assert_keyserver_has_verify_keys(trusted_key_server):
+def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None:
if not trusted_key_server.verify_keys:
raise ConfigError(INSECURE_NOTARY_ERROR)
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 7ac82edb0ed1..1cc26e757812 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -22,10 +22,12 @@
@attr.s
class MetricsFlags:
- known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
+ known_servers: bool = attr.ib(
+ default=False, validator=attr.validators.instance_of(bool)
+ )
@classmethod
- def all_off(cls):
+ def all_off(cls) -> "MetricsFlags":
"""
Instantiate the flags with all options set to off.
"""
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index ae0821e5a504..85fb05890d7c 100644
--- a/synapse/config/modules.py
+++ b/synapse/config/modules.py
@@ -37,7 +37,7 @@ def generate_config_section(self, **kwargs):
# Server admins can expand Synapse's functionality with external modules.
#
- # See https://matrix-org.github.io/synapse/latest/modules.html for more
+ # See https://matrix-org.github.io/synapse/latest/modules/index.html for more
# documentation on how to configure or create custom modules for Synapse.
#
modules:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 79c400fe30b8..e783b1131501 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -148,10 +148,13 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
- # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+ # endpoint, or to rely on the data returned in the id_token from the
+ # token_endpoint.
#
- # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
- # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+ # Valid values are: 'auto' or 'userinfo_endpoint'.
+ #
+ # Defaults to 'auto', which uses the userinfo endpoint if 'openid' is
+ # not included in 'scopes'. Set to 'userinfo_endpoint' to always use the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index b129b9dd681c..1980351e7711 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -14,10 +14,11 @@
import logging
import os
-from collections import namedtuple
from typing import Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore
+import attr
+
from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
@@ -44,18 +45,20 @@
HTTP_PROXY_SET_WARNING = """\
The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
-ThumbnailRequirement = namedtuple(
- "ThumbnailRequirement", ["width", "height", "method", "media_type"]
-)
-MediaStorageProviderConfig = namedtuple(
- "MediaStorageProviderConfig",
- (
- "store_local", # Whether to store newly uploaded local files
- "store_remote", # Whether to store newly downloaded remote files
- "store_synchronous", # Whether to wait for successful storage for local uploads
- ),
-)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThumbnailRequirement:
+ width: int
+ height: int
+ method: str
+ media_type: str
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class MediaStorageProviderConfig:
+ store_local: bool # Whether to store newly uploaded local files
+ store_remote: bool # Whether to store newly downloaded remote files
+ store_synchronous: bool # Whether to wait for successful storage for local uploads
def parse_thumbnail_requirements(
@@ -66,11 +69,10 @@ def parse_thumbnail_requirements(
method, and thumbnail media type to precalculate
Args:
- thumbnail_sizes(list): List of dicts with "width", "height", and
- "method" keys
+ thumbnail_sizes: List of dicts with "width", "height", and "method" keys
+
Returns:
- Dictionary mapping from media type string to list of
- ThumbnailRequirement tuples.
+ Dictionary mapping from media type string to list of ThumbnailRequirement.
"""
requirements: Dict[str, List[ThumbnailRequirement]] = {}
for size in thumbnail_sizes:
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 57316c59b6a0..3c5e0f7ce73c 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -15,8 +15,9 @@
from typing import List
+from matrix_common.regex import glob_to_regex
+
from synapse.types import JsonDict
-from synapse.util import glob_to_regex
from ._base import Config, ConfigError
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ba5b95426338..f200d0c1f1cf 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -200,8 +200,8 @@ class HttpListenerConfig:
"""Object describing the http-specific parts of the config of a listener"""
x_forwarded: bool = False
- resources: List[HttpResourceConfig] = attr.ib(factory=list)
- additional_resources: Dict[str, dict] = attr.ib(factory=dict)
+ resources: List[HttpResourceConfig] = attr.Factory(list)
+ additional_resources: Dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None
@@ -259,7 +259,6 @@ def read_config(self, config, **kwargs):
raise ConfigError(str(e))
self.pid_file = self.abspath(config.get("pid_file"))
- self.web_client_location = config.get("web_client_location", None)
self.soft_file_limit = config.get("soft_file_limit", 0)
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
@@ -506,8 +505,17 @@ def read_config(self, config, **kwargs):
l2.append(listener)
self.listeners = l2
- if not self.web_client_location:
- _warn_if_webclient_configured(self.listeners)
+ self.web_client_location = config.get("web_client_location", None)
+ self.web_client_location_is_redirect = self.web_client_location and (
+ self.web_client_location.startswith("http://")
+ or self.web_client_location.startswith("https://")
+ )
+ # A non-HTTP(S) web client location is deprecated.
+ if self.web_client_location and not self.web_client_location_is_redirect:
+ logger.warning(NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING)
+
+ # Warn if webclient is configured for a worker.
+ _warn_if_webclient_configured(self.listeners)
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
@@ -793,13 +801,7 @@ def generate_config_section(
#
pid_file: %(pid_file)s
- # The absolute URL to the web client which /_matrix/client will redirect
- # to if 'webclient' is configured under the 'listeners' configuration.
- #
- # This option can be also set to the filesystem path to the web client
- # which will be served at /_matrix/client/ if 'webclient' is configured
- # under the 'listeners' configuration, however this is a security risk:
- # https://github.com/matrix-org/synapse#security-note
+ # The absolute URL to the web client which / will redirect to.
#
#web_client_location: https://riot.example.com/
@@ -883,7 +885,7 @@ def generate_config_section(
# The default room version for newly created rooms.
#
# Known room versions are listed here:
- # https://matrix.org/docs/spec/#complete-list-of-room-versions
+ # https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
@@ -1011,8 +1013,6 @@ def generate_config_section(
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
#
- # webclient: A web client. Requires web_client_location to be set.
- #
listeners:
# TLS-enabled listener: for when matrix traffic is sent directly to synapse.
#
@@ -1257,7 +1257,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> None:
help="Turn on the twisted telnet manhole service on the given port.",
)
- def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
+ def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]:
"""Reads the three durations for the GC min interval option, returning seconds."""
if durations is None:
return None
@@ -1349,9 +1349,15 @@ def parse_listener_def(listener: Any) -> ListenerConfig:
return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING = """
+Synapse no longer supports serving a web client. To remove this warning,
+configure 'web_client_location' with an HTTP(S) URL.
+"""
+
+
NO_MORE_WEB_CLIENT_WARNING = """
-Synapse no longer includes a web client. To enable a web client, configure
-web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
+Synapse no longer includes a web client. To redirect the root resource to a web client, configure
+'web_client_location'. To remove this warning, remove 'webclient' from the 'listeners'
configuration.
"""
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4ca111618fe9..6e673d65a711 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -16,11 +16,12 @@
import os
from typing import List, Optional, Pattern
+from matrix_common.regex import glob_to_regex
+
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
from synapse.config._base import Config, ConfigError
-from synapse.util import glob_to_regex
logger = logging.getLogger(__name__)
@@ -132,7 +133,7 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs):
self.tls_certificate: Optional[crypto.X509] = None
self.tls_private_key: Optional[crypto.PKey] = None
- def read_certificate_from_disk(self):
+ def read_certificate_from_disk(self) -> None:
"""
Read the certificates and private key from disk.
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 576f519188bb..bdaba6db3787 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -51,12 +51,12 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj
-@attr.s
+@attr.s(auto_attribs=True)
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication."""
- host = attr.ib(type=str)
- port = attr.ib(type=int)
+ host: str
+ port: int
@attr.s
@@ -77,34 +77,28 @@ class WriterLocations:
can only be a single instance.
"""
- events = attr.ib(
+ events: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- typing = attr.ib(
+ typing: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- to_device = attr.ib(
+ to_device: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- account_data = attr.ib(
+ account_data: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- receipts = attr.ib(
+ receipts: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- presence = attr.ib(
+ presence: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 993b04099e28..72d4a69aac35 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -58,7 +58,7 @@
logger = logging.getLogger(__name__)
-@attr.s(slots=True, cmp=False)
+@attr.s(slots=True, frozen=True, cmp=False, auto_attribs=True)
class VerifyJsonRequest:
"""
A request to verify a JSON object.
@@ -78,10 +78,10 @@ class VerifyJsonRequest:
key_ids: The set of key_ids to that could be used to verify the JSON object
"""
- server_name = attr.ib(type=str)
- get_json_object = attr.ib(type=Callable[[], JsonDict])
- minimum_valid_until_ts = attr.ib(type=int)
- key_ids = attr.ib(type=List[str])
+ server_name: str
+ get_json_object: Callable[[], JsonDict]
+ minimum_valid_until_ts: int
+ key_ids: List[str]
@staticmethod
def from_json_object(
@@ -124,7 +124,7 @@ class KeyLookupError(ValueError):
pass
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _FetchKeyRequest:
"""A request for keys for a given server.
@@ -138,9 +138,9 @@ class _FetchKeyRequest:
key_ids: The IDs of the keys to attempt to fetch
"""
- server_name = attr.ib(type=str)
- minimum_valid_until_ts = attr.ib(type=int)
- key_ids = attr.ib(type=List[str])
+ server_name: str
+ minimum_valid_until_ts: int
+ key_ids: List[str]
class Keyring:
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f251402ed8f2..0eab1aefd637 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -28,7 +28,7 @@
from synapse.storage.databases.main import DataStore
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class EventContext:
"""
Holds information relevant to persisting an event
@@ -103,15 +103,15 @@ class EventContext:
accessed via get_prev_state_ids.
"""
- rejected = attr.ib(default=False, type=Union[bool, str])
- _state_group = attr.ib(default=None, type=Optional[int])
- state_group_before_event = attr.ib(default=None, type=Optional[int])
- prev_group = attr.ib(default=None, type=Optional[int])
- delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
- app_service = attr.ib(default=None, type=Optional[ApplicationService])
+ rejected: Union[bool, str] = False
+ _state_group: Optional[int] = None
+ state_group_before_event: Optional[int] = None
+ prev_group: Optional[int] = None
+ delta_ids: Optional[StateMap[str]] = None
+ app_service: Optional[ApplicationService] = None
- _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
- _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+ _current_state_ids: Optional[StateMap[str]] = None
+ _prev_state_ids: Optional[StateMap[str]] = None
@staticmethod
def with_state(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 84ef69df679b..918adeecf8cd 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,17 +14,7 @@
# limitations under the License.
import collections.abc
import re
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Union,
-)
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from frozendict import frozendict
@@ -32,14 +22,10 @@
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
-from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze
from . import EventBase
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (? JsonDict:
"""Serializes a single event.
@@ -418,63 +399,41 @@ async def serialize_event(
serialized_event = serialize_event(event, time_now, **kwargs)
# Check if there are any bundled aggregations to include with the event.
- #
- # Do not bundle aggregations if any of the following at true:
- #
- # * Support is disabled via the configuration or the caller.
- # * The event is a state event.
- # * The event has been redacted.
- if (
- self._msc1849_enabled
- and bundle_aggregations
- and not event.is_state()
- and not event.internal_metadata.is_redacted()
- ):
- await self._injected_bundled_aggregations(event, time_now, serialized_event)
+ if bundle_aggregations:
+ event_aggregations = bundle_aggregations.get(event.event_id)
+ if event_aggregations:
+ self._inject_bundled_aggregations(
+ event,
+ time_now,
+ bundle_aggregations[event.event_id],
+ serialized_event,
+ )
return serialized_event
- async def _injected_bundled_aggregations(
- self, event: EventBase, time_now: int, serialized_event: JsonDict
+ def _inject_bundled_aggregations(
+ self,
+ event: EventBase,
+ time_now: int,
+ aggregations: JsonDict,
+ serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Args:
event: The event being serialized.
time_now: The current time in milliseconds
+ aggregations: The bundled aggregation to serialize.
serialized_event: The serialized event which may be modified.
"""
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return
-
- event_id = event.event_id
-
- # The bundled aggregations to include.
- aggregations = {}
+ # Make a copy in-case the object is cached.
+ aggregations = aggregations.copy()
- annotations = await self.store.get_aggregation_groups_for_event(event_id)
- if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
-
- references = await self.store.get_relations_for_event(
- event_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
-
- edit = None
- if event.type == EventTypes.Message:
- edit = await self.store.get_applicable_edit(event_id)
-
- if edit:
+ if RelationTypes.REPLACE in aggregations:
# If there is an edit replace the content, preserving existing
# relations.
+ edit = aggregations[RelationTypes.REPLACE]
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
@@ -499,27 +458,19 @@ async def _injected_bundled_aggregations(
}
# If this event is the start of a thread, include a summary of the replies.
- if self._msc3440_enabled:
- (
- thread_count,
- latest_thread_event,
- ) = await self.store.get_thread_summary(event_id)
- if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- # Don't bundle aggregations as this could recurse forever.
- "latest_event": await self.serialize_event(
- latest_thread_event, time_now, bundle_aggregations=False
- ),
- "count": thread_count,
- }
-
- # If any bundled aggregations were found, include them.
- if aggregations:
- serialized_event["unsigned"].setdefault("m.relations", {}).update(
- aggregations
+ if RelationTypes.THREAD in aggregations:
+ # Serialize the latest thread event.
+ latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
+
+ # Don't bundle aggregations as this could recurse forever.
+ aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
+ latest_thread_event, time_now, bundle_aggregations=None
)
- async def serialize_events(
+ # Include the bundled aggregations in the event.
+ serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+
+ def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
"""Serializes multiple events.
@@ -532,9 +483,9 @@ async def serialize_events(
Returns:
The list of serialized events
"""
- return await yieldable_gather_results(
- self.serialize_event, events, time_now=time_now, **kwargs
- )
+ return [
+ self.serialize_event(event, time_now=time_now, **kwargs) for event in events
+ ]
def copy_power_levels_contents(
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f56344a3b94f..896168c05c0a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from collections import namedtuple
from typing import TYPE_CHECKING
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
@@ -104,10 +103,6 @@ async def _check_sigs_and_hash(
return pdu
-class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
- pass
-
-
async def _check_sigs_on_pdu(
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
) -> None:
@@ -220,15 +215,12 @@ def _is_invite_via_3pid(event: EventBase) -> bool:
)
-def event_from_pdu_json(
- pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False
-) -> EventBase:
+def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase:
"""Construct an EventBase from an event json received over federation
Args:
pdu_json: pdu as received over federation
room_version: The version of the room this event belongs to
- outlier: True to mark this event as an outlier
Raises:
SynapseError: if the pdu is missing required fields or is otherwise
@@ -238,6 +230,10 @@ def event_from_pdu_json(
# origin, etc etc)
assert_params_in_dict(pdu_json, ("type", "depth"))
+ # Strip any unauthorized values from "unsigned" if they exist
+ if "unsigned" in pdu_json:
+ _strip_unsigned_values(pdu_json)
+
depth = pdu_json["depth"]
if not isinstance(depth, int):
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
@@ -252,6 +248,25 @@ def event_from_pdu_json(
validate_canonicaljson(pdu_json)
event = make_event_from_dict(pdu_json, room_version)
- event.internal_metadata.outlier = outlier
-
return event
+
+
+def _strip_unsigned_values(pdu_dict: JsonDict) -> None:
+ """
+ Strip any unsigned values unless specifically allowed, as defined by the whitelist.
+
+ pdu: the json dict to strip values from. Note that the dict is mutated by this
+ function
+ """
+ unsigned = pdu_dict["unsigned"]
+
+ if not isinstance(unsigned, dict):
+ pdu_dict["unsigned"] = {}
+
+ if pdu_dict["type"] == "m.room.member":
+ whitelist = ["knock_room_state", "invite_room_state", "age"]
+ else:
+ whitelist = ["age"]
+
+ filtered_unsigned = {k: v for k, v in unsigned.items() if k in whitelist}
+ pdu_dict["unsigned"] = filtered_unsigned
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index fee1477ab684..74f17aa4daa3 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,7 +56,6 @@
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -119,7 +118,8 @@ def __init__(self, hs: "HomeServer"):
# It is a map of (room ID, suggested-only) -> the response of
# get_room_hierarchy.
self._get_room_hierarchy_cache: ExpiringCache[
- Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
+ Tuple[str, bool],
+ Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]],
] = ExpiringCache(
cache_name="get_room_hierarchy_cache",
clock=self._clock,
@@ -144,7 +144,6 @@ def _clear_tried_cache(self) -> None:
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
- @log_function
async def make_query(
self,
destination: str,
@@ -178,7 +177,6 @@ async def make_query(
ignore_backoff=ignore_backoff,
)
- @log_function
async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
@@ -196,7 +194,6 @@ async def query_client_keys(
destination, content, timeout
)
- @log_function
async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
@@ -208,7 +205,6 @@ async def query_user_devices(
destination, user_id, timeout
)
- @log_function
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
@@ -265,14 +261,11 @@ async def backfill(
room_version = await self.store.get_room_version(room_id)
- pdus = [
- event_from_pdu_json(p, room_version, outlier=False)
- for p in transaction_data_pdus
- ]
+ pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus]
# Check signatures and hash of pdus, removing any from the list that fail checks
pdus[:] = await self._check_sigs_and_hash_and_fetch(
- dest, pdus, outlier=True, room_version=room_version
+ dest, pdus, room_version=room_version
)
return pdus
@@ -282,7 +275,6 @@ async def get_pdu_from_destination_raw(
destination: str,
event_id: str,
room_version: RoomVersion,
- outlier: bool = False,
timeout: Optional[int] = None,
) -> Optional[EventBase]:
"""Requests the PDU with given origin and ID from the remote home
@@ -292,9 +284,6 @@ async def get_pdu_from_destination_raw(
destination: Which homeserver to query
event_id: event to fetch
room_version: version of the room
- outlier: Indicates whether the PDU is an `outlier`, i.e. if
- it's from an arbitrary point in the context as opposed to part
- of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@@ -316,8 +305,7 @@ async def get_pdu_from_destination_raw(
)
pdu_list: List[EventBase] = [
- event_from_pdu_json(p, room_version, outlier=outlier)
- for p in transaction_data["pdus"]
+ event_from_pdu_json(p, room_version) for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]:
@@ -334,7 +322,6 @@ async def get_pdu(
destinations: Iterable[str],
event_id: str,
room_version: RoomVersion,
- outlier: bool = False,
timeout: Optional[int] = None,
) -> Optional[EventBase]:
"""Requests the PDU with given origin and ID from the remote home
@@ -347,9 +334,6 @@ async def get_pdu(
destinations: Which homeservers to query
event_id: event to fetch
room_version: version of the room
- outlier: Indicates whether the PDU is an `outlier`, i.e. if
- it's from an arbitrary point in the context as opposed to part
- of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@@ -377,7 +361,6 @@ async def get_pdu(
destination=destination,
event_id=event_id,
room_version=room_version,
- outlier=outlier,
timeout=timeout,
)
@@ -435,7 +418,6 @@ async def _check_sigs_and_hash_and_fetch(
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
- outlier: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
@@ -451,7 +433,6 @@ async def _check_sigs_and_hash_and_fetch(
origin
pdu
room_version
- outlier: Whether the events are outliers or not
Returns:
A list of PDUs that have valid signatures and hashes.
@@ -466,7 +447,6 @@ async def _execute(pdu: EventBase) -> None:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=pdu,
origin=origin,
- outlier=outlier,
room_version=room_version,
)
@@ -482,7 +462,6 @@ async def _check_sigs_and_hash_and_fetch_one(
pdu: EventBase,
origin: str,
room_version: RoomVersion,
- outlier: bool = False,
) -> Optional[EventBase]:
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
its signature check then we check if we have it in the database and if
@@ -494,9 +473,6 @@ async def _check_sigs_and_hash_and_fetch_one(
origin
pdu
room_version
- outlier: Whether the events are outliers or not
- include_none: Whether to include None in the returned list
- for events that have failed their checks
Returns:
The PDU (possibly redacted) if it has valid signatures and hashes.
@@ -521,7 +497,6 @@ async def _check_sigs_and_hash_and_fetch_one(
destinations=[pdu_origin],
event_id=pdu.event_id,
room_version=room_version,
- outlier=outlier,
timeout=10000,
)
except SynapseError:
@@ -541,13 +516,10 @@ async def get_event_auth(
room_version = await self.store.get_room_version(room_id)
- auth_chain = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in res["auth_chain"]
- ]
+ auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]]
signed_auth = await self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True, room_version=room_version
+ destination, auth_chain, room_version=room_version
)
return signed_auth
@@ -816,7 +788,6 @@ async def send_request(destination: str) -> SendJoinResult:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=event,
origin=destination,
- outlier=True,
room_version=room_version,
)
@@ -864,7 +835,6 @@ async def _execute(pdu: EventBase) -> None:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=pdu,
origin=destination,
- outlier=True,
room_version=room_version,
)
@@ -1235,7 +1205,7 @@ async def get_missing_events(
]
signed_events = await self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=False, room_version=room_version
+ destination, events, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
@@ -1364,7 +1334,7 @@ async def get_room_hierarchy(
destinations: Iterable[str],
room_id: str,
suggested_only: bool,
- ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
+ ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
"""
Call other servers to get a hierarchy of the given room.
@@ -1379,7 +1349,8 @@ async def get_room_hierarchy(
Returns:
A tuple of:
- The room as a JSON dictionary.
+ The room as a JSON dictionary, without a "children_state" key.
+ A list of `m.space.child` state events.
A list of children rooms, as JSON dictionaries.
A list of inaccessible children room IDs.
@@ -1394,7 +1365,7 @@ async def get_room_hierarchy(
async def send_request(
destination: str,
- ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
+ ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
try:
res = await self.transport_layer.get_room_hierarchy(
destination=destination,
@@ -1423,7 +1394,7 @@ async def send_request(
raise InvalidResponseError("'room' must be a dict")
# Validate children_state of the room.
- children_state = room.get("children_state", [])
+ children_state = room.pop("children_state", [])
if not isinstance(children_state, Sequence):
raise InvalidResponseError("'room.children_state' must be a list")
if any(not isinstance(e, dict) for e in children_state):
@@ -1452,7 +1423,7 @@ async def send_request(
"Invalid room ID in 'inaccessible_children' list"
)
- return room, children, inaccessible_children
+ return room, children_state, children, inaccessible_children
try:
result = await self._try_destination_list(
@@ -1500,8 +1471,6 @@ async def send_request(
if event.room_id == room_id:
children_events.append(event.data)
children_room_ids.add(event.state_key)
- # And add them under the requested room.
- requested_room["children_state"] = children_events
# Find the children rooms.
children = []
@@ -1511,7 +1480,7 @@ async def send_request(
# It isn't clear from the response whether some of the rooms are
# not accessible.
- result = (requested_room, children, ())
+ result = (requested_room, children_events, children, ())
# Cache the result to avoid fetching data over federation every time.
self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8e37e76206ac..af9cb98f67b1 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -28,9 +28,9 @@
Union,
)
+from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram
-from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@@ -58,7 +58,6 @@
run_in_background,
)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
-from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -66,8 +65,8 @@
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util import json_decoder, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name
@@ -360,13 +359,13 @@ async def _handle_incoming_transaction(
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
pdu_results, _ = await make_deferred_yieldable(
- defer.gatherResults(
- [
+ gather_results(
+ (
run_in_background(
self._handle_pdus_in_txn, origin, transaction, request_time
),
run_in_background(self._handle_edus_in_txn, origin, transaction),
- ],
+ ),
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -859,7 +858,6 @@ async def on_event_auth(
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res
- @log_function
async def on_query_client_keys(
self, origin: str, content: Dict[str, str]
) -> Tuple[int, Dict[str, Any]]:
@@ -940,7 +938,6 @@ async def on_get_missing_events(
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
- @log_function
async def on_openid_userinfo(self, token: str) -> Optional[str]:
ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 523ab1c51ed1..60e2e6cf019f 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -23,7 +23,6 @@
from typing import Optional, Tuple
from synapse.federation.units import Transaction
-from synapse.logging.utils import log_function
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
@@ -36,7 +35,6 @@ class TransactionActions:
def __init__(self, datastore: DataStore):
self.store = datastore
- @log_function
async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[Tuple[int, JsonDict]]:
@@ -53,7 +51,6 @@ async def have_responded(
return await self.store.get_received_txn_response(transaction_id, origin)
- @log_function
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 63289a5a334f..0d7c4f506758 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -30,7 +30,6 @@
"""
import logging
-from collections import namedtuple
from typing import (
TYPE_CHECKING,
Dict,
@@ -43,6 +42,7 @@
Type,
)
+import attr
from sortedcontainers import SortedDict
from synapse.api.presence import UserPresenceState
@@ -382,13 +382,11 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
raise NotImplementedError()
-class PresenceDestinationsRow(
- BaseFederationRow,
- namedtuple(
- "PresenceDestinationsRow",
- ("state", "destinations"), # UserPresenceState # list[str]
- ),
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PresenceDestinationsRow(BaseFederationRow):
+ state: UserPresenceState
+ destinations: List[str]
+
TypeId = "pd"
@staticmethod
@@ -404,17 +402,15 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.presence_destinations.append((self.state, self.destinations))
-class KeyedEduRow(
- BaseFederationRow,
- namedtuple(
- "KeyedEduRow",
- ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
- ),
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class KeyedEduRow(BaseFederationRow):
"""Streams EDUs that have an associated key that is ued to clobber. For example,
typing EDUs clobber based on room_id.
"""
+ key: Tuple[str, ...] # the edu key passed to send_edu
+ edu: Edu
+
TypeId = "k"
@staticmethod
@@ -428,9 +424,12 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
-class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EduRow(BaseFederationRow):
"""Streams EDUs that don't have keys. See KeyedEduRow"""
+ edu: Edu
+
TypeId = "e"
@staticmethod
@@ -453,14 +452,14 @@ def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
-ParsedFederationStreamData = namedtuple(
- "ParsedFederationStreamData",
- (
- "presence_destinations", # list of tuples of UserPresenceState and destinations
- "keyed_edus", # dict of destination -> { key -> Edu }
- "edus", # dict of destination -> [Edu]
- ),
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ParsedFederationStreamData:
+ # list of tuples of UserPresenceState and destinations
+ presence_destinations: List[Tuple[UserPresenceState, List[str]]]
+ # dict of destination -> { key -> Edu }
+ keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]]
+ # dict of destination -> [Edu]
+ edus: Dict[str, List[Edu]]
def process_rows_for_federation(
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 391b30fbb559..8152e80b88d2 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -607,18 +607,18 @@ def _start_catching_up(self) -> None:
self._pending_pdus = []
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _TransactionQueueManager:
"""A helper async context manager for pulling stuff off the queues and
tracking what was last successfully sent, etc.
"""
- queue = attr.ib(type=PerDestinationQueue)
+ queue: PerDestinationQueue
- _device_stream_id = attr.ib(type=Optional[int], default=None)
- _device_list_id = attr.ib(type=Optional[int], default=None)
- _last_stream_ordering = attr.ib(type=Optional[int], default=None)
- _pdus = attr.ib(type=List[EventBase], factory=list)
+ _device_stream_id: Optional[int] = None
+ _device_list_id: Optional[int] = None
+ _last_stream_ordering: Optional[int] = None
+ _pdus: List[EventBase] = attr.Factory(list)
async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
# First we calculate the EDUs we want to send, if any.
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index ab935e5a7eda..742ee572558d 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -35,6 +35,7 @@
import synapse.server
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
last_pdu_ts_metric = Gauge(
"synapse_federation_last_sent_pdu_time",
@@ -124,6 +125,17 @@ async def send_new_transaction(
len(pdus),
len(edus),
)
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+ device_list_updates = [
+ edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS
+ ]
+ if device_list_updates:
+ issue_8631_logger.debug(
+ "about to send txn [%s] including device list updates: %s",
+ transaction.transaction_id,
+ device_list_updates,
+ )
# Actually send the transaction
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9fc4c31c93f6..8782586cd6b4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -44,7 +44,6 @@
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
-from synapse.logging.utils import log_function
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -62,7 +61,6 @@ def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
- @log_function
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
@@ -88,7 +86,6 @@ async def get_room_state_ids(
try_trailing_slash_on_400=True,
)
- @log_function
async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
@@ -111,7 +108,6 @@ async def get_event(
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
- @log_function
async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]:
@@ -149,7 +145,6 @@ async def backfill(
destination, path=path, args=args, try_trailing_slash_on_400=True
)
- @log_function
async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: str
) -> Union[JsonDict, List]:
@@ -185,7 +180,6 @@ async def timestamp_to_event(
return remote_response
- @log_function
async def send_transaction(
self,
transaction: Transaction,
@@ -234,7 +228,6 @@ async def send_transaction(
try_trailing_slash_on_400=True,
)
- @log_function
async def make_query(
self,
destination: str,
@@ -254,7 +247,6 @@ async def make_query(
ignore_backoff=ignore_backoff,
)
- @log_function
async def make_membership_event(
self,
destination: str,
@@ -317,7 +309,6 @@ async def make_membership_event(
ignore_backoff=ignore_backoff,
)
- @log_function
async def send_join_v1(
self,
room_version: RoomVersion,
@@ -336,7 +327,6 @@ async def send_join_v1(
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- @log_function
async def send_join_v2(
self,
room_version: RoomVersion,
@@ -355,7 +345,6 @@ async def send_join_v2(
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- @log_function
async def send_leave_v1(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> Tuple[int, JsonDict]:
@@ -372,7 +361,6 @@ async def send_leave_v1(
ignore_backoff=True,
)
- @log_function
async def send_leave_v2(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> JsonDict:
@@ -389,7 +377,6 @@ async def send_leave_v2(
ignore_backoff=True,
)
- @log_function
async def send_knock_v1(
self,
destination: str,
@@ -423,7 +410,6 @@ async def send_knock_v1(
destination=destination, path=path, data=content
)
- @log_function
async def send_invite_v1(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> Tuple[int, JsonDict]:
@@ -433,7 +419,6 @@ async def send_invite_v1(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def send_invite_v2(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> JsonDict:
@@ -443,7 +428,6 @@ async def send_invite_v2(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def get_public_rooms(
self,
remote_server: str,
@@ -516,7 +500,6 @@ async def get_public_rooms(
return response
- @log_function
async def exchange_third_party_invite(
self, destination: str, room_id: str, event_dict: JsonDict
) -> JsonDict:
@@ -526,7 +509,6 @@ async def exchange_third_party_invite(
destination=destination, path=path, data=event_dict
)
- @log_function
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
@@ -534,7 +516,6 @@ async def get_event_auth(
return await self.client.get_json(destination=destination, path=path)
- @log_function
async def query_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
) -> JsonDict:
@@ -576,7 +557,6 @@ async def query_client_keys(
destination=destination, path=path, data=query_content, timeout=timeout
)
- @log_function
async def query_user_devices(
self, destination: str, user_id: str, timeout: int
) -> JsonDict:
@@ -616,7 +596,6 @@ async def query_user_devices(
destination=destination, path=path, timeout=timeout
)
- @log_function
async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
) -> JsonDict:
@@ -655,7 +634,6 @@ async def claim_client_keys(
destination=destination, path=path, data=query_content, timeout=timeout
)
- @log_function
async def get_missing_events(
self,
destination: str,
@@ -680,7 +658,6 @@ async def get_missing_events(
timeout=timeout,
)
- @log_function
async def get_group_profile(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -694,7 +671,6 @@ async def get_group_profile(
ignore_backoff=True,
)
- @log_function
async def update_group_profile(
self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
@@ -716,7 +692,6 @@ async def update_group_profile(
ignore_backoff=True,
)
- @log_function
async def get_group_summary(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -730,7 +705,6 @@ async def get_group_summary(
ignore_backoff=True,
)
- @log_function
async def get_rooms_in_group(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -798,7 +772,6 @@ async def remove_room_from_group(
ignore_backoff=True,
)
- @log_function
async def get_users_in_group(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -812,7 +785,6 @@ async def get_users_in_group(
ignore_backoff=True,
)
- @log_function
async def get_invited_users_in_group(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -826,7 +798,6 @@ async def get_invited_users_in_group(
ignore_backoff=True,
)
- @log_function
async def accept_group_invite(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -837,7 +808,6 @@ async def accept_group_invite(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
def join_group(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> Awaitable[JsonDict]:
@@ -848,7 +818,6 @@ def join_group(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def invite_to_group(
self,
destination: str,
@@ -868,7 +837,6 @@ async def invite_to_group(
ignore_backoff=True,
)
- @log_function
async def invite_to_group_notification(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -882,7 +850,6 @@ async def invite_to_group_notification(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def remove_user_from_group(
self,
destination: str,
@@ -902,7 +869,6 @@ async def remove_user_from_group(
ignore_backoff=True,
)
- @log_function
async def remove_user_from_group_notification(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -916,7 +882,6 @@ async def remove_user_from_group_notification(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def renew_group_attestation(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -930,7 +895,6 @@ async def renew_group_attestation(
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def update_group_summary_room(
self,
destination: str,
@@ -959,7 +923,6 @@ async def update_group_summary_room(
ignore_backoff=True,
)
- @log_function
async def delete_group_summary_room(
self,
destination: str,
@@ -986,7 +949,6 @@ async def delete_group_summary_room(
ignore_backoff=True,
)
- @log_function
async def get_group_categories(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -1000,7 +962,6 @@ async def get_group_categories(
ignore_backoff=True,
)
- @log_function
async def get_group_category(
self, destination: str, group_id: str, requester_user_id: str, category_id: str
) -> JsonDict:
@@ -1014,7 +975,6 @@ async def get_group_category(
ignore_backoff=True,
)
- @log_function
async def update_group_category(
self,
destination: str,
@@ -1034,7 +994,6 @@ async def update_group_category(
ignore_backoff=True,
)
- @log_function
async def delete_group_category(
self, destination: str, group_id: str, requester_user_id: str, category_id: str
) -> JsonDict:
@@ -1048,7 +1007,6 @@ async def delete_group_category(
ignore_backoff=True,
)
- @log_function
async def get_group_roles(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -1062,7 +1020,6 @@ async def get_group_roles(
ignore_backoff=True,
)
- @log_function
async def get_group_role(
self, destination: str, group_id: str, requester_user_id: str, role_id: str
) -> JsonDict:
@@ -1076,7 +1033,6 @@ async def get_group_role(
ignore_backoff=True,
)
- @log_function
async def update_group_role(
self,
destination: str,
@@ -1096,7 +1052,6 @@ async def update_group_role(
ignore_backoff=True,
)
- @log_function
async def delete_group_role(
self, destination: str, group_id: str, requester_user_id: str, role_id: str
) -> JsonDict:
@@ -1110,7 +1065,6 @@ async def delete_group_role(
ignore_backoff=True,
)
- @log_function
async def update_group_summary_user(
self,
destination: str,
@@ -1136,7 +1090,6 @@ async def update_group_summary_user(
ignore_backoff=True,
)
- @log_function
async def set_group_join_policy(
self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
@@ -1151,7 +1104,6 @@ async def set_group_join_policy(
ignore_backoff=True,
)
- @log_function
async def delete_group_summary_user(
self,
destination: str,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index dc39e3537bf6..da1fbf8b6361 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -22,13 +22,11 @@
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging import opentracing
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
- SynapseTags,
- start_active_span,
- start_active_span_from_request,
- tags,
+ set_tag,
+ span_context_from_request,
+ start_active_span_follows_from,
whitelisted_homeserver,
)
from synapse.server import HomeServer
@@ -279,30 +277,19 @@ async def new_func(
logger.warning("authenticate_request failed: %s", e)
raise
- request_tags = {
- SynapseTags.REQUEST_ID: request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientIP(),
- "authenticated_entity": origin,
- "servlet_name": request.request_metrics.name,
- }
-
- # Only accept the span context if the origin is authenticated
- # and whitelisted
+ # update the active opentracing span with the authenticated entity
+ set_tag("authenticated_entity", origin)
+
+ # if the origin is authenticated and whitelisted, link to its span context
+ context = None
if origin and whitelisted_homeserver(origin):
- scope = start_active_span_from_request(
- request, "incoming-federation-request", tags=request_tags
- )
- else:
- scope = start_active_span(
- "incoming-federation-request", tags=request_tags
- )
+ context = span_context_from_request(request)
- with scope:
- opentracing.inject_response_headers(request.responseHeaders)
+ scope = start_active_span_follows_from(
+ "incoming-federation-request", contexts=(context,) if context else ()
+ )
+ with scope:
if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d:
await d
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 77bfd88ad052..beadfa422ba3 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -36,6 +36,7 @@
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
class BaseFederationServerServlet(BaseFederationServlet):
@@ -95,6 +96,20 @@ async def on_PUT(
len(transaction_data.get("edus", [])),
)
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+ device_list_updates = [
+ edu.content
+ for edu in transaction_data.get("edus", [])
+ if edu.edu_type in DEVICE_UPDATE_EDUS
+ ]
+ if device_list_updates:
+ issue_8631_logger.debug(
+ "received transaction [%s] including device list updates: %s",
+ transaction_id,
+ device_list_updates,
+ )
+
except Exception as e:
logger.exception(e)
return 400, {"error": "Invalid transaction"}
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 96273e2f81c3..bad48713bcb1 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -77,7 +77,7 @@ async def add_account_data_to_room(
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
- """Add some account_data to a room for a user.
+ """Add some global account_data for a user.
Args:
user_id: The user to add a tag for.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 85157a138b71..00ab5e79bf2e 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -55,21 +55,47 @@ async def get_whois(self, user: UserID) -> JsonDict:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
- ret = await self.store.get_user_by_id(user.to_string())
- if ret:
- profile = await self.store.get_profileinfo(user.localpart)
- threepids = await self.store.user_get_threepids(user.to_string())
- external_ids = [
- ({"auth_provider": auth_provider, "external_id": external_id})
- for auth_provider, external_id in await self.store.get_external_ids_by_user(
- user.to_string()
- )
- ]
- ret["displayname"] = profile.display_name
- ret["avatar_url"] = profile.avatar_url
- ret["threepids"] = threepids
- ret["external_ids"] = external_ids
- return ret
+ user_info_dict = await self.store.get_user_by_id(user.to_string())
+ if user_info_dict is None:
+ return None
+
+ # Restrict returned information to a known set of fields. This prevents additional
+ # fields added to get_user_by_id from modifying Synapse's external API surface.
+ user_info_to_return = {
+ "name",
+ "admin",
+ "deactivated",
+ "shadow_banned",
+ "creation_ts",
+ "appservice_id",
+ "consent_server_notice_sent",
+ "consent_version",
+ "user_type",
+ "is_guest",
+ }
+
+ # Restrict returned keys to a known set.
+ user_info_dict = {
+ key: value
+ for key, value in user_info_dict.items()
+ if key in user_info_to_return
+ }
+
+ # Add additional user metadata
+ profile = await self.store.get_profileinfo(user.localpart)
+ threepids = await self.store.user_get_threepids(user.to_string())
+ external_ids = [
+ ({"auth_provider": auth_provider, "external_id": external_id})
+ for auth_provider, external_id in await self.store.get_external_ids_by_user(
+ user.to_string()
+ )
+ ]
+ user_info_dict["displayname"] = profile.display_name
+ user_info_dict["avatar_url"] = profile.avatar_url
+ user_info_dict["threepids"] = threepids
+ user_info_dict["external_ids"] = external_ids
+
+ return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9abdad262b78..7833e77e2b6b 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -462,9 +462,9 @@ async def query_room_alias_exists(
Args:
room_alias: The room alias to query.
+
Returns:
- namedtuple: with keys "room_id" and "servers" or None if no
- association can be found.
+ RoomAliasMapping or None if no association can be found.
"""
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 61607cf2bad7..bd1a3225638a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -168,25 +168,25 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class SsoLoginExtraAttributes:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
- creation_time = attr.ib(type=int)
- extra_attributes = attr.ib(type=JsonDict)
+ creation_time: int
+ extra_attributes: JsonDict
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class LoginTokenAttributes:
"""Data we store in a short-term login token"""
- user_id = attr.ib(type=str)
+ user_id: str
- auth_provider_id = attr.ib(type=str)
+ auth_provider_id: str
"""The SSO Identity Provider that the user authenticated with, to get this token."""
- auth_provider_session_id = attr.ib(type=Optional[str])
+ auth_provider_session_id: Optional[str]
"""The session ID advertised by the SSO Identity Provider."""
@@ -997,9 +997,7 @@ async def create_access_token_for_user_id(
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
- try:
- await self.store.get_device(user_id, device_id)
- except StoreError:
+ if await self.store.get_device(user_id, device_id) is None:
await self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
@@ -2283,7 +2281,7 @@ async def on_logged_out(
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
- callback(user_id, device_id, access_token)
+ await callback(user_id, device_id, access_token)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 82ee11e921e6..b184a48cb16c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -106,10 +106,10 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
Raises:
errors.NotFoundError: if the device was not found
"""
- try:
- device = await self.store.get_device(user_id, device_id)
- except errors.StoreError:
- raise errors.NotFoundError
+ device = await self.store.get_device(user_id, device_id)
+ if device is None:
+ raise errors.NotFoundError()
+
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
@@ -602,6 +602,8 @@ async def rehydrate_device(
access_token, device_id
)
old_device = await self.store.get_device(user_id, old_device_id)
+ if old_device is None:
+ raise errors.NotFoundError()
await self.store.update_device(user_id, device_id, old_device["display_name"])
# can't call self.delete_device because that will clobber the
# access token so call the storage layer directly
@@ -946,8 +948,16 @@ async def user_device_resync(
devices = []
ignore_devices = True
else:
+ prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
+ user_id
+ )
cached_devices = await self.store.get_cached_devices_for_user(user_id)
- if cached_devices == {d["device_id"]: d for d in devices}:
+
+ # To ensure that a user with no devices is cached, we skip the resync only
+ # if we have a stream_id from previously writing a cache entry.
+ if prev_stream_id is not None and cached_devices == {
+ d["device_id"]: d for d in devices
+ }:
logging.info(
"Skipping device list resync for %s, as our cache matches already",
user_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7ee5c47fd96b..082f521791cd 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -278,13 +278,15 @@ async def get_association(self, room_alias: RoomAlias) -> JsonDict:
users = await self.store.get_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
- servers = set(extra_servers) | set(servers)
+ servers_set = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
- if self.server_name in servers:
- servers = [self.server_name] + [s for s in servers if s != self.server_name]
+ if self.server_name in servers_set:
+ servers = [self.server_name] + [
+ s for s in servers_set if s != self.server_name
+ ]
else:
- servers = list(servers)
+ servers = list(servers_set)
return {"room_id": room_id, "servers": servers}
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 60c11e3d2128..d4dfddf63fb4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -65,8 +65,12 @@ def __init__(self, hs: "HomeServer"):
else:
# Only register this edu handler on master as it requires writing
# device updates to the db
- #
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "m.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
+ # also handle the unstable version
+ # FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
"org.matrix.signing_key_update",
self._edu_updater.incoming_signing_key_update,
@@ -576,7 +580,9 @@ async def upload_keys_for_user(
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
- fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+ fallback_keys = keys.get("fallback_keys") or keys.get(
+ "org.matrix.msc2732.fallback_keys"
+ )
if fallback_keys and isinstance(fallback_keys, dict):
log_kv(
{
@@ -1315,14 +1321,14 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
return old_key == new_key_copy
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys."""
- signing_key_id = attr.ib(type=str)
- target_user_id = attr.ib(type=str)
- target_device_id = attr.ib(type=str)
- signature = attr.ib(type=JsonDict)
+ signing_key_id: str
+ target_user_id: str
+ target_device_id: str
+ signature: JsonDict
class SigningKeyEduUpdater:
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 31742236a94d..12614b2c5d5a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, Dict, Optional
+
+from typing_extensions import Literal
from synapse.api.errors import (
Codes,
@@ -24,6 +26,7 @@
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
@@ -58,7 +61,9 @@ async def get_room_keys(
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
- ) -> List[JsonDict]:
+ ) -> Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -72,8 +77,8 @@ async def get_room_keys(
Raises:
NotFoundError: if the backup version does not exist
Returns:
- A list of dicts giving the session_data and message metadata for
- these room keys.
+ A dict giving the session_data and message metadata for these room keys.
+ `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
"""
# we deliberately take the lock to get keys so that changing the version
@@ -273,7 +278,7 @@ async def upload_room_keys(
@staticmethod
def _should_replace_room_key(
- current_room_key: Optional[JsonDict], room_key: JsonDict
+ current_room_key: Optional[RoomKey], room_key: RoomKey
) -> bool:
"""
Determine whether to replace a given current_room_key (if any)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 32b0254c5f08..bac5de052609 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -20,7 +20,6 @@
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
-from synapse.logging.utils import log_function
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
@@ -43,7 +42,6 @@ def __init__(self, hs: "HomeServer"):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- @log_function
async def get_stream(
self,
auth_user_id: str,
@@ -79,13 +77,14 @@ async def get_stream(
# thundering herds on restart.
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
- events, tokens = await self.notifier.get_events_for(
+ stream_result = await self.notifier.get_events_for(
auth_user,
pagin_config,
timeout,
is_guest=is_guest,
explicit_room_id=room_id,
)
+ events = stream_result.events
time_now = self.clock.time_msec()
@@ -118,18 +117,16 @@ async def get_stream(
events.extend(to_add)
- chunks = await self._event_serializer.serialize_events(
+ chunks = self._event_serializer.serialize_events(
events,
time_now,
as_client_event=as_client_event,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
)
chunk = {
"chunk": chunks,
- "start": await tokens[0].to_string(self.store),
- "end": await tokens[1].to_string(self.store),
+ "start": await stream_result.start_token.to_string(self.store),
+ "end": await stream_result.end_token.to_string(self.store),
}
return chunk
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1ea837d08211..a37ae0ca094f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -51,7 +51,6 @@
preserve_fn,
run_in_background,
)
-from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
@@ -360,31 +359,34 @@ async def try_backfill(domains: List[str]) -> bool:
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = await make_deferred_yieldable(
+ states_list = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
)
- # dict[str, dict[tuple, str]], a map from event_id to state map of
- # event_ids.
- states = dict(zip(event_ids, [s.state for s in states]))
+ # A map from event_id to state map of event_ids.
+ state_ids: Dict[str, StateMap[str]] = dict(
+ zip(event_ids, [s.state for s in states_list])
+ )
state_map = await self.store.get_events(
- [e_id for ids in states.values() for e_id in ids.values()],
+ [e_id for ids in state_ids.values() for e_id in ids.values()],
get_prev_content=False,
)
- states = {
+
+ # A map from event_id to state map of events.
+ state_events: Dict[str, StateMap[EventBase]] = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
}
- for key, state_dict in states.items()
+ for key, state_dict in state_ids.items()
}
for e_id in event_ids:
- likely_extremeties_domains = get_domains_from_state(states[e_id])
+ likely_extremeties_domains = get_domains_from_state(state_events[e_id])
success = await try_backfill(
[
@@ -553,7 +555,6 @@ async def do_invite_join(
run_in_background(self._handle_queued_pdus, room_queue)
- @log_function
async def do_knock(
self,
target_hosts: List[str],
@@ -925,7 +926,6 @@ async def on_make_leave_request(
return event
- @log_function
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
@@ -1036,7 +1036,6 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
else:
return []
- @log_function
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
@@ -1053,7 +1052,6 @@ async def on_backfill_request(
return events
- @log_function
async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
@@ -1115,7 +1113,6 @@ async def on_get_missing_events(
return missing_events
- @log_function
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
) -> None:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9917613298c6..3905f60b3a78 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -56,7 +56,6 @@
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
from synapse.logging.context import nested_logging_context, run_in_background
-from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
@@ -275,7 +274,6 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
await self._process_received_pdu(origin, pdu, state=None)
- @log_function
async def on_send_membership_event(
self, origin: str, event: EventBase
) -> Tuple[EventBase, EventContext]:
@@ -421,9 +419,6 @@ async def process_remote_join(
Raises:
SynapseError if the response is in some way invalid.
"""
- for e in itertools.chain(auth_events, state):
- e.internal_metadata.outlier = True
-
event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
create_event = None
@@ -475,7 +470,6 @@ async def process_remote_join(
return await self.persist_events_and_notify(room_id, [(event, context)])
- @log_function
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> None:
@@ -666,7 +660,9 @@ async def _process_pulled_event(
logger.info("Processing pulled event %s", event)
# these should not be outliers.
- assert not event.internal_metadata.is_outlier()
+ assert (
+ not event.internal_metadata.is_outlier()
+ ), "pulled event unexpectedly flagged as outlier"
event_id = event.event_id
@@ -1192,7 +1188,6 @@ async def get_event(event_id: str) -> None:
[destination],
event_id,
room_version,
- outlier=True,
)
if event is None:
logger.warning(
@@ -1221,9 +1216,10 @@ async def _auth_and_persist_outliers(
"""Persist a batch of outlier events fetched from remote servers.
We first sort the events to make sure that we process each event's auth_events
- before the event itself, and then auth and persist them.
+ before the event itself.
- Notifies about the events where appropriate.
+ We then mark the events as outliers, persist them to the database, and, where
+ appropriate (eg, an invite), awake the notifier.
Params:
room_id: the room that the events are meant to be in (though this has
@@ -1274,7 +1270,8 @@ async def _auth_and_persist_outliers_inner(
Persists a batch of events where we have (theoretically) already persisted all
of their auth events.
- Notifies about the events where appropriate.
+ Marks the events as outliers, auths them, persists them to the database, and,
+ where appropriate (eg, an invite), awakes the notifier.
Params:
origin: where the events came from
@@ -1312,6 +1309,9 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
return None
auth.append(ae)
+ # we're not bothering about room state, so flag the event as an outlier.
+ event.internal_metadata.outlier = True
+
context = EventContext.for_outlier()
try:
validate_event_for_room_version(room_version_obj, event)
@@ -1838,7 +1838,7 @@ async def persist_events_and_notify(
The stream ID after which all events have been persisted.
"""
if not event_and_contexts:
- return self._store.get_current_events_token()
+ return self._store.get_room_max_stream_ordering()
instance = self._config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9cd21e7f2b3c..346a06ff49b7 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,21 +13,27 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.events import EventBase
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomStreamToken,
+ StateMap,
+ StreamToken,
+ UserID,
+)
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
@@ -164,11 +170,9 @@ async def handle_room(event: RoomsForUser) -> None:
d["inviter"] = event.sender
invite_event = await self.store.get_event(event.event_id)
- d["invite"] = await self._event_serializer.serialize_event(
+ d["invite"] = self._event_serializer.serialize_event(
invite_event,
time_now,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
as_client_event=as_client_event,
)
@@ -190,14 +194,13 @@ async def handle_room(event: RoomsForUser) -> None:
)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
- )
- deferred_room_state.addCallback(
- lambda states: states[event.event_id]
+ ).addCallback(
+ lambda states: cast(StateMap[EventBase], states[event.event_id])
)
(messages, token), current_state = await make_deferred_yieldable(
- defer.gatherResults(
- [
+ gather_results(
+ (
run_in_background(
self.store.get_recent_events_for_room,
event.room_id,
@@ -205,7 +208,7 @@ async def handle_room(event: RoomsForUser) -> None:
end_token=room_end_token,
),
deferred_room_state,
- ]
+ )
)
).addErrback(unwrapFirstError)
@@ -219,11 +222,9 @@ async def handle_room(event: RoomsForUser) -> None:
d["messages"] = {
"chunk": (
- await self._event_serializer.serialize_events(
+ self._event_serializer.serialize_events(
messages,
time_now=time_now,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
as_client_event=as_client_event,
)
),
@@ -231,11 +232,9 @@ async def handle_room(event: RoomsForUser) -> None:
"end": await end_token.to_string(self.store),
}
- d["state"] = await self._event_serializer.serialize_events(
+ d["state"] = self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
as_client_event=as_client_event,
)
@@ -377,18 +376,14 @@ async def _room_initial_sync_parted(
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(
- messages, time_now, bundle_aggregations=False
- )
+ self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(
- room_state.values(), time_now, bundle_aggregations=False
- )
+ self._event_serializer.serialize_events(room_state.values(), time_now)
),
"presence": [],
"receipts": [],
@@ -407,8 +402,8 @@ async def _room_initial_sync_joined(
# TODO: These concurrently
time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
- state = await self._event_serializer.serialize_events(
- current_state.values(), time_now, bundle_aggregations=False
+ state = self._event_serializer.serialize_events(
+ current_state.values(), time_now
)
now_token = self.hs.get_event_sources().get_current_token()
@@ -454,8 +449,8 @@ async def get_receipts() -> List[JsonDict]:
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable(
- defer.gatherResults(
- [
+ gather_results(
+ (
run_in_background(get_presence),
run_in_background(get_receipts),
run_in_background(
@@ -464,7 +459,7 @@ async def get_receipts() -> List[JsonDict]:
limit=limit,
end_token=now_token.room_key,
),
- ],
+ ),
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -483,9 +478,7 @@ async def get_receipts() -> List[JsonDict]:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(
- messages, time_now, bundle_aggregations=False
- )
+ self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 87f671708c4e..b37250aa3895 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -21,7 +21,6 @@
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth
@@ -57,7 +56,7 @@
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder, log_failure
-from synapse.util.async_helpers import Linearizer, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -247,7 +246,7 @@ async def get_state_events(
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
- events = await self._event_serializer.serialize_events(room_state.values(), now)
+ events = self._event_serializer.serialize_events(room_state.values(), now)
return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
@@ -496,6 +495,7 @@ async def create_event(
require_consent: bool = True,
outlier: bool = False,
historical: bool = False,
+ allow_no_prev_events: bool = False,
depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""
@@ -607,6 +607,7 @@ async def create_event(
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
depth=depth,
+ allow_no_prev_events=allow_no_prev_events,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@@ -882,6 +883,7 @@ async def create_new_client_event(
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
+ allow_no_prev_events: bool = False,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -912,6 +914,7 @@ async def create_new_client_event(
full_state_ids_at_event = None
if auth_event_ids is not None:
# If auth events are provided, prev events must be also.
+ # prev_event_ids could be an empty array though.
assert prev_event_ids is not None
# Copy the full auth state before it stripped down
@@ -943,14 +946,22 @@ async def create_new_client_event(
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
- # we now ought to have some prev_events (unless it's a create event).
- #
- # do a quick sanity check here, rather than waiting until we've created the
+ # Do a quick sanity check here, rather than waiting until we've created the
# event and then try to auth it (which fails with a somewhat confusing "No
# create event in auth events")
- assert (
- builder.type == EventTypes.Create or len(prev_event_ids) > 0
- ), "Attempting to create an event with no prev_events"
+ if allow_no_prev_events:
+ # We allow events with no `prev_events` but it better have some `auth_events`
+ assert (
+ builder.type == EventTypes.Create
+ # Allow an event to have empty list of prev_event_ids
+ # only if it has auth_event_ids.
+ or auth_event_ids
+ ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
+ else:
+ # we now ought to have some prev_events (unless it's a create event).
+ assert (
+ builder.type == EventTypes.Create or prev_event_ids
+ ), "Attempting to create a non-m.room.create event with no prev_events"
event = await builder.build(
prev_event_ids=prev_event_ids,
@@ -1156,9 +1167,9 @@ async def handle_new_client_event(
# We now persist the event (and update the cache in parallel, since we
# don't want to block on it).
- result = await make_deferred_yieldable(
- defer.gatherResults(
- [
+ result, _ = await make_deferred_yieldable(
+ gather_results(
+ (
run_in_background(
self._persist_event,
requester=requester,
@@ -1170,12 +1181,12 @@ async def handle_new_client_event(
run_in_background(
self.cache_joined_hosts_for_event, event, context
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
- ],
+ ),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
- return result[0]
+ return result
async def _persist_event(
self,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 4f424380533b..973f262964c1 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -537,12 +537,17 @@ async def get_messages(
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
+ aggregations = await self.store.get_bundled_aggregations(events, user_id)
+
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- await self._event_serializer.serialize_events(
- events, time_now, as_client_event=as_client_event
+ self._event_serializer.serialize_events(
+ events,
+ time_now,
+ bundle_aggregations=aggregations,
+ as_client_event=as_client_event,
)
),
"start": await from_token.to_string(self.store),
@@ -550,7 +555,7 @@ async def get_messages(
}
if state:
- chunk["state"] = await self._event_serializer.serialize_events(
+ chunk["state"] = self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index dc00f28038ea..f74efb375d0d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -55,7 +55,6 @@
from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
-from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.presence import (
@@ -729,7 +728,7 @@ def run_persister() -> Awaitable[None]:
# Presence is best effort and quickly heals itself, so lets just always
# stream from the current state when we restart.
- self._event_pos = self.store.get_current_events_token()
+ self._event_pos = self.store.get_room_max_stream_ordering()
self._event_processing = False
async def _on_shutdown(self) -> None:
@@ -1542,7 +1541,6 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @log_function
async def get_new_events(
self,
user: UserID,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4911a1153519..5cb1ff749d92 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.appservice import ApplicationService
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
for event_id in content.keys():
event_content = content.get(event_id, {})
- m_read = event_content.get("m.read", {})
+ m_read = event_content.get(ReceiptTypes.READ, {})
# If m_read is missing copy over the original event_content as there is nothing to process here
if not m_read:
@@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
# Set new users unless empty
if len(new_users.keys()) > 0:
- new_event["content"][event_id] = {"m.read": new_users}
+ new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
# Append new_event to visible_events unless empty
if len(new_event["content"].keys()) > 0:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ead2198e14fe..f963078e596c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -172,7 +172,7 @@ async def upgrade_room(
user_id = requester.user.to_string()
# Check if this room is already being upgraded by another person
- for key in self._upgrade_response_cache.pending_result_cache:
+ for key in self._upgrade_response_cache.keys():
if key[0] == old_room_id and key[1] != user_id:
# Two different people are trying to upgrade the same room.
# Send the second an error.
@@ -393,7 +393,9 @@ async def clone_existing_room(
user_id = requester.user.to_string()
if not await self.spam_checker.user_may_create_room(user_id):
- raise SynapseError(403, "You are not permitted to create rooms")
+ raise SynapseError(
+ 403, "You are not permitted to create rooms", Codes.FORBIDDEN
+ )
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
@@ -685,7 +687,9 @@ async def create_room(
invite_3pid_list,
)
):
- raise SynapseError(403, "You are not permitted to create rooms")
+ raise SynapseError(
+ 403, "You are not permitted to create rooms", Codes.FORBIDDEN
+ )
if ratelimit:
await self.request_ratelimiter.ratelimit(requester)
@@ -1177,6 +1181,22 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
# `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0]
+ # Fetch the aggregations.
+ aggregations = await self.store.get_bundled_aggregations(
+ [results["event"]], user.to_string()
+ )
+ aggregations.update(
+ await self.store.get_bundled_aggregations(
+ results["events_before"], user.to_string()
+ )
+ )
+ aggregations.update(
+ await self.store.get_bundled_aggregations(
+ results["events_after"], user.to_string()
+ )
+ )
+ results["aggregations"] = aggregations
+
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
else:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index ba7a14d651c5..1a33211a1fe1 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -13,9 +13,9 @@
# limitations under the License.
import logging
-from collections import namedtuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
+import attr
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -474,16 +474,12 @@ async def _get_remote_list_cached(
)
-class RoomListNextBatch(
- namedtuple(
- "RoomListNextBatch",
- (
- "last_joined_members", # The count to get rooms after/before
- "last_room_id", # The room_id to get rooms after/before
- "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
- ),
- )
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomListNextBatch:
+ last_joined_members: int # The count to get rooms after/before
+ last_room_id: str # The room_id to get rooms after/before
+ direction_is_forward: bool # True if this is a next_batch, false if prev_batch
+
KEY_DICT = {
"last_joined_members": "m",
"last_room_id": "r",
@@ -502,12 +498,12 @@ def from_token(cls, token: str) -> "RoomListNextBatch":
def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
- {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
+ {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
)
)
def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
- return self._replace(**kwds)
+ return attr.evolve(self, **kwds)
def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a6dbff637f54..6aa910dd10f9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -658,7 +658,8 @@ async def update_membership_locked(
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
- if prev_event_ids:
+ # An empty prev_events list is allowed as long as the auth_event_ids are present
+ if prev_event_ids is not None:
return await self._local_membership_update(
requester=requester,
target=target,
@@ -1019,7 +1020,7 @@ async def transfer_room_state_on_room_upgrade(
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
- if old_room and old_room["is_public"]:
+ if old_room is not None and old_room["is_public"]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)
@@ -1030,7 +1031,9 @@ async def transfer_room_state_on_room_upgrade(
local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids:
# Add new the new room to those groups
- await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+ await self.store.add_room_to_group(
+ group_id, room_id, old_room is not None and old_room["is_public"]
+ )
# Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id)
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index b2cfe537dfb1..4844b69a0345 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -153,6 +153,9 @@ async def get_space_summary(
rooms_result: List[JsonDict] = []
events_result: List[JsonDict] = []
+ if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE:
+ max_rooms_per_space = MAX_ROOMS_PER_SPACE
+
while room_queue and len(rooms_result) < MAX_ROOMS:
queue_entry = room_queue.popleft()
room_id = queue_entry.room_id
@@ -167,7 +170,7 @@ async def get_space_summary(
# The client-specified max_rooms_per_space limit doesn't apply to the
# room_id specified in the request, so we ignore it if this is the
# first room we are processing.
- max_children = max_rooms_per_space if processed_rooms else None
+ max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS
if is_in_room:
room_entry = await self._summarize_local_room(
@@ -209,7 +212,7 @@ async def get_space_summary(
# Before returning to the client, remove the allowed_room_ids
# and allowed_spaces keys.
room.pop("allowed_room_ids", None)
- room.pop("allowed_spaces", None)
+ room.pop("allowed_spaces", None) # historical
rooms_result.append(room)
events.extend(room_entry.children_state_events)
@@ -395,7 +398,7 @@ async def _get_room_hierarchy(
None,
room_id,
suggested_only,
- # TODO Handle max children.
+ # Do not limit the maximum children.
max_children=None,
)
@@ -525,6 +528,10 @@ async def federation_space_summary(
rooms_result: List[JsonDict] = []
events_result: List[JsonDict] = []
+ # Set a limit on the number of rooms to return.
+ if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE:
+ max_rooms_per_space = MAX_ROOMS_PER_SPACE
+
while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft()
if room_id in processed_rooms:
@@ -583,7 +590,9 @@ async def get_federation_hierarchy(
# Iterate through each child and potentially add it, but not its children,
# to the response.
- for child_room in root_room_entry.children_state_events:
+ for child_room in itertools.islice(
+ root_room_entry.children_state_events, MAX_ROOMS_PER_SPACE
+ ):
room_id = child_room.get("state_key")
assert isinstance(room_id, str)
# If the room is unknown, skip it.
@@ -633,8 +642,8 @@ async def _summarize_local_room(
suggested_only: True if only suggested children should be returned.
Otherwise, all children are returned.
max_children:
- The maximum number of children rooms to include. This is capped
- to a server-set limit.
+ The maximum number of children rooms to include. A value of None
+ means no limit.
Returns:
A room entry if the room should be returned. None, otherwise.
@@ -656,8 +665,13 @@ async def _summarize_local_room(
# we only care about suggested children
child_events = filter(_is_suggested_child_event, child_events)
- if max_children is None or max_children > MAX_ROOMS_PER_SPACE:
- max_children = MAX_ROOMS_PER_SPACE
+ # TODO max_children is legacy code for the /spaces endpoint.
+ if max_children is not None:
+ child_iter: Iterable[EventBase] = itertools.islice(
+ child_events, max_children
+ )
+ else:
+ child_iter = child_events
stripped_events: List[JsonDict] = [
{
@@ -668,7 +682,7 @@ async def _summarize_local_room(
"sender": e.sender,
"origin_server_ts": e.origin_server_ts,
}
- for e in itertools.islice(child_events, max_children)
+ for e in child_iter
]
return _RoomEntry(room_id, room_entry, stripped_events)
@@ -766,6 +780,7 @@ async def _summarize_remote_room_hierarchy(
try:
(
room_response,
+ children_state_events,
children,
inaccessible_children,
) = await self._federation_client.get_room_hierarchy(
@@ -790,7 +805,7 @@ async def _summarize_remote_room_hierarchy(
}
return (
- _RoomEntry(room_id, room_response, room_response.pop("children_state", ())),
+ _RoomEntry(room_id, room_response, children_state_events),
children_by_room_id,
set(inaccessible_children),
)
@@ -988,12 +1003,14 @@ async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDic
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
+ # plural join_rules is a documentation error but kept for historical
+ # purposes. Should match /publicRooms.
"join_rules": stats["join_rules"],
+ "join_rule": stats["join_rules"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
),
"guest_can_join": stats["guest_access"] == "can_join",
- "creation_ts": create_event.origin_server_ts,
"room_type": create_event.content.get(EventContentFields.ROOM_TYPE),
}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ab7eaab2fb56..0b153a682261 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -420,10 +420,10 @@ async def search(
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = await self._event_serializer.serialize_events(
+ context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now
)
- context["events_after"] = await self._event_serializer.serialize_events(
+ context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now
)
@@ -441,9 +441,7 @@ async def search(
results.append(
{
"rank": rank_map[e.event_id],
- "result": (
- await self._event_serializer.serialize_event(e, time_now)
- ),
+ "result": self._event_serializer.serialize_event(e, time_now),
"context": contexts.get(e.event_id, {}),
}
)
@@ -457,7 +455,7 @@ async def search(
if state_results:
s = {}
for room_id, state_events in state_results.items():
- s[room_id] = await self._event_serializer.serialize_events(
+ s[room_id] = self._event_serializer.serialize_events(
state_events, time_now
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 65c27bc64a5e..0bb8b0929e7e 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -126,45 +126,45 @@ async def handle_redirect_request(
raise NotImplementedError()
-@attr.s
+@attr.s(auto_attribs=True)
class UserAttributes:
# the localpart of the mxid that the mapper has assigned to the user.
# if `None`, the mapper has not picked a userid, and the user should be prompted to
# enter one.
- localpart = attr.ib(type=Optional[str])
- display_name = attr.ib(type=Optional[str], default=None)
- emails = attr.ib(type=Collection[str], default=attr.Factory(list))
+ localpart: Optional[str]
+ display_name: Optional[str] = None
+ emails: Collection[str] = attr.Factory(list)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class UsernameMappingSession:
"""Data we track about SSO sessions"""
# A unique identifier for this SSO provider, e.g. "oidc" or "saml".
- auth_provider_id = attr.ib(type=str)
+ auth_provider_id: str
# user ID on the IdP server
- remote_user_id = attr.ib(type=str)
+ remote_user_id: str
# attributes returned by the ID mapper
- display_name = attr.ib(type=Optional[str])
- emails = attr.ib(type=Collection[str])
+ display_name: Optional[str]
+ emails: Collection[str]
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
- extra_login_attributes = attr.ib(type=Optional[JsonDict])
+ extra_login_attributes: Optional[JsonDict]
# where to redirect the client back to
- client_redirect_url = attr.ib(type=str)
+ client_redirect_url: str
# expiry time for the session, in milliseconds
- expiry_time_ms = attr.ib(type=int)
+ expiry_time_ms: int
# choices made by the user
- chosen_localpart = attr.ib(type=Optional[str], default=None)
- use_display_name = attr.ib(type=bool, default=True)
- emails_to_use = attr.ib(type=Collection[str], default=())
- terms_accepted_version = attr.ib(type=Optional[str], default=None)
+ chosen_localpart: Optional[str] = None
+ use_display_name: bool = True
+ emails_to_use: Collection[str] = ()
+ terms_accepted_version: Optional[str] = None
# the HTTP cookie used to track the mapping session id
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index bd3e6f2ec77b..29e41a4c796c 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -80,6 +80,17 @@ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos > room_max_stream_ordering:
+ # apparently, we've processed more events than exist in the database!
+ # this can happen if events are removed with history purge or similar.
+ logger.warning(
+ "Event stream ordering appears to have gone backwards (%i -> %i): "
+ "rewinding stats processor",
+ self.pos,
+ room_max_stream_ordering,
+ )
+ self.pos = room_max_stream_ordering
# Loop round handling deltas until we're up to date
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f3039c3c3fb7..ffc6b748e84e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,7 +28,7 @@
import attr
from prometheus_client import Counter
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -36,6 +36,7 @@
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
+from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -59,10 +60,6 @@
logger = logging.getLogger(__name__)
-# Debug logger for https://github.com/matrix-org/synapse/issues/4422
-issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug")
-
-
# Counts the number of times we returned a non-empty sync. `type` is one of
# "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is
# "true" or "false" depending on if the request asked for lazy loaded members or
@@ -101,6 +98,9 @@ class TimelineBatch:
prev_batch: StreamToken
events: List[EventBase]
limited: bool
+ # A mapping of event ID to the bundled aggregations for the above events.
+ # This is only calculated if limited is true.
+ bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -421,7 +421,7 @@ async def current_sync_for_user(
span to track the sync. See `generate_sync_result` for the next part of your
indoctrination.
"""
- with start_active_span("current_sync_for_user"):
+ with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state
@@ -633,10 +633,19 @@ async def _load_filtered_recents(
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+ # Don't bother to bundle aggregations if the timeline is unlimited,
+ # as clients will have all the necessary information.
+ bundled_aggregations = None
+ if limited or newly_joined_room:
+ bundled_aggregations = await self.store.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
+
return TimelineBatch(
events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room,
+ bundled_aggregations=bundled_aggregations,
)
async def get_state_after_event(
@@ -1041,18 +1050,17 @@ async def compute_state_delta(
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
- ) -> Dict[str, int]:
+ ) -> NotifCounts:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
- receipt_type="m.read",
+ receipt_type=ReceiptTypes.READ,
)
- notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+ return await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
- return notifs
async def generate_sync_result(
self,
@@ -1161,13 +1169,8 @@ async def generate_sync_result(
num_events = 0
- # debug for https://github.com/matrix-org/synapse/issues/4422
+ # debug for https://github.com/matrix-org/synapse/issues/9424
for joined_room in sync_result_builder.joined:
- room_id = joined_room.room_id
- if room_id in newly_joined_rooms:
- issue4422_logger.debug(
- "Sync result for newly joined room %s: %r", room_id, joined_room
- )
num_events += len(joined_room.timeline.events)
log_kv(
@@ -1585,7 +1588,8 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
)
logger.debug("Generated room entry for %s", room_entry.room_id)
- await concurrently_execute(handle_room_entries, room_entries, 10)
+ with start_active_span("sync.generate_room_entries"):
+ await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
@@ -1662,20 +1666,20 @@ async def _get_rooms_changed(
) -> _RoomChanges:
"""Determine the changes in rooms to report to the user.
- Ideally, we want to report all events whose stream ordering `s` lies in the
- range `since_token < s <= now_token`, where the two tokens are read from the
- sync_result_builder.
+ This function is a first pass at generating the rooms part of the sync response.
+ It determines which rooms have changed during the sync period, and categorises
+ them into four buckets: "knock", "invite", "join" and "leave".
- If there are too many events in that range to report, things get complicated.
- In this situation we return a truncated list of the most recent events, and
- indicate in the response that there is a "gap" of omitted events. Additionally:
+ 1. Finds all membership changes for the user in the sync period (from
+ `since_token` up to `now_token`).
+ 2. Uses those to place the room in one of the four categories above.
+ 3. Builds a `_RoomChanges` struct to record this, and return that struct.
- - we include a "state_delta", to describe the changes in state over the gap,
- - we include all membership events applying to the user making the request,
- even those in the gap.
-
- See the spec for the rationale:
- https://spec.matrix.org/v1.1/client-server-api/#syncing
+ For rooms classified as "knock", "invite" or "leave", we just need to report
+ a single membership event in the eventual /sync response. For "join" we need
+ to fetch additional non-membership events, e.g. messages in the room. That is
+ more complicated, so instead we report an intermediary `RoomSyncResultBuilder`
+ struct, and leave the additional work to `_generate_room_entry`.
The sync_result_builder is not modified by this function.
"""
@@ -1686,16 +1690,6 @@ async def _get_rooms_changed(
assert since_token
- # The spec
- # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
- # notes that membership events need special consideration:
- #
- # > When a sync is limited, the server MUST return membership events for events
- # > in the gap (between since and the start of the returned timeline), regardless
- # > as to whether or not they are redundant.
- #
- # We fetch such events here, but we only seem to use them for categorising rooms
- # as newly joined, newly left, invited or knocked.
# TODO: we've already called this function and ran this query in
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
@@ -1749,18 +1743,6 @@ async def _get_rooms_changed(
old_mem_ev_id, allow_none=True
)
- # debug for #4422
- if has_join:
- prev_membership = None
- if old_mem_ev:
- prev_membership = old_mem_ev.membership
- issue4422_logger.debug(
- "Previous membership for room %s with join: %s (event %s)",
- room_id,
- prev_membership,
- old_mem_ev_id,
- )
-
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
@@ -1902,13 +1884,6 @@ async def _get_rooms_changed(
upto_token=since_token,
)
- if newly_joined:
- # debugging for https://github.com/matrix-org/synapse/issues/4422
- issue4422_logger.debug(
- "RoomSyncResultBuilder events for newly joined room %s: %r",
- room_id,
- entry.events,
- )
room_entries.append(entry)
return _RoomChanges(
@@ -2009,6 +1984,23 @@ async def _generate_room_entry(
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
+ Ideally, we want to report all events whose stream ordering `s` lies in the
+ range `since_token < s <= now_token`, where the two tokens are read from the
+ sync_result_builder.
+
+ If there are too many events in that range to report, things get complicated.
+ In this situation we return a truncated list of the most recent events, and
+ indicate in the response that there is a "gap" of omitted events. Lots of this
+ is handled in `_load_filtered_recents`, but some of is handled in this method.
+
+ Additionally:
+ - we include a "state_delta", to describe the changes in state over the gap,
+ - we include all membership events applying to the user making the request,
+ even those in the gap.
+
+ See the spec for the rationale:
+ https://spec.matrix.org/v1.1/client-server-api/#syncing
+
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
@@ -2038,7 +2030,7 @@ async def _generate_room_entry(
since_token = room_builder.since_token
upto_token = room_builder.upto_token
- with start_active_span("generate_room_entry"):
+ with start_active_span("sync.generate_room_entry"):
set_tag("room_id", room_id)
log_kv({"events": len(events or ())})
@@ -2069,14 +2061,6 @@ async def _generate_room_entry(
# `_load_filtered_recents` can't find any events the user should see
# (e.g. due to having ignored the sender of the last 50 events).
- if newly_joined:
- # debug for https://github.com/matrix-org/synapse/issues/4422
- issue4422_logger.debug(
- "Timeline events after filtering in newly-joined room %s: %r",
- room_id,
- batch,
- )
-
# When we join the room (or the client requests full_state), we should
# send down any existing tags. Usually the user won't have tags in a
# newly joined room, unless either a) they've joined before or b) the
@@ -2166,10 +2150,10 @@ async def _generate_room_entry(
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
- unread_notifications["notification_count"] = notifs["notify_count"]
- unread_notifications["highlight_count"] = notifs["highlight_count"]
+ unread_notifications["notification_count"] = notifs.notify_count
+ unread_notifications["highlight_count"] = notifs.highlight_count
- room_sync.unread_count = notifs["unread_count"]
+ room_sync.unread_count = notifs.unread_count
sync_result_builder.joined.append(room_sync)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 1676ebd057c1..e43c22832da6 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -13,9 +13,10 @@
# limitations under the License.
import logging
import random
-from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+import attr
+
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import (
@@ -37,7 +38,10 @@
# A tiny object useful for storing a user's membership in a room, as a mapping
# key
-RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomMember:
+ room_id: str
+ user_id: str
# How often we expect remote servers to resend us presence.
@@ -119,7 +123,7 @@ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member: RoomMember) -> bool:
- return member.user_id in self._room_typing.get(member.room_id, [])
+ return member.user_id in self._room_typing.get(member.room_id, set())
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
@@ -166,9 +170,9 @@ def process_replication_rows(
for row in rows:
self._room_serials[row.room_id] = token
- prev_typing = set(self._room_typing.get(row.room_id, []))
+ prev_typing = self._room_typing.get(row.room_id, set())
now_typing = set(row.user_ids)
- self._room_typing[row.room_id] = row.user_ids
+ self._room_typing[row.room_id] = now_typing
if self.federation:
run_as_background_process(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a0eb45446f56..1565e034cb25 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -148,9 +148,21 @@ async def _unsafe_process(self) -> None:
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
- # If still None then the initial background update hasn't happened yet.
- if self.pos is None:
- return None
+ # If still None then the initial background update hasn't happened yet.
+ if self.pos is None:
+ return None
+
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos > room_max_stream_ordering:
+ # apparently, we've processed more events than exist in the database!
+ # this can happen if events are removed with history purge or similar.
+ logger.warning(
+ "Event stream ordering appears to have gone backwards (%i -> %i): "
+ "rewinding user directory processor",
+ self.pos,
+ room_max_stream_ordering,
+ )
+ self.pos = room_max_stream_ordering
# Loop round handling deltas until we're up to date
while True:
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 578fc48ef454..efecb089c135 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,7 +25,7 @@
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
- def __init__(self, msg):
+ def __init__(self, msg: str):
super().__init__(504, msg)
@@ -33,7 +33,7 @@ def __init__(self, msg):
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
-def redact_uri(uri):
+def redact_uri(uri: str) -> str:
"""Strips sensitive information from the uri replaces with """
uri = ACCESS_TOKEN_RE.sub(r"\1\3", uri)
return CLIENT_SECRET_RE.sub(r"\1\3", uri)
@@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
https://twistedmatrix.com/trac/ticket/6528
"""
- def stopProducing(self):
+ def stopProducing(self) -> None:
try:
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 9a2684aca432..6a9f6635d2c0 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request
@@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""
- def __init__(self, hs: "HomeServer", handler):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
+ ):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
@@ -47,7 +51,7 @@ def __init__(self, hs: "HomeServer", handler):
super().__init__()
self._handler = handler
- def _async_render(self, request: Request):
+ async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
- return self._handler(request)
+ return await self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b5a2d333a6ce..ca33b45cb21f 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
+from http import HTTPStatus
from io import BytesIO
from typing import (
TYPE_CHECKING,
@@ -280,7 +281,9 @@ def request(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
- e = SynapseError(403, "IP address blocked by IP blacklist entry")
+ e = SynapseError(
+ HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
+ )
return defer.fail(Failure(e))
return self._agent.request(
@@ -585,7 +588,7 @@ async def get_json(
if headers:
actual_headers.update(headers) # type: ignore
- body = await self.get_raw(uri, args, headers=headers)
+ body = await self.get_raw(uri, args, headers=actual_headers)
return json_decoder.decode(body.decode("utf-8"))
async def put_json(
@@ -719,7 +722,9 @@ async def get_file(
if response.code > 299:
logger.warning("Got %d when downloading %s" % (response.code, url))
- raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
+ )
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
@@ -731,12 +736,14 @@ async def get_file(
)
except BodyExceededMaxSize:
raise SynapseError(
- 502,
+ HTTPStatus.BAD_GATEWAY,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
except Exception as e:
- raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
+ ) from e
return (
length,
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index fbafffd69bd6..203e995bb77d 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -32,9 +32,9 @@ class ProxyConnectError(ConnectError):
pass
-@attr.s
+@attr.s(auto_attribs=True)
class ProxyCredentials:
- username_password = attr.ib(type=bytes)
+ username_password: bytes
def as_proxy_authorization_value(self) -> bytes:
"""
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 1238bfd28726..a8a520f80944 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -25,6 +25,7 @@
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import (
+ IProtocol,
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
@@ -309,12 +310,14 @@ def __init__(
self._srv_resolver = srv_resolver
- def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
+ def connect(
+ self, protocol_factory: IProtocolFactory
+ ) -> "defer.Deferred[IProtocol]":
"""Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)
- async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
first_exception = None
server_list = await self._resolve_server()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 203d723d4120..2e668363b2f5 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,6 +19,7 @@
import sys
import typing
import urllib.parse
+from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
@@ -122,37 +123,37 @@ def finish(self) -> T:
pass
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class MatrixFederationRequest:
- method = attr.ib(type=str)
+ method: str
"""HTTP method
"""
- path = attr.ib(type=str)
+ path: str
"""HTTP path
"""
- destination = attr.ib(type=str)
+ destination: str
"""The remote server to send the HTTP request to.
"""
- json = attr.ib(default=None, type=Optional[JsonDict])
+ json: Optional[JsonDict] = None
"""JSON to send in the body.
"""
- json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
+ json_callback: Optional[Callable[[], JsonDict]] = None
"""A callback to generate the JSON.
"""
- query = attr.ib(default=None, type=Optional[dict])
+ query: Optional[dict] = None
"""Query arguments.
"""
- txn_id = attr.ib(default=None, type=Optional[str])
+ txn_id: Optional[str] = None
"""Unique ID for this request (for logging)
"""
- uri = attr.ib(init=False, type=bytes)
+ uri: bytes = attr.ib(init=False)
"""The URI of this request
"""
@@ -1154,7 +1155,7 @@ async def get_file(
request.destination,
msg,
)
- raise SynapseError(502, msg, Codes.TOO_LARGE)
+ raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 37af9b514f31..6d0658562782 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,7 +14,6 @@
# limitations under the License.
import abc
-import collections
import html
import logging
import types
@@ -30,12 +29,14 @@
Iterable,
Iterator,
List,
+ NoReturn,
Optional,
Pattern,
Tuple,
Union,
)
+import attr
import jinja2
from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
@@ -57,12 +58,14 @@
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
-from synapse.logging.opentracing import trace_servlet
+from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
+ import opentracing
+
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -170,7 +173,9 @@ def return_html_error(
respond_with_html(request, code, body)
-def wrap_async_request_handler(h):
+def wrap_async_request_handler(
+ h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
+) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
@@ -183,7 +188,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""
- async def wrapped_async_request_handler(self, request):
+ async def wrapped_async_request_handler(
+ self: "_AsyncResource", request: SynapseRequest
+ ) -> None:
with request.processing():
await h(self, request)
@@ -240,18 +247,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
context from the request the servlet is handling.
"""
- def __init__(self, extract_context=False):
+ def __init__(self, extract_context: bool = False):
super().__init__()
self._extract_context = extract_context
- def render(self, request):
+ def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
- async def _async_render_wrapper(self, request: SynapseRequest):
+ async def _async_render_wrapper(self, request: SynapseRequest) -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@@ -271,7 +278,7 @@ async def _async_render_wrapper(self, request: SynapseRequest):
f = failure.Failure()
self._send_error_response(f, request)
- async def _async_render(self, request: Request):
+ async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
@@ -318,7 +325,7 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""
- def __init__(self, canonical_json=False, extract_context=False):
+ def __init__(self, canonical_json: bool = False, extract_context: bool = False):
super().__init__(extract_context)
self.canonical_json = canonical_json
@@ -327,7 +334,7 @@ def _send_response(
request: SynapseRequest,
code: int,
response_object: Any,
- ):
+ ) -> None:
"""Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
@@ -347,9 +354,11 @@ def _send_error_response(
return_json_error(f, request)
-_PathEntry = collections.namedtuple(
- "_PathEntry", ["pattern", "callback", "servlet_classname"]
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PathEntry:
+ pattern: Pattern
+ callback: ServletCallback
+ servlet_classname: str
class JsonResource(DirectServeJsonResource):
@@ -368,34 +377,45 @@ class JsonResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ canonical_json: bool = True,
+ extract_context: bool = False,
+ ):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs
- def register_paths(self, method, path_patterns, callback, servlet_classname):
+ def register_paths(
+ self,
+ method: str,
+ path_patterns: Iterable[Pattern],
+ callback: ServletCallback,
+ servlet_classname: str,
+ ) -> None:
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.
Args:
- method (str): GET, POST etc
+ method: GET, POST etc
- path_patterns (Iterable[str]): A list of regular expressions to which
- the request URLs are compared.
+ path_patterns: A list of regular expressions to which the request
+ URLs are compared.
- callback (function): The handler for the request. Usually a Servlet
+ callback: The handler for the request. Usually a Servlet
- servlet_classname (str): The name of the handler to be used in prometheus
+ servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
- method = method.encode("utf-8") # method is bytes on py3
+ method_bytes = method.encode("utf-8")
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
- self.path_regexs.setdefault(method, []).append(
+ self.path_regexs.setdefault(method_bytes, []).append(
_PathEntry(path_pattern, callback, servlet_classname)
)
@@ -427,7 +447,7 @@ def _get_handler_for_request(
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}
- async def _async_render(self, request):
+ async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appropriate name for this handler in prometheus
@@ -468,7 +488,7 @@ def _send_response(
request: SynapseRequest,
code: int,
response_object: Any,
- ):
+ ) -> None:
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
@@ -492,12 +512,12 @@ class StaticResource(File):
Differs from the File resource by adding clickjacking protection.
"""
- def render_GET(self, request: Request):
+ def render_GET(self, request: Request) -> bytes:
set_clickjacking_protection_headers(request)
return super().render_GET(request)
-def _unrecognised_request_handler(request):
+def _unrecognised_request_handler(request: Request) -> NoReturn:
"""Request handler for unrecognised requests
This is a request handler suitable for return from
@@ -505,7 +525,7 @@ def _unrecognised_request_handler(request):
UnrecognizedRequestError.
Args:
- request (twisted.web.http.Request):
+ request: Unused, but passed in to match the signature of ServletCallback.
"""
raise UnrecognizedRequestError()
@@ -513,23 +533,23 @@ def _unrecognised_request_handler(request):
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
- def __init__(self, path):
- resource.Resource.__init__(self)
+ def __init__(self, path: str):
+ super().__init__()
self.url = path
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
return redirectTo(self.url.encode("ascii"), request)
- def getChild(self, name, request):
+ def getChild(self, name: str, request: Request) -> resource.Resource:
if len(name) == 0:
return self # select ourselves as the child to render
- return resource.Resource.getChild(self, name, request)
+ return super().getChild(name, request)
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
- def render_OPTIONS(self, request):
+ def render_OPTIONS(self, request: Request) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@@ -537,10 +557,10 @@ def render_OPTIONS(self, request):
return b""
- def getChildWithDefault(self, path, request):
+ def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
if request.method == b"OPTIONS":
return self # select ourselves as the child to render
- return resource.Resource.getChildWithDefault(self, path, request)
+ return super().getChildWithDefault(path, request)
class RootOptionsRedirectResource(OptionsResource, RootRedirect):
@@ -649,7 +669,7 @@ def respond_with_json(
json_object: Any,
send_cors: bool = False,
canonical_json: bool = True,
-):
+) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
@@ -696,7 +716,7 @@ def respond_with_json_bytes(
code: int,
json_bytes: bytes,
send_cors: bool = False,
-):
+) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
@@ -713,7 +733,7 @@ def respond_with_json_bytes(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
- return
+ return None
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -731,7 +751,7 @@ async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
-):
+) -> None:
"""Encodes the given JSON object on a thread and then writes it to the
request.
@@ -743,7 +763,20 @@ async def _async_write_json_to_request_in_thread(
expensive.
"""
- json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
+ def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes:
+ # it might take a while for the threadpool to schedule us, so we write
+ # opentracing logs once we actually get scheduled, so that we can see how
+ # much that contributed.
+ if opentracing_span:
+ opentracing_span.log_kv({"event": "scheduled"})
+ res = json_encoder(json_object)
+ if opentracing_span:
+ opentracing_span.log_kv({"event": "encoded"})
+ return res
+
+ with start_active_span("encode_json_response"):
+ span = active_span()
+ json_str = await defer_to_thread(request.reactor, encode, span)
_write_bytes_to_request(request, json_str)
@@ -773,7 +806,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
-def set_cors_headers(request: Request):
+def set_cors_headers(request: Request) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -790,14 +823,14 @@ def set_cors_headers(request: Request):
)
-def respond_with_html(request: Request, code: int, html: str):
+def respond_with_html(request: Request, code: int, html: str) -> None:
"""
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
"""
respond_with_html_bytes(request, code, html.encode("utf-8"))
-def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
+def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
"""
Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
@@ -815,7 +848,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
- return
+ return None
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@@ -828,7 +861,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
finish_request(request)
-def set_clickjacking_protection_headers(request: Request):
+def set_clickjacking_protection_headers(request: Request) -> None:
"""
Set headers to guard against clickjacking of embedded content.
@@ -850,7 +883,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
finish_request(request)
-def finish_request(request: Request):
+def finish_request(request: Request) -> None:
"""Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6dd9b9ad0358..4ff840ca0ef8 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,6 +14,7 @@
""" This module contains base REST classes for constructing REST servlets. """
import logging
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Iterable,
@@ -30,6 +31,7 @@
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder
@@ -137,11 +139,15 @@ def parse_integer_from_args(
return int(args[name_bytes][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
- raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
else:
if required:
message = "Missing integer query parameter %r" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+ )
else:
return default
@@ -246,11 +252,15 @@ def parse_boolean_from_args(
message = (
"Boolean query parameter %r must be one of ['true', 'false']"
) % (name,)
- raise SynapseError(400, message)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+ )
else:
return default
@@ -313,7 +323,7 @@ def parse_bytes_from_args(
return args[name_bytes][0]
elif required:
message = "Missing string query parameter %s" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
return default
@@ -407,14 +417,16 @@ def _parse_string_value(
try:
value_str = value.decode(encoding)
except ValueError:
- raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
+ )
if allowed_values is not None and value_str not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name,
", ".join(repr(v) for v in allowed_values),
)
- raise SynapseError(400, message)
+ raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
else:
return value_str
@@ -510,7 +522,9 @@ def parse_strings_from_args(
else:
if required:
message = "Missing string query parameter %r" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+ )
return default
@@ -638,7 +652,7 @@ def parse_json_value_from_request(
try:
content_bytes = request.content.read() # type: ignore
except Exception:
- raise SynapseError(400, "Error reading JSON content.")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
@@ -647,7 +661,9 @@ def parse_json_value_from_request(
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
- raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
+ )
return content
@@ -673,7 +689,7 @@ def parse_json_object_from_request(
if not isinstance(content, dict):
message = "Content must be a JSON object."
- raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+ raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
return content
@@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent.append(k)
if len(absent) > 0:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
+ )
class RestServlet:
@@ -709,7 +727,7 @@ class attribute containing a pre-compiled regular expression. The automatic
into the appropriate HTTP response.
"""
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
"""Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
@@ -758,10 +776,12 @@ async def resolve_room_id(
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ HTTPStatus.BAD_REQUEST,
+ "%s was not legal room ID or room alias" % (room_identifier,),
)
if not resolved_room_id:
raise SynapseError(
- 400, "Unknown room ID or room alias %s" % room_identifier
+ HTTPStatus.BAD_REQUEST,
+ "Unknown room ID or room alias %s" % room_identifier,
)
return resolved_room_id, remote_room_hosts
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 755ad56637da..c180a1d3231b 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Generator, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
@@ -35,6 +35,9 @@
)
from synapse.types import Requester
+if TYPE_CHECKING:
+ import opentracing
+
logger = logging.getLogger(__name__)
_next_request_seq = 0
@@ -66,9 +69,9 @@ def __init__(
self,
channel: HTTPChannel,
site: "SynapseSite",
- *args,
+ *args: Any,
max_request_body_size: int = 1024,
- **kw,
+ **kw: Any,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
@@ -81,6 +84,10 @@ def __init__(
# server name, for client requests this is the Requester object.
self._requester: Optional[Union[Requester, str]] = None
+ # An opentracing span for this request. Will be closed when the request is
+ # completely processed.
+ self._opentracing_span: "Optional[opentracing.Span]" = None
+
# we can't yet create the logcontext, as we don't know the method.
self.logcontext: Optional[LoggingContext] = None
@@ -148,6 +155,13 @@ def requester(self, value: Union[Requester, str]) -> None:
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
+ def set_opentracing_span(self, span: "opentracing.Span") -> None:
+ """attach an opentracing span to this request
+
+ Doing so will cause the span to be closed when we finish processing the request
+ """
+ self._opentracing_span = span
+
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
@@ -286,6 +300,9 @@ async def handle_request(request):
self._processing_finished_time = time.time()
self._is_processing = False
+ if self._opentracing_span:
+ self._opentracing_span.log_kv({"event": "finished processing"})
+
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
if self.finish_time is not None:
@@ -299,6 +316,8 @@ def finish(self) -> None:
"""
self.finish_time = time.time()
Request.finish(self)
+ if self._opentracing_span:
+ self._opentracing_span.log_kv({"event": "response sent"})
if not self._is_processing:
assert self.logcontext is not None
with PreserveLoggingContext(self.logcontext):
@@ -333,6 +352,11 @@ def connectionLost(self, reason: Union[Failure, Exception]) -> None:
with PreserveLoggingContext(self.logcontext):
logger.info("Connection from client lost before response was sent")
+ if self._opentracing_span:
+ self._opentracing_span.log_kv(
+ {"event": "client connection lost", "reason": str(reason.value)}
+ )
+
if not self._is_processing:
self._finished_processing()
@@ -421,6 +445,10 @@ def _finished_processing(self) -> None:
usage.evt_db_fetch_count,
)
+ # complete the opentracing span, if any.
+ if self._opentracing_span:
+ self._opentracing_span.finish()
+
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
@@ -506,9 +534,9 @@ def getClientAddress(self) -> IAddress:
@implementer(IAddress)
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class _XForwardedForAddress:
- host = attr.ib(type=str)
+ host: str
class SynapseSite(Site):
@@ -557,7 +585,7 @@ def __init__(
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
- def request_factory(channel, queued: bool) -> Request:
+ def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel,
self,
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 8202d0494d72..475756f1db64 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -39,7 +39,7 @@
logger = logging.getLogger(__name__)
-@attr.s
+@attr.s(slots=True, auto_attribs=True)
@implementer(IPushProducer)
class LogProducer:
"""
@@ -54,10 +54,10 @@ class LogProducer:
# This is essentially ITCPTransport, but that is missing certain fields
# (connected and registerProducer) which are part of the implementation.
- transport = attr.ib(type=Connection)
- _format = attr.ib(type=Callable[[logging.LogRecord], str])
- _buffer = attr.ib(type=deque)
- _paused = attr.ib(default=False, type=bool, init=False)
+ transport: Connection
+ _format: Callable[[logging.LogRecord], str]
+ _buffer: Deque[logging.LogRecord]
+ _paused: bool = attr.ib(default=False, init=False)
def pauseProducing(self):
self._paused = True
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index d8ae3188b7da..c31c2960ad95 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,20 +22,33 @@
See doc/log_contexts.rst for details on how this works.
"""
-import inspect
import logging
import threading
import typing
import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
+from twisted.python.threadpool import ThreadPool
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
+ from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@@ -55,7 +68,6 @@
def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return resource.getrusage(RUSAGE_THREAD)
-
except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage.
@@ -66,7 +78,7 @@ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
-def logcontext_error(msg: str):
+def logcontext_error(msg: str) -> None:
logger.warning(msg)
@@ -181,7 +193,7 @@ def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
return res
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class ContextRequest:
"""
A bundle of attributes from the SynapseRequest object.
@@ -193,15 +205,15 @@ class ContextRequest:
their children.
"""
- request_id = attr.ib(type=str)
- ip_address = attr.ib(type=str)
- site_tag = attr.ib(type=str)
- requester = attr.ib(type=Optional[str])
- authenticated_entity = attr.ib(type=Optional[str])
- method = attr.ib(type=str)
- url = attr.ib(type=str)
- protocol = attr.ib(type=str)
- user_agent = attr.ib(type=str)
+ request_id: str
+ ip_address: str
+ site_tag: str
+ requester: Optional[str]
+ authenticated_entity: Optional[str]
+ method: str
+ url: str
+ protocol: str
+ user_agent: str
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
@@ -223,22 +235,19 @@ def __init__(self) -> None:
def __str__(self) -> str:
return "sentinel"
- def copy_to(self, record):
- pass
-
- def start(self, rusage: "Optional[resource.struct_rusage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
- def stop(self, rusage: "Optional[resource.struct_rusage]"):
+ def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
- def add_database_transaction(self, duration_sec):
+ def add_database_transaction(self, duration_sec: float) -> None:
pass
- def add_database_scheduled(self, sched_sec):
+ def add_database_scheduled(self, sched_sec: float) -> None:
pass
- def record_event_fetch(self, event_count):
+ def record_event_fetch(self, event_count: int) -> None:
pass
def __bool__(self) -> Literal[False]:
@@ -379,7 +388,12 @@ def __enter__(self) -> "LoggingContext":
)
return self
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@@ -399,17 +413,6 @@ def __exit__(self, type, value, traceback) -> None:
# recorded against the correct metrics.
self.finished = True
- def copy_to(self, record) -> None:
- """Copy logging fields from this context to a log record or
- another LoggingContext
- """
-
- # we track the current request
- record.request = self.request
-
- # we also track the current scope:
- record.scope = self.scope
-
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@@ -626,7 +629,12 @@ def __init__(
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
context = set_current_context(self._old_context)
if context != self._new_context:
@@ -711,16 +719,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)
-def preserve_fn(f):
+R = TypeVar("R")
+
+
+@overload
+def preserve_fn( # type: ignore[misc]
+ f: Callable[..., Awaitable[R]],
+) -> Callable[..., "defer.Deferred[R]"]:
+ # The `type: ignore[misc]` above suppresses
+ # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+ ...
+
+
+@overload
+def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
+ ...
+
+
+def preserve_fn(
+ f: Union[
+ Callable[..., R],
+ Callable[..., Awaitable[R]],
+ ]
+) -> Callable[..., "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background"""
- def g(*args, **kwargs):
+ def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
return run_in_background(f, *args, **kwargs)
return g
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
+@overload
+def run_in_background( # type: ignore[misc]
+ f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+ # The `type: ignore[misc]` above suppresses
+ # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+ ...
+
+
+@overload
+def run_in_background(
+ f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+ ...
+
+
+def run_in_background(
+ f: Union[
+ Callable[..., R],
+ Callable[..., Awaitable[R]],
+ ],
+ *args: Any,
+ **kwargs: Any,
+) -> "defer.Deferred[R]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -751,6 +804,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
+ # `res` is not a `Deferred` and not a `Coroutine`.
+ # There are no other types of `Awaitable`s we expect to encounter in Synapse.
+ assert not isinstance(res, Awaitable)
+
return defer.succeed(res)
if res.called and not res.paused:
@@ -778,13 +835,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
return res
-def make_deferred_yieldable(deferred):
- """Given a deferred (or coroutine), make it follow the Synapse logcontext
- rules:
+T = TypeVar("T")
- If the deferred has completed (or is not actually a Deferred), essentially
- does nothing (just returns another completed deferred with the
- result/failure).
+
+def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed, essentially does nothing (just returns another
+ completed deferred with the result/failure).
If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
@@ -792,16 +850,6 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
- if inspect.isawaitable(deferred):
- # If we're given a coroutine we convert it to a deferred so that we
- # run it and find out if it immediately finishes, it it does then we
- # don't need to fiddle with log contexts at all and can return
- # immediately.
- deferred = defer.ensureDeferred(deferred)
-
- if not isinstance(deferred, defer.Deferred):
- return deferred
-
if deferred.called and not deferred.paused:
# it looks like this deferred is ready to run any callbacks we give it
# immediately. We may as well optimise out the logcontext faffery.
@@ -823,7 +871,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result
-def defer_to_thread(reactor, f, *args, **kwargs):
+def defer_to_thread(
+ reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@@ -855,7 +905,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+def defer_to_threadpool(
+ reactor: "ISynapseReactor",
+ threadpool: ThreadPool,
+ f: Callable[..., R],
+ *args: Any,
+ **kwargs: Any,
+) -> "defer.Deferred[R]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@@ -897,7 +953,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
- def g():
+ def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 20d23a426064..b240d2d21da2 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
import attr
from twisted.internet import defer
+from twisted.web.http import Request
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
@@ -219,11 +220,12 @@ class _DummyTagNames:
try:
import opentracing
+ import opentracing.tags
tags = opentracing.tags
except ImportError:
- opentracing = None
- tags = _DummyTagNames
+ opentracing = None # type: ignore[assignment]
+ tags = _DummyTagNames # type: ignore[assignment]
try:
from jaeger_client import Config as JaegerConfig
@@ -245,11 +247,11 @@ class _DummyTagNames:
class BaseReporter: # type: ignore[no-redef]
pass
- @attr.s(slots=True, frozen=True)
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
class _WrappedRustReporter(BaseReporter):
"""Wrap the reporter to ensure `report_span` never throws."""
- _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
+ _reporter: Reporter = attr.Factory(Reporter)
def set_process(self, *args, **kwargs):
return self._reporter.set_process(*args, **kwargs)
@@ -366,7 +368,7 @@ def init_tracer(hs: "HomeServer"):
global opentracing
if not hs.config.tracing.opentracer_enabled:
# We don't have a tracer
- opentracing = None
+ opentracing = None # type: ignore[assignment]
return
if not opentracing or not JaegerConfig:
@@ -452,7 +454,7 @@ def start_active_span(
"""
if opentracing is None:
- return noop_context_manager()
+ return noop_context_manager() # type: ignore[unreachable]
return opentracing.tracer.start_active_span(
operation_name,
@@ -477,7 +479,7 @@ def start_active_span_follows_from(
forced, the new span will also have tracing forced.
"""
if opentracing is None:
- return noop_context_manager()
+ return noop_context_manager() # type: ignore[unreachable]
references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(operation_name, references=references)
@@ -490,48 +492,6 @@ def start_active_span_follows_from(
return scope
-def start_active_span_from_request(
- request,
- operation_name,
- references=None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
-):
- """
- Extracts a span context from a Twisted Request.
- args:
- headers (twisted.web.http.Request)
-
- For the other args see opentracing.tracer
-
- returns:
- span_context (opentracing.span.SpanContext)
- """
- # Twisted encodes the values as lists whereas opentracing doesn't.
- # So, we take the first item in the list.
- # Also, twisted uses byte arrays while opentracing expects strings.
-
- if opentracing is None:
- return noop_context_manager()
-
- header_dict = {
- k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
- }
- context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
-
- return opentracing.tracer.start_active_span(
- operation_name,
- child_of=context,
- references=references,
- tags=tags,
- start_time=start_time,
- ignore_active_span=ignore_active_span,
- finish_on_close=finish_on_close,
- )
-
-
def start_active_span_from_edu(
edu_content,
operation_name,
@@ -553,7 +513,7 @@ def start_active_span_from_edu(
references = references or []
if opentracing is None:
- return noop_context_manager()
+ return noop_context_manager() # type: ignore[unreachable]
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
@@ -594,18 +554,21 @@ def active_span():
@ensure_active_span("set a tag")
def set_tag(key, value):
"""Sets a tag on the active span"""
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log")
def log_kv(key_values, timestamp=None):
"""Log to the active span"""
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name")
def set_operation_name(operation_name):
"""Sets the operation name of the active span"""
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
@@ -674,6 +637,7 @@ def inject_header_dict(
span = opentracing.tracer.active_span
carrier: Dict[str, str] = {}
+ assert span is not None
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -716,6 +680,7 @@ def get_active_span_text_map(destination=None):
return {}
carrier: Dict[str, str] = {}
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
@@ -731,12 +696,27 @@ def active_span_context_as_string():
"""
carrier: Dict[str, str] = {}
if opentracing:
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
return json_encoder.encode(carrier)
+def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
+ """Extract an opentracing context from the headers on an HTTP request
+
+ This is useful when we have received an HTTP request from another part of our
+ system, and want to link our spans to those of the remote system.
+ """
+ if not opentracing:
+ return None
+ header_dict = {
+ k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+ }
+ return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
+
+
@only_if_tracing
def span_context_from_string(carrier):
"""
@@ -773,7 +753,7 @@ def trace(func=None, opname=None):
def decorator(func):
if opentracing is None:
- return func
+ return func # type: ignore[unreachable]
_opname = opname if opname else func.__name__
@@ -864,7 +844,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
"""
if opentracing is None:
- yield
+ yield # type: ignore[unreachable]
return
request_tags = {
@@ -876,10 +856,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
}
request_name = request.request_metrics.name
- if extract_context:
- scope = start_active_span_from_request(request, request_name)
- else:
- scope = start_active_span(request_name)
+ context = span_context_from_request(request) if extract_context else None
+
+ # we configure the scope not to finish the span immediately on exit, and instead
+ # pass the span into the SynapseRequest, which will finish it once we've finished
+ # sending the response to the client.
+ scope = start_active_span(request_name, child_of=context, finish_on_close=False)
+ request.set_opentracing_span(scope.span)
with scope:
inject_response_headers(request.responseHeaders)
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index b1e8e08fe96f..db8ca2c0497b 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -71,7 +71,7 @@ def activate(self, span, finish_on_close):
if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
- return Scope(None, span)
+ return Scope(None, span) # type: ignore[arg-type]
elif ctx.scope is not None:
# We want the logging scope to look exactly the same so we give it
# a blank suffix
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
deleted file mode 100644
index 4a01b902c255..000000000000
--- a/synapse/logging/utils.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-from functools import wraps
-from inspect import getcallargs
-from typing import Callable, TypeVar, cast
-
-_TIME_FUNC_ID = 0
-
-
-def _log_debug_as_f(f, msg, msg_args):
- name = f.__module__
- logger = logging.getLogger(name)
-
- if logger.isEnabledFor(logging.DEBUG):
- lineno = f.__code__.co_firstlineno
- pathname = f.__code__.co_filename
-
- record = logger.makeRecord(
- name=name,
- level=logging.DEBUG,
- fn=pathname,
- lno=lineno,
- msg=msg,
- args=msg_args,
- exc_info=None,
- )
-
- logger.handle(record)
-
-
-F = TypeVar("F", bound=Callable)
-
-
-def log_function(f: F) -> F:
- """Function decorator that logs every call to that function."""
- func_name = f.__name__
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- name = f.__module__
- logger = logging.getLogger(name)
- level = logging.DEBUG
-
- if logger.isEnabledFor(level):
- bound_args = getcallargs(f, *args, **kwargs)
-
- def format(value):
- r = str(value)
- if len(r) > 50:
- r = r[:50] + "..."
- return r
-
- func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()]
-
- msg_args = {"func_name": func_name, "args": ", ".join(func_args)}
-
- _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args)
-
- return f(*args, **kwargs)
-
- wrapped.__name__ = func_name
- return cast(F, wrapped)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index ceef57ad883f..9e6c1b2f3b54 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -12,16 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import functools
-import gc
import itertools
import logging
import os
import platform
import threading
-import time
from typing import (
- Any,
Callable,
Dict,
Generic,
@@ -34,35 +30,31 @@
Type,
TypeVar,
Union,
- cast,
)
import attr
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
from prometheus_client.core import (
REGISTRY,
- CounterMetricFamily,
GaugeHistogramMetricFamily,
GaugeMetricFamily,
)
-from twisted.internet import reactor
-from twisted.internet.base import ReactorBase
from twisted.python.threadpool import ThreadPool
-import synapse
+import synapse.metrics._reactor_metrics
from synapse.metrics._exposition import (
MetricsResource,
generate_latest,
start_http_server,
)
+from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
-running_on_pypy = platform.python_implementation() == "PyPy"
all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@@ -76,19 +68,17 @@ def collect() -> Iterable[Metric]:
yield metric
-@attr.s(slots=True, hash=True)
+@attr.s(slots=True, hash=True, auto_attribs=True)
class LaterGauge:
- name = attr.ib(type=str)
- desc = attr.ib(type=str)
- labels = attr.ib(hash=False, type=Optional[Iterable[str]])
+ name: str
+ desc: str
+ labels: Optional[Iterable[str]] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
- caller = attr.ib(
- type=Callable[
- [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
- ]
- )
+ caller: Callable[
+ [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
+ ]
def collect(self) -> Iterable[Metric]:
@@ -157,7 +147,9 @@ def __init__(
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class: Type[MetricsEntry] = attr.make_class(
- "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
+ "_MetricsEntry",
+ attrs={x: attr.ib(default=0) for x in sub_metrics},
+ slots=True,
)
# Counts number of in flight blocks for a given set of label values
@@ -369,136 +361,6 @@ def collect(self) -> Iterable[Metric]:
REGISTRY.register(CPUMetrics())
-#
-# Python GC metrics
-#
-
-gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
-gc_time = Histogram(
- "python_gc_time",
- "Time taken to GC (sec)",
- ["gen"],
- buckets=[
- 0.0025,
- 0.005,
- 0.01,
- 0.025,
- 0.05,
- 0.10,
- 0.25,
- 0.50,
- 1.00,
- 2.50,
- 5.00,
- 7.50,
- 15.00,
- 30.00,
- 45.00,
- 60.00,
- ],
-)
-
-
-class GCCounts:
- def collect(self) -> Iterable[Metric]:
- cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
- for n, m in enumerate(gc.get_count()):
- cm.add_metric([str(n)], m)
-
- yield cm
-
-
-if not running_on_pypy:
- REGISTRY.register(GCCounts())
-
-
-#
-# PyPy GC / memory metrics
-#
-
-
-class PyPyGCStats:
- def collect(self) -> Iterable[Metric]:
-
- # @stats is a pretty-printer object with __str__() returning a nice table,
- # plus some fields that contain data from that table.
- # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
- stats = gc.get_stats(memory_pressure=False) # type: ignore
- # @s contains same fields as @stats, but as actual integers.
- s = stats._s # type: ignore
-
- # also note that field naming is completely braindead
- # and only vaguely correlates with the pretty-printed table.
- # >>>> gc.get_stats(False)
- # Total memory consumed:
- # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory
- # in arenas: 3.0MB # s.total_arena_memory
- # rawmalloced: 1.7MB # s.total_rawmalloced_memory
- # nursery: 4.0MB # s.nursery_size
- # raw assembler used: 31.0kB # s.jit_backend_used
- # -----------------------------
- # Total: 8.8MB # stats.memory_used_sum
- #
- # Total memory allocated:
- # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory
- # in arenas: 30.9MB # s.peak_arena_memory
- # rawmalloced: 4.1MB # s.peak_rawmalloced_memory
- # nursery: 4.0MB # s.nursery_size
- # raw assembler allocated: 1.0MB # s.jit_backend_allocated
- # -----------------------------
- # Total: 39.7MB # stats.memory_allocated_sum
- #
- # Total time spent in GC: 0.073 # s.total_gc_time
-
- pypy_gc_time = CounterMetricFamily(
- "pypy_gc_time_seconds_total",
- "Total time spent in PyPy GC",
- labels=[],
- )
- pypy_gc_time.add_metric([], s.total_gc_time / 1000)
- yield pypy_gc_time
-
- pypy_mem = GaugeMetricFamily(
- "pypy_memory_bytes",
- "Memory tracked by PyPy allocator",
- labels=["state", "class", "kind"],
- )
- # memory used by JIT assembler
- pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used)
- pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated)
- # memory used by GCed objects
- pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory)
- pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory)
- pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory)
- pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory)
- pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size)
- pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size)
- # totals
- pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory)
- pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory)
- pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory)
- pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory)
- yield pypy_mem
-
-
-if running_on_pypy:
- REGISTRY.register(PyPyGCStats())
-
-
-#
-# Twisted reactor metrics
-#
-
-tick_time = Histogram(
- "python_twisted_reactor_tick_time",
- "Tick time of the Twisted reactor (sec)",
- buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
-)
-pending_calls_metric = Histogram(
- "python_twisted_reactor_pending_calls",
- "Pending calls",
- buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000],
-)
#
# Federation Metrics
@@ -551,8 +413,6 @@ def collect(self) -> Iterable[Metric]:
" ".join([platform.system(), platform.release()]),
).set(1)
-last_ticked = time.time()
-
# 3PID send info
threepid_send_requests = Histogram(
"synapse_threepid_send_requests_with_tries",
@@ -600,116 +460,6 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
)
-class ReactorLastSeenMetric:
- def collect(self) -> Iterable[Metric]:
- cm = GaugeMetricFamily(
- "python_twisted_reactor_last_seen",
- "Seconds since the Twisted reactor was last seen",
- )
- cm.add_metric([], time.time() - last_ticked)
- yield cm
-
-
-REGISTRY.register(ReactorLastSeenMetric())
-
-# The minimum time in seconds between GCs for each generation, regardless of the current GC
-# thresholds and counts.
-MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
-
-# The time (in seconds since the epoch) of the last time we did a GC for each generation.
-_last_gc = [0.0, 0.0, 0.0]
-
-
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
- @functools.wraps(func)
- def f(*args: Any, **kwargs: Any) -> Any:
- now = reactor.seconds()
- num_pending = 0
-
- # _newTimedCalls is one long list of *all* pending calls. Below loop
- # is based off of impl of reactor.runUntilCurrent
- for delayed_call in reactor._newTimedCalls:
- if delayed_call.time > now:
- break
-
- if delayed_call.delayed_time > 0:
- continue
-
- num_pending += 1
-
- num_pending += len(reactor.threadCallQueue)
- start = time.time()
- ret = func(*args, **kwargs)
- end = time.time()
-
- # record the amount of wallclock time spent running pending calls.
- # This is a proxy for the actual amount of time between reactor polls,
- # since about 25% of time is actually spent running things triggered by
- # I/O events, but that is harder to capture without rewriting half the
- # reactor.
- tick_time.observe(end - start)
- pending_calls_metric.observe(num_pending)
-
- # Update the time we last ticked, for the metric to test whether
- # Synapse's reactor has frozen
- global last_ticked
- last_ticked = end
-
- if running_on_pypy:
- return ret
-
- # Check if we need to do a manual GC (since its been disabled), and do
- # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
- # promote an object into gen 2, and we don't want to handle the same
- # object multiple times.
- threshold = gc.get_threshold()
- counts = gc.get_count()
- for i in (2, 1, 0):
- # We check if we need to do one based on a straightforward
- # comparison between the threshold and count. We also do an extra
- # check to make sure that we don't a GC too often.
- if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
- if i == 0:
- logger.debug("Collecting gc %d", i)
- else:
- logger.info("Collecting gc %d", i)
-
- start = time.time()
- unreachable = gc.collect(i)
- end = time.time()
-
- _last_gc[i] = end
-
- gc_time.labels(i).observe(end - start)
- gc_unreachable.labels(i).set(unreachable)
-
- return ret
-
- return cast(F, f)
-
-
-try:
- # Ensure the reactor has all the attributes we expect
- reactor.seconds # type: ignore
- reactor.runUntilCurrent # type: ignore
- reactor._newTimedCalls # type: ignore
- reactor.threadCallQueue # type: ignore
-
- # runUntilCurrent is called when we have pending calls. It is called once
- # per iteratation after fd polling.
- reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
-
- # We manually run the GC each reactor tick so that we can get some metrics
- # about time spent doing GC,
- if not running_on_pypy:
- gc.disable()
-except AttributeError:
- pass
-
-
__all__ = [
"MetricsResource",
"generate_latest",
@@ -717,4 +467,6 @@ def f(*args: Any, **kwargs: Any) -> Any:
"LaterGauge",
"InFlightGauge",
"GaugeBucketCollector",
+ "MIN_TIME_BETWEEN_GCS",
+ "install_gc_manager",
]
diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py
new file mode 100644
index 000000000000..2bc909efa0d3
--- /dev/null
+++ b/synapse/metrics/_gc.py
@@ -0,0 +1,203 @@
+# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import gc
+import logging
+import platform
+import time
+from typing import Iterable
+
+from prometheus_client.core import (
+ REGISTRY,
+ CounterMetricFamily,
+ Gauge,
+ GaugeMetricFamily,
+ Histogram,
+ Metric,
+)
+
+from twisted.internet import task
+
+"""Prometheus metrics for garbage collection"""
+
+
+logger = logging.getLogger(__name__)
+
+# The minimum time in seconds between GCs for each generation, regardless of the current GC
+# thresholds and counts.
+MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
+
+running_on_pypy = platform.python_implementation() == "PyPy"
+
+#
+# Python GC metrics
+#
+
+gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
+gc_time = Histogram(
+ "python_gc_time",
+ "Time taken to GC (sec)",
+ ["gen"],
+ buckets=[
+ 0.0025,
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.10,
+ 0.25,
+ 0.50,
+ 1.00,
+ 2.50,
+ 5.00,
+ 7.50,
+ 15.00,
+ 30.00,
+ 45.00,
+ 60.00,
+ ],
+)
+
+
+class GCCounts:
+ def collect(self) -> Iterable[Metric]:
+ cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
+ for n, m in enumerate(gc.get_count()):
+ cm.add_metric([str(n)], m)
+
+ yield cm
+
+
+def install_gc_manager() -> None:
+ """Disable automatic GC, and replace it with a task that runs every 100ms
+
+ This means that (a) we can limit how often GC runs; (b) we can get some metrics
+ about GC activity.
+
+ It does nothing on PyPy.
+ """
+
+ if running_on_pypy:
+ return
+
+ REGISTRY.register(GCCounts())
+
+ gc.disable()
+
+ # The time (in seconds since the epoch) of the last time we did a GC for each generation.
+ _last_gc = [0.0, 0.0, 0.0]
+
+ def _maybe_gc() -> None:
+ # Check if we need to do a manual GC (since its been disabled), and do
+ # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
+ # promote an object into gen 2, and we don't want to handle the same
+ # object multiple times.
+ threshold = gc.get_threshold()
+ counts = gc.get_count()
+ end = time.time()
+ for i in (2, 1, 0):
+ # We check if we need to do one based on a straightforward
+ # comparison between the threshold and count. We also do an extra
+ # check to make sure that we don't a GC too often.
+ if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
+ if i == 0:
+ logger.debug("Collecting gc %d", i)
+ else:
+ logger.info("Collecting gc %d", i)
+
+ start = time.time()
+ unreachable = gc.collect(i)
+ end = time.time()
+
+ _last_gc[i] = end
+
+ gc_time.labels(i).observe(end - start)
+ gc_unreachable.labels(i).set(unreachable)
+
+ gc_task = task.LoopingCall(_maybe_gc)
+ gc_task.start(0.1)
+
+
+#
+# PyPy GC / memory metrics
+#
+
+
+class PyPyGCStats:
+ def collect(self) -> Iterable[Metric]:
+
+ # @stats is a pretty-printer object with __str__() returning a nice table,
+ # plus some fields that contain data from that table.
+ # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
+ stats = gc.get_stats(memory_pressure=False) # type: ignore
+ # @s contains same fields as @stats, but as actual integers.
+ s = stats._s # type: ignore
+
+ # also note that field naming is completely braindead
+ # and only vaguely correlates with the pretty-printed table.
+ # >>>> gc.get_stats(False)
+ # Total memory consumed:
+ # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory
+ # in arenas: 3.0MB # s.total_arena_memory
+ # rawmalloced: 1.7MB # s.total_rawmalloced_memory
+ # nursery: 4.0MB # s.nursery_size
+ # raw assembler used: 31.0kB # s.jit_backend_used
+ # -----------------------------
+ # Total: 8.8MB # stats.memory_used_sum
+ #
+ # Total memory allocated:
+ # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory
+ # in arenas: 30.9MB # s.peak_arena_memory
+ # rawmalloced: 4.1MB # s.peak_rawmalloced_memory
+ # nursery: 4.0MB # s.nursery_size
+ # raw assembler allocated: 1.0MB # s.jit_backend_allocated
+ # -----------------------------
+ # Total: 39.7MB # stats.memory_allocated_sum
+ #
+ # Total time spent in GC: 0.073 # s.total_gc_time
+
+ pypy_gc_time = CounterMetricFamily(
+ "pypy_gc_time_seconds_total",
+ "Total time spent in PyPy GC",
+ labels=[],
+ )
+ pypy_gc_time.add_metric([], s.total_gc_time / 1000)
+ yield pypy_gc_time
+
+ pypy_mem = GaugeMetricFamily(
+ "pypy_memory_bytes",
+ "Memory tracked by PyPy allocator",
+ labels=["state", "class", "kind"],
+ )
+ # memory used by JIT assembler
+ pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used)
+ pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated)
+ # memory used by GCed objects
+ pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory)
+ pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory)
+ pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory)
+ pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory)
+ pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size)
+ pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size)
+ # totals
+ pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory)
+ pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory)
+ pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory)
+ pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory)
+ yield pypy_mem
+
+
+if running_on_pypy:
+ REGISTRY.register(PyPyGCStats())
diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py
new file mode 100644
index 000000000000..f38f7983131f
--- /dev/null
+++ b/synapse/metrics/_reactor_metrics.py
@@ -0,0 +1,83 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import select
+import time
+from typing import Any, Iterable, List, Tuple
+
+from prometheus_client import Histogram, Metric
+from prometheus_client.core import REGISTRY, GaugeMetricFamily
+
+from twisted.internet import reactor
+
+#
+# Twisted reactor metrics
+#
+
+tick_time = Histogram(
+ "python_twisted_reactor_tick_time",
+ "Tick time of the Twisted reactor (sec)",
+ buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
+)
+
+
+class EpollWrapper:
+ """a wrapper for an epoll object which records the time between polls"""
+
+ def __init__(self, poller: "select.epoll"): # type: ignore[name-defined]
+ self.last_polled = time.time()
+ self._poller = poller
+
+ def poll(self, *args, **kwargs) -> List[Tuple[int, int]]: # type: ignore[no-untyped-def]
+ # record the time since poll() was last called. This gives a good proxy for
+ # how long it takes to run everything in the reactor - ie, how long anything
+ # waiting for the next tick will have to wait.
+ tick_time.observe(time.time() - self.last_polled)
+
+ ret = self._poller.poll(*args, **kwargs)
+
+ self.last_polled = time.time()
+ return ret
+
+ def __getattr__(self, item: str) -> Any:
+ return getattr(self._poller, item)
+
+
+class ReactorLastSeenMetric:
+ def __init__(self, epoll_wrapper: EpollWrapper):
+ self._epoll_wrapper = epoll_wrapper
+
+ def collect(self) -> Iterable[Metric]:
+ cm = GaugeMetricFamily(
+ "python_twisted_reactor_last_seen",
+ "Seconds since the Twisted reactor was last seen",
+ )
+ cm.add_metric([], time.time() - self._epoll_wrapper.last_polled)
+ yield cm
+
+
+try:
+ # if the reactor has a `_poller` attribute, which is an `epoll` object
+ # (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will
+ # measure the time between ticks
+ from select import epoll # type: ignore[attr-defined]
+
+ poller = reactor._poller # type: ignore[attr-defined]
+except (AttributeError, ImportError):
+ pass
+else:
+ if isinstance(poller, epoll):
+ poller = EpollWrapper(poller)
+ reactor._poller = poller # type: ignore[attr-defined]
+ REGISTRY.register(ReactorLastSeenMetric(poller))
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 60e5409895e6..632b2245ef55 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -13,7 +13,6 @@
# limitations under the License.
import logging
-from collections import namedtuple
from typing import (
Awaitable,
Callable,
@@ -41,10 +40,15 @@
from synapse.logging import issue9533_logger
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import log_kv, start_active_span
-from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
-from synapse.types import PersistedEventPosition, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ JsonDict,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StreamToken,
+ UserID,
+)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -178,20 +182,25 @@ def new_listener(self, token: StreamToken) -> _NotificationListener:
return _NotificationListener(self.notify_deferred.observe())
-class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventStreamResult:
+ events: List[Union[JsonDict, EventBase]]
+ start_token: StreamToken
+ end_token: StreamToken
+
def __bool__(self):
return bool(self.events)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _PendingRoomEventEntry:
- event_pos = attr.ib(type=PersistedEventPosition)
- extra_users = attr.ib(type=Collection[UserID])
+ event_pos: PersistedEventPosition
+ extra_users: Collection[UserID]
- room_id = attr.ib(type=str)
- type = attr.ib(type=str)
- state_key = attr.ib(type=Optional[str])
- membership = attr.ib(type=Optional[str])
+ room_id: str
+ type: str
+ state_key: Optional[str]
+ membership: Optional[str]
class Notifier:
@@ -582,9 +591,12 @@ async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if after_token == before_token:
- return EventStreamResult([], (from_token, from_token))
+ return EventStreamResult([], from_token, from_token)
- events: List[EventBase] = []
+ # The events fetched from each source are a JsonDict, EventBase, or
+ # UserPresenceState, but see below for UserPresenceState being
+ # converted to JsonDict.
+ events: List[Union[JsonDict, EventBase]] = []
end_token = from_token
for name, source in self.event_sources.sources.get_sources():
@@ -623,7 +635,7 @@ async def check_for_updates(
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
- return EventStreamResult(events, (from_token, end_token))
+ return EventStreamResult(events, from_token, end_token)
user_id_for_stream = user.to_string()
if is_peeking:
@@ -673,7 +685,6 @@ async def _is_world_readable(self, room_id: str) -> bool:
else:
return False
- @log_function
def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec()
expired_streams = []
@@ -687,7 +698,6 @@ def remove_expired_streams(self) -> None:
for expired_stream in expired_streams:
expired_stream.remove(self)
- @log_function
def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 820f6f3f7ec0..5176a1c1861d 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -23,25 +23,25 @@
from synapse.server import HomeServer
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class PusherConfig:
"""Parameters necessary to configure a pusher."""
- id = attr.ib(type=Optional[str])
- user_name = attr.ib(type=str)
- access_token = attr.ib(type=Optional[int])
- profile_tag = attr.ib(type=str)
- kind = attr.ib(type=str)
- app_id = attr.ib(type=str)
- app_display_name = attr.ib(type=str)
- device_display_name = attr.ib(type=str)
- pushkey = attr.ib(type=str)
- ts = attr.ib(type=int)
- lang = attr.ib(type=Optional[str])
- data = attr.ib(type=Optional[JsonDict])
- last_stream_ordering = attr.ib(type=int)
- last_success = attr.ib(type=Optional[int])
- failing_since = attr.ib(type=Optional[int])
+ id: Optional[str]
+ user_name: str
+ access_token: Optional[int]
+ profile_tag: str
+ kind: str
+ app_id: str
+ app_display_name: str
+ device_display_name: str
+ pushkey: str
+ ts: int
+ lang: Optional[str]
+ data: Optional[JsonDict]
+ last_stream_ordering: int
+ last_success: Optional[int]
+ failing_since: Optional[int]
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
@@ -57,12 +57,12 @@ def as_dict(self) -> Dict[str, Any]:
}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class ThrottleParams:
"""Parameters for controlling the rate of sending pushes via email."""
- last_sent_ts = attr.ib(type=int)
- throttle_ms = attr.ib(type=int)
+ last_sent_ts: int
+ throttle_ms: int
class Pusher(metaclass=abc.ABCMeta):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 009d8e77b05b..bee660893be2 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -298,7 +298,7 @@ def _condition_checker(
StateGroup = Union[object, int]
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class RulesForRoomData:
"""The data stored in the cache by `RulesForRoom`.
@@ -307,29 +307,29 @@ class RulesForRoomData:
"""
# event_id -> (user_id, state)
- member_map = attr.ib(type=MemberMap, factory=dict)
+ member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
- rules_by_user = attr.ib(type=RulesByUser, factory=dict)
+ rules_by_user: RulesByUser = attr.Factory(dict)
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
- state_group = attr.ib(type=StateGroup, factory=object)
+ state_group: StateGroup = attr.Factory(object)
# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
# the sequence number changes while we're calculating stuff we should
# not update the cache with it.
- sequence = attr.ib(type=int, default=0)
+ sequence: int = 0
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
- uninteresting_user_set = attr.ib(type=Set[str], factory=set)
+ uninteresting_user_set: Set[str] = attr.Factory(set)
class RulesForRoom:
@@ -553,7 +553,7 @@ def update_cache(
self.data.state_group = state_group
-@attr.attrs(slots=True, frozen=True)
+@attr.attrs(slots=True, frozen=True, auto_attribs=True)
class _Invalidation:
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
@@ -564,8 +564,8 @@ class _Invalidation:
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
# set `frozen=True`.
- cache = attr.ib(type=LruCache)
- room_id = attr.ib(type=str)
+ cache: LruCache
+ room_id: str
def __call__(self) -> None:
rules_data = self.cache.get(self.room_id, None, update_metrics=False)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 4f13c0418ab9..39bb2acae4b5 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -177,12 +177,12 @@ async def _unsafe_process(self) -> None:
return
for push_action in unprocessed:
- received_at = push_action["received_ts"]
+ received_at = push_action.received_ts
if received_at is None:
received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
- room_ready_at = self.room_ready_to_notify_at(push_action["room_id"])
+ room_ready_at = self.room_ready_to_notify_at(push_action.room_id)
should_notify_at = max(notif_ready_at, room_ready_at)
@@ -193,23 +193,23 @@ async def _unsafe_process(self) -> None:
# to be delivered.
reason: EmailReason = {
- "room_id": push_action["room_id"],
+ "room_id": push_action.room_id,
"now": self.clock.time_msec(),
"received_at": received_at,
"delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
- "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]),
- "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
+ "last_sent_ts": self.get_room_last_sent_ts(push_action.room_id),
+ "throttle_ms": self.get_room_throttle_ms(push_action.room_id),
}
await self.send_notification(unprocessed, reason)
await self.save_last_stream_ordering_and_success(
- max(ea["stream_ordering"] for ea in unprocessed)
+ max(ea.stream_ordering for ea in unprocessed)
)
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
- await self.sent_notif_update_throttle(ea["room_id"], ea)
+ await self.sent_notif_update_throttle(ea.room_id, ea)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
@@ -284,10 +284,10 @@ async def sent_notif_update_throttle(
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
- notified_push_action["stream_ordering"]
+ notified_push_action.stream_ordering
)
- time_of_this_notifs = notified_push_action["received_ts"]
+ time_of_this_notifs = notified_push_action.received_ts
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 3fa603ccb7f7..96559081d0bf 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -199,7 +199,7 @@ async def _unsafe_process(self) -> None:
"http-push",
tags={
"authenticated_entity": self.user_id,
- "event_id": push_action["event_id"],
+ "event_id": push_action.event_id,
"app_id": self.app_id,
"app_display_name": self.app_display_name,
},
@@ -209,7 +209,7 @@ async def _unsafe_process(self) -> None:
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.last_stream_ordering = push_action["stream_ordering"]
+ self.last_stream_ordering = push_action.stream_ordering
pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
@@ -252,7 +252,7 @@ async def _unsafe_process(self) -> None:
self.pushkey,
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.last_stream_ordering = push_action["stream_ordering"]
+ self.last_stream_ordering = push_action.stream_ordering
await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
@@ -275,17 +275,17 @@ async def _unsafe_process(self) -> None:
break
async def _process_one(self, push_action: HttpPushAction) -> bool:
- if "notify" not in push_action["actions"]:
+ if "notify" not in push_action.actions:
return True
- tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
+ tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
badge = await push_tools.get_badge_count(
self.hs.get_datastore(),
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
- event = await self.store.get_event(push_action["event_id"], allow_none=True)
+ event = await self.store.get_event(push_action.event_id, allow_none=True)
if event is None:
return True # It's been redacted
rejected = await self.dispatch_push(event, tweaks, badge)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ba4f866487ec..dadfc574134c 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -178,7 +178,7 @@ async def send_registration_mail(
await self.send_email(
email_address,
self.email_subjects.email_validation
- % {"server_name": self.hs.config.server.server_name},
+ % {"server_name": self.hs.config.server.server_name, "app": self.app_name},
template_vars,
)
@@ -209,7 +209,7 @@ async def send_add_threepid_mail(
await self.send_email(
email_address,
self.email_subjects.email_validation
- % {"server_name": self.hs.config.server.server_name},
+ % {"server_name": self.hs.config.server.server_name, "app": self.app_name},
template_vars,
)
@@ -232,15 +232,13 @@ async def send_notification_mail(
reason: The notification that was ready and is the cause of an email
being sent.
"""
- rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
+ rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions])
- notif_events = await self.store.get_events(
- [pa["event_id"] for pa in push_actions]
- )
+ notif_events = await self.store.get_events([pa.event_id for pa in push_actions])
notifs_by_room: Dict[str, List[EmailPushAction]] = {}
for pa in push_actions:
- notifs_by_room.setdefault(pa["room_id"], []).append(pa)
+ notifs_by_room.setdefault(pa.room_id, []).append(pa)
# collect the current state for all the rooms in which we have
# notifications
@@ -264,7 +262,7 @@ async def _fetch_room_state(room_id: str) -> None:
await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
- rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
+ rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0))
rooms: List[RoomVars] = []
@@ -356,7 +354,7 @@ async def _get_room_vars(
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
- ev = notif_events[n["event_id"]]
+ ev = notif_events[n.event_id]
if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE:
is_invite = True
@@ -376,7 +374,7 @@ async def _get_room_vars(
if not is_invite:
for n in notifs:
notifvars = await self._get_notif_vars(
- n, user_id, notif_events[n["event_id"]], room_state_ids
+ n, user_id, notif_events[n.event_id], room_state_ids
)
# merge overlapping notifs together.
@@ -444,15 +442,15 @@ async def _get_notif_vars(
"""
results = await self.store.get_events_around(
- notif["room_id"],
- notif["event_id"],
+ notif.room_id,
+ notif.event_id,
before_limit=CONTEXT_BEFORE,
after_limit=CONTEXT_AFTER,
)
ret: NotifVars = {
"link": self._make_notif_link(notif),
- "ts": notif["received_ts"],
+ "ts": notif.received_ts,
"messages": [],
}
@@ -516,7 +514,7 @@ async def _get_message_vars(
ret: MessageVars = {
"event_type": event.type,
- "is_historical": event.event_id != notif["event_id"],
+ "is_historical": event.event_id != notif.event_id,
"id": event.event_id,
"ts": event.origin_server_ts,
"sender_name": sender_name,
@@ -610,7 +608,7 @@ async def _make_summary_text_single_room(
# See if one of the notifs is an invite event for the user
invite_event = None
for n in notifs:
- ev = notif_events[n["event_id"]]
+ ev = notif_events[n.event_id]
if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE:
invite_event = ev
@@ -659,7 +657,7 @@ async def _make_summary_text_single_room(
if len(notifs) == 1:
# There is just the one notification, so give some detail
sender_name = None
- event = notif_events[notifs[0]["event_id"]]
+ event = notif_events[notifs[0].event_id]
if ("m.room.member", event.sender) in room_state_ids:
state_event_id = room_state_ids[("m.room.member", event.sender)]
state_event = await self.store.get_event(state_event_id)
@@ -753,9 +751,9 @@ async def _make_summary_text_from_member_events(
# are already in descending received_ts.
sender_ids = {}
for n in notifs:
- sender = notif_events[n["event_id"]].sender
+ sender = notif_events[n.event_id].sender
if sender not in sender_ids:
- sender_ids[sender] = n["event_id"]
+ sender_ids[sender] = n.event_id
# Get the actual member events (in order to calculate a pretty name for
# the room).
@@ -830,17 +828,17 @@ def _make_notif_link(self, notif: EmailPushAction) -> str:
if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email.email_riot_base_url,
- notif["room_id"],
- notif["event_id"],
+ notif.room_id,
+ notif.event_id,
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
- notif["room_id"],
- notif["event_id"],
+ notif.room_id,
+ notif.event_id,
)
else:
- return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
+ return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id)
def _make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 7f68092ec5e5..659a53805df1 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,9 +17,10 @@
import re
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
+from matrix_common.regex import glob_to_regex, to_word_pattern
+
from synapse.events import EventBase
from synapse.types import JsonDict, UserID
-from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -184,7 +185,7 @@ def _contains_display_name(self, display_name: Optional[str]) -> bool:
r = regex_cache.get((display_name, False, True), None)
if not r:
r1 = re.escape(display_name)
- r1 = re_word_boundary(r1)
+ r1 = to_word_pattern(r1)
r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
@@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
try:
r = regex_cache.get((glob, True, word_boundary), None)
if not r:
- r = glob_to_regex(glob, word_boundary)
+ r = glob_to_regex(glob, word_boundary=word_boundary)
regex_cache[(glob, True, word_boundary)] = r
return bool(r.search(value))
except re.error:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 9c85200c0fb4..957c9b780b94 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -13,6 +13,7 @@
# limitations under the License.
from typing import Dict
+from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
@@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
- my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
+ my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
badge = len(invites)
@@ -36,7 +37,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
room_id, user_id, last_unread_event_id
)
)
- if notifs["notify_count"] == 0:
+ if notifs.notify_count == 0:
continue
if group_by_room:
@@ -44,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1
else:
# increment the badge count by the number of unread messages in the room
- badge += notifs["notify_count"]
+ badge += notifs.notify_count
return badge
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 26735447a6f1..7912311d2401 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -27,6 +27,7 @@
from synapse.replication.http.push import ReplicationRemovePusherRestServlet
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
+from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -113,7 +114,9 @@ async def add_pusher(
"""
if kind == "email":
- email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
+ email_owner = await self.store.get_user_id_by_threepid(
+ "email", canonicalise_email(pushkey)
+ )
if email_owner != user_id:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 13fb69460ea9..d844fbb3b3d6 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -88,6 +88,7 @@
# with the latest security patches.
"cryptography>=3.4.7",
"ijson>=3.1",
+ "matrix-common==1.0.0",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 7ecb446e7c78..7644146dbadb 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Optional
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -27,7 +27,12 @@
class BaseSlavedStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 61cd7e522800..bc888ce1a871 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache
@@ -25,7 +25,12 @@
class SlavedClientIpStore(BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0a582960896d..a2aff75b7075 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -17,7 +17,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -27,7 +27,12 @@
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 63ed50caa5eb..0f0837269486 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
@@ -58,7 +58,12 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -75,12 +80,3 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
-
- # Cached functions can't be accessed through a class instance so we need
- # to reach inside the __dict__ to extract them.
-
- def get_room_max_stream_ordering(self):
- return self._stream_id_gen.get_current_token()
-
- def get_room_min_stream_ordering(self):
- return self._backfill_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 90284c202d55..4d185e2b56c7 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore
@@ -24,7 +24,12 @@
class SlavedFilteringStore(BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 497e16c69e6a..9d90e26375f0 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -17,7 +17,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -26,7 +26,12 @@
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 743a01da08f0..5a2d90c5309f 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -15,7 +15,6 @@
import heapq
import logging
-from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
@@ -30,6 +29,7 @@
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -226,17 +226,14 @@ class BackfillStream(Stream):
or it went from being an outlier to not.
"""
- BackfillStreamRow = namedtuple(
- "BackfillStreamRow",
- (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
- "relates_to", # str, optional
- ),
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class BackfillStreamRow:
+ event_id: str
+ room_id: str
+ type: str
+ state_key: Optional[str]
+ redacts: Optional[str]
+ relates_to: Optional[str]
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
@@ -256,18 +253,15 @@ def _current_token(self, instance_name: str) -> int:
class PresenceStream(Stream):
- PresenceStreamRow = namedtuple(
- "PresenceStreamRow",
- (
- "user_id", # str
- "state", # str
- "last_active_ts", # int
- "last_federation_update_ts", # int
- "last_user_sync_ts", # int
- "status_msg", # str
- "currently_active", # bool
- ),
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class PresenceStreamRow:
+ user_id: str
+ state: str
+ last_active_ts: int
+ last_federation_update_ts: int
+ last_user_sync_ts: int
+ status_msg: str
+ currently_active: bool
NAME = "presence"
ROW_TYPE = PresenceStreamRow
@@ -302,7 +296,7 @@ class PresenceFederationStream(Stream):
send.
"""
- @attr.s(slots=True, auto_attribs=True)
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceFederationStreamRow:
destination: str
user_id: str
@@ -320,9 +314,10 @@ def __init__(self, hs: "HomeServer"):
class TypingStream(Stream):
- TypingStreamRow = namedtuple(
- "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class TypingStreamRow:
+ room_id: str
+ user_ids: List[str]
NAME = "typing"
ROW_TYPE = TypingStreamRow
@@ -348,16 +343,13 @@ def __init__(self, hs: "HomeServer"):
class ReceiptsStream(Stream):
- ReceiptsStreamRow = namedtuple(
- "ReceiptsStreamRow",
- (
- "room_id", # str
- "receipt_type", # str
- "user_id", # str
- "event_id", # str
- "data", # dict
- ),
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class ReceiptsStreamRow:
+ room_id: str
+ receipt_type: str
+ user_id: str
+ event_id: str
+ data: dict
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
@@ -374,7 +366,9 @@ def __init__(self, hs: "HomeServer"):
class PushRulesStream(Stream):
"""A user has changed their push rules"""
- PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class PushRulesStreamRow:
+ user_id: str
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -396,10 +390,12 @@ def _current_token(self, instance_name: str) -> int:
class PushersStream(Stream):
"""A user has added/changed/removed a pusher"""
- PushersStreamRow = namedtuple(
- "PushersStreamRow",
- ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class PushersStreamRow:
+ user_id: str
+ app_id: str
+ pushkey: str
+ deleted: bool
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@@ -419,7 +415,7 @@ class CachesStream(Stream):
the cache on the workers
"""
- @attr.s(slots=True)
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
@@ -430,9 +426,9 @@ class CachesStreamRow:
invalidation_ts: Timestamp of when the invalidation took place.
"""
- cache_func = attr.ib(type=str)
- keys = attr.ib(type=Optional[List[Any]])
- invalidation_ts = attr.ib(type=int)
+ cache_func: str
+ keys: Optional[List[Any]]
+ invalidation_ts: int
NAME = "caches"
ROW_TYPE = CachesStreamRow
@@ -451,9 +447,9 @@ class DeviceListsStream(Stream):
told about a device update.
"""
- @attr.s(slots=True)
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
- entity = attr.ib(type=str)
+ entity: str
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
@@ -470,7 +466,9 @@ def __init__(self, hs: "HomeServer"):
class ToDeviceStream(Stream):
"""New to_device messages for a client"""
- ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class ToDeviceStreamRow:
+ entity: str
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@@ -487,9 +485,11 @@ def __init__(self, hs: "HomeServer"):
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room"""
- TagAccountDataStreamRow = namedtuple(
- "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class TagAccountDataStreamRow:
+ user_id: str
+ room_id: str
+ data: JsonDict
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@@ -506,10 +506,11 @@ def __init__(self, hs: "HomeServer"):
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
- AccountDataStreamRow = namedtuple(
- "AccountDataStreamRow",
- ("user_id", "room_id", "data_type"), # str # Optional[str] # str
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class AccountDataStreamRow:
+ user_id: str
+ room_id: Optional[str]
+ data_type: str
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@@ -573,10 +574,12 @@ async def _update_function(
class GroupServerStream(Stream):
- GroupsStreamRow = namedtuple(
- "GroupsStreamRow",
- ("group_id", "user_id", "type", "content"), # str # str # str # dict
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class GroupsStreamRow:
+ group_id: str
+ user_id: str
+ type: str
+ content: JsonDict
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@@ -593,7 +596,9 @@ def __init__(self, hs: "HomeServer"):
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key"""
- UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class UserSignatureStreamRow:
+ user_id: str
NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a390cfcb74d5..4f4f1ad45378 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -50,12 +50,12 @@
"""
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamRow:
"""A parsed row from the events replication stream"""
- type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
- data = attr.ib() # BaseEventsStreamRow
+ type: str # the TypeId of one of the *EventsStreamRows
+ data: "BaseEventsStreamRow"
class BaseEventsStreamRow:
@@ -79,28 +79,28 @@ def from_data(cls, data):
return cls(*data)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib(type=str)
- room_id = attr.ib(type=str)
- type = attr.ib(type=str)
- state_key = attr.ib(type=Optional[str])
- redacts = attr.ib(type=Optional[str])
- relates_to = attr.ib(type=Optional[str])
- membership = attr.ib(type=Optional[str])
- rejected = attr.ib(type=bool)
+ event_id: str
+ room_id: str
+ type: str
+ state_key: Optional[str]
+ redacts: Optional[str]
+ relates_to: Optional[str]
+ membership: Optional[str]
+ rejected: bool
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamCurrentStateRow(BaseEventsStreamRow):
TypeId = "state"
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str
- event_id = attr.ib() # str, optional
+ room_id: str
+ type: str
+ state_key: str
+ event_id: Optional[str]
_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 0600cdbf363d..4046bdec6931 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -12,14 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
+import attr
+
from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,13 +32,10 @@ class FederationStream(Stream):
sending disabled.
"""
- FederationStreamRow = namedtuple(
- "FederationStreamRow",
- (
- "type", # str, the type of data as defined in the BaseFederationRows
- "data", # dict, serialization of a federation.send_queue.BaseFederationRow
- ),
- )
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class FederationStreamRow:
+ type: str # the type of data as defined in the BaseFederationRows
+ data: JsonDict # serialization of a federation.send_queue.BaseFederationRow
NAME = "federation"
ROW_TYPE = FederationStreamRow
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c499afd4be57..465e06772b26 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -69,6 +69,7 @@
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
from synapse.rest.admin.username_available import UsernameAvailableRestServlet
from synapse.rest.admin.users import (
+ AccountDataRestServlet,
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
@@ -108,7 +109,7 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = admin_patterns(
- "/purge_history/(?P[^/]*)(/(?P[^/]+))?"
+ "/purge_history/(?P[^/]*)(/(?P[^/]*))?$"
)
def __init__(self, hs: "HomeServer"):
@@ -195,7 +196,7 @@ async def on_POST(
class PurgeHistoryStatusRestServlet(RestServlet):
- PATTERNS = admin_patterns("/purge_history_status/(?P[^/]+)")
+ PATTERNS = admin_patterns("/purge_history_status/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler()
@@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
+ AccountDataRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 479672d4d568..e9bce22a347b 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -22,7 +22,7 @@
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
-from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -41,8 +41,7 @@ def __init__(self, hs: "HomeServer"):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@@ -51,8 +50,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
@@ -84,8 +82,7 @@ def __init__(self, hs: "HomeServer"):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@@ -111,15 +108,14 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
class BackgroundUpdateStartJobRestServlet(RestServlet):
"""Allows to start specific background updates"""
- PATTERNS = admin_patterns("/background_updates/start_job")
+ PATTERNS = admin_patterns("/background_updates/start_job$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["job_name"])
@@ -127,34 +123,25 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
job_name = body["job_name"]
if job_name == "populate_stats_process_rooms":
- jobs = [
- {
- "update_name": "populate_stats_process_rooms",
- "progress_json": "{}",
- },
- ]
+ jobs = [("populate_stats_process_rooms", "{}", "")]
elif job_name == "regenerate_directory":
jobs = [
- {
- "update_name": "populate_user_directory_createtables",
- "progress_json": "{}",
- "depends_on": "",
- },
- {
- "update_name": "populate_user_directory_process_rooms",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_createtables",
- },
- {
- "update_name": "populate_user_directory_process_users",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_rooms",
- },
- {
- "update_name": "populate_user_directory_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_users",
- },
+ ("populate_user_directory_createtables", "{}", ""),
+ (
+ "populate_user_directory_process_rooms",
+ "{}",
+ "populate_user_directory_createtables",
+ ),
+ (
+ "populate_user_directory_process_users",
+ "{}",
+ "populate_user_directory_process_rooms",
+ ),
+ (
+ "populate_user_directory_cleanup",
+ "{}",
+ "populate_user_directory_process_users",
+ ),
]
else:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name")
@@ -162,6 +149,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
try:
await self._store.db_pool.simple_insert_many(
table="background_updates",
+ keys=("update_name", "progress_json", "depends_on"),
values=jobs,
desc=f"admin_api_run_{job_name}",
)
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 2e5a6600d337..d9905ff560cb 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str, device_id: str
@@ -53,7 +53,7 @@ async def on_GET(
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -63,6 +63,8 @@ async def on_GET(
device = await self.device_handler.get_device(
target_user.to_string(), device_id
)
+ if device is None:
+ raise NotFoundError("No device found")
return HTTPStatus.OK, device
async def on_DELETE(
@@ -71,7 +73,7 @@ async def on_DELETE(
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -87,7 +89,7 @@ async def on_PUT(
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P[^/]*)/devices$", "v2")
def __init__(self, hs: "HomeServer"):
- """
- Args:
- hs: server
- """
- self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -124,7 +122,7 @@ async def on_GET(
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P[^/]*)/delete_devices$", "v2")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
+ self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, user_id: str
@@ -155,7 +153,7 @@ async def on_POST(
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 5ee8b11110e0..38477f8eadeb 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 744687be35fc..8cd3fa189e8d 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
200 OK with details of a destination if success otherwise an error.
"""
- PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$")
+ PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -111,25 +111,37 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
+ if not await self._store.is_destination_known(destination):
+ raise NotFoundError("Unknown destination")
+
destination_retry_timings = await self._store.get_destination_retry_timings(
destination
)
- if not destination_retry_timings:
- raise NotFoundError("Unknown destination")
-
last_successful_stream_ordering = (
await self._store.get_destination_last_successful_stream_ordering(
destination
)
)
- response = {
+ response: JsonDict = {
"destination": destination,
- "failure_ts": destination_retry_timings.failure_ts,
- "retry_last_ts": destination_retry_timings.retry_last_ts,
- "retry_interval": destination_retry_timings.retry_interval,
"last_successful_stream_ordering": last_successful_stream_ordering,
}
+ if destination_retry_timings:
+ response = {
+ **response,
+ "failure_ts": destination_retry_timings.failure_ts,
+ "retry_last_ts": destination_retry_timings.retry_last_ts,
+ "retry_interval": destination_retry_timings.retry_interval,
+ }
+ else:
+ response = {
+ **response,
+ "failure_ts": None,
+ "retry_last_ts": 0,
+ "retry_interval": 0,
+ }
+
return HTTPStatus.OK, response
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index a27110388f4f..cd697e180ef6 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -30,7 +30,7 @@
class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups"""
- PATTERNS = admin_patterns("/delete_group/(?P[^/]*)")
+ PATTERNS = admin_patterns("/delete_group/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 9e23e2d8fc00..299f5c9eb0f2 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,7 +17,7 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = [
- *admin_patterns("/room/(?P[^/]+)/media/quarantine$"),
+ *admin_patterns("/room/(?P[^/]*)/media/quarantine$"),
# This path kept around for legacy reasons
- *admin_patterns("/quarantine_media/(?P[^/]+)"),
+ *admin_patterns("/quarantine_media/(?P[^/]*)$"),
]
def __init__(self, hs: "HomeServer"):
@@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
- PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine$")
+ PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
- "/media/quarantine/(?P[^/]+)/(?P[^/]+)"
+ "/media/quarantine/(?P[^/]*)/(?P[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
- "/media/unquarantine/(?P[^/]+)/(?P[^/]+)"
+ "/media/unquarantine/(?P[^/]*)/(?P[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@@ -138,8 +138,7 @@ def __init__(self, hs: "HomeServer"):
async def on_POST(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
logging.info(
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
@@ -154,7 +153,7 @@ async def on_POST(
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined."""
- PATTERNS = admin_patterns("/media/protect/(?P[^/]+)")
+ PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -163,8 +162,7 @@ def __init__(self, hs: "HomeServer"):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
logging.info("Protecting local media by ID: %s", media_id)
@@ -177,7 +175,7 @@ async def on_POST(
class UnprotectMediaByID(RestServlet):
"""Unprotect local media from being quarantined."""
- PATTERNS = admin_patterns("/media/unprotect/(?P[^/]+)")
+ PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -186,8 +184,7 @@ def __init__(self, hs: "HomeServer"):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
logging.info("Unprotecting local media by ID: %s", media_id)
@@ -200,7 +197,7 @@ async def on_POST(
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room."""
- PATTERNS = admin_patterns("/room/(?P[^/]+)/media$")
+ PATTERNS = admin_patterns("/room/(?P[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -209,10 +206,7 @@ def __init__(self, hs: "HomeServer"):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- is_admin = await self.auth.is_server_admin(requester.user)
- if not is_admin:
- raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
+ await assert_requester_is_admin(self.auth, request)
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
@@ -254,7 +248,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
class DeleteMediaByID(RestServlet):
"""Delete local media by a given ID. Removes it from this server."""
- PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)")
+ PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size.
"""
- PATTERNS = admin_patterns("/media/(?P[^/]+)/delete$")
+ PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
media that exist given for this user
"""
- PATTERNS = admin_patterns("/users/(?P[^/]+)/media$")
+ PATTERNS = admin_patterns("/users/(?P[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@@ -403,16 +397,7 @@ async def on_GET(
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
- allowed_values=(
- MediaSortOrder.MEDIA_ID.value,
- MediaSortOrder.UPLOAD_NAME.value,
- MediaSortOrder.CREATED_TS.value,
- MediaSortOrder.LAST_ACCESS_TS.value,
- MediaSortOrder.MEDIA_LENGTH.value,
- MediaSortOrder.MEDIA_TYPE.value,
- MediaSortOrder.QUARANTINED_BY.value,
- MediaSortOrder.SAFE_FROM_QUARANTINE.value,
- ),
+ allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
@@ -470,16 +455,7 @@ async def on_DELETE(
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
- allowed_values=(
- MediaSortOrder.MEDIA_ID.value,
- MediaSortOrder.UPLOAD_NAME.value,
- MediaSortOrder.CREATED_TS.value,
- MediaSortOrder.LAST_ACCESS_TS.value,
- MediaSortOrder.MEDIA_LENGTH.value,
- MediaSortOrder.MEDIA_TYPE.value,
- MediaSortOrder.QUARANTINED_BY.value,
- MediaSortOrder.SAFE_FROM_QUARANTINE.value,
- ),
+ allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
@@ -490,7 +466,7 @@ async def on_DELETE(
)
deleted_media, total = await self.media_repository.delete_local_media_ids(
- ([row["media_id"] for row in media])
+ [row["media_id"] for row in media]
)
return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 891b98c0888a..04948b640834 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/new$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 829e86675aba..efe25fe7ebf7 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
If 'purge' is true, it will remove all traces of a room from the database.
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -123,7 +123,7 @@ async def on_DELETE(
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
- PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/delete_status$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -160,7 +160,7 @@ async def on_GET(
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
- PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2")
+ PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -193,35 +193,17 @@ def __init__(self, hs: "HomeServer"):
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
# Extract query parameters
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
- order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
- if order_by not in (
- RoomSortOrder.ALPHABETICAL.value,
- RoomSortOrder.SIZE.value,
- RoomSortOrder.NAME.value,
- RoomSortOrder.CANONICAL_ALIAS.value,
- RoomSortOrder.JOINED_MEMBERS.value,
- RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
- RoomSortOrder.VERSION.value,
- RoomSortOrder.CREATOR.value,
- RoomSortOrder.ENCRYPTION.value,
- RoomSortOrder.FEDERATABLE.value,
- RoomSortOrder.PUBLIC.value,
- RoomSortOrder.JOIN_RULES.value,
- RoomSortOrder.GUEST_ACCESS.value,
- RoomSortOrder.HISTORY_VISIBILITY.value,
- RoomSortOrder.STATE_EVENTS.value,
- ):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown value for order_by: %s" % (order_by,),
- errcode=Codes.INVALID_PARAM,
- )
+ order_by = parse_string(
+ request,
+ "order_by",
+ default=RoomSortOrder.NAME.value,
+ allowed_values=[sort_order.value for sort_order in RoomSortOrder],
+ )
search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
@@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
TODO: Add on_POST to allow room creation without joining the room
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]+)$")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
Get members list of a room.
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/members$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
Get full state within a room.
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/state$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -436,8 +415,7 @@ def __init__(self, hs: "HomeServer"):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
@@ -446,7 +424,7 @@ async def on_GET(
event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
- room_state = await self._event_serializer.serialize_events(events.values(), now)
+ room_state = self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@@ -454,14 +432,14 @@ async def on_GET(
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
- PATTERNS = admin_patterns("/join/(?P[^/]*)")
+ PATTERNS = admin_patterns("/join/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
+ self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, room_identifier: str
@@ -477,7 +455,7 @@ async def on_POST(
assert_params_in_dict(content, ["user_id"])
target_user = UserID.from_string(content["user_id"])
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"This endpoint can only be used with local users",
@@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
}
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
GET /_synapse/admin/v1/rooms//forward_extremities
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_DELETE(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
room_id, _ = await self.resolve_room_id(room_identifier)
@@ -710,8 +685,7 @@ async def on_DELETE(
async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
room_id, _ = await self.resolve_room_id(room_identifier)
@@ -770,16 +744,17 @@ async def on_GET(
)
time_now = self.clock.time_msec()
- results["events_before"] = await self._event_serializer.serialize_events(
- results["events_before"], time_now
+ aggregations = results.pop("aggregations", None)
+ results["events_before"] = self._event_serializer.serialize_events(
+ results["events_before"], time_now, bundle_aggregations=aggregations
)
- results["event"] = await self._event_serializer.serialize_event(
- results["event"], time_now
+ results["event"] = self._event_serializer.serialize_event(
+ results["event"], time_now, bundle_aggregations=aggregations
)
- results["events_after"] = await self._event_serializer.serialize_events(
- results["events_after"], time_now
+ results["events_after"] = self._event_serializer.serialize_events(
+ results["events_after"], time_now, bundle_aggregations=aggregations
)
- results["state"] = await self._event_serializer.serialize_events(
+ results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)
@@ -793,7 +768,7 @@ class BlockRoomRestServlet(RestServlet):
On GET: Get blocking status of room and user who has blocked this room.
"""
- PATTERNS = admin_patterns("/rooms/(?P[^/]+)/block$")
+ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/block$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b295fb078bc7..15da9cd88153 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.server_notices_manager = hs.get_server_notices_manager()
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
+ self.is_mine = hs.is_mine
def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
@@ -88,7 +88,7 @@ async def on_POST(
)
target_user = UserID.from_string(body["user_id"])
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
)
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index ca41fd45f2bd..7a6546372eef 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
PATTERNS = admin_patterns("/statistics/users/media$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -45,19 +44,16 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
order_by = parse_string(
- request, "order_by", default=UserSortOrder.USER_ID.value
+ request,
+ "order_by",
+ default=UserSortOrder.USER_ID.value,
+ allowed_values=(
+ UserSortOrder.MEDIA_LENGTH.value,
+ UserSortOrder.MEDIA_COUNT.value,
+ UserSortOrder.USER_ID.value,
+ UserSortOrder.DISPLAYNAME.value,
+ ),
)
- if order_by not in (
- UserSortOrder.MEDIA_LENGTH.value,
- UserSortOrder.MEDIA_COUNT.value,
- UserSortOrder.USER_ID.value,
- UserSortOrder.DISPLAYNAME.value,
- ):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown value for order_by: %s" % (order_by,),
- errcode=Codes.INVALID_PARAM,
- )
start = parse_integer(request, "from", default=0)
if start < 0:
diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py
index 2bf1472967dd..5353dc368235 100644
--- a/synapse/rest/admin/username_available.py
+++ b/synapse/rest/admin/username_available.py
@@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
}
"""
- PATTERNS = admin_patterns("/username_available")
+ PATTERNS = admin_patterns("/username_available$")
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2a60b602b1f8..c2617ee30c48 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@@ -126,7 +125,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
class UserRestServletV2(RestServlet):
- PATTERNS = admin_patterns("/users/(?P[^/]+)$", "v2")
+ PATTERNS = admin_patterns("/users/(?P[^/]*)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@@ -174,12 +173,11 @@ async def on_GET(
if not self.hs.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
- ret = await self.admin_handler.get_user(target_user)
-
- if not ret:
+ user_info_dict = await self.admin_handler.get_user(target_user)
+ if not user_info_dict:
raise NotFoundError("User not found")
- return HTTPStatus.OK, ret
+ return HTTPStatus.OK, user_info_dict
async def on_PUT(
self, request: SynapseRequest, user_id: str
@@ -400,10 +398,10 @@ async def on_PUT(
target_user, requester, body["avatar_url"], True
)
- user = await self.admin_handler.get_user(target_user)
- assert user is not None
+ user_info_dict = await self.admin_handler.get_user(target_user)
+ assert user_info_dict is not None
- return 201, user
+ return HTTPStatus.CREATED, user_info_dict
class UserRegisterServlet(RestServlet):
@@ -414,7 +412,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = admin_patterns("/register")
+ PATTERNS = admin_patterns("/register$")
NONCE_TIMEOUT = 60
def __init__(self, hs: "HomeServer"):
@@ -561,9 +559,9 @@ class WhoisRestServlet(RestServlet):
]
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -575,7 +573,7 @@ async def on_GET(
if target_user != auth_user:
await assert_user_is_admin(self.auth, auth_user)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
ret = await self.admin_handler.get_whois(target_user)
@@ -584,7 +582,7 @@ async def on_GET(
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = admin_patterns("/deactivate/(?P[^/]*)")
+ PATTERNS = admin_patterns("/deactivate/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -630,7 +628,6 @@ class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
@@ -674,11 +671,10 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = admin_patterns("/reset_password/(?P[^/]*)")
+ PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
@@ -718,12 +714,12 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = admin_patterns("/search_users/(?P[^/]*)")
+ PATTERNS = admin_patterns("/search_users/(?P[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, target_user_id: str
@@ -740,7 +736,7 @@ async def on_GET(
# if not is_admin and target_user != auth_user:
# raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
term = parse_string(request, "term", required=True)
@@ -779,9 +775,9 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -790,7 +786,7 @@ async def on_GET(
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver",
@@ -813,7 +809,7 @@ async def on_PUT(
assert_params_in_dict(body, ["admin"])
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver",
@@ -834,7 +830,7 @@ class UserMembershipRestServlet(RestServlet):
Get room list of an user.
"""
- PATTERNS = admin_patterns("/users/(?P[^/]+)/joined_rooms$")
+ PATTERNS = admin_patterns("/users/(?P[^/]*)/joined_rooms$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@@ -909,10 +905,10 @@ class UserTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P[^/]*)/login$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.is_mine_id = hs.is_mine_id
async def on_POST(
self, request: SynapseRequest, user_id: str
@@ -921,7 +917,7 @@ async def on_POST(
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
)
@@ -975,19 +971,19 @@ class ShadowBanRestServlet(RestServlet):
{}
"""
- PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban")
+ PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine_id = hs.is_mine_id
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
)
@@ -1001,7 +997,7 @@ async def on_DELETE(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
)
@@ -1027,19 +1023,19 @@ class RateLimitRestServlet(RestServlet):
}
"""
- PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit")
+ PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine_id = hs.is_mine_id
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
@@ -1068,7 +1064,7 @@ async def on_POST(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
)
@@ -1113,7 +1109,7 @@ async def on_DELETE(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
)
@@ -1124,3 +1120,33 @@ async def on_DELETE(
await self.store.delete_ratelimit_for_user(user_id)
return HTTPStatus.OK, {}
+
+
+class AccountDataRestServlet(RestServlet):
+ """Retrieve the given user's account data"""
+
+ PATTERNS = admin_patterns("/users/(?P[^/]*)/accountdata")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastore()
+ self._is_mine_id = hs.is_mine_id
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ if not self._is_mine_id(user_id):
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
+
+ if not await self._store.get_user_by_id(user_id):
+ raise NotFoundError("User not found")
+
+ global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+ return HTTPStatus.OK, {
+ "account_data": {
+ "global": global_data,
+ "rooms": by_room_data,
+ },
+ }
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 2a3e24ae7e55..5c0e3a568007 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -73,6 +73,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"enabled": self.config.registration.enable_3pid_changes
}
+ if self.config.experimental.msc3440_enabled:
+ response["capabilities"]["io.element.thread"] = {"enabled": True}
+
return 200, response
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8566dc5cb594..ad6fd6492baa 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
+from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -24,10 +25,9 @@
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict
-from ._base import client_patterns, interactive_auth_handler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -116,6 +116,8 @@ async def on_GET(
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
+ if device is None:
+ raise NotFoundError("No device found")
return 200, device
@interactive_auth_handler
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 13b72a045a4a..672c821061ff 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -91,7 +91,7 @@ async def on_GET(
time_now = self.clock.time_msec()
if event:
- result = await self._event_serializer.serialize_event(event, time_now)
+ result = self._event_serializer.serialize_event(event, time_now)
return 200, result
else:
return 404, "Event not found."
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index d1d8a984c630..8e427a96a320 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -54,10 +55,10 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
)
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
- user_id, "m.read"
+ user_id, ReceiptTypes.READ
)
- notif_event_ids = [pa["event_id"] for pa in push_actions]
+ notif_event_ids = [pa.event_id for pa in push_actions]
notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = []
@@ -66,30 +67,30 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
for pa in push_actions:
returned_pa = {
- "room_id": pa["room_id"],
- "profile_tag": pa["profile_tag"],
- "actions": pa["actions"],
- "ts": pa["received_ts"],
+ "room_id": pa.room_id,
+ "profile_tag": pa.profile_tag,
+ "actions": pa.actions,
+ "ts": pa.received_ts,
"event": (
- await self._event_serializer.serialize_event(
- notif_events[pa["event_id"]],
+ self._event_serializer.serialize_event(
+ notif_events[pa.event_id],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
)
),
}
- if pa["room_id"] not in receipts_by_room:
+ if pa.room_id not in receipts_by_room:
returned_pa["read"] = False
else:
- receipt = receipts_by_room[pa["room_id"]]
+ receipt = receipts_by_room[pa.room_id]
returned_pa["read"] = (
receipt["topological_ordering"],
receipt["stream_ordering"],
- ) >= (pa["topological_ordering"], pa["stream_ordering"])
+ ) >= (pa.topological_ordering, pa.stream_ordering)
returned_push_actions.append(returned_pa)
- next_token = str(pa["stream_ordering"])
+ next_token = str(pa.stream_ordering)
return 200, {"notifications": returned_push_actions, "next_token": next_token}
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 43c04fac6fdb..f51be511d1f4 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -48,7 +48,7 @@ async def on_POST(
await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
- read_event_id = body.get("m.read", None)
+ read_event_id = body.get(ReceiptTypes.READ, None)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
if not isinstance(hidden, bool):
@@ -62,7 +62,7 @@ async def on_POST(
if read_event_id:
await self.receipts_handler.received_client_receipt(
room_id,
- "m.read",
+ ReceiptTypes.READ,
user_id=requester.user.to_string(),
event_id=read_event_id,
hidden=hidden,
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 2b25b9aad6a3..b24ad2d1be13 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@
import re
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
@@ -53,7 +53,7 @@ async def on_POST(
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- if receipt_type != "m.read":
+ if receipt_type != ReceiptTypes.READ:
raise SynapseError(400, "Receipt type must be 'm.read'")
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index fc4e6921c5e6..8cf5ebaa07b7 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,28 +19,20 @@
"""
import logging
-from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
-from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import (
- RestServlet,
- parse_integer,
- parse_json_object_from_request,
- parse_string,
-)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
-from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client._base import client_patterns
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
-from synapse.util.stringutils import random_string
-
-from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -48,112 +40,6 @@
logger = logging.getLogger(__name__)
-class RelationSendServlet(RestServlet):
- """Helper API for sending events that have relation data.
-
- Example API shape to send a 👍 reaction to a room:
-
- POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D
- {}
-
- {
- "event_id": "$foobar"
- }
- """
-
- PATTERN = (
- "/rooms/(?P[^/]*)/send_relation"
- "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.txns = HttpTransactionCache(hs)
-
- def register(self, http_server: HttpServer) -> None:
- http_server.register_paths(
- "POST",
- client_patterns(self.PATTERN + "$", releases=()),
- self.on_PUT_or_POST,
- self.__class__.__name__,
- )
- http_server.register_paths(
- "PUT",
- client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()),
- self.on_PUT,
- self.__class__.__name__,
- )
-
- def on_PUT(
- self,
- request: SynapseRequest,
- room_id: str,
- parent_id: str,
- relation_type: str,
- event_type: str,
- txn_id: Optional[str] = None,
- ) -> Awaitable[Tuple[int, JsonDict]]:
- return self.txns.fetch_or_execute_request(
- request,
- self.on_PUT_or_POST,
- request,
- room_id,
- parent_id,
- relation_type,
- event_type,
- txn_id,
- )
-
- async def on_PUT_or_POST(
- self,
- request: SynapseRequest,
- room_id: str,
- parent_id: str,
- relation_type: str,
- event_type: str,
- txn_id: Optional[str] = None,
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
- if event_type == EventTypes.Member:
- # Add relations to a membership is meaningless, so we just deny it
- # at the CS API rather than trying to handle it correctly.
- raise SynapseError(400, "Cannot send member events with relations")
-
- content = parse_json_object_from_request(request)
-
- aggregation_key = parse_string(request, "key", encoding="utf-8")
-
- content["m.relates_to"] = {
- "event_id": parent_id,
- "rel_type": relation_type,
- }
- if aggregation_key is not None:
- content["m.relates_to"]["key"] = aggregation_key
-
- event_dict = {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- }
-
- try:
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict=event_dict, txn_id=txn_id
- )
- event_id = event.event_id
- except ShadowBanError:
- event_id = "$" + random_string(43)
-
- return 200, {"event_id": event_id}
-
-
class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
@@ -212,6 +98,7 @@ async def on_GET(
pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
+ room_id=room_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
@@ -226,12 +113,17 @@ async def on_GET(
now = self.clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
- original_event = await self._event_serializer.serialize_event(
- event, now, bundle_aggregations=False
+ original_event = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
- serialized_events = await self._event_serializer.serialize_events(events, now)
+ aggregations = await self.store.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
return_value = pagination_chunk.to_dict()
return_value["chunk"] = serialized_events
@@ -317,6 +209,7 @@ async def on_GET(
pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
+ room_id=room_id,
event_type=event_type,
limit=limit,
from_token=from_token,
@@ -383,7 +276,9 @@ async def on_GET(
# This checks that a) the event exists and b) the user is allowed to
# view it.
- await self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -402,6 +297,7 @@ async def on_GET(
result = await self.store.get_relations_for_event(
event_id=parent_id,
+ room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,
@@ -415,7 +311,7 @@ async def on_GET(
)
now = self.clock.time_msec()
- serialized_events = await self._event_serializer.serialize_events(events, now)
+ serialized_events = self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = serialized_events
@@ -424,7 +320,6 @@ async def on_GET(
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index f48e2e6ca248..90bb9142a098 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -187,7 +187,7 @@ async def on_PUT(
state_key: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if txn_id:
set_tag("txn_id", txn_id)
@@ -642,6 +642,7 @@ class RoomEventServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
+ self._store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@@ -660,9 +661,16 @@ async def on_GET(
# https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
- time_now = self.clock.time_msec()
if event:
- event_dict = await self._event_serializer.serialize_event(event, time_now)
+ # Ensure there are bundled aggregations available.
+ aggregations = await self._store.get_bundled_aggregations(
+ [event], requester.user.to_string()
+ )
+
+ time_now = self.clock.time_msec()
+ event_dict = self._event_serializer.serialize_event(
+ event, time_now, bundle_aggregations=aggregations
+ )
return 200, event_dict
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -706,16 +714,17 @@ async def on_GET(
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- results["events_before"] = await self._event_serializer.serialize_events(
- results["events_before"], time_now
+ aggregations = results.pop("aggregations", None)
+ results["events_before"] = self._event_serializer.serialize_events(
+ results["events_before"], time_now, bundle_aggregations=aggregations
)
- results["event"] = await self._event_serializer.serialize_event(
- results["event"], time_now
+ results["event"] = self._event_serializer.serialize_event(
+ results["event"], time_now, bundle_aggregations=aggregations
)
- results["events_after"] = await self._event_serializer.serialize_events(
- results["events_after"], time_now
+ results["events_after"] = self._event_serializer.serialize_events(
+ results["events_after"], time_now, bundle_aggregations=aggregations
)
- results["state"] = await self._event_serializer.serialize_events(
+ results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 7f5846d38934..d20ae1421e19 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -17,7 +17,6 @@
from typing import (
TYPE_CHECKING,
Any,
- Awaitable,
Callable,
Dict,
Iterable,
@@ -48,6 +47,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
+from synapse.logging.opentracing import trace
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -222,6 +222,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
logger.debug("Event formatting complete")
return 200, response_content
+ @trace(opname="sync.encode_response")
async def encode_response(
self,
time_now: int,
@@ -293,6 +294,9 @@ async def encode_response(
response[
"org.matrix.msc2732.device_unused_fallback_key_types"
] = sync_result.device_unused_fallback_key_types
+ response[
+ "device_unused_fallback_key_types"
+ ] = sync_result.device_unused_fallback_key_types
if joined:
response["rooms"][Membership.JOIN] = joined
@@ -329,6 +333,7 @@ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
]
}
+ @trace(opname="sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
@@ -365,6 +370,7 @@ async def encode_joined(
return joined
+ @trace(opname="sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
@@ -388,7 +394,7 @@ async def encode_invited(
"""
invited = {}
for room in rooms:
- invite = await self._event_serializer.serialize_event(
+ invite = self._event_serializer.serialize_event(
room.invite,
time_now,
token_id=token_id,
@@ -403,6 +409,7 @@ async def encode_invited(
return invited
+ @trace(opname="sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
@@ -424,7 +431,7 @@ async def encode_knocked(
"""
knocked = {}
for room in rooms:
- knock = await self._event_serializer.serialize_event(
+ knock = self._event_serializer.serialize_event(
room.knock,
time_now,
token_id=token_id,
@@ -457,6 +464,7 @@ async def encode_knocked(
return knocked
+ @trace(opname="sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],
@@ -516,21 +524,14 @@ async def encode_room(
The room, encoded in our response format
"""
- def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
+ def serialize(
+ events: Iterable[EventBase],
+ aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+ ) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
time_now=time_now,
- # Don't bother to bundle aggregations if the timeline is unlimited,
- # as clients will have all the necessary information.
- # bundle_aggregations=room.timeline.limited,
- #
- # richvdh 2021-12-15: disable this temporarily as it has too high an
- # overhead for initialsyncs. We need to figure out a way that the
- # bundling can be done *before* the events are stored in the
- # SyncResponseCache so that this part can be synchronous.
- #
- # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations.
- bundle_aggregations=False,
+ bundle_aggregations=aggregations,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
@@ -552,8 +553,10 @@ def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
event.room_id,
)
- serialized_state = await serialize(state_events)
- serialized_timeline = await serialize(timeline_events)
+ serialized_state = serialize(state_events)
+ serialized_timeline = serialize(
+ timeline_events, room.timeline.bundled_aggregations
+ )
account_data = room.account_data
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 8d888f456520..2290c57c126e 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -93,6 +93,10 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving hidden read receipts as per MSC2285
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+ # Adds support for importing historical messages as per MSC2716
+ "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
+ # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
+ "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
},
},
)
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 12b3ae120cdb..b9bfbea21b83 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@@ -99,7 +99,7 @@ def response_json_object(self) -> JsonDict:
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request: Request) -> int:
+ def render_GET(self, request: Request) -> Optional[int]:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 244ba261bbc4..71b9a34b145d 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -739,14 +739,21 @@ async def _generate_thumbnails(
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails: Dict[Tuple[int, int, str], str] = {}
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "crop":
- thumbnails.setdefault((r_width, r_height, r_type), r_method)
- elif r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ for requirement in requirements:
+ if requirement.method == "crop":
+ thumbnails.setdefault(
+ (requirement.width, requirement.height, requirement.media_type),
+ requirement.method,
+ )
+ elif requirement.method == "scale":
+ t_width, t_height = thumbnailer.aspect(
+ requirement.width, requirement.height
+ )
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- thumbnails[(t_width, t_height, r_type)] = r_method
+ thumbnails[
+ (t_width, t_height, requirement.media_type)
+ ] = requirement.method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items():
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index fca239d8c7ec..9f6c251caf21 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -343,7 +343,7 @@ class SpamMediaException(NotFoundError):
"""
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class ReadableFileWrapper:
"""Wrapper that allows reading a file in chunks, yielding to the reactor,
and writing to a callback.
@@ -354,8 +354,8 @@ class ReadableFileWrapper:
CHUNK_SIZE = 2 ** 14
- clock = attr.ib(type=Clock)
- path = attr.ib(type=str)
+ clock: Clock
+ path: str
async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2a59552c20a3..2177b46c9eba 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -17,6 +17,7 @@
import attr
+from synapse.rest.media.v1.preview_html import parse_html_description
from synapse.types import JsonDict
from synapse.util import json_decoder
@@ -32,6 +33,8 @@
class OEmbedResult:
# The Open Graph result (converted from the oEmbed result).
open_graph_result: JsonDict
+ # The author_name of the oEmbed result
+ author_name: Optional[str]
# Number of milliseconds to cache the content, according to the oEmbed response.
#
# This will be None if no cache-age is provided in the oEmbed response (or
@@ -153,11 +156,12 @@ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"og:url": url,
}
- # Use either title or author's name as the title.
- title = oembed.get("title") or oembed.get("author_name")
+ title = oembed.get("title")
if title:
open_graph_response["og:title"] = title
+ author_name = oembed.get("author_name")
+
# Use the provider name and as the site.
provider_name = oembed.get("provider_name")
if provider_name:
@@ -192,9 +196,10 @@ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
# Trap any exception and let the code follow as usual.
logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
open_graph_response = {}
+ author_name = None
cache_age = None
- return OEmbedResult(open_graph_response, cache_age)
+ return OEmbedResult(open_graph_response, author_name, cache_age)
def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
@@ -245,8 +250,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
if video_urls:
open_graph_response["og:video"] = video_urls[0]
- from synapse.rest.media.v1.preview_url_resource import _calc_description
-
- description = _calc_description(tree)
+ description = parse_html_description(tree)
if description:
open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
new file mode 100644
index 000000000000..30b067dd4271
--- /dev/null
+++ b/synapse/rest/media/v1/preview_html.py
@@ -0,0 +1,397 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import itertools
+import logging
+import re
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
+from urllib import parse as urlparse
+
+if TYPE_CHECKING:
+ from lxml import etree
+
+logger = logging.getLogger(__name__)
+
+_charset_match = re.compile(
+ br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
+_xml_encoding_match = re.compile(
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
+)
+_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+
+
+def _normalise_encoding(encoding: str) -> Optional[str]:
+ """Use the Python codec's name as the normalised entry."""
+ try:
+ return codecs.lookup(encoding).name
+ except LookupError:
+ return None
+
+
+def _get_html_media_encodings(
+ body: bytes, content_type: Optional[str]
+) -> Iterable[str]:
+ """
+ Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
+
+ The precedence used for finding a character encoding is:
+
+ 1. tag with a charset declared.
+ 2. The XML document's character encoding attribute.
+ 3. The Content-Type header.
+ 4. Fallback to utf-8.
+ 5. Fallback to windows-1252.
+
+ This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
+
+ Args:
+ body: The HTML document, as bytes.
+ content_type: The Content-Type header.
+
+ Returns:
+ The character encoding of the body, as a string.
+ """
+ # There's no point in returning an encoding more than once.
+ attempted_encodings: Set[str] = set()
+
+ # Limit searches to the first 1kb, since it ought to be at the top.
+ body_start = body[:1024]
+
+ # Check if it has an encoding set in a meta tag.
+ match = _charset_match.search(body_start)
+ if match:
+ encoding = _normalise_encoding(match.group(1).decode("ascii"))
+ if encoding:
+ attempted_encodings.add(encoding)
+ yield encoding
+
+ # TODO Support
+
+ # Check if it has an XML document with an encoding.
+ match = _xml_encoding_match.match(body_start)
+ if match:
+ encoding = _normalise_encoding(match.group(1).decode("ascii"))
+ if encoding and encoding not in attempted_encodings:
+ attempted_encodings.add(encoding)
+ yield encoding
+
+ # Check the HTTP Content-Type header for a character set.
+ if content_type:
+ content_match = _content_type_match.match(content_type)
+ if content_match:
+ encoding = _normalise_encoding(content_match.group(1))
+ if encoding and encoding not in attempted_encodings:
+ attempted_encodings.add(encoding)
+ yield encoding
+
+ # Finally, fallback to UTF-8, then windows-1252.
+ for fallback in ("utf-8", "cp1252"):
+ if fallback not in attempted_encodings:
+ yield fallback
+
+
+def decode_body(
+ body: bytes, uri: str, content_type: Optional[str] = None
+) -> Optional["etree.Element"]:
+ """
+ This uses lxml to parse the HTML document.
+
+ Args:
+ body: The HTML document, as bytes.
+ uri: The URI used to download the body.
+ content_type: The Content-Type header.
+
+ Returns:
+ The parsed HTML body, or None if an error occurred during processed.
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not body:
+ return None
+
+ # The idea here is that multiple encodings are tried until one works.
+ # Unfortunately the result is never used and then LXML will decode the string
+ # again with the found encoding.
+ for encoding in _get_html_media_encodings(body, content_type):
+ try:
+ body.decode(encoding)
+ except Exception:
+ pass
+ else:
+ break
+ else:
+ logger.warning("Unable to decode HTML body for %s", uri)
+ return None
+
+ from lxml import etree
+
+ # Create an HTML parser.
+ parser = etree.HTMLParser(recover=True, encoding=encoding)
+
+ # Attempt to parse the body. Returns None if the body was successfully
+ # parsed, but no tree was found.
+ return etree.fromstring(body, parser)
+
+
+def parse_html_to_open_graph(
+ tree: "etree.Element", media_uri: str
+) -> Dict[str, Optional[str]]:
+ """
+ Parse the HTML document into an Open Graph response.
+
+ This uses lxml to search the HTML document for Open Graph data (or
+ synthesizes it from the document).
+
+ Args:
+ tree: The parsed HTML document.
+ media_url: The URI used to download the body.
+
+ Returns:
+ The Open Graph response as a dictionary.
+ """
+
+ # if we see any image URLs in the OG response, then spider them
+ # (although the client could choose to do this by asking for previews of those
+ # URLs to avoid DoSing the server)
+
+ # "og:type" : "video",
+ # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+ # "og:site_name" : "YouTube",
+ # "og:video:type" : "application/x-shockwave-flash",
+ # "og:description" : "Fun stuff happening here",
+ # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+ # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+ # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+ # "og:video:width" : "1280"
+ # "og:video:height" : "720",
+ # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+ og: Dict[str, Optional[str]] = {}
+ for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
+ if "content" in tag.attrib:
+ # if we've got more than 50 tags, someone is taking the piss
+ if len(og) >= 50:
+ logger.warning("Skipping OG for page with too many 'og:' tags")
+ return {}
+ og[tag.attrib["property"]] = tag.attrib["content"]
+
+ # TODO: grab article: meta tags too, e.g.:
+
+ # "article:publisher" : "https://www.facebook.com/thethudonline" />
+ # "article:author" content="https://www.facebook.com/thethudonline" />
+ # "article:tag" content="baby" />
+ # "article:section" content="Breaking News" />
+ # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+ # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+ if "og:title" not in og:
+ # do some basic spidering of the HTML
+ title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
+ if title and title[0].text is not None:
+ og["og:title"] = title[0].text.strip()
+ else:
+ og["og:title"] = None
+
+ if "og:image" not in og:
+ # TODO: extract a favicon failing all else
+ meta_image = tree.xpath(
+ "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+ )
+ if meta_image:
+ og["og:image"] = rebase_url(meta_image[0], media_uri)
+ else:
+ # TODO: consider inlined CSS styles as well as width & height attribs
+ images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+ images = sorted(
+ images,
+ key=lambda i: (
+ -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+ ),
+ )
+ if not images:
+ images = tree.xpath("//img[@src]")
+ if images:
+ og["og:image"] = images[0].attrib["src"]
+
+ if "og:description" not in og:
+ meta_description = tree.xpath(
+ "//*/meta"
+ "[translate(@name, 'DESCRIPTION', 'description')='description']"
+ "/@content"
+ )
+ if meta_description:
+ og["og:description"] = meta_description[0]
+ else:
+ og["og:description"] = parse_html_description(tree)
+ elif og["og:description"]:
+ # This must be a non-empty string at this point.
+ assert isinstance(og["og:description"], str)
+ og["og:description"] = summarize_paragraphs([og["og:description"]])
+
+ # TODO: delete the url downloads to stop diskfilling,
+ # as we only ever cared about its OG
+ return og
+
+
+def parse_html_description(tree: "etree.Element") -> Optional[str]:
+ """
+ Calculate a text description based on an HTML document.
+
+ Grabs any text nodes which are inside the tag, unless they are within
+ an HTML5 semantic markup tag (, , , ), or
+ if they are within a or tag.
+
+ This is a very very very coarse approximation to a plain text render of the page.
+
+ Args:
+ tree: The parsed HTML document.
+
+ Returns:
+ The plain text description, or None if one cannot be generated.
+ """
+ # We don't just use XPATH here as that is slow on some machines.
+
+ from lxml import etree
+
+ TAGS_TO_REMOVE = (
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment,
+ )
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r"\s+", "\n", el).strip()
+ for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+ )
+ return summarize_paragraphs(text_nodes)
+
+
+def _iterate_over_text(
+ tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
+ """Iterate over the tree returning text nodes in a depth first fashion,
+ skipping text nodes inside certain tags.
+ """
+ # This is basically a stack that we extend using itertools.chain.
+ # This will either consist of an element to iterate over *or* a string
+ # to be returned.
+ elements = iter([tree])
+ while True:
+ el = next(elements, None)
+ if el is None:
+ return
+
+ if isinstance(el, str):
+ yield el
+ elif el.tag not in tags_to_ignore:
+ # el.text is the text before the first child, so we can immediately
+ # return it if the text exists.
+ if el.text:
+ yield el.text
+
+ # We add to the stack all the elements children, interspersed with
+ # each child's tail text (if it exists). The tail text of a node
+ # is text that comes *after* the node, so we always include it even
+ # if we ignore the child node.
+ elements = itertools.chain(
+ itertools.chain.from_iterable( # Basically a flatmap
+ [child, child.tail] if child.tail else [child]
+ for child in el.iterchildren()
+ ),
+ elements,
+ )
+
+
+def rebase_url(url: str, base: str) -> str:
+ base_parts = list(urlparse.urlparse(base))
+ url_parts = list(urlparse.urlparse(url))
+ if not url_parts[0]: # fix up schema
+ url_parts[0] = base_parts[0] or "http"
+ if not url_parts[1]: # fix up hostname
+ url_parts[1] = base_parts[1]
+ if not url_parts[2].startswith("/"):
+ url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+ return urlparse.urlunparse(url_parts)
+
+
+def summarize_paragraphs(
+ text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
+ """
+ Try to get a summary respecting first paragraph and then word boundaries.
+
+ Args:
+ text_nodes: The paragraphs to summarize.
+ min_size: The minimum number of words to include.
+ max_size: The maximum number of words to include.
+
+ Returns:
+ A summary of the text nodes, or None if that was not possible.
+ """
+
+ # TODO: Respect sentences?
+
+ description = ""
+
+ # Keep adding paragraphs until we get to the MIN_SIZE.
+ for text_node in text_nodes:
+ if len(description) < min_size:
+ text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+ description += text_node + "\n\n"
+ else:
+ break
+
+ description = description.strip()
+ description = re.sub(r"[\t ]+", " ", description)
+ description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
+
+ # If the concatenation of paragraphs to get above MIN_SIZE
+ # took us over MAX_SIZE, then we need to truncate mid paragraph
+ if len(description) > max_size:
+ new_desc = ""
+
+ # This splits the paragraph into words, but keeping the
+ # (preceding) whitespace intact so we can easily concat
+ # words back together.
+ for match in re.finditer(r"\s*\S+", description):
+ word = match.group()
+
+ # Keep adding words while the total length is less than
+ # MAX_SIZE.
+ if len(word) + len(new_desc) < max_size:
+ new_desc += word
+ else:
+ # At this point the next word *will* take us over
+ # MAX_SIZE, but we also want to ensure that its not
+ # a huge word. If it is add it anyway and we'll
+ # truncate later.
+ if len(new_desc) < min_size:
+ new_desc += word
+ break
+
+ # Double check that we're not over the limit
+ if len(new_desc) > max_size:
+ new_desc = new_desc[:max_size]
+
+ # We always add an ellipsis because at the very least
+ # we chopped mid paragraph.
+ description = new_desc.strip() + "…"
+ return description if description else None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 054f3c296da2..e8881bc8709e 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,18 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import codecs
import datetime
import errno
import fnmatch
-import itertools
import logging
import os
import re
import shutil
import sys
import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from urllib import parse as urlparse
import attr
@@ -45,6 +43,11 @@
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.preview_html import (
+ decode_body,
+ parse_html_to_open_graph,
+ rebase_url,
+)
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -54,21 +57,11 @@
from ._base import FileInfo
if TYPE_CHECKING:
- from lxml import etree
-
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-_charset_match = re.compile(
- br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
-)
-_xml_encoding_match = re.compile(
- br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
-)
-_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
-
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
@@ -269,6 +262,7 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
# The number of milliseconds that the response should be considered valid.
expiration_ms = media_info.expires
+ author_name: Optional[str] = None
if _is_media(media_info.media_type):
file_id = media_info.filesystem_id
@@ -301,17 +295,25 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
# Check if this HTML document points to oEmbed information and
# defer to that.
oembed_url = self._oembed.autodiscover_from_html(tree)
- og = {}
+ og_from_oembed: JsonDict = {}
if oembed_url:
oembed_info = await self._download_url(oembed_url, user)
- og, expiration_ms = await self._handle_oembed_response(
+ (
+ og_from_oembed,
+ author_name,
+ expiration_ms,
+ ) = await self._handle_oembed_response(
url, oembed_info, expiration_ms
)
- # If there was no oEmbed URL (or oEmbed parsing failed), attempt
- # to generate the Open Graph information from the HTML.
- if not oembed_url or not og:
- og = _calc_og(tree, media_info.uri)
+ # Parse Open Graph information from the HTML in case the oEmbed
+ # response failed or is incomplete.
+ og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+
+ # Compile the Open Graph response by using the scraped
+ # information from the HTML and overlaying any information
+ # from the oEmbed response.
+ og = {**og_from_html, **og_from_oembed}
await self._precache_image_url(user, media_info, og)
else:
@@ -319,7 +321,7 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
elif oembed_url:
# Handle the oEmbed information.
- og, expiration_ms = await self._handle_oembed_response(
+ og, author_name, expiration_ms = await self._handle_oembed_response(
url, media_info, expiration_ms
)
await self._precache_image_url(user, media_info, og)
@@ -328,6 +330,11 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
logger.warning("Failed to find any OG data in %s", url)
og = {}
+ # If we don't have a title but we have author_name, copy it as
+ # title
+ if not og.get("og:title") and author_name:
+ og["og:title"] = author_name
+
# filter out any stupidly long values
keys_to_remove = []
for k, v in og.items():
@@ -468,7 +475,7 @@ async def _precache_image_url(
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
image_info = await self._download_url(
- _rebase_url(og["og:image"], media_info.uri), user
+ rebase_url(og["og:image"], media_info.uri), user
)
if _is_media(image_info.media_type):
@@ -491,7 +498,7 @@ async def _precache_image_url(
async def _handle_oembed_response(
self, url: str, media_info: MediaInfo, expiration_ms: int
- ) -> Tuple[JsonDict, int]:
+ ) -> Tuple[JsonDict, Optional[str], int]:
"""
Parse the downloaded oEmbed info.
@@ -504,11 +511,12 @@ async def _handle_oembed_response(
Returns:
A tuple of:
The Open Graph dictionary, if the oEmbed info can be parsed.
+ The author name if it could be retrieved from oEmbed.
The (possibly updated) length of time, in milliseconds, the media is valid for.
"""
# If JSON was not returned, there's nothing to do.
if not _is_json(media_info.media_type):
- return {}, expiration_ms
+ return {}, None, expiration_ms
with open(media_info.filename, "rb") as file:
body = file.read()
@@ -520,7 +528,7 @@ async def _handle_oembed_response(
if open_graph_result and oembed_response.cache_age is not None:
expiration_ms = oembed_response.cache_age
- return open_graph_result, expiration_ms
+ return open_graph_result, oembed_response.author_name, expiration_ms
def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
@@ -632,301 +640,6 @@ def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
logger.debug("No media removed from url cache")
-def _normalise_encoding(encoding: str) -> Optional[str]:
- """Use the Python codec's name as the normalised entry."""
- try:
- return codecs.lookup(encoding).name
- except LookupError:
- return None
-
-
-def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
- """
- Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
-
- The precedence used for finding a character encoding is:
-
- 1. tag with a charset declared.
- 2. The XML document's character encoding attribute.
- 3. The Content-Type header.
- 4. Fallback to utf-8.
- 5. Fallback to windows-1252.
-
- This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
-
- Args:
- body: The HTML document, as bytes.
- content_type: The Content-Type header.
-
- Returns:
- The character encoding of the body, as a string.
- """
- # There's no point in returning an encoding more than once.
- attempted_encodings: Set[str] = set()
-
- # Limit searches to the first 1kb, since it ought to be at the top.
- body_start = body[:1024]
-
- # Check if it has an encoding set in a meta tag.
- match = _charset_match.search(body_start)
- if match:
- encoding = _normalise_encoding(match.group(1).decode("ascii"))
- if encoding:
- attempted_encodings.add(encoding)
- yield encoding
-
- # TODO Support
-
- # Check if it has an XML document with an encoding.
- match = _xml_encoding_match.match(body_start)
- if match:
- encoding = _normalise_encoding(match.group(1).decode("ascii"))
- if encoding and encoding not in attempted_encodings:
- attempted_encodings.add(encoding)
- yield encoding
-
- # Check the HTTP Content-Type header for a character set.
- if content_type:
- content_match = _content_type_match.match(content_type)
- if content_match:
- encoding = _normalise_encoding(content_match.group(1))
- if encoding and encoding not in attempted_encodings:
- attempted_encodings.add(encoding)
- yield encoding
-
- # Finally, fallback to UTF-8, then windows-1252.
- for fallback in ("utf-8", "cp1252"):
- if fallback not in attempted_encodings:
- yield fallback
-
-
-def decode_body(
- body: bytes, uri: str, content_type: Optional[str] = None
-) -> Optional["etree.Element"]:
- """
- This uses lxml to parse the HTML document.
-
- Args:
- body: The HTML document, as bytes.
- uri: The URI used to download the body.
- content_type: The Content-Type header.
-
- Returns:
- The parsed HTML body, or None if an error occurred during processed.
- """
- # If there's no body, nothing useful is going to be found.
- if not body:
- return None
-
- # The idea here is that multiple encodings are tried until one works.
- # Unfortunately the result is never used and then LXML will decode the string
- # again with the found encoding.
- for encoding in get_html_media_encodings(body, content_type):
- try:
- body.decode(encoding)
- except Exception:
- pass
- else:
- break
- else:
- logger.warning("Unable to decode HTML body for %s", uri)
- return None
-
- from lxml import etree
-
- # Create an HTML parser.
- parser = etree.HTMLParser(recover=True, encoding=encoding)
-
- # Attempt to parse the body. Returns None if the body was successfully
- # parsed, but no tree was found.
- return etree.fromstring(body, parser)
-
-
-def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
- """
- Calculate metadata for an HTML document.
-
- This uses lxml to search the HTML document for Open Graph data.
-
- Args:
- tree: The parsed HTML document.
- media_url: The URI used to download the body.
-
- Returns:
- The Open Graph response as a dictionary.
- """
-
- # if we see any image URLs in the OG response, then spider them
- # (although the client could choose to do this by asking for previews of those
- # URLs to avoid DoSing the server)
-
- # "og:type" : "video",
- # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
- # "og:site_name" : "YouTube",
- # "og:video:type" : "application/x-shockwave-flash",
- # "og:description" : "Fun stuff happening here",
- # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
- # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
- # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
- # "og:video:width" : "1280"
- # "og:video:height" : "720",
- # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
-
- og: Dict[str, Optional[str]] = {}
- for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
- if "content" in tag.attrib:
- # if we've got more than 50 tags, someone is taking the piss
- if len(og) >= 50:
- logger.warning("Skipping OG for page with too many 'og:' tags")
- return {}
- og[tag.attrib["property"]] = tag.attrib["content"]
-
- # TODO: grab article: meta tags too, e.g.:
-
- # "article:publisher" : "https://www.facebook.com/thethudonline" />
- # "article:author" content="https://www.facebook.com/thethudonline" />
- # "article:tag" content="baby" />
- # "article:section" content="Breaking News" />
- # "article:published_time" content="2016-03-31T19:58:24+00:00" />
- # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
-
- if "og:title" not in og:
- # do some basic spidering of the HTML
- title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
- if title and title[0].text is not None:
- og["og:title"] = title[0].text.strip()
- else:
- og["og:title"] = None
-
- if "og:image" not in og:
- # TODO: extract a favicon failing all else
- meta_image = tree.xpath(
- "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
- )
- if meta_image:
- og["og:image"] = _rebase_url(meta_image[0], media_uri)
- else:
- # TODO: consider inlined CSS styles as well as width & height attribs
- images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
- images = sorted(
- images,
- key=lambda i: (
- -1 * float(i.attrib["width"]) * float(i.attrib["height"])
- ),
- )
- if not images:
- images = tree.xpath("//img[@src]")
- if images:
- og["og:image"] = images[0].attrib["src"]
-
- if "og:description" not in og:
- meta_description = tree.xpath(
- "//*/meta"
- "[translate(@name, 'DESCRIPTION', 'description')='description']"
- "/@content"
- )
- if meta_description:
- og["og:description"] = meta_description[0]
- else:
- og["og:description"] = _calc_description(tree)
- elif og["og:description"]:
- # This must be a non-empty string at this point.
- assert isinstance(og["og:description"], str)
- og["og:description"] = summarize_paragraphs([og["og:description"]])
-
- # TODO: delete the url downloads to stop diskfilling,
- # as we only ever cared about its OG
- return og
-
-
-def _calc_description(tree: "etree.Element") -> Optional[str]:
- """
- Calculate a text description based on an HTML document.
-
- Grabs any text nodes which are inside the tag, unless they are within
- an HTML5 semantic markup tag (, , , ), or
- if they are within a or tag.
-
- This is a very very very coarse approximation to a plain text render of the page.
-
- Args:
- tree: The parsed HTML document.
-
- Returns:
- The plain text description, or None if one cannot be generated.
- """
- # We don't just use XPATH here as that is slow on some machines.
-
- from lxml import etree
-
- TAGS_TO_REMOVE = (
- "header",
- "nav",
- "aside",
- "footer",
- "script",
- "noscript",
- "style",
- etree.Comment,
- )
-
- # Split all the text nodes into paragraphs (by splitting on new
- # lines)
- text_nodes = (
- re.sub(r"\s+", "\n", el).strip()
- for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
- )
- return summarize_paragraphs(text_nodes)
-
-
-def _iterate_over_text(
- tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
-) -> Generator[str, None, None]:
- """Iterate over the tree returning text nodes in a depth first fashion,
- skipping text nodes inside certain tags.
- """
- # This is basically a stack that we extend using itertools.chain.
- # This will either consist of an element to iterate over *or* a string
- # to be returned.
- elements = iter([tree])
- while True:
- el = next(elements, None)
- if el is None:
- return
-
- if isinstance(el, str):
- yield el
- elif el.tag not in tags_to_ignore:
- # el.text is the text before the first child, so we can immediately
- # return it if the text exists.
- if el.text:
- yield el.text
-
- # We add to the stack all the elements children, interspersed with
- # each child's tail text (if it exists). The tail text of a node
- # is text that comes *after* the node, so we always include it even
- # if we ignore the child node.
- elements = itertools.chain(
- itertools.chain.from_iterable( # Basically a flatmap
- [child, child.tail] if child.tail else [child]
- for child in el.iterchildren()
- ),
- elements,
- )
-
-
-def _rebase_url(url: str, base: str) -> str:
- base_parts = list(urlparse.urlparse(base))
- url_parts = list(urlparse.urlparse(url))
- if not url_parts[0]: # fix up schema
- url_parts[0] = base_parts[0] or "http"
- if not url_parts[1]: # fix up hostname
- url_parts[1] = base_parts[1]
- if not url_parts[2].startswith("/"):
- url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
- return urlparse.urlunparse(url_parts)
-
-
def _is_media(content_type: str) -> bool:
return content_type.lower().startswith("image/")
@@ -940,68 +653,3 @@ def _is_html(content_type: str) -> bool:
def _is_json(content_type: str) -> bool:
return content_type.lower().startswith("application/json")
-
-
-def summarize_paragraphs(
- text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
-) -> Optional[str]:
- """
- Try to get a summary respecting first paragraph and then word boundaries.
-
- Args:
- text_nodes: The paragraphs to summarize.
- min_size: The minimum number of words to include.
- max_size: The maximum number of words to include.
-
- Returns:
- A summary of the text nodes, or None if that was not possible.
- """
-
- # TODO: Respect sentences?
-
- description = ""
-
- # Keep adding paragraphs until we get to the MIN_SIZE.
- for text_node in text_nodes:
- if len(description) < min_size:
- text_node = re.sub(r"[\t \r\n]+", " ", text_node)
- description += text_node + "\n\n"
- else:
- break
-
- description = description.strip()
- description = re.sub(r"[\t ]+", " ", description)
- description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
-
- # If the concatenation of paragraphs to get above MIN_SIZE
- # took us over MAX_SIZE, then we need to truncate mid paragraph
- if len(description) > max_size:
- new_desc = ""
-
- # This splits the paragraph into words, but keeping the
- # (preceding) whitespace intact so we can easily concat
- # words back together.
- for match in re.finditer(r"\s*\S+", description):
- word = match.group()
-
- # Keep adding words while the total length is less than
- # MAX_SIZE.
- if len(word) + len(new_desc) < max_size:
- new_desc += word
- else:
- # At this point the next word *will* take us over
- # MAX_SIZE, but we also want to ensure that its not
- # a huge word. If it is add it anyway and we'll
- # truncate later.
- if len(new_desc) < min_size:
- new_desc += word
- break
-
- # Double check that we're not over the limit
- if len(new_desc) > max_size:
- new_desc = new_desc[:max_size]
-
- # We always add an ellipsis because at the very least
- # we chopped mid paragraph.
- description = new_desc.strip() + "…"
- return description if description else None
diff --git a/synapse/server.py b/synapse/server.py
index 185e40e4da0f..3032f0b738a8 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -759,7 +759,7 @@ def get_oidc_handler(self) -> "OidcHandler":
@cache_in_self
def get_event_client_serializer(self) -> EventClientSerializer:
- return EventClientSerializer(self)
+ return EventClientSerializer()
@cache_in_self
def get_password_policy_handler(self) -> PasswordPolicyHandler:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 446204dbe52f..67e8bc6ec288 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
import logging
-from collections import defaultdict, namedtuple
+from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
@@ -45,7 +45,6 @@
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import ContextResourceUsage
-from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
@@ -69,9 +68,6 @@
)
-KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-
-
EVICTION_TIMEOUT_SECONDS = 60 * 60
@@ -453,19 +449,19 @@ async def resolve_events(
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _StateResMetrics:
"""Keeps track of some usage metrics about state res."""
# System and User CPU time, in seconds
- cpu_time = attr.ib(type=float, default=0.0)
+ cpu_time: float = 0.0
# time spent on database transactions (excluding scheduling time). This roughly
# corresponds to the amount of work done on the db server, excluding event fetches.
- db_time = attr.ib(type=float, default=0.0)
+ db_time: float = 0.0
# number of events fetched from the db.
- db_events = attr.ib(type=int, default=0)
+ db_events: int = 0
_biggest_room_by_cpu_counter = Counter(
@@ -515,7 +511,6 @@ def __init__(self, hs: "HomeServer"):
self.clock.looping_call(self._report_metrics, 120 * 1000)
- @log_function
async def resolve_state_groups(
self,
room_id: str,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3056e64ff570..7967011afdc0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,10 +17,8 @@
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
-from synapse.storage.database import LoggingTransaction # noqa: F401
-from synapse.storage.database import make_in_list_sql_clause # noqa: F401
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
@@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0693d390064f..57cc1d76e02f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import time
+import types
from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
@@ -53,6 +55,7 @@
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -140,7 +143,7 @@ def make_conn(
return db_conn
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class LoggingDatabaseConnection:
"""A wrapper around a database connection that returns `LoggingTransaction`
as its cursor class.
@@ -148,9 +151,9 @@ class LoggingDatabaseConnection:
This is mainly used on startup to ensure that queries get logged correctly
"""
- conn = attr.ib(type=Connection)
- engine = attr.ib(type=BaseDatabaseEngine)
- default_txn_name = attr.ib(type=str)
+ conn: Connection
+ engine: BaseDatabaseEngine
+ default_txn_name: str
def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
@@ -175,7 +178,7 @@ def commit(self) -> None:
def rollback(self) -> None:
self.conn.rollback()
- def __enter__(self) -> "Connection":
+ def __enter__(self) -> "LoggingDatabaseConnection":
self.conn.__enter__()
return self
@@ -526,6 +529,12 @@ def new_transaction(
the function will correctly handle being aborted and retried half way
through its execution.
+ Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+ since they could be evaluated multiple times (which would produce an empty
+ result on the second or subsequent evaluation). Likewise, the closure of `func`
+ must not reference any generators. This method attempts to detect such usage
+ and will log an error.
+
Args:
conn
desc
@@ -536,6 +545,39 @@ def new_transaction(
**kwargs
"""
+ # Robustness check: ensure that none of the arguments are generators, since that
+ # will fail if we have to repeat the transaction.
+ # For now, we just log an error, and hope that it works on the first attempt.
+ # TODO: raise an exception.
+ for i, arg in enumerate(args):
+ if inspect.isgenerator(arg):
+ logger.error(
+ "Programming error: generator passed to new_transaction as "
+ "argument %i to function %s",
+ i,
+ func,
+ )
+ for name, val in kwargs.items():
+ if inspect.isgenerator(val):
+ logger.error(
+ "Programming error: generator passed to new_transaction as "
+ "argument %s to function %s",
+ name,
+ func,
+ )
+ # also check variables referenced in func's closure
+ if inspect.isfunction(func):
+ f = cast(types.FunctionType, func)
+ if f.__closure__:
+ for i, cell in enumerate(f.__closure__):
+ if inspect.isgenerator(cell.cell_contents):
+ logger.error(
+ "Programming error: function %s references generator %s "
+ "via its closure",
+ f,
+ f.__code__.co_freevars[i],
+ )
+
start = monotonic_time()
txn_id = self._TXN_ID
@@ -892,64 +934,63 @@ def simple_insert_txn(
txn.execute(sql, vals)
async def simple_insert_many(
- self, table: str, values: List[Dict[str, Any]], desc: str
+ self,
+ table: str,
+ keys: Collection[str],
+ values: Collection[Collection[Any]],
+ desc: str,
) -> None:
"""Executes an INSERT query on the named table.
+ The input is given as a list of rows, where each row is a list of values.
+ (Actually any iterable is fine.)
+
Args:
table: string giving the table name
- values: dict of new column names and values for them
+ keys: list of column names
+ values: for each row, a list of values in the same order as `keys`
desc: description of the transaction, for logging and metrics
"""
- await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+ await self.runInteraction(
+ desc, self.simple_insert_many_txn, table, keys, values
+ )
@staticmethod
def simple_insert_many_txn(
- txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+ txn: LoggingTransaction,
+ table: str,
+ keys: Collection[str],
+ values: Iterable[Iterable[Any]],
) -> None:
"""Executes an INSERT query on the named table.
+ The input is given as a list of rows, where each row is a list of values.
+ (Actually any iterable is fine.)
+
Args:
txn: The transaction to use.
table: string giving the table name
- values: dict of new column names and values for them
+ keys: list of column names
+ values: for each row, a list of values in the same order as `keys`
"""
- if not values:
- return
-
- # This is a *slight* abomination to get a list of tuples of key names
- # and a list of tuples of value names.
- #
- # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
- # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
- #
- # The sort is to ensure that we don't rely on dictionary iteration
- # order.
- keys, vals = zip(
- *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
- )
-
- for k in keys:
- if k != keys[0]:
- raise RuntimeError("All items must have the same keys")
if isinstance(txn.database_engine, PostgresEngine):
# We use `execute_values` as it can be a lot faster than `execute_batch`,
# but it's only available on postgres.
sql = "INSERT INTO %s (%s) VALUES ?" % (
table,
- ", ".join(k for k in keys[0]),
+ ", ".join(k for k in keys),
)
- txn.execute_values(sql, vals, fetch=False)
+ txn.execute_values(sql, values, fetch=False)
else:
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys),
)
- txn.execute_batch(sql, vals)
+ txn.execute_batch(sql, values)
async def simple_upsert(
self,
@@ -1177,9 +1218,9 @@ async def simple_upsert_many(
self,
table: str,
key_names: Collection[str],
- key_values: Collection[Iterable[Any]],
+ key_values: Collection[Collection[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[Any]],
+ value_values: Collection[Collection[Any]],
desc: str,
) -> None:
"""
@@ -1337,7 +1378,7 @@ async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
@@ -1348,7 +1389,7 @@ async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
@@ -1358,7 +1399,7 @@ async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
@@ -1528,7 +1569,7 @@ async def simple_select_list(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
- retcols: Iterable[str],
+ retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
@@ -1591,7 +1632,7 @@ async def simple_select_many_batch(
table: str,
column: str,
iterable: Iterable[Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
@@ -1614,16 +1655,7 @@ async def simple_select_many_batch(
results: List[Dict[str, Any]] = []
- if not iterable:
- return results
-
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
-
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
+ for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
@@ -1763,7 +1795,7 @@ def simple_select_one_txn(
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1871,7 +1903,7 @@ async def simple_delete_many(
self,
table: str,
column: str,
- iterable: Iterable[Any],
+ iterable: Collection[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> int:
@@ -1882,7 +1914,8 @@ async def simple_delete_many(
Args:
table: string giving the table name
column: column name to test for inclusion against `iterable`
- iterable: list
+ iterable: list of values to match against `column`. NB cannot be a generator
+ as it may be evaluated multiple times.
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
@@ -2055,7 +2088,7 @@ async def simple_search_list(
table: str,
term: Optional[str],
col: str,
- retcols: Iterable[str],
+ retcols: Collection[str],
desc="simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c35a..f024761ba7b8 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -68,7 +68,7 @@
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
-from .stream import StreamStore
+from .stream import StreamWorkerStore
from .tags import TagsStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
@@ -87,7 +87,7 @@ class DataStore(
RoomStore,
RoomBatchStore,
RegistrationStore,
- StreamStore,
+ StreamWorkerStore,
ProfileStore,
PresenceStore,
TransactionWorkerStore,
@@ -129,7 +129,12 @@ class DataStore(
LockStore,
SessionStore,
):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
@@ -143,11 +148,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
("device_lists_outbound_pokes", "stream_id"),
],
)
- self._cross_signing_id_gen = StreamIdGenerator(
- db_conn, "e2e_cross_signing_keys", "stream_id"
- )
- self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac41..ef475e18c788 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage._base import db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -34,13 +44,19 @@
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(SQLBaseStore):
- """This is an abstract base class where subclasses must implement
- `get_max_account_data_stream_id` which can be called in the initializer.
- """
+class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
- self._instance_name = hs.get_instance_name()
+ # `_can_write_to_account_data` indicates whether the current worker is allowed
+ # to write account data. A value of `True` implies that `_account_data_id_gen`
+ # is an `AbstractStreamIdGenerator` and not just a tracker.
+ self._account_data_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
@@ -61,8 +77,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
writers=hs.config.worker.writers.account_data,
)
else:
- self._can_write_to_account_data = True
-
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
@@ -70,7 +84,8 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
- if hs.get_instance_name() in hs.config.worker.writers.account_data:
+ if self._instance_name in hs.config.worker.writers.account_data:
+ self._can_write_to_account_data = True
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
@@ -90,8 +105,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
"AccountDataAndTagsChangeCache", account_max
)
- super().__init__(database, db_conn, hs)
-
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
@@ -113,7 +126,9 @@ async def get_account_data_for_user(
room_id string to per room account_data dicts.
"""
- def get_account_data_for_user_txn(txn):
+ def get_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
rows = self.db_pool.simple_select_list_txn(
txn,
"account_data",
@@ -132,7 +147,7 @@ def get_account_data_for_user_txn(txn):
["room_id", "account_data_type", "content"],
)
- by_room = {}
+ by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = db_to_json(row["content"])
@@ -177,7 +192,9 @@ async def get_account_data_for_room(
A dict of the room account_data
"""
- def get_account_data_for_room_txn(txn):
+ def get_account_data_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
@@ -207,7 +224,9 @@ async def get_account_data_for_room_and_type(
The room account_data for that type, or None if there isn't any set.
"""
- def get_account_data_for_room_and_type_txn(txn):
+ def get_account_data_for_room_and_type_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[JsonDict]:
content_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_account_data",
@@ -243,14 +262,16 @@ async def get_updated_global_account_data(
if last_id == current_id:
return []
- def get_updated_global_account_data_txn(txn):
+ def get_updated_global_account_data_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str]]:
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
@@ -273,14 +294,16 @@ async def get_updated_room_account_data(
if last_id == current_id:
return []
- def get_updated_room_account_data_txn(txn):
+ def get_updated_room_account_data_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str]]:
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
@@ -299,7 +322,9 @@ async def get_updated_account_data_for_user(
mapping from room_id string to per room account_data dicts.
"""
- def get_updated_account_data_for_user_txn(txn):
+ def get_updated_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
sql = (
"SELECT account_data_type, content FROM account_data"
" WHERE user_id = ? AND stream_id > ?"
@@ -316,7 +341,7 @@ def get_updated_account_data_for_user_txn(txn):
txn.execute(sql, (user_id, stream_id))
- account_data_by_room = {}
+ account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
@@ -353,12 +378,15 @@ async def ignored_by(self, user_id: str) -> Set[str]:
)
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
@@ -372,7 +400,8 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -389,6 +418,7 @@ async def add_account_data_to_room(
The maximum stream ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -420,7 +450,7 @@ async def add_account_data_to_room(
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
- """Add some account_data to a room for a user.
+ """Add some global account_data for a user.
Args:
user_id: The user to add a tag for.
@@ -431,6 +461,7 @@ async def add_account_data_for_user(
The maximum stream ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
@@ -452,7 +483,7 @@ async def add_account_data_for_user(
def _add_account_data_for_user(
self,
- txn,
+ txn: LoggingTransaction,
next_id: int,
user_id: str,
account_data_type: str,
@@ -505,9 +536,9 @@ def _add_account_data_for_user(
self.db_pool.simple_insert_many_txn(
txn,
table="ignored_users",
+ keys=("ignorer_user_id", "ignored_user_id"),
values=[
- {"ignorer_user_id": user_id, "ignored_user_id": u}
- for u in currently_ignored_users - previously_ignored_users
+ (user_id, u) for u in currently_ignored_users - previously_ignored_users
],
)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc16647..92c95a41d793 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -58,7 +57,12 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc63b..0024348067d5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -41,7 +41,12 @@
class CacheInvalidationWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220d0..fd3fc298b37a 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
@@ -31,7 +35,12 @@
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if (
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 1dc7f0ebe346..8b0c614ecef7 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
-from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
class ClientIpBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -394,7 +398,12 @@ def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -532,7 +541,12 @@ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b62..4eca97189bef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -432,14 +432,21 @@ def add_messages_txn(txn, now_ms, stream_id):
self.db_pool.simple_insert_many_txn(
txn,
table="device_federation_outbox",
+ keys=(
+ "destination",
+ "stream_id",
+ "queued_ts",
+ "messages_json",
+ "instance_name",
+ ),
values=[
- {
- "destination": destination,
- "stream_id": stream_id,
- "queued_ts": now_ms,
- "messages_json": json_encoder.encode(edu),
- "instance_name": self._instance_name,
- }
+ (
+ destination,
+ stream_id,
+ now_ms,
+ json_encoder.encode(edu),
+ self._instance_name,
+ )
for destination, edu in remote_messages_by_destination.items()
],
)
@@ -571,14 +578,9 @@ def _add_messages_to_local_device_inbox_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="device_inbox",
+ keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"),
values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "message_json": message_json,
- "instance_name": self._instance_name,
- }
+ (user_id, device_id, stream_id, message_json, self._instance_name)
for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items()
],
@@ -601,7 +603,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -668,7 +675,7 @@ def _remove_dead_devices_from_device_inbox_txn(
# There's a type mismatch here between how we want to type the row and
# what fetchone says it returns, but we silence it because we know that
# res can't be None.
- res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
+ res = cast(Tuple[Optional[int]], txn.fetchone())
if res[0] is None:
# this can only happen if the `device_inbox` table is empty, in which
# case we have no work to do.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd1a..b2a5cd9a6508 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@@ -52,6 +53,7 @@
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
"drop_device_list_streams_non_unique_indexes"
@@ -61,7 +63,12 @@
class DeviceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -101,7 +108,9 @@ def count_devices_by_users_txn(txn, user_ids):
"count_devices_by_users", count_devices_by_users_txn, user_ids
)
- async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+ async def get_device(
+ self, user_id: str, device_id: str
+ ) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -109,15 +118,35 @@ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
- A dict containing the device information
- Raises:
- StoreError: if the device is not found
+ A dict containing the device information, or `None` if the device does not
+ exist.
+ """
+ return await self.db_pool.simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ allow_none=True,
+ )
+
+ async def get_device_opt(
+ self, user_id: str, device_id: str
+ ) -> Optional[Dict[str, Any]]:
+ """Retrieve a device. Only returns devices that are not marked as
+ hidden.
+
+ Args:
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to retrieve
+ Returns:
+ A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
+ allow_none=True,
)
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
@@ -163,7 +192,7 @@ async def get_devices_by_auth_provider_session_id(
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
- ) -> Tuple[int, List[Tuple[str, dict]]]:
+ ) -> Tuple[int, List[Tuple[str, JsonDict]]]:
"""Get a stream of device updates to send to the given remote server.
Args:
@@ -172,9 +201,10 @@ async def get_device_updates_by_remote(
limit: Maximum number of device updates to return
Returns:
- A mapping from the current stream id (ie, the stream id of the last
- update included in the response), and the list of updates, where
- each update is a pair of EDU type and EDU contents.
+ - The current stream id (i.e. the stream id of the last update included
+ in the response); and
+ - The list of updates, where each update is a pair of EDU type and
+ EDU contents.
"""
now_stream_id = self.get_device_stream_token()
@@ -193,10 +223,19 @@ async def get_device_updates_by_remote(
limit,
)
+ # We need to ensure `updates` doesn't grow too big.
+ # Currently: `len(updates) <= limit`.
+
# Return an empty list if there are no updates
if not updates:
return now_stream_id, []
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ data = {(user, device): stream_id for user, device, stream_id, _ in updates}
+ issue_8631_logger.debug(
+ "device updates need to be sent to %s: %s", destination, data
+ )
+
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
users = {r[0] for r in updates}
@@ -242,19 +281,50 @@ async def get_device_updates_by_remote(
# The most recent request's opentracing_context is used as the
# context which created the Edu.
+ # This is the stream ID that we will return for the consumer to resume
+ # following this stream later.
+ last_processed_stream_id = from_stream_id
+
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if (
+ # Calculate the remaining length budget.
+ # Note that, for now, each entry in `cross_signing_keys_by_user`
+ # gives rise to two device updates in the result, so those cost twice
+ # as much (and are the whole reason we need to separately calculate
+ # the budget; we know len(updates) <= limit otherwise!)
+ # N.B. len() on dicts is cheap since they store their size.
+ remaining_length_budget = limit - (
+ len(query_map) + 2 * len(cross_signing_keys_by_user)
+ )
+ assert remaining_length_budget >= 0
+
+ is_master_key_update = (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
- ):
- result = cross_signing_keys_by_user.setdefault(user_id, {})
- result["master_key"] = master_key_by_user[user_id]["key_info"]
- elif (
+ )
+ is_self_signing_key_update = (
user_id in self_signing_key_by_user
and device_id == self_signing_key_by_user[user_id]["device_id"]
+ )
+
+ is_cross_signing_key_update = (
+ is_master_key_update or is_self_signing_key_update
+ )
+
+ if (
+ is_cross_signing_key_update
+ and user_id not in cross_signing_keys_by_user
):
+ # This will give rise to 2 device updates.
+ # If we don't have the budget, stop here!
+ if remaining_length_budget < 2:
+ break
+
+ if is_master_key_update:
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif is_self_signing_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
@@ -262,22 +332,58 @@ async def get_device_updates_by_remote(
else:
key = (user_id, device_id)
+ if key not in query_map and remaining_length_budget < 1:
+ # We don't have space for a new entry
+ break
+
previous_update_stream_id, _ = query_map.get(key, (0, None))
if update_stream_id > previous_update_stream_id:
+ # FIXME If this overwrites an older update, this discards the
+ # previous OpenTracing context.
+ # It might make it harder to track down issues using OpenTracing.
+ # If there's a good reason why it doesn't matter, a comment here
+ # about that would not hurt.
query_map[key] = (update_stream_id, update_context)
+ # As this update has been added to the response, advance the stream
+ # position.
+ last_processed_stream_id = update_stream_id
+
+ # In the worst case scenario, each update is for a distinct user and is
+ # added either to the query_map or to cross_signing_keys_by_user,
+ # but not both:
+ # len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
+ # so len(query_map) + len(cross_signing_keys_by_user) <= limit.
+
results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
- # add the updated cross-signing keys to the results list
+ # len(results) <= len(query_map) here,
+ # so len(results) + len(cross_signing_keys_by_user) <= limit.
+
+ # Add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("m.signing_key_update", result))
+ # also send the unstable version
+ # FIXME: remove this when enough servers have upgraded
+ # and remove the length budgeting above.
results.append(("org.matrix.signing_key_update", result))
- return now_stream_id, results
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ for (user_id, edu) in results:
+ issue_8631_logger.debug(
+ "device update to %s for %s from %s to %s: %s",
+ destination,
+ user_id,
+ from_stream_id,
+ last_processed_stream_id,
+ edu,
+ )
+
+ return last_processed_stream_id, results
def _get_device_updates_by_remote_txn(
self,
@@ -286,7 +392,7 @@ def _get_device_updates_by_remote_txn(
from_stream_id: int,
now_stream_id: int,
limit: int,
- ):
+ ) -> List[Tuple[str, str, int, Optional[str]]]:
"""Return device update information for a given remote destination
Args:
@@ -297,7 +403,11 @@ def _get_device_updates_by_remote_txn(
limit: Maximum number of device updates to return
Returns:
- List: List of device updates
+ List: List of device update tuples:
+ - user_id
+ - device_id
+ - stream_id
+ - opentracing_context
"""
# get the list of device updates that need to be sent
sql = """
@@ -321,15 +431,21 @@ async def _get_device_update_edus_by_remote(
Args:
destination: The host the device updates are intended for
from_stream_id: The minimum stream_id to filter updates by, exclusive
- query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
- user_id/device_id to update stream_id and the relevant json-encoded
- opentracing context
+ query_map: Dictionary mapping (user_id, device_id) to
+ (update stream_id, the relevant json-encoded opentracing context)
Returns:
- List of objects representing an device update EDU
+ List of objects representing a device update EDU.
+
+ Postconditions:
+ The returned list has a length not exceeding that of the query_map:
+ len(result) <= len(query_map)
"""
devices = (
await self.get_e2e_device_keys_and_signatures(
+ # Because these are (user_id, device_id) tuples with all
+ # device_ids not being None, the returned list's length will not
+ # exceed that of query_map.
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -683,7 +799,7 @@ def _get_all_device_list_changes_for_remotes(txn):
@cached(max_entries=10000)
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
- ) -> Optional[Any]:
+ ) -> Optional[str]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
@@ -699,7 +815,9 @@ async def get_device_list_last_stream_id_for_remote(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
- async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
+ async def get_device_list_last_stream_id_for_remotes(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@@ -949,7 +1067,12 @@ def _prune_txn(txn):
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -1081,7 +1204,12 @@ def _txn(txn):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
@@ -1276,6 +1404,7 @@ def _update_remote_device_list_cache_entry_txn(
content: JsonDict,
stream_id: str,
) -> None:
+ """Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
@@ -1335,6 +1464,7 @@ async def update_remote_device_list_cache(
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None:
+ """Replace the list of cached devices for this user with the given list."""
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
@@ -1342,12 +1472,9 @@ def _update_remote_device_list_cache_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
+ keys=("user_id", "device_id", "content"),
values=[
- {
- "user_id": user_id,
- "device_id": content["device_id"],
- "content": json_encoder.encode(content),
- }
+ (user_id, content["device_id"], json_encoder.encode(content))
for content in devices
],
)
@@ -1435,8 +1562,9 @@ def _add_device_change_to_stream_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
+ keys=("stream_id", "user_id", "device_id"),
values=[
- {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
+ (stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
@@ -1463,18 +1591,27 @@ def _add_device_outbound_poke_to_stream_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
+ keys=(
+ "destination",
+ "stream_id",
+ "user_id",
+ "device_id",
+ "sent",
+ "ts",
+ "opentracing_context",
+ ),
values=[
- {
- "destination": destination,
- "stream_id": next(next_stream_id),
- "user_id": user_id,
- "device_id": device_id,
- "sent": False,
- "ts": now,
- "opentracing_context": json_encoder.encode(context)
+ (
+ destination,
+ next(next_stream_id),
+ user_id,
+ device_id,
+ False,
+ now,
+ json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
- }
+ )
for destination in hosts
for device_id in device_ids
],
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index a3442814d77a..5903fdaf007a 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
+import attr
+
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
-RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomAliasMapping:
+ room_id: str
+ room_alias: str
+ servers: List[str]
class DirectoryWorkerStore(CacheInvalidationWorkerStore):
@@ -106,10 +112,8 @@ def alias_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_many_txn(
txn,
table="room_alias_servers",
- values=[
- {"room_alias": room_alias.to_string(), "server": server}
- for server in servers
- ],
+ keys=("room_alias", "server"),
+ values=[(room_alias.to_string(), server) for server in servers],
)
self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e6258..b789a588a54b 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
+
+from typing_extensions import Literal, TypedDict
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict, JsonSerializable
from synapse.util import json_encoder
+class RoomKey(TypedDict):
+ """`KeyBackupData` in the Matrix spec.
+
+ https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
+ """
+
+ first_message_index: int
+ forwarded_count: int
+ is_verified: bool
+ session_data: JsonSerializable
+
+
class EndToEndRoomKeyStore(SQLBaseStore):
+ """The store for end to end room key backups.
+
+ See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
+
+ As per the spec, backups are identified by an opaque version string. Internally,
+ version identifiers are assigned using incrementing integers. Non-numeric version
+ strings are treated as if they do not exist, since we would have never issued them.
+ """
+
async def update_e2e_room_key(
- self, user_id, version, room_id, session_id, room_key
- ):
+ self,
+ user_id: str,
+ version: str,
+ room_id: str,
+ session_id: str,
+ room_key: RoomKey,
+ ) -> None:
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_id(str): the ID of the room whose keys we're setting
- session_id(str): the session whose room_key we're setting
- room_key(dict): the room_key being set
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_id: the ID of the room whose keys we're setting
+ session_id: the session whose room_key we're setting
+ room_key: the room_key being set
Raises:
StoreError
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
- "version": version,
+ "version": version_int,
"room_id": room_id,
"session_id": session_id,
},
@@ -54,29 +90,36 @@ async def update_e2e_room_key(
desc="update_e2e_room_key",
)
- async def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(
+ self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
+ ) -> None:
"""Bulk add room keys to a given backup.
Args:
- user_id (str): the user whose backup we're adding to
- version (str): the version ID of the backup for the set of keys we're adding to
- room_keys (iterable[(str, str, dict)]): the keys to add, in the form
- (roomID, sessionID, keyData)
+ user_id: the user whose backup we're adding to
+ version: the version ID of the backup for the set of keys we're adding to
+ room_keys: the keys to add, in the form (roomID, sessionID, keyData)
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
- {
- "user_id": user_id,
- "version": version,
- "room_id": room_id,
- "session_id": session_id,
- "first_message_index": room_key["first_message_index"],
- "forwarded_count": room_key["forwarded_count"],
- "is_verified": room_key["is_verified"],
- "session_data": json_encoder.encode(room_key["session_data"]),
- }
+ (
+ user_id,
+ version_int,
+ room_id,
+ session_id,
+ room_key["first_message_index"],
+ room_key["forwarded_count"],
+ room_key["is_verified"],
+ json_encoder.encode(room_key["session_data"]),
+ )
)
log_kv(
{
@@ -88,35 +131,55 @@ async def add_e2e_room_keys(self, user_id, version, room_keys):
)
await self.db_pool.simple_insert_many(
- table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
+ table="e2e_room_keys",
+ keys=(
+ "user_id",
+ "version",
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ values=values,
+ desc="add_e2e_room_keys",
)
@trace
- async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup for the set of keys we're querying
- room_id (str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup for the set of keys we're querying
+ room_id: Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id (str): Optional. the session whose room_key we're querying, if any.
+ session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
Returns:
- A list of dicts giving the session_data and message metadata for
- these room keys.
+ A dict giving the session_data and message metadata for these room keys.
+ `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
"""
try:
- version = int(version)
+ version_int = int(version)
except ValueError:
return {"rooms": {}}
- keyvalues = {"user_id": user_id, "version": version}
+ keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
@@ -137,7 +200,9 @@ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=Non
desc="get_e2e_room_keys",
)
- sessions = {"rooms": {}}
+ sessions: Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
@@ -150,7 +215,12 @@ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=Non
return sessions
- async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ async def get_e2e_room_keys_multi(
+ self,
+ user_id: str,
+ version: str,
+ room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+ ) -> Dict[str, Dict[str, RoomKey]]:
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -158,26 +228,36 @@ async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
specific key.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
- room_keys (dict[str, dict[str, iterable[str]]]): a map from
- room ID -> {"session": [session ids]} indicating the session IDs
- that we want to query
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
+ room_keys: a map from room ID -> {"sessions": [session ids]}
+ indicating the session IDs that we want to query
Returns:
- dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
+ A map of room IDs to session IDs to room key
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return {}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
- version,
+ version_int,
room_keys,
)
@staticmethod
- def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ def _get_e2e_room_keys_multi_txn(
+ txn: LoggingTransaction,
+ user_id: str,
+ version: int,
+ room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+ ) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys:
return {}
@@ -209,7 +289,7 @@ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
txn.execute(sql, params)
- ret = {}
+ ret: Dict[str, Dict[str, RoomKey]] = {}
for row in txn:
room_id = row[0]
@@ -231,36 +311,49 @@ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return 0
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
- keyvalues={"user_id": user_id, "version": version},
+ keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
@trace
async def delete_e2e_room_keys(
- self, user_id, version, room_id=None, session_id=None
- ):
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> None:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
Args:
- user_id(str): the user whose backup we're deleting from
- version(str): the version ID of the backup for the set of keys we're deleting
- room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
+ user_id: the user whose backup we're deleting from
+ version: the version ID of the backup for the set of keys we're deleting
+ room_id: Optional. the ID of the room whose keys we're deleting, if any.
If not specified, we delete the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we delete all the keys in this version of
the backup (or for the specified room)
-
- Returns:
- The deletion transaction
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return
- keyvalues = {"user_id": user_id, "version": int(version)}
+ keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
@@ -271,23 +364,27 @@ async def delete_e2e_room_keys(
)
@staticmethod
- def _get_current_version(txn, user_id):
+ def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0",
(user_id,),
)
- row = txn.fetchone()
- if not row:
+ # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
+ # be `NULL` when there are no available versions.
+ row = cast(Tuple[Optional[int]], txn.fetchone())
+ if row[0] is None:
raise StoreError(404, "No current backup version")
return row[0]
- async def get_e2e_room_keys_version_info(self, user_id, version=None):
+ async def get_e2e_room_keys_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get info metadata about a version of our room_keys backup.
Args:
- user_id(str): the user whose backup we're querying
- version(str): Optional. the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: Optional. the version ID of the backup we're querying about
If missing, we return the information about the current version.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
@@ -300,7 +397,7 @@ async def get_e2e_room_keys_version_info(self, user_id, version=None):
etag(int): tag of the keys in the backup
"""
- def _get_e2e_room_keys_version_info_txn(txn):
+ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
@@ -309,14 +406,16 @@ def _get_e2e_room_keys_version_info_txn(txn):
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
- raise StoreError(404, "No row found")
+ raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
+ allow_none=False,
)
+ assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
@@ -328,28 +427,28 @@ def _get_e2e_room_keys_version_info_txn(txn):
)
@trace
- async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
+ async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
Args:
- user_id(str): the user whose backup we're creating a version
- info(dict): the info about the backup version to be created
+ user_id: the user whose backup we're creating a version
+ info: the info about the backup version to be created
Returns:
The newly created version ID
"""
- def _create_e2e_room_keys_version_txn(txn):
+ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,),
)
- current_version = txn.fetchone()[0]
+ current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
if current_version is None:
- current_version = "0"
+ current_version = 0
- new_version = str(int(current_version) + 1)
+ new_version = current_version + 1
self.db_pool.simple_insert_txn(
txn,
@@ -362,7 +461,7 @@ def _create_e2e_room_keys_version_txn(txn):
},
)
- return new_version
+ return str(new_version)
return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@@ -373,7 +472,7 @@ async def update_e2e_room_keys_version(
self,
user_id: str,
version: str,
- info: Optional[dict] = None,
+ info: Optional[JsonDict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version
@@ -386,7 +485,7 @@ async def update_e2e_room_keys_version(
version_etag: etag of the keys in the backup. If None, then the etag
is not updated.
"""
- updatevalues = {}
+ updatevalues: Dict[str, object] = {}
if info is not None and "auth_data" in info:
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@@ -394,9 +493,16 @@ async def update_e2e_room_keys_version(
updatevalues["etag"] = version_etag
if updatevalues:
- await self.db_pool.simple_update(
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
+
+ await self.db_pool.simple_update_one(
table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
+ keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
@@ -417,13 +523,16 @@ async def delete_e2e_room_keys_version(
or if the version requested doesn't exist.
"""
- def _delete_e2e_room_keys_version_txn(txn):
+ def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None:
this_version = self._get_current_version(txn, user_id)
- if this_version is None:
- raise StoreError(404, "No current backup version")
else:
- this_version = version
+ try:
+ this_version = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it isn't there.
+ raise StoreError(404, "No backup with that version exists")
self.db_pool.simple_delete_txn(
txn,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b2d..1f8447b5076f 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
import attr
from canonicaljson import encode_canonical_json
-from twisted.enterprise.adbapi import Connection
-
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -37,20 +50,25 @@
from synapse.server import HomeServer
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
- display_name = attr.ib(type=Optional[str])
+ display_name: Optional[str]
# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
- keys = attr.ib(type=Optional[JsonDict])
+ keys: Optional[JsonDict]
class EndToEndKeyBackgroundStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -62,8 +80,13 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"
)
-class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._allow_device_name_lookup_over_federation = (
@@ -124,7 +147,7 @@ async def get_e2e_device_keys_for_cs_api(
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
- rv = {}
+ rv: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
@@ -195,6 +218,10 @@ async def get_e2e_device_keys_and_signatures(
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
+ # We've only looked up cross-signatures for non-deleted devices with key
+ # data.
+ assert target_device_result is not None
+ assert target_device_result.keys is not None
target_device_signatures = target_device_result.keys.setdefault(
"signatures", {}
)
@@ -207,7 +234,11 @@ async def get_e2e_device_keys_and_signatures(
return result
def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ self,
+ txn: LoggingTransaction,
+ query_list: Collection[Tuple[str, str]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
@@ -263,7 +294,7 @@ def _get_e2e_device_keys_txn(
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
- self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+ self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
@@ -289,7 +320,17 @@ def _get_e2e_cross_signing_signatures_for_devices_txn(
)
txn.execute(signature_sql, signature_query_params)
- return txn.fetchall()
+ return cast(
+ List[
+ Tuple[
+ str,
+ str,
+ str,
+ str,
+ ]
+ ],
+ txn.fetchall(),
+ )
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
@@ -335,7 +376,7 @@ async def add_e2e_one_time_keys(
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn):
+ def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
@@ -346,15 +387,16 @@ def _add_e2e_one_time_keys(txn):
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- "key_id": key_id,
- "ts_added_ms": time_now,
- "key_json": json_bytes,
- }
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
@@ -375,7 +417,7 @@ async def count_e2e_one_time_keys(
A mapping from algorithm to number of keys for that algorithm.
"""
- def _count_e2e_one_time_keys(txn):
+ def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
@@ -421,7 +463,11 @@ async def set_e2e_fallback_keys(
)
def _set_e2e_fallback_keys_txn(
- self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ fallback_keys: JsonDict,
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@@ -483,7 +529,7 @@ async def get_e2e_unused_fallback_key_types(
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Returns a user's cross-signing key.
Args:
@@ -504,7 +550,7 @@ async def get_e2e_cross_signing_key(
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id):
+ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -517,7 +563,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id):
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -531,32 +577,35 @@ async def _get_bare_e2e_cross_signing_keys_bulk(
their user ID will map to None.
"""
- return await self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
+ # The `Optional` comes from the `@cachedList` decorator.
+ return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
- txn: Connection,
+ txn: LoggingTransaction,
user_ids: Iterable[str],
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_ids (list[str]): the users whose keys are being requested
+ txn: db connection
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, their user
- ID will not be in the dict.
+ Mapping from user ID to key type to key data.
+ If a user's cross-signing keys were not found, their user ID will not be in
+ the dict.
"""
- result = {}
+ result: Dict[str, Dict[str, JsonDict]] = {}
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
@@ -596,43 +645,48 @@ def _get_bare_e2e_cross_signing_keys_bulk_txn(
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
- user_info = result.setdefault(user_id, {})
- user_info[key_type] = key
+ user_keys = result.setdefault(user_id, {})
+ user_keys[key_type] = key
return result
def _get_e2e_cross_signing_signatures_txn(
self,
- txn: Connection,
- keys: Dict[str, Dict[str, dict]],
+ txn: LoggingTransaction,
+ keys: Dict[str, Optional[Dict[str, JsonDict]]],
from_user_id: str,
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- keys (dict[str, dict[str, dict]]): a map of user ID to key type to
- key data. This dict will be modified to add signatures.
- from_user_id (str): fetch the signatures made by this user
+ txn: db connection
+ keys: a map of user ID to key type to key data.
+ This dict will be modified to add signatures.
+ from_user_id: fetch the signatures made by this user
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. The return value will be the same as the keys argument,
- with the modifications included.
+ Mapping from user ID to key type to key data.
+ The return value will be the same as the keys argument, with the
+ modifications included.
"""
# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
- devices = {}
+ devices: Dict[Tuple[str, str], str] = {}
- for user_id, user_info in keys.items():
- if user_info is None:
+ for user_id, user_keys in keys.items():
+ if user_keys is None:
continue
- for key_type, key in user_info.items():
+ for key_type, key in user_keys.items():
device_id = None
for k in key["keys"].values():
device_id = k
+ # `key` ought to be a `CrossSigningKey`, whose .keys property is a
+ # dictionary with a single entry:
+ # "algorithm:base64_public_key": "base64_public_key"
+ # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
+ assert isinstance(device_id, str)
devices[(user_id, device_id)] = key_type
for batch in batch_iter(devices.keys(), size=100):
@@ -656,15 +710,20 @@ def _get_e2e_cross_signing_signatures_txn(
# and add the signatures to the appropriate keys
for row in rows:
- key_id = row["key_id"]
- target_user_id = row["target_user_id"]
- target_device_id = row["target_device_id"]
+ key_id: str = row["key_id"]
+ target_user_id: str = row["target_user_id"]
+ target_device_id: str = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
- user_info = keys[target_user_id] = keys[target_user_id].copy()
- target_user_key = user_info[key_type] = user_info[key_type].copy()
+ user_keys = keys[target_user_id]
+ # `user_keys` cannot be `None` because we only fetched signatures for
+ # users with keys
+ assert user_keys is not None
+ user_keys = keys[target_user_id] = user_keys.copy()
+
+ target_user_key = user_keys[key_type] = user_keys[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
@@ -683,7 +742,7 @@ def _get_e2e_cross_signing_signatures_txn(
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Dict[str, dict]]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -741,7 +800,9 @@ async def get_all_user_signature_changes_for_remotes(
if last_id == current_id:
return [], current_id, False
- def _get_all_user_signature_changes_for_remotes_txn(txn):
+ def _get_all_user_signature_changes_for_remotes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
@@ -785,7 +846,7 @@ async def claim_e2e_one_time_keys(
@trace
def _claim_e2e_one_time_key_simple(
- txn, user_id: str, device_id: str, algorithm: str
+ txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
@@ -825,7 +886,7 @@ def _claim_e2e_one_time_key_simple(
@trace
def _claim_e2e_one_time_key_returning(
- txn, user_id: str, device_id: str, algorithm: str
+ txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
@@ -860,7 +921,7 @@ def _claim_e2e_one_time_key_returning(
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
- results = {}
+ results: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
@@ -930,6 +991,18 @@ def _claim_e2e_one_time_key_returning(
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._cross_signing_id_gen = StreamIdGenerator(
+ db_conn, "e2e_cross_signing_keys", "stream_id"
+ )
+
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
@@ -937,7 +1010,7 @@ async def set_e2e_device_keys(
or the keys were already in the database.
"""
- def _set_e2e_device_keys_txn(txn):
+ def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
@@ -973,7 +1046,7 @@ def _set_e2e_device_keys_txn(txn):
)
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
- def delete_e2e_keys_by_device_txn(txn):
+ def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(
{
"message": "Deleting keys for device",
@@ -1012,17 +1085,24 @@ def delete_e2e_keys_by_device_txn(txn):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
+ def _set_e2e_cross_signing_key_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ key_type: str,
+ key: JsonDict,
+ stream_id: int,
+ ) -> None:
"""Set a user's cross-signing key.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user to set the signing key for
- key_type (str): the type of key that is being set: either 'master'
+ txn: db connection
+ user_id: the user to set the signing key for
+ key_type: the type of key that is being set: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- key (dict): the key data
- stream_id (int)
+ key: the key data
+ stream_id
"""
# the 'key' dict will look something like:
# {
@@ -1075,13 +1155,15 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id)
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- async def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, key: JsonDict
+ ) -> None:
"""Set a user's cross-signing key.
Args:
- user_id (str): the user to set the user-signing key for
- key_type (str): the type of cross-signing key to set
- key (dict): the key data
+ user_id: the user to set the user-signing key for
+ key_type: the type of cross-signing key to set
+ key: the key data
"""
async with self._cross_signing_id_gen.get_next() as stream_id:
@@ -1105,15 +1187,22 @@ async def store_e2e_cross_signing_signatures(
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
- [
- {
- "user_id": user_id,
- "key_id": item.signing_key_id,
- "target_user_id": item.target_user_id,
- "target_device_id": item.target_device_id,
- "signature": item.signature,
- }
+ keys=(
+ "user_id",
+ "key_id",
+ "target_user_id",
+ "target_device_id",
+ "signature",
+ ),
+ values=[
+ (
+ user_id,
+ item.signing_key_id,
+ item.target_user_id,
+ item.target_device_id,
+ item.signature,
+ )
for item in signatures
],
- "add_e2e_signing_key",
+ desc="add_e2e_signing_key",
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a4078538..a556f17dac15 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@
from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
@@ -62,7 +66,12 @@ def __init__(self, room_id: str):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -279,7 +288,7 @@ def _get_auth_chain_ids_txn(
new_front = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
- to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ to_fetch: List[str] = [] # Event IDs to fetch from DB
for event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
@@ -606,8 +615,8 @@ def _get_auth_chain_difference_txn(
# currently walking, either from cache or DB.
search, chunk = search[:-100], search[-100:]
- found = [] # Results found # type: List[Tuple[str, str, int]]
- to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ found: List[Tuple[str, str, int]] = [] # Results found
+ to_fetch: List[str] = [] # Event IDs to fetch from DB
for _, event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
@@ -1384,7 +1393,7 @@ async def prune_staged_events_in_room(
count = await self.db_pool.simple_select_one_onecol(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
- retcol="COALESCE(COUNT(*), 0)",
+ retcol="COUNT(*)",
desc="prune_staged_events_in_room_count",
)
@@ -1423,7 +1432,10 @@ async def prune_staged_events_in_room(
if room_version.event_format == EventFormatVersions.V1:
for prev_event_tuple in prev_events:
- if not isinstance(prev_event_tuple, list) or len(prev_events) != 2:
+ if (
+ not isinstance(prev_event_tuple, list)
+ or len(prev_event_tuple) != 2
+ ):
logger.info("Invalid prev_events for %s", event_id)
break
@@ -1476,9 +1488,7 @@ async def _get_stats_for_federation_staging(self):
"""Update the prometheus metrics for the inbound federation staging area."""
def _get_stats_for_federation_staging_txn(txn):
- txn.execute(
- "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
- )
+ txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
(count,) = txn.fetchone()
txn.execute(
@@ -1514,7 +1524,12 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920f6..b7c4c62222bd 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
import attr
-from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -30,29 +33,64 @@
logger = logging.getLogger(__name__)
-DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
-DEFAULT_HIGHLIGHT_ACTION = [
+DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [
+ "notify",
+ {"set_tweak": "highlight", "value": False},
+]
+DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
]
-class BasePushAction(TypedDict):
- event_id: str
- actions: List[Union[dict, str]]
-
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class HttpPushAction:
+ """
+ HttpPushAction instances include the information used to generate HTTP
+ requests to a push gateway.
+ """
-class HttpPushAction(BasePushAction):
+ event_id: str
room_id: str
stream_ordering: int
+ actions: List[Union[dict, str]]
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EmailPushAction(HttpPushAction):
+ """
+ EmailPushAction instances include the information used to render an email
+ push notification.
+ """
+
received_ts: Optional[int]
-def _serialize_action(actions, is_highlight):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPushAction(EmailPushAction):
+ """
+ UserPushAction instances include the necessary information to respond to
+ /notifications requests.
+ """
+
+ topological_ordering: int
+ highlight: bool
+ profile_tag: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class NotifCounts:
+ """
+ The per-user, per-room count of notifications. Used by sync and push.
+ """
+
+ notify_count: int
+ unread_count: int
+ highlight_count: int
+
+
+def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -70,7 +108,7 @@ def _serialize_action(actions, is_highlight):
return json_encoder.encode(actions)
-def _deserialize_action(actions, is_highlight):
+def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]:
"""Custom deserializer for actions. This allows us to "compress" common actions"""
if actions:
return db_to_json(actions)
@@ -82,12 +120,17 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
- self.stream_ordering_month_ago = None
- self.stream_ordering_day_ago = None
+ self.stream_ordering_month_ago: Optional[int] = None
+ self.stream_ordering_day_ago: Optional[int] = None
cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
@@ -111,7 +154,7 @@ async def get_unread_event_push_actions_by_room_for_user(
room_id: str,
user_id: str,
last_read_event_id: Optional[str],
- ) -> Dict[str, int]:
+ ) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -140,15 +183,15 @@ async def get_unread_event_push_actions_by_room_for_user(
def _get_unread_counts_by_receipt_txn(
self,
- txn,
- room_id,
- user_id,
- last_read_event_id,
- ):
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ last_read_event_id: Optional[str],
+ ) -> NotifCounts:
stream_ordering = None
if last_read_event_id is not None:
- stream_ordering = self.get_stream_id_for_event_txn(
+ stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined]
txn,
last_read_event_id,
allow_none=True,
@@ -166,13 +209,15 @@ def _get_unread_counts_by_receipt_txn(
retcol="event_id",
)
- stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined]
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
- def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
+ def _get_unread_counts_by_pos_txn(
+ self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ ) -> NotifCounts:
sql = (
"SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
@@ -210,16 +255,16 @@ def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# for this row.
unread_count += row[1]
- return {
- "notify_count": notif_count,
- "unread_count": unread_count,
- "highlight_count": highlight_count,
- }
+ return NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=highlight_count,
+ )
async def get_push_action_users_in_range(
- self, min_stream_ordering, max_stream_ordering
- ):
- def f(txn):
+ self, min_stream_ordering: int, max_stream_ordering: int
+ ) -> List[str]:
+ def f(txn: LoggingTransaction) -> List[str]:
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
" stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
@@ -227,8 +272,7 @@ def f(txn):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
- return ret
+ return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
async def get_unread_push_actions_for_user_in_range_for_http(
self,
@@ -254,7 +298,9 @@ async def get_unread_push_actions_for_user_in_range_for_http(
"""
# find rooms that have a read receipt in them and return the next
# push actions
- def get_after_receipt(txn):
+ def get_after_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
@@ -280,7 +326,7 @@ def get_after_receipt(txn):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@@ -289,7 +335,9 @@ def get_after_receipt(txn):
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
- def get_no_receipt(txn):
+ def get_no_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight "
@@ -309,19 +357,19 @@ def get_no_receipt(txn):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
notifs = [
- {
- "event_id": row[0],
- "room_id": row[1],
- "stream_ordering": row[2],
- "actions": _deserialize_action(row[3], row[4]),
- }
+ HttpPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[3], row[4]),
+ )
for row in after_read_receipt + no_read_receipt
]
@@ -329,7 +377,7 @@ def get_no_receipt(txn):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
- notifs.sort(key=lambda r: r["stream_ordering"])
+ notifs.sort(key=lambda r: r.stream_ordering)
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
@@ -359,7 +407,9 @@ async def get_unread_push_actions_for_user_in_range_for_email(
"""
# find rooms that have a read receipt in them and return the most recent
# push actions
- def get_after_receipt(txn):
+ def get_after_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts"
@@ -384,7 +434,7 @@ def get_after_receipt(txn):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@@ -393,7 +443,9 @@ def get_after_receipt(txn):
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
- def get_no_receipt(txn):
+ def get_no_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts"
@@ -413,7 +465,7 @@ def get_no_receipt(txn):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@@ -421,13 +473,13 @@ def get_no_receipt(txn):
# Make a list of dicts from the two sets of results.
notifs = [
- {
- "event_id": row[0],
- "room_id": row[1],
- "stream_ordering": row[2],
- "actions": _deserialize_action(row[3], row[4]),
- "received_ts": row[5],
- }
+ EmailPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[3], row[4]),
+ received_ts=row[5],
+ )
for row in after_read_receipt + no_read_receipt
]
@@ -435,7 +487,7 @@ def get_no_receipt(txn):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
- notifs.sort(key=lambda r: -(r["received_ts"] or 0))
+ notifs.sort(key=lambda r: -(r.received_ts or 0))
# Now return the first `limit`
return notifs[:limit]
@@ -456,7 +508,7 @@ async def get_if_maybe_push_in_range_for_user(
True if there may be push to process, False if there definitely isn't.
"""
- def _get_if_maybe_push_in_range_for_user_txn(txn):
+ def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool:
sql = """
SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ? AND notif = 1
@@ -490,19 +542,21 @@ async def add_push_actions_to_staging(
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
- def _gen_entry(user_id, actions):
+ def _gen_entry(
+ user_id: str, actions: List[Union[dict, str]]
+ ) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
- _serialize_action(actions, is_highlight), # actions column
+ _serialize_action(actions, bool(is_highlight)), # actions column
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
)
- def _add_push_actions_to_staging_txn(txn):
+ def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
# We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
@@ -530,12 +584,11 @@ async def remove_push_actions_from_staging(self, event_id: str) -> None:
"""
try:
- res = await self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -588,7 +641,9 @@ async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
)
@staticmethod
- def _find_first_stream_ordering_after_ts_txn(txn, ts):
+ def _find_first_stream_ordering_after_ts_txn(
+ txn: LoggingTransaction, ts: int
+ ) -> int:
"""
Find the stream_ordering of the first event that was received on or
after a given timestamp. This is relatively slow as there is no index
@@ -600,14 +655,14 @@ def _find_first_stream_ordering_after_ts_txn(txn, ts):
stream_ordering
Args:
- txn (twisted.enterprise.adbapi.Transaction):
- ts (int): timestamp to search for
+ txn:
+ ts: timestamp to search for
Returns:
- int: stream ordering
+ The stream ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
- max_stream_ordering = txn.fetchone()[0]
+ max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_stream_ordering is None:
return 0
@@ -663,8 +718,10 @@ def _find_first_stream_ordering_after_ts_txn(txn, ts):
return range_end
- async def get_time_of_last_push_action_before(self, stream_ordering):
- def f(txn):
+ async def get_time_of_last_push_action_before(
+ self, stream_ordering: int
+ ) -> Optional[int]:
+ def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
@@ -674,7 +731,7 @@ def f(txn):
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
- return txn.fetchone()
+ return cast(Optional[Tuple[int]], txn.fetchone())
result = await self.db_pool.runInteraction(
"get_time_of_last_push_action_before", f
@@ -682,7 +739,7 @@ def f(txn):
return result[0] if result else None
@wrap_as_background_process("rotate_notifs")
- async def _rotate_notifs(self):
+ async def _rotate_notifs(self) -> None:
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
@@ -700,7 +757,7 @@ async def _rotate_notifs(self):
finally:
self._doing_notif_rotation = False
- def _rotate_notifs_txn(self, txn):
+ def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
"""Archives older notifications into event_push_summary. Returns whether
the archiving process has caught up or not.
"""
@@ -725,6 +782,7 @@ def _rotate_notifs_txn(self, txn):
stream_row = txn.fetchone()
if stream_row:
(offset_stream_ordering,) = stream_row
+ assert self.stream_ordering_day_ago is not None
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
@@ -740,7 +798,9 @@ def _rotate_notifs_txn(self, txn):
# We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up
- def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+ def _rotate_notifs_before_txn(
+ self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+ ) -> None:
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
@@ -815,14 +875,21 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
self.db_pool.simple_insert_many_txn(
txn,
table="event_push_summary",
+ keys=(
+ "user_id",
+ "room_id",
+ "notif_count",
+ "unread_count",
+ "stream_ordering",
+ ),
values=[
- {
- "user_id": user_id,
- "room_id": room_id,
- "notif_count": summary.notif_count,
- "unread_count": summary.unread_count,
- "stream_ordering": summary.stream_ordering,
- }
+ (
+ user_id,
+ room_id,
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ )
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is None
],
@@ -861,8 +928,8 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
)
def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
+ self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ ) -> None:
"""
Purges old push actions for a user and room before a given
stream_ordering.
@@ -910,7 +977,12 @@ def _remove_old_push_actions_before_txn(
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -929,9 +1001,15 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
)
async def get_push_actions_for_user(
- self, user_id, before=None, limit=50, only_highlight=False
- ):
- def f(txn):
+ self,
+ user_id: str,
+ before: Optional[str] = None,
+ limit: int = 50,
+ only_highlight: bool = False,
+ ) -> List[UserPushAction]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
before_clause = ""
if before:
before_clause = "AND epa.stream_ordering < ?"
@@ -958,32 +1036,44 @@ def f(txn):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
+ )
push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
- for pa in push_actions:
- pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- return push_actions
+ return [
+ UserPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[4], row[5]),
+ received_ts=row[7],
+ topological_ordering=row[3],
+ highlight=row[5],
+ profile_tag=row[6],
+ )
+ for row in push_actions
+ ]
-def _action_has_highlight(actions):
+def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
for action in actions:
- try:
- if action.get("set_tweak", None) == "highlight":
- return action.get("value", True)
- except AttributeError:
- pass
+ if not isinstance(action, dict):
+ continue
+
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
return False
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
"""
- unread_count = attr.ib(type=int)
- stream_ordering = attr.ib(type=int)
- old_user_id = attr.ib(type=str)
- notif_count = attr.ib(type=int)
+ unread_count: int
+ stream_ordering: int
+ old_user_id: str
+ notif_count: int
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612eab7..1ae1ebe10879 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@
from typing import (
TYPE_CHECKING,
Any,
+ Collection,
Dict,
Generator,
Iterable,
@@ -38,12 +39,14 @@
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.types import Connection
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
@@ -65,7 +68,7 @@
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -76,9 +79,9 @@ class DeltaState:
should e.g. be removed from `current_state_events` table.
"""
- to_delete = attr.ib(type=List[Tuple[str, str]])
- to_insert = attr.ib(type=StateMap[str])
- no_longer_in_room = attr.ib(type=bool, default=False)
+ to_delete: List[Tuple[str, str]]
+ to_insert: StateMap[str]
+ no_longer_in_room: bool = False
class PersistEventsStore:
@@ -94,7 +97,7 @@ def __init__(
hs: "HomeServer",
db: DatabasePool,
main_data_store: "DataStore",
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
):
self.hs = hs
self.db_pool = db
@@ -324,7 +327,6 @@ def _get_prevs_before_rejected_txn(txn, batch):
return existing_prevs
- @log_function
def _persist_events_txn(
self,
txn: LoggingTransaction,
@@ -438,12 +440,9 @@ def _persist_event_auth_chain_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
+ keys=("event_id", "room_id", "auth_id"),
values=[
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- }
+ (event.event_id, event.room_id, auth_id)
for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
@@ -671,8 +670,9 @@ def _add_chain_cover_index(
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
+ keys=("event_id", "chain_id", "sequence_number"),
values=[
- {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+ (event_id, c_id, seq)
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
@@ -778,13 +778,14 @@ def _add_chain_cover_index(
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
+ keys=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
values=[
- {
- "origin_chain_id": source_id,
- "origin_sequence_number": source_seq,
- "target_chain_id": target_id,
- "target_sequence_number": target_seq,
- }
+ (source_id, source_seq, target_id, target_seq)
for (
source_id,
source_seq,
@@ -939,20 +940,28 @@ def _persist_transaction_ids_txn(
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
to_insert.append(
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "user_id": event.sender,
- "token_id": token_id,
- "txn_id": txn_id,
- "inserted_ts": self._clock.time_msec(),
- }
+ (
+ event.event_id,
+ event.room_id,
+ event.sender,
+ token_id,
+ txn_id,
+ self._clock.time_msec(),
+ )
)
if to_insert:
self.db_pool.simple_insert_many_txn(
txn,
table="event_txn_id",
+ keys=(
+ "event_id",
+ "room_id",
+ "user_id",
+ "token_id",
+ "txn_id",
+ "inserted_ts",
+ ),
values=to_insert,
)
@@ -1157,8 +1166,9 @@ def _update_forward_extremities_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
+ keys=("event_id", "room_id"),
values=[
- {"event_id": ev_id, "room_id": room_id}
+ (ev_id, room_id)
for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem
],
@@ -1170,12 +1180,9 @@ def _update_forward_extremities_txn(
self.db_pool.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
+ keys=("room_id", "event_id", "stream_ordering"),
values=[
- {
- "room_id": room_id,
- "event_id": event_id,
- "stream_ordering": max_stream_order,
- }
+ (room_id, event_id, max_stream_order)
for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem
],
@@ -1247,20 +1254,22 @@ def _update_room_depths_txn(
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
- def _update_outliers_txn(self, txn, events_and_contexts):
+ def _update_outliers_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Update any outliers with new event info.
- This turns outliers into ex-outliers (unless the new event was
- rejected).
+ This turns outliers into ex-outliers (unless the new event was rejected), and
+ also removes any other events we have already seen from the list.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
+ txn: db connection
+ events_and_contexts: events we are persisting
Returns:
- list[(EventBase, EventContext)] new list, without events which
- are already in the events table.
+ new list, without events which are already in the events table.
"""
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)"
@@ -1268,7 +1277,9 @@ def _update_outliers_txn(self, txn, events_and_contexts):
[event.event_id for event, _ in events_and_contexts],
)
- have_persisted = {event_id: outlier for event_id, outlier in txn}
+ have_persisted: Dict[str, bool] = {
+ event_id: outlier for event_id, outlier in txn
+ }
to_remove = set()
for event, context in events_and_contexts:
@@ -1278,15 +1289,22 @@ def _update_outliers_txn(self, txn, events_and_contexts):
to_remove.add(event)
if context.rejected:
- # If the event is rejected then we don't care if the event
- # was an outlier or not.
+ # If the incoming event is rejected then we don't care if the event
+ # was an outlier or not - what we have is at least as good.
continue
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
- # an outlier in the database. We now have some state at that
+ # an outlier in the database. We now have some state at that event
# so we need to update the state_groups table with that state.
+ #
+ # Note that we do not update the stream_ordering of the event in this
+ # scenario. XXX: does this cause bugs? It will mean we won't send such
+ # events down /sync. In general they will be historical events, so that
+ # doesn't matter too much, but that is not always the case.
+
+ logger.info("Updating state for ex-outlier event %s", event.event_id)
# insert into event_to_state_groups.
try:
@@ -1319,14 +1337,13 @@ def _update_outliers_txn(self, txn, events_and_contexts):
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- def _store_event_txn(self, txn, events_and_contexts):
+ def _store_event_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+ ) -> None:
"""Insert new events into the event, event_json, redaction and
state_events tables.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
"""
if not events_and_contexts:
@@ -1342,43 +1359,55 @@ def event_dict(event):
self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
- values=[
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": json_encoder.encode(
- event.internal_metadata.get_dict()
- ),
- "json": json_encoder.encode(event_dict(event)),
- "format_version": event.format_version,
- }
+ keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
+ values=(
+ (
+ event.event_id,
+ event.room_id,
+ json_encoder.encode(event.internal_metadata.get_dict()),
+ json_encoder.encode(event_dict(event)),
+ event.format_version,
+ )
for event, _ in events_and_contexts
- ],
+ ),
)
self.db_pool.simple_insert_many_txn(
txn,
table="events",
- values=[
- {
- "instance_name": self._instance_name,
- "stream_ordering": event.internal_metadata.stream_ordering,
- "topological_ordering": event.depth,
- "depth": event.depth,
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "processed": True,
- "outlier": event.internal_metadata.is_outlier(),
- "origin_server_ts": int(event.origin_server_ts),
- "received_ts": self._clock.time_msec(),
- "sender": event.sender,
- "contains_url": (
- "url" in event.content and isinstance(event.content["url"], str)
- ),
- }
+ keys=(
+ "instance_name",
+ "stream_ordering",
+ "topological_ordering",
+ "depth",
+ "event_id",
+ "room_id",
+ "type",
+ "processed",
+ "outlier",
+ "origin_server_ts",
+ "received_ts",
+ "sender",
+ "contains_url",
+ ),
+ values=(
+ (
+ self._instance_name,
+ event.internal_metadata.stream_ordering,
+ event.depth, # topological_ordering
+ event.depth, # depth
+ event.event_id,
+ event.room_id,
+ event.type,
+ True, # processed
+ event.internal_metadata.is_outlier(),
+ int(event.origin_server_ts),
+ self._clock.time_msec(),
+ event.sender,
+ "url" in event.content and isinstance(event.content["url"], str),
+ )
for event, _ in events_and_contexts
- ],
+ ),
)
# If we're persisting an unredacted event we go and ensure
@@ -1397,27 +1426,15 @@ def event_dict(event):
)
txn.execute(sql + clause, [False] + args)
- state_events_and_contexts = [
- ec for ec in events_and_contexts if ec[0].is_state()
- ]
-
- state_values = []
- for event, _ in state_events_and_contexts:
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- state_values.append(vals)
-
self.db_pool.simple_insert_many_txn(
- txn, table="state_events", values=state_values
+ txn,
+ table="state_events",
+ keys=("event_id", "room_id", "type", "state_key"),
+ values=(
+ (event.event_id, event.room_id, event.type, event.state_key)
+ for event, _ in events_and_contexts
+ if event.is_state()
+ ),
)
def _store_rejected_events_txn(self, txn, events_and_contexts):
@@ -1619,14 +1636,9 @@ def insert_labels_for_event_txn(
return self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
+ keys=("event_id", "label", "room_id", "topological_ordering"),
values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": room_id,
- "topological_ordering": topological_ordering,
- }
- for label in labels
+ (event_id, label, room_id, topological_ordering) for label in labels
],
)
@@ -1654,16 +1666,13 @@ def _store_event_reference_hashes_txn(self, txn, events):
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append(
- {
- "event_id": event.event_id,
- "algorithm": ref_alg,
- "hash": memoryview(ref_hash_bytes),
- }
- )
+ vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
self.db_pool.simple_insert_many_txn(
- txn, table="event_reference_hashes", values=vals
+ txn,
+ table="event_reference_hashes",
+ keys=("event_id", "algorithm", "hash"),
+ values=vals,
)
def _store_room_members_txn(
@@ -1686,18 +1695,25 @@ def non_null_str_or_none(val: Any) -> Optional[str]:
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
+ keys=(
+ "event_id",
+ "user_id",
+ "sender",
+ "room_id",
+ "membership",
+ "display_name",
+ "avatar_url",
+ ),
values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": non_null_str_or_none(
- event.content.get("displayname")
- ),
- "avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
- }
+ (
+ event.event_id,
+ event.state_key,
+ event.user_id,
+ event.room_id,
+ event.membership,
+ non_null_str_or_none(event.content.get("displayname")),
+ non_null_str_or_none(event.content.get("avatar_url")),
+ )
for event in events
],
)
@@ -1780,10 +1796,21 @@ def _handle_event_relations(
)
if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ txn.call_after(
+ self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+ )
if rel_type == RelationTypes.THREAD:
- txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+ txn.call_after(
+ self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+ )
+ # It should be safe to only invalidate the cache if the user has not
+ # previously participated in the thread, but that's difficult (and
+ # potentially error-prone) so it is always invalidated.
+ txn.call_after(
+ self.store.get_thread_participated.invalidate,
+ (parent_id, event.room_id, event.sender),
+ )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
@@ -1969,14 +1996,17 @@ def _store_retention_policy_for_room_txn(self, txn, event):
txn, self.store.get_retention_policy_for_room, (event.room_id,)
)
- def store_event_search_txn(self, txn, event, key, value):
+ def store_event_search_txn(
+ self, txn: LoggingTransaction, event: EventBase, key: str, value: str
+ ) -> None:
"""Add event to the search table
Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
+ txn: The database transaction.
+ event: The event being added to the search table.
+ key: A key describing the search value (one of "content.name",
+ "content.topic", or "content.body")
+ value: The value from the event's content.
"""
self.store.store_search_entries_txn(
txn,
@@ -2153,13 +2183,9 @@ def _handle_mult_prev_events(self, txn, events):
self.db_pool.simple_insert_many_txn(
txn,
table="event_edges",
+ keys=("event_id", "prev_event_id", "room_id", "is_state"),
values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
+ (ev.event_id, e_id, ev.room_id, False)
for ev in events
for e_id in ev.prev_event_ids()
],
@@ -2216,17 +2242,17 @@ def _update_backward_extremeties(self, txn, events):
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _LinkMap:
"""A helper type for tracking links between chains."""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
- maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+ maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict)
# Stores the links that have been added (with new set to true), as tuples of
# `(source chain ID, source sequence no, target chain ID, target sequence no.)`
- additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+ additions: Set[Tuple[int, int, int, int]] = attr.Factory(set)
def add_link(
self,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f3a..d5f005966597 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
import attr
@@ -23,6 +23,7 @@
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@@ -64,26 +65,31 @@ class _BackgroundUpdates:
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""
# The last room_id/depth/stream processed.
- room_id = attr.ib(type=str)
- depth = attr.ib(type=int)
- stream = attr.ib(type=int)
+ room_id: str
+ depth: int
+ stream: int
# Number of rows processed
- processed_count = attr.ib(type=int)
+ processed_count: int
# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
- finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+ finished_room_map: Dict[str, Tuple[int, int]]
class EventsBackgroundUpdatesStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
@@ -234,12 +240,14 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
################################################################################
- async def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
@@ -301,12 +309,14 @@ def reindex_txn(txn):
return result
- async def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
@@ -375,7 +385,9 @@ def reindex_search_txn(txn):
return result
- async def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to clean out extremities that should have been
deleted previously.
@@ -396,12 +408,12 @@ async def _cleanup_extremities_bg_update(self, progress, batch_size):
# have any descendants, but if they do then we should delete those
# extremities.
- def _cleanup_extremities_bg_update_txn(txn):
+ def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
# The set of extremity event IDs that we're checking this round
original_set = set()
- # A dict[str, set[str]] of event ID to their prev events.
- graph = {}
+ # A dict[str, Set[str]] of event ID to their prev events.
+ graph: Dict[str, Set[str]] = {}
# The set of descendants of the original set that are not rejected
# nor soft-failed. Ancestors of these events should be removed
@@ -530,7 +542,7 @@ def _cleanup_extremities_bg_update_txn(txn):
room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate, (room_id,)
+ self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
)
self.db_pool.simple_delete_many_txn(
@@ -552,7 +564,7 @@ def _cleanup_extremities_bg_update_txn(txn):
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)
- def _drop_table_txn(txn):
+ def _drop_table_txn(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE _extremities_to_check")
await self.db_pool.runInteraction(
@@ -561,11 +573,11 @@ def _drop_table_txn(txn):
return num_handled
- async def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
"""Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")
- def _redactions_received_ts_txn(txn):
+ def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
@@ -616,10 +628,12 @@ def _redactions_received_ts_txn(txn):
return count
- async def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Undoes hex encoded censored redacted event JSON."""
- def _event_fix_redactions_bytes_txn(txn):
+ def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
# This update is quite fast due to new index.
txn.execute(
"""
@@ -644,11 +658,11 @@ def _event_fix_redactions_bytes_txn(txn):
return 1
- async def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
- def _event_store_labels_txn(txn):
+ def _event_store_labels_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT event_id, json FROM event_json
@@ -670,13 +684,14 @@ def _event_store_labels_txn(txn):
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
+ keys=("event_id", "label", "room_id", "topological_ordering"),
values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": event_json["room_id"],
- "topological_ordering": event_json["depth"],
- }
+ (
+ event_id,
+ label,
+ event_json["room_id"],
+ event_json["depth"],
+ )
for label in event_json["content"].get(
EventContentFields.LABELS, []
)
@@ -748,7 +763,10 @@ def get_rejected_events(
),
)
- return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
+ return cast(
+ List[Tuple[str, str, JsonDict, bool, bool]],
+ [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
+ )
results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events
@@ -786,29 +804,19 @@ def get_rejected_events(
if not has_state:
state_events.append(
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
+ (event.event_id, event.room_id, event.type, event.state_key)
)
if not has_event_auth:
# Old, dodgy, events may have duplicate auth events, which we
# need to deduplicate as we have a unique constraint.
for auth_id in set(event.auth_event_ids()):
- auth_events.append(
- {
- "room_id": event.room_id,
- "event_id": event.event_id,
- "auth_id": auth_id,
- }
- )
+ auth_events.append((event.event_id, event.room_id, auth_id))
if state_events:
await self.db_pool.simple_insert_many(
table="state_events",
+ keys=("event_id", "room_id", "type", "state_key"),
values=state_events,
desc="_rejected_events_metadata_state_events",
)
@@ -816,6 +824,7 @@ def get_rejected_events(
if auth_events:
await self.db_pool.simple_insert_many(
table="event_auth",
+ keys=("event_id", "room_id", "auth_id"),
values=auth_events,
desc="_rejected_events_metadata_event_auth",
)
@@ -906,7 +915,7 @@ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
def _calculate_chain_cover_txn(
self,
- txn: Cursor,
+ txn: LoggingTransaction,
last_room_id: str,
last_depth: int,
last_stream: int,
@@ -1017,10 +1026,10 @@ def _calculate_chain_cover_txn(
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
- self.event_chain_id_gen,
+ self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
- event_to_auth_chain,
+ cast(Dict[str, Sequence[str]], event_to_auth_chain),
)
return _CalculateChainCover(
@@ -1040,7 +1049,7 @@ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> in
"""
current_event_id = progress.get("current_event_id", "")
- def purged_chain_cover_txn(txn) -> int:
+ def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
# The event ID from events will be null if the chain ID / sequence
# number points to a purged event.
sql = """
@@ -1175,14 +1184,14 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
- self._invalidate_cache_and_stream(
- txn, self.get_relations_for_event, cache_tuple
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
+ txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
- self._invalidate_cache_and_stream(
- txn, self.get_aggregation_groups_for_event, cache_tuple
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
+ txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
)
- self._invalidate_cache_and_stream(
- txn, self.get_thread_summary, cache_tuple
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
+ txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
if results:
@@ -1214,7 +1223,7 @@ async def _background_populate_stream_ordering2(
"""
batch_size = max(batch_size, 1)
- def process(txn: Cursor) -> int:
+ def process(txn: LoggingTransaction) -> int:
last_stream = progress.get("last_stream", -(1 << 31))
txn.execute(
"""
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a6f..8d4287045a8b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1383,10 +1383,6 @@ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
return {"v1": complexity_v1}
- def get_current_events_token(self) -> int:
- """The current maximum token that events have reached"""
- return self._stream_id_gen.get_current_token()
-
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cf842803bcd6..cb9ee08fa8e8 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Union
+from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
@@ -63,7 +63,7 @@ def _do_txn(txn: LoggingTransaction) -> int:
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
- max_id = txn.fetchone()[0] # type: ignore[index]
+ max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_id is None:
filter_id = 0
else:
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0ddb6..3f6086050bb2 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
database.updates.register_background_index_update(
update_name="local_group_updates_index",
index_name="local_group_updates_stream_id_index",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb2681..bedacaf0d745 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.types import Connection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
`last_renewed_ts` column with the current time.
"""
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._reactor = hs.get_reactor()
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1b076683f762..cbba356b4a98 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -23,6 +23,7 @@
Optional,
Tuple,
Union,
+ cast,
)
from synapse.storage._base import SQLBaseStore
@@ -220,7 +221,7 @@ def get_local_media_by_user_paginate_txn(
WHERE user_id = ?
"""
txn.execute(sql, args)
- count = txn.fetchone()[0] # type: ignore[index]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4f2..1480a0f04829 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
@@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
@@ -100,7 +105,7 @@ async def count_daily_e2ee_messages(self):
def _count_messages(txn):
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
@@ -117,7 +122,7 @@ def _count_messages(txn):
like_clause = "%:" + self.hs.hostname
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND sender LIKE ?
AND stream_ordering > ?
@@ -134,7 +139,7 @@ def _count_messages(txn):
async def count_daily_active_e2ee_rooms(self):
def _count(txn):
sql = """
- SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
@@ -156,7 +161,7 @@ async def count_daily_messages(self):
def _count_messages(txn):
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
@@ -173,7 +178,7 @@ def _count_messages(txn):
like_clause = "%:" + self.hs.hostname
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND sender LIKE ?
AND stream_ordering > ?
@@ -190,7 +195,7 @@ def _count_messages(txn):
async def count_daily_active_rooms(self):
def _count(txn):
sql = """
- SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
@@ -226,7 +231,7 @@ def _count_users(self, txn, time_from):
Returns number of users seen in the past time_from period
"""
sql = """
- SELECT COALESCE(count(*), 0) FROM (
+ SELECT COUNT(*) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
@@ -253,7 +258,7 @@ def _count_r30_users(txn):
thirty_days_ago_in_secs = now - thirty_days_in_secs
sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT platform, COUNT(*) FROM (
SELECT
users.name, platform, users.creation_ts * 1000,
MAX(uip.last_seen)
@@ -291,7 +296,7 @@ def _count_r30_users(txn):
results[row[0]] = row[1]
sql = """
- SELECT COALESCE(count(*), 0) FROM (
+ SELECT COUNT(*) FROM (
SELECT users.name, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b5284e4f6783..8f09dd8e8751 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,8 +16,13 @@
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ make_in_list_sql_clause,
+)
from synapse.util.caches.descriptors import cached
+from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,7 +35,12 @@
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -49,7 +59,7 @@ async def get_monthly_active_count(self) -> int:
def _count_users(txn):
# Exclude app service users
sql = """
- SELECT COALESCE(count(*), 0)
+ SELECT COUNT(*)
FROM monthly_active_users
LEFT JOIN users
ON monthly_active_users.user_id=users.name
@@ -76,7 +86,7 @@ async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
def _count_users_by_service(txn):
sql = """
- SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+ SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
LEFT JOIN users ON monthly_active_users.user_id=users.name
GROUP BY appservice_id;
@@ -103,7 +113,7 @@ async def get_registered_reserved_users(self) -> List[str]:
: self.hs.config.server.max_mau_value
]:
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- tp["medium"], tp["address"]
+ tp["medium"], canonicalise_email(tp["address"])
)
if user_id:
users.append(user_id)
@@ -212,7 +222,12 @@ def _reap_users(txn, reserved_users):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb4606..4f05811a77eb 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@@ -129,18 +129,29 @@ def _update_presence_txn(self, txn, stream_orderings, presence_states):
self.db_pool.simple_insert_many_txn(
txn,
table="presence_stream",
+ keys=(
+ "stream_id",
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ "instance_name",
+ ),
values=[
- {
- "stream_id": stream_id,
- "user_id": state.user_id,
- "state": state.state,
- "last_active_ts": state.last_active_ts,
- "last_federation_update_ts": state.last_federation_update_ts,
- "last_user_sync_ts": state.last_user_sync_ts,
- "status_msg": state.status_msg,
- "currently_active": state.currently_active,
- "instance_name": self._instance_name,
- }
+ (
+ stream_id,
+ state.user_id,
+ state.state,
+ state.last_active_ts,
+ state.last_federation_update_ts,
+ state.last_user_sync_ts,
+ state.status_msg,
+ state.currently_active,
+ self._instance_name,
+ )
for stream_id, state in zip(stream_orderings, presence_states)
],
)
@@ -269,6 +280,7 @@ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
"""
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
+ presence_stream_id = self._presence_id_gen.get_current_token()
await self.db_pool.simple_upsert_many(
table="users_to_send_full_presence_to",
key_names=("user_id",),
@@ -279,9 +291,7 @@ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
# devices at different times, each device will receive full presence once - when
# the presence stream ID in their sync token is less than the one in the table
# for their user ID.
- value_values=(
- (self._presence_id_gen.get_current_token(),) for _ in user_ids
- ),
+ value_values=[(presence_stream_id,) for _ in user_ids],
desc="add_users_to_send_full_presence_to",
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395c3..e01c94930aed 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -81,7 +81,12 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b73ce53c9156..cf64cd63a46f 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -22,7 +22,7 @@
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -196,27 +196,6 @@ async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator
raise NotImplementedError()
- @cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- )
- async def get_if_users_have_pushers(
- self, user_ids: Iterable[str]
- ) -> Dict[str, bool]:
- rows = await self.db_pool.simple_select_many_batch(
- table="pushers",
- column="user_name",
- iterable=user_ids,
- retcols=["user_name"],
- desc="get_if_users_have_pushers",
- )
-
- result = {user_id: False for user_id in user_ids}
- result.update({r["user_name"]: True for r in rows})
-
- return result
-
async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
) -> None:
@@ -515,7 +494,7 @@ async def add_pusher(
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream, # type: ignore
+ self._invalidate_cache_and_stream, # type: ignore[attr-defined]
self.get_if_user_has_pusher,
(user_id,),
)
@@ -524,7 +503,7 @@ async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream( # type: ignore
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
@@ -569,7 +548,7 @@ async def delete_all_pushers_for_user(self, user_id: str) -> None:
pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn, stream_ids):
- self._invalidate_cache_and_stream( # type: ignore
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
@@ -582,13 +561,9 @@ def delete_pushers_txn(txn, stream_ids):
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
+ keys=("stream_id", "app_id", "pushkey", "user_id"),
values=[
- {
- "stream_id": stream_id,
- "app_id": pusher.app_id,
- "pushkey": pusher.pushkey,
- "user_id": user_id,
- }
+ (stream_id, pusher.app_id, pusher.pushkey, user_id)
for stream_id, pusher in zip(stream_ids, pushers)
],
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdbdd..bf0b903af2fc 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,29 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from twisted.internet import defer
+from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@@ -36,7 +51,12 @@
class ReceiptsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
@@ -78,17 +98,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- def get_max_receipt_stream_id(self):
- """Get the current max stream ID for receipts stream
-
- Returns:
- int
- """
+ def get_max_receipt_stream_id(self) -> int:
+ """Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@cached()
- async def get_users_with_read_receipts_in_room(self, room_id):
- receipts = await self.get_receipts_for_room(room_id, "m.read")
+ async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+ receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -119,7 +135,9 @@ async def get_last_receipt_event_id_for_user(
)
@cached(num_args=2)
- async def get_receipts_for_user(self, user_id, receipt_type):
+ async def get_receipts_for_user(
+ self, user_id: str, receipt_type: str
+ ) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +147,10 @@ async def get_receipts_for_user(self, user_id, receipt_type):
return {row["room_id"]: row["event_id"] for row in rows}
- async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
- def f(txn):
+ async def get_receipts_for_user_with_orderings(
+ self, user_id: str, receipt_type: str
+ ) -> JsonDict:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
@@ -209,10 +229,10 @@ async def get_linearized_receipts_for_room(
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> List[JsonDict]:
"""See get_linearized_receipts_for_room"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +270,13 @@ def f(txn):
list_name="room_ids",
num_args=3,
)
- async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(
+ self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, List[JsonDict]]:
if not room_ids:
return {}
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -323,7 +345,7 @@ async def get_linearized_receipts_for_all_rooms(
A dictionary of roomids to a list of receipts.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -379,7 +401,7 @@ async def get_users_sent_receipts_between(
if last_id == current_id:
return defer.succeed([])
- def _get_users_sent_receipts_between_txn(txn):
+ def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +441,9 @@ async def get_all_updated_receipts(
if last_id == current_id:
return [], current_id, False
- def get_all_updated_receipts_txn(txn):
+ def get_all_updated_receipts_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized
@@ -446,8 +470,8 @@ def get_all_updated_receipts_txn(txn):
def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
- ):
- if receipt_type != "m.read":
+ ) -> None:
+ if receipt_type != ReceiptTypes.READ:
return
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +485,9 @@ def _invalidate_get_users_with_receipts_in_room(
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
- def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ def invalidate_caches_for_receipt(
+ self, room_id: str, receipt_type: str, user_id: str
+ ) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +508,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_id: str,
+ data: JsonDict,
+ stream_id: int,
+ ) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR
- Returns: int|None
+ Returns:
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
@@ -550,7 +583,7 @@ def insert_linearized_receipt_txn(
lock=False,
)
- if receipt_type == "m.read" and stream_ordering is not None:
+ if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -580,7 +613,7 @@ async def insert_receipt(
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
- def graph_to_linear(txn):
+ def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
@@ -634,11 +667,16 @@ def graph_to_linear(txn):
return stream_id, max_persisted_id
async def insert_graph_receipt(
- self, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -649,8 +687,14 @@ async def insert_graph_receipt(
)
def insert_graph_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf0691646..aac94fa46444 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
import logging
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
@@ -51,7 +51,7 @@ class ExternalIDReuseException(Exception):
pass
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@@ -69,14 +69,14 @@ class TokenLookupResult:
cached.
"""
- user_id = attr.ib(type=str)
- is_guest = attr.ib(type=bool, default=False)
- shadow_banned = attr.ib(type=bool, default=False)
- token_id = attr.ib(type=Optional[int], default=None)
- device_id = attr.ib(type=Optional[str], default=None)
- valid_until_ms = attr.ib(type=Optional[int], default=None)
- token_owner = attr.ib(type=str)
- token_used = attr.ib(type=bool, default=False)
+ user_id: str
+ is_guest: bool = False
+ shadow_banned: bool = False
+ token_id: Optional[int] = None
+ device_id: Optional[str] = None
+ valid_until_ms: Optional[int] = None
+ token_owner: str = attr.ib()
+ token_used: bool = False
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
@@ -794,7 +794,7 @@ def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
- SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+ SELECT user_type, COUNT(*) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
@@ -819,7 +819,7 @@ async def count_nonbridged_users(self):
def _count_users(txn):
txn.execute(
"""
- SELECT COALESCE(COUNT(*), 0) FROM users
+ SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL
"""
)
@@ -856,7 +856,8 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s
Args:
medium: threepid medium e.g. email
- address: threepid address e.g. me@example.com
+ address: threepid address e.g. me@example.com. This must already be
+ in canonical form.
Returns:
The user ID or None if no user id/threepid mapping exists
@@ -1356,12 +1357,15 @@ def _use_registration_token_txn(txn):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
- res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
- txn,
- "registration_tokens",
- keyvalues={"token": token},
- retcols=["pending", "completed"],
- ) # type: ignore
+ res = cast(
+ Dict[str, Any],
+ self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ),
+ )
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07bb..2cb5d06c1352 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
+from frozendict import frozendict
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -29,14 +45,29 @@
)
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+ self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
@cached(tree=True)
async def get_relations_for_event(
self,
event_id: str,
+ room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
@@ -49,6 +80,7 @@ async def get_relations_for_event(
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +95,8 @@ async def get_relations_for_event(
the form `{"event_id": "..."}`.
"""
- where_clause = ["relates_to_id = ?"]
- where_args: List[Union[str, int]] = [event_id]
+ where_clause = ["relates_to_id = ?", "room_id = ?"]
+ where_args: List[Union[str, int]] = [event_id, room_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -199,6 +231,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool:
async def get_aggregation_groups_for_event(
self,
event_id: str,
+ room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
@@ -213,6 +246,7 @@ async def get_aggregation_groups_for_event(
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +259,12 @@ async def get_aggregation_groups_for_event(
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+ where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+ where_args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
if event_type:
where_clause.append("type = ?")
@@ -288,7 +326,9 @@ def _get_aggregation_groups_for_event_txn(
)
@cached()
- async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+ async def get_applicable_edit(
+ self, event_id: str, room_id: str
+ ) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
@@ -296,6 +336,7 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
Args:
event_id: The original event ID
+ room_id: The original event's room ID
Returns:
The most recent edit, if any.
@@ -317,13 +358,14 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
WHERE
relates_to_id = ?
AND relation_type = ?
+ AND edit.room_id = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone()
if row:
return row[0]
@@ -340,13 +382,13 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
@cached()
async def get_thread_summary(
- self, event_id: str
+ self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
- """Get the number of threaded replies, the senders of those replies, and
- the latest reply (if any) for the given event.
+ """Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
- event_id: The original event ID
+ event_id: Summarize the thread related to this event ID.
+ room_id: The room the event belongs to.
Returns:
The number of items in the thread and the most recent response, if any.
@@ -355,7 +397,7 @@ async def get_thread_summary(
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
- # Fetch the count of threaded events and the latest event ID.
+ # Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
@@ -363,27 +405,31 @@ def _get_thread_summary_txn(
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
latest_event_id = row[0]
+ # Fetch the number of threaded replies.
sql = """
- SELECT COALESCE(COUNT(event_id), 0)
+ SELECT COUNT(event_id)
FROM event_relations
+ INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
- count = txn.fetchone()[0] # type: ignore[index]
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
+ count = cast(Tuple[int], txn.fetchone())[0]
return count, latest_event_id
@@ -397,6 +443,44 @@ def _get_thread_summary_txn(
return count, latest_event
+ @cached()
+ async def get_thread_participated(
+ self, event_id: str, room_id: str, user_id: str
+ ) -> bool:
+ """Get whether the requesting user participated in a thread.
+
+ This is separate from get_thread_summary since that can be cached across
+ all users while this value is specific to the requeser.
+
+ Args:
+ event_id: The thread related to this event ID.
+ room_id: The room the event belongs to.
+ user_id: The user requesting the summary.
+
+ Returns:
+ True if the requesting user participated in the thread, otherwise false.
+ """
+
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ # Fetch whether the requester has participated or not.
+ sql = """
+ SELECT 1
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND room_id = ?
+ AND relation_type = ?
+ AND sender = ?
+ """
+
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -499,6 +583,104 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+ # State events and redacted events do not get bundled aggregations.
+ if event.is_state() or event.internal_metadata.is_redacted():
+ return None
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations: Dict[str, Any] = {}
+
+ annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+ if annotations.chunk:
+ aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ references = await self.get_relations_for_event(
+ event_id, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = await self.get_applicable_edit(event_id, room_id)
+
+ if edit:
+ aggregations[RelationTypes.REPLACE] = edit
+
+ # If this event is the start of a thread, include a summary of the replies.
+ if self._msc3440_enabled:
+ thread_count, latest_thread_event = await self.get_thread_summary(
+ event_id, room_id
+ )
+ participated = await self.get_thread_participated(
+ event_id, room_id, user_id
+ )
+ if latest_thread_event:
+ aggregations[RelationTypes.THREAD] = {
+ "latest_event": latest_thread_event,
+ "count": thread_count,
+ "current_user_participated": participated,
+ }
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self,
+ events: Iterable[EventBase],
+ user_id: str,
+ ) -> Dict[str, Dict[str, Any]]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # If bundled aggregations are disabled, nothing to do.
+ if not self._msc1849_enabled:
+ return {}
+
+ # TODO Parallelize.
+ results = {}
+ for event in events:
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result is not None:
+ results[event.event_id] = event_result
+
+ return results
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d53..95167116c953 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import logging
from abc import abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
+
+import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -38,9 +54,10 @@
logger = logging.getLogger(__name__)
-RatelimitOverride = collections.namedtuple(
- "RatelimitOverride", ("messages_per_second", "burst_count")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RatelimitOverride:
+ messages_per_second: int
+ burst_count: int
class RoomSortOrder(Enum):
@@ -71,8 +88,13 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
-class RoomWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class RoomWorkerStore(CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -83,7 +105,7 @@ async def store_room(
room_creator_user_id: str,
is_public: bool,
room_version: RoomVersion,
- ):
+ ) -> None:
"""Stores a room.
Args:
@@ -111,7 +133,7 @@ async def store_room(
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def get_room(self, room_id: str) -> dict:
+ async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve a room.
Args:
@@ -136,7 +158,9 @@ async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
A dict containing the room information, or None if the room is unknown.
"""
- def get_room_with_stats_txn(txn, room_id):
+ def get_room_with_stats_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -185,7 +209,7 @@ async def count_public_rooms(
ignore_non_federatable: If true filters out non-federatable rooms
"""
- def _count_public_rooms_txn(txn):
+ def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
if network_tuple:
@@ -195,6 +219,7 @@ def _count_public_rooms_txn(txn):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
+ assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """
@@ -208,7 +233,7 @@ def _count_public_rooms_txn(txn):
sql = """
SELECT
- COALESCE(COUNT(*), 0)
+ COUNT(*)
FROM (
%(published_sql)s
) published
@@ -226,7 +251,7 @@ def _count_public_rooms_txn(txn):
}
txn.execute(sql, query_args)
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
@@ -235,11 +260,11 @@ def _count_public_rooms_txn(txn):
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
sql = "SELECT count(*) FROM rooms"
txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
+ row = cast(Tuple[int], txn.fetchone())
+ return row[0]
return await self.db_pool.runInteraction("get_rooms", f)
@@ -251,7 +276,7 @@ async def get_largest_public_rooms(
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ):
+ ) -> List[Dict[str, Any]]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@@ -272,7 +297,7 @@ async def get_largest_public_rooms(
"""
where_clauses = []
- query_args = []
+ query_args: List[Union[str, int]] = []
if network_tuple:
if network_tuple.appservice_id:
@@ -281,6 +306,7 @@ async def get_largest_public_rooms(
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
+ assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """
@@ -372,7 +398,9 @@ async def get_largest_public_rooms(
LIMIT ?
"""
- def _get_largest_public_rooms_txn(txn):
+ def _get_largest_public_rooms_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
txn.execute(sql, query_args)
results = self.db_pool.cursor_to_dict(txn)
@@ -435,7 +463,7 @@ async def get_rooms_paginate(
"""
# Filter room names by a string
where_statement = ""
- search_pattern = []
+ search_pattern: List[object] = []
if search_term:
where_statement = """
WHERE LOWER(state.name) LIKE ?
@@ -523,27 +551,29 @@ async def get_rooms_paginate(
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
INNER JOIN rooms USING (room_id)
- %s
- ORDER BY %s %s
+ {where}
+ ORDER BY {order_by} {direction}, state.room_id {direction}
LIMIT ?
OFFSET ?
- """ % (
- where_statement,
- order_by_column,
- "ASC" if order_by_asc else "DESC",
+ """.format(
+ where=where_statement,
+ order_by=order_by_column,
+ direction="ASC" if order_by_asc else "DESC",
)
# Use a nested SELECT statement as SQL can't count(*) with an OFFSET
count_sql = """
SELECT count(*) FROM (
SELECT room_id FROM room_stats_state state
- %s
+ {where}
) AS get_room_ids
- """ % (
- where_statement,
+ """.format(
+ where=where_statement,
)
- def _get_rooms_paginate_txn(txn):
+ def _get_rooms_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
# Add the search term into the WHERE clause
# and execute the data query
txn.execute(info_sql, search_pattern + [limit, start])
@@ -575,7 +605,7 @@ def _get_rooms_paginate_txn(txn):
# Add the search term into the WHERE clause if present
txn.execute(count_sql, search_pattern)
- room_count = txn.fetchone()
+ room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0]
return await self.db_pool.runInteraction(
@@ -620,7 +650,7 @@ async def set_ratelimit_for_user(
burst_count: How many actions that can be performed before being limited.
"""
- def set_ratelimit_txn(txn):
+ def set_ratelimit_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="ratelimit_override",
@@ -643,7 +673,7 @@ async def delete_ratelimit_for_user(self, user_id: str) -> None:
user_id: user ID of the user
"""
- def delete_ratelimit_txn(txn):
+ def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
row = self.db_pool.simple_select_one_txn(
txn,
table="ratelimit_override",
@@ -667,7 +697,7 @@ def delete_ratelimit_txn(txn):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached()
- async def get_retention_policy_for_room(self, room_id):
+ async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -676,13 +706,15 @@ async def get_retention_policy_for_room(self, room_id):
configuration).
Args:
- room_id (str): The ID of the room to get the retention policy of.
+ room_id: The ID of the room to get the retention policy of.
Returns:
- dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ A dict containing "min_lifetime" and "max_lifetime" for this room.
"""
- def get_retention_policy_for_room_txn(txn):
+ def get_retention_policy_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -707,19 +739,23 @@ def get_retention_policy_for_room_txn(txn):
"max_lifetime": self.config.retention.retention_default_max_lifetime,
}
- row = ret[0]
+ min_lifetime = ret[0]["min_lifetime"]
+ max_lifetime = ret[0]["max_lifetime"]
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
- if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
+ if min_lifetime is None:
+ min_lifetime = self.config.retention.retention_default_min_lifetime
- if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
+ if max_lifetime is None:
+ max_lifetime = self.config.retention.retention_default_max_lifetime
- return row
+ return {
+ "min_lifetime": min_lifetime,
+ "max_lifetime": max_lifetime,
+ }
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
@@ -731,7 +767,9 @@ async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[st
The local and remote media as a lists of the media IDs.
"""
- def _get_media_mxcs_in_room_txn(txn):
+ def _get_media_mxcs_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], List[str]]:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
@@ -757,7 +795,7 @@ async def quarantine_media_ids_in_room(
logger.info("Quarantining media in room: %s", room_id)
- def _quarantine_media_in_room_txn(txn):
+ def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
@@ -767,13 +805,11 @@ def _quarantine_media_in_room_txn(txn):
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
- def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ def _get_media_mxcs_in_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> Tuple[List[str], List[Tuple[str, str]]]:
"""Retrieves all the local and remote media MXC URIs in a given room
- Args:
- txn (cursor)
- room_id (str)
-
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
@@ -841,7 +877,7 @@ async def quarantine_media_by_id(
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name
- def _quarantine_media_by_id_txn(txn):
+ def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else []
remote_mxcs = [(server_name, media_id)] if not is_local else []
@@ -863,7 +899,7 @@ async def quarantine_media_ids_by_user(
quarantined_by: The ID of the user who made the quarantine request
"""
- def _quarantine_media_by_user_txn(txn):
+ def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
@@ -871,7 +907,9 @@ def _quarantine_media_by_user_txn(txn):
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
- def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ def _get_media_ids_by_user_txn(
+ self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
+ ) -> List[str]:
"""Retrieves local media IDs by a given user
Args:
@@ -900,7 +938,7 @@ def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True)
def _quarantine_media_txn(
self,
- txn,
+ txn: LoggingTransaction,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
quarantined_by: Optional[str],
@@ -928,12 +966,15 @@ def _quarantine_media_txn(
# set quarantine
if quarantined_by is not None:
sql += "AND safe_from_quarantine = ?"
- rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ txn.executemany(
+ sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ )
# remove from quarantine
else:
- rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+ txn.executemany(
+ sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ )
- txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
@@ -951,7 +992,7 @@ def _quarantine_media_txn(
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, dict]:
+ ) -> Dict[str, Dict[str, Optional[int]]]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
@@ -971,7 +1012,9 @@ async def get_rooms_for_retention_period_in_range(
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
- def get_rooms_for_retention_period_in_range_txn(txn):
+ def get_rooms_for_retention_period_in_range_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, Optional[int]]]:
range_conditions = []
args = []
@@ -1050,11 +1093,14 @@ class _BackgroundUpdates:
class RoomBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self.config = hs.config
-
self.db_pool.updates.register_background_update_handler(
"insert_room_retention",
self._background_insert_retention,
@@ -1085,7 +1131,9 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._background_populate_rooms_creator_column,
)
- async def _background_insert_retention(self, progress, batch_size):
+ async def _background_insert_retention(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@@ -1095,7 +1143,7 @@ async def _background_insert_retention(self, progress, batch_size):
last_room = progress.get("room_id", "")
- def _background_insert_retention_txn(txn):
+ def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT state.room_id, state.event_id, events.json
@@ -1154,15 +1202,17 @@ def _background_insert_retention_txn(txn):
return batch_size
async def _background_add_rooms_room_version_column(
- self, progress: dict, batch_size: int
- ):
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to go and add room version information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
- def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+ def _background_add_rooms_room_version_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT room_id, json FROM current_state_events
INNER JOIN event_json USING (room_id, event_id)
@@ -1223,7 +1273,7 @@ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
return batch_size
async def _remove_tombstoned_rooms_from_directory(
- self, progress, batch_size
+ self, progress: JsonDict, batch_size: int
) -> int:
"""Removes any rooms with tombstone events from the room directory
@@ -1233,7 +1283,7 @@ async def _remove_tombstoned_rooms_from_directory(
last_room = progress.get("room_id", "")
- def _get_rooms(txn):
+ def _get_rooms(txn: LoggingTransaction) -> List[str]:
txn.execute(
"""
SELECT room_id
@@ -1271,7 +1321,7 @@ def _get_rooms(txn):
return len(rooms)
@abstractmethod
- def set_room_is_public(self, room_id, is_public):
+ def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
# this will need to be implemented if a background update is performed with
# existing (tombstoned, public) rooms in the database.
#
@@ -1318,7 +1368,7 @@ async def _background_populate_room_depth_min_depth2(
32-bit integer field.
"""
- def process(txn: Cursor) -> int:
+ def process(txn: LoggingTransaction) -> int:
last_room = progress.get("last_room", "")
txn.execute(
"""
@@ -1375,15 +1425,17 @@ def process(txn: Cursor) -> None:
return 0
async def _background_populate_rooms_creator_column(
- self, progress: dict, batch_size: int
- ):
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to go and add creator information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
- def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+ def _background_populate_rooms_creator_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT room_id, json FROM event_json
INNER JOIN rooms AS room USING (room_id)
@@ -1434,15 +1486,20 @@ def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
return batch_size
-class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self.config = hs.config
+ self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
- ):
+ ) -> None:
"""Ensure that the room is stored in the table
Called when we join a room over federation, and overwrites any room version
@@ -1488,7 +1545,7 @@ async def upsert_room_on_join(
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
- ):
+ ) -> None:
"""
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
@@ -1528,8 +1585,8 @@ async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
self.hs.get_notifier().on_new_replication_data()
async def set_room_is_public_appservice(
- self, room_id, appservice_id, network_id, is_public
- ):
+ self, room_id: str, appservice_id: str, network_id: str, is_public: bool
+ ) -> None:
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@@ -1538,11 +1595,10 @@ async def set_room_is_public_appservice(
network.
Args:
- room_id (str)
- appservice_id (str)
- network_id (str)
- is_public (bool): Whether to publish or unpublish the room from the
- list.
+ room_id
+ appservice_id
+ network_id
+ is_public: Whether to publish or unpublish the room from the list.
"""
if is_public:
@@ -1607,7 +1663,9 @@ async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
event_report: json list of information from event report
"""
- def _get_event_report_txn(txn, report_id):
+ def _get_event_report_txn(
+ txn: LoggingTransaction, report_id: int
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT
@@ -1679,9 +1737,11 @@ async def get_event_reports_paginate(
count: total number of event reports matching the filter criteria
"""
- def _get_event_reports_paginate_txn(txn):
+ def _get_event_reports_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
filters = []
- args = []
+ args: List[object] = []
if user_id:
filters.append("er.user_id LIKE ?")
@@ -1705,7 +1765,7 @@ def _get_event_reports_paginate_txn(txn):
where_clause
)
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a67c..4489732fdac4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@@ -64,7 +64,12 @@
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -985,7 +990,12 @@ def _is_local_host_in_room_ignoring_users_txn(txn):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1135,7 +1145,12 @@ def _background_current_state_membership_txn(txn, last_processed_room):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
@@ -1162,18 +1177,18 @@ def f(txn):
await self.db_pool.runInteraction("forget_membership", f)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _JoinedHostsCache:
"""The cached data used by the `_get_joined_hosts_cache`."""
# Dict of host to the set of their users in the room at the state group.
- hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
+ hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict)
# The state group `hosts_to_joined_users` is derived from. Will be an object
# if the instance is newly created or if the state is not based on a state
# group. (An object is used as a sentinel value to ensure that it never is
# equal to anything else).
- state_group = attr.ib(type=Union[object, int], factory=object)
+ state_group: Union[object, int] = attr.Factory(object)
def __len__(self):
return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7fe233767f76..2d085a5764a7 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,13 +14,18 @@
import logging
import re
-from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
+import attr
+
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -29,10 +34,15 @@
logger = logging.getLogger(__name__)
-SearchEntry = namedtuple(
- "SearchEntry",
- ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
-)
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SearchEntry:
+ key: str
+ value: str
+ event_id: str
+ room_id: str
+ stream_ordering: Optional[int]
+ origin_server_ts: int
def _clean_value_for_search(value: str) -> str:
@@ -105,7 +115,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
@@ -358,7 +373,12 @@ def reindex_search_txn(txn):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
index 5a97120437d0..e8c776b97a9b 100644
--- a/synapse/storage/databases/main/session.py
+++ b/synapse/storage/databases/main/session.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb91..2fb3e65192f5 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,6 @@
# limitations under the License.
import collections.abc
import logging
-from collections import namedtuple
from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
@@ -22,7 +21,11 @@
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
@@ -39,24 +42,16 @@
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
- """Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
- """
-
- __slots__ = []
-
- def __len__(self):
- return len(self.delta_ids) if self.delta_ids else 0
-
-
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -182,11 +177,15 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase:
NotFoundError if the room is unknown
"""
state_ids = await self.get_current_state_ids(room_id)
+
+ if not state_ids:
+ raise NotFoundError(f"Current state for room {room_id} is empty")
+
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id,))
+ raise NotFoundError(f"No create event in current state for room {room_id}")
# Retrieve the room's create event and return
create_event = await self.get_event(create_id)
@@ -349,7 +348,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -536,5 +540,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7f3624b12872..188afec332dd 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -56,7 +56,9 @@ async def get_current_state_deltas(
prev_stream_id = int(prev_stream_id)
# check we're not going backwards
- assert prev_stream_id <= max_stream_id
+ assert (
+ prev_stream_id <= max_stream_id
+ ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861c9..427ae1f649b1 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import Counter
@@ -24,7 +24,11 @@
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -96,7 +100,12 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -117,7 +126,9 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
- async def _populate_stats_process_users(self, progress, batch_size):
+ async def _populate_stats_process_users(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
This is a background update which regenerates statistics for users.
"""
@@ -129,7 +140,7 @@ async def _populate_stats_process_users(self, progress, batch_size):
last_user_id = progress.get("last_user_id", "")
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT name FROM users
WHERE name > ?
@@ -163,7 +174,9 @@ def _get_next_batch(txn):
return len(users_to_work_on)
- async def _populate_stats_process_rooms(self, progress, batch_size):
+ async def _populate_stats_process_rooms(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
@@ -173,7 +186,7 @@ async def _populate_stats_process_rooms(self, progress, batch_size):
last_room_id = progress.get("last_room_id", "")
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ?
@@ -302,7 +315,7 @@ async def bulk_update_stats_delta(
stream_id: Current position.
"""
- def _bulk_update_stats_delta_txn(txn):
+ def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
logger.debug(
@@ -334,7 +347,7 @@ async def update_stats_delta(
stats_type: str,
stats_id: str,
fields: Dict[str, int],
- complete_with_stream_id: Optional[int],
+ complete_with_stream_id: int,
absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None:
"""
@@ -367,14 +380,14 @@ async def update_stats_delta(
def _update_stats_delta_txn(
self,
- txn,
- ts,
- stats_type,
- stats_id,
- fields,
- complete_with_stream_id,
- absolute_field_overrides=None,
- ):
+ txn: LoggingTransaction,
+ ts: int,
+ stats_type: str,
+ stats_id: str,
+ fields: Dict[str, int],
+ complete_with_stream_id: int,
+ absolute_field_overrides: Optional[Dict[str, int]] = None,
+ ) -> None:
if absolute_field_overrides is None:
absolute_field_overrides = {}
@@ -417,20 +430,23 @@ def _update_stats_delta_txn(
)
def _upsert_with_additive_relatives_txn(
- self, txn, table, keyvalues, absolutes, additive_relatives
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ absolutes: Dict[str, Any],
+ additive_relatives: Dict[str, int],
+ ) -> None:
"""Used to update values in the stats tables.
This is basically a slightly convoluted upsert that *adds* to any
existing rows.
Args:
- txn
- table (str): Table name
- keyvalues (dict[str, any]): Row-identifying key values
- absolutes (dict[str, any]): Absolute (set) fields
- additive_relatives (dict[str, int]): Fields that will be added onto
- if existing row present.
+ table: Table name
+ keyvalues: Row-identifying key values
+ absolutes: Absolute (set) fields
+ additive_relatives: Fields that will be added onto if existing row present.
"""
if self.database_engine.can_native_upsert:
absolute_updates = [
@@ -486,20 +502,17 @@ def _upsert_with_additive_relatives_txn(
current_row.update(absolutes)
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
- async def _calculate_and_set_initial_state_for_room(
- self, room_id: str
- ) -> Tuple[dict, dict, int]:
+ async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
"""Calculate and insert an entry into room_stats_current.
Args:
room_id: The room ID under calculation.
-
- Returns:
- A tuple of room state, membership counts and stream position.
"""
- def _fetch_current_state_stats(txn):
- pos = self.get_room_max_stream_ordering()
+ def _fetch_current_state_stats(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
+ pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
rows = self.db_pool.simple_select_many_txn(
txn,
@@ -519,7 +532,7 @@ def _fetch_current_state_stats(txn):
retcols=["event_id"],
)
- event_ids = [row["event_id"] for row in rows]
+ event_ids = cast(List[str], [row["event_id"] for row in rows])
txn.execute(
"""
@@ -533,15 +546,15 @@ def _fetch_current_state_stats(txn):
txn.execute(
"""
- SELECT COALESCE(count(*), 0) FROM current_state_events
+ SELECT COUNT(*) FROM current_state_events
WHERE room_id = ?
""",
(room_id,),
)
- (current_state_events_count,) = txn.fetchone()
+ current_state_events_count = cast(Tuple[int], txn.fetchone())[0]
- users_in_room = self.get_users_in_room_txn(txn, room_id)
+ users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined]
return (
event_ids,
@@ -561,7 +574,7 @@ def _fetch_current_state_stats(txn):
"get_initial_state_for_room", _fetch_current_state_stats
)
- state_event_map = await self.get_events(event_ids, get_prev_content=False)
+ state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
room_state = {
"join_rules": None,
@@ -617,8 +630,10 @@ def _fetch_current_state_stats(txn):
},
)
- async def _calculate_and_set_initial_state_for_user(self, user_id):
- def _calculate_and_set_initial_state_for_user_txn(txn):
+ async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
+ def _calculate_and_set_initial_state_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, int]:
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
txn.execute(
@@ -629,7 +644,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn):
""",
(user_id,),
)
- (count,) = txn.fetchone()
+ count = cast(Tuple[int], txn.fetchone())[0]
return count, pos
joined_rooms, pos = await self.db_pool.runInteraction(
@@ -673,7 +688,9 @@ async def get_users_media_usage_paginate(
users that exist given this query
"""
- def get_users_media_usage_paginate_txn(txn):
+ def get_users_media_usage_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]
@@ -728,7 +745,7 @@ def get_users_media_usage_paginate_txn(txn):
sql_base=sql_base,
)
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab5525937..319464b1fa83 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -34,11 +34,11 @@
- topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively.
"""
-import abc
+
import logging
-from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+import attr
from frozendict import frozendict
from twisted.internet import defer
@@ -49,6 +49,7 @@
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
@@ -73,9 +74,11 @@
# Used as return values for pagination APIs
-_EventDictReturn = namedtuple(
- "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventDictReturn:
+ event_id: str
+ topological_ordering: Optional[int]
+ stream_ordering: int
def generate_pagination_where_clause(
@@ -333,13 +336,13 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
return " AND ".join(clauses), args
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
- """This is an abstract base class where subclasses must implement
- `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
- which can be called in the initializer.
- """
-
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -371,13 +374,22 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._stream_order_on_start = self.get_room_max_stream_ordering()
- @abc.abstractmethod
def get_room_max_stream_ordering(self) -> int:
- raise NotImplementedError()
+ """Get the stream_ordering of regular events that we have committed up to
+
+ Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ return self._stream_id_gen.get_current_token()
- @abc.abstractmethod
def get_room_min_stream_ordering(self) -> int:
- raise NotImplementedError()
+ """Get the stream_ordering of backfilled events that we have committed up to
+
+ Backfilled events use *negative* stream orderings, so this returns the
+ minimum negative stream id such that all stream ids greater than or
+ equal to it have been successfully persisted.
+ """
+ return self._backfill_id_gen.get_current_token()
def get_room_max_token(self) -> RoomStreamToken:
"""Get a `RoomStreamToken` that marks the current maximum persisted
@@ -819,7 +831,7 @@ def _set_before_and_after(
for event, row in zip(events, rows):
stream = row.stream_ordering
if topo_order and row.topological_ordering:
- topo = row.topological_ordering
+ topo: Optional[int] = row.topological_ordering
else:
topo = None
internal = event.internal_metadata
@@ -1343,11 +1355,3 @@ async def get_name_from_instance_id(self, instance_id: int) -> str:
retcol="instance_name",
desc="get_name_from_instance_id",
)
-
-
-class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self) -> int:
- return self._stream_id_gen.get_current_token()
-
- def get_room_min_stream_ordering(self) -> int:
- return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d43..c8e508a910fb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@
# limitations under the License.
import logging
-from typing import Dict, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Tuple, cast
+from synapse.replication.tcp.streams import TagAccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -204,6 +206,7 @@ async def add_tag_to_room(
The next account data ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -230,6 +233,7 @@ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> in
The next account data ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
sql = (
@@ -258,6 +262,7 @@ def _update_revision_txn(
next_id: The the revision to advance to.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -287,6 +292,21 @@ def _update_revision_txn(
# than the id that the client has.
pass
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
class TagsStore(TagsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 162282255232..4b78b4d098a9 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -13,16 +13,19 @@
# limitations under the License.
import logging
-from collections import namedtuple
from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -35,16 +38,6 @@
logger = logging.getLogger(__name__)
-_TransactionRow = namedtuple(
- "_TransactionRow",
- ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
-)
-
-_UpdateTransactionRow = namedtuple(
- "_TransactionRow", ("response_code", "response_json")
-)
-
-
class DestinationSortOrder(Enum):
"""Enum to define the sorting method used when returning destinations."""
@@ -71,7 +64,12 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -82,7 +80,7 @@ async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
- def _cleanup_transactions_txn(txn):
+ def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
await self.db_pool.runInteraction(
@@ -112,7 +110,9 @@ async def get_received_txn_response(
origin,
)
- def _get_received_txn_response(self, txn, transaction_id, origin):
+ def _get_received_txn_response(
+ self, txn: LoggingTransaction, transaction_id: str, origin: str
+ ) -> Optional[Tuple[int, JsonDict]]:
result = self.db_pool.simple_select_one_txn(
txn,
table="received_transactions",
@@ -187,7 +187,7 @@ async def get_destination_retry_timings(
return result
def _get_destination_retry_timings(
- self, txn, destination: str
+ self, txn: LoggingTransaction, destination: str
) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
@@ -222,7 +222,7 @@ async def set_destination_retry_timings(
"""
if self.database_engine.can_native_upsert:
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_native,
destination,
@@ -232,7 +232,7 @@ async def set_destination_retry_timings(
db_autocommit=True, # Safe as its a single upsert
)
else:
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_emulated,
destination,
@@ -242,8 +242,13 @@ async def set_destination_retry_timings(
)
def _set_destination_retry_timings_native(
- self, txn, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
assert self.database_engine.can_native_upsert
# Upsert retry time interval if retry_interval is zero (i.e. we're
@@ -273,8 +278,13 @@ def _set_destination_retry_timings_native(
)
def _set_destination_retry_timings_emulated(
- self, txn, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
@@ -384,7 +394,7 @@ async def set_destination_last_successful_stream_ordering(
last_successful_stream_ordering: the stream_ordering of the most
recent successfully-sent PDU
"""
- return await self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"destinations",
keyvalues={"destination": destination},
values={"last_successful_stream_ordering": last_successful_stream_ordering},
@@ -525,7 +535,7 @@ def get_destinations_paginate_txn(
else:
order = "ASC"
- args = []
+ args: List[object] = []
where_statement = ""
if destination:
args.extend(["%" + destination.lower() + "%"])
@@ -534,7 +544,7 @@ def get_destinations_paginate_txn(
sql_base = f"FROM destinations {where_statement} "
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = f"""
SELECT destination, retry_last_ts, retry_interval, failure_ts,
@@ -550,3 +560,14 @@ def get_destinations_paginate_txn(
return await self.db_pool.runInteraction(
"get_destinations_paginate_txn", get_destinations_paginate_txn
)
+
+ async def is_destination_known(self, destination: str) -> bool:
+ """Check if a destination is known to the server."""
+ result = await self.db_pool.simple_select_one_onecol(
+ table="destinations",
+ keyvalues={"destination": destination},
+ retcol="1",
+ allow_none=True,
+ desc="is_destination_known",
+ )
+ return bool(result)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 340ca9e47d47..2d339b60083b 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
import attr
@@ -23,19 +23,19 @@
from synapse.util import json_encoder, stringutils
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class UIAuthSessionData:
- session_id = attr.ib(type=str)
+ session_id: str
# The dictionary from the client root level, not the 'auth' key.
- clientdict = attr.ib(type=JsonDict)
+ clientdict: JsonDict
# The URI and method the session was intiatied with. These are checked at
# each stage of the authentication to ensure that the asked for operation
# has not changed.
- uri = attr.ib(type=str)
- method = attr.ib(type=str)
+ uri: str
+ method: str
# A string description of the operation that the current authentication is
# authorising.
- description = attr.ib(type=str)
+ description: str
class UIAuthWorkerStore(SQLBaseStore):
@@ -225,11 +225,14 @@ def _set_ui_auth_session_data_txn(
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value.
- result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
- txn,
- table="ui_auth_sessions",
- keyvalues={"session_id": session_id},
- retcols=("serverdict",),
+ result = cast(
+ Dict[str, Any],
+ self.db_pool.simple_select_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ ),
)
# Update it and add it back to the database.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af60..f7c778bdf22b 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@
from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@@ -102,8 +105,10 @@ def _make_staging_area(txn: LoggingTransaction) -> None:
GROUP BY room_id
"""
txn.execute(sql)
- rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ rooms = list(txn.fetchall())
+ self.db_pool.simple_insert_many_txn(
+ txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms
+ )
del rooms
sql = (
@@ -114,9 +119,11 @@ def _make_staging_area(txn: LoggingTransaction) -> None:
txn.execute(sql)
txn.execute("SELECT name FROM users")
- users = [{"user_id": x[0]} for x in txn.fetchall()]
+ users = list(txn.fetchall())
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(
+ txn, TEMP_TABLE + "_users", keys=("user_id",), values=users
+ )
new_pos = await self.get_max_stream_id_in_current_state_deltas()
await self.db_pool.runInteraction(
@@ -592,7 +599,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index eb1118d2cb20..5de70f31d294 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -327,14 +327,15 @@ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=(
+ "state_group",
+ "room_id",
+ "type",
+ "state_key",
+ "event_id",
+ ),
values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (state_group, room_id, key[0], key[1], state_id)
for key, state_id in delta_state.items()
],
)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index c4c8c0021bca..7614d76ac646 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -460,14 +460,9 @@ def _store_state_group_txn(txn: LoggingTransaction) -> int:
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (state_group, room_id, key[0], key[1], state_id)
for key, state_id in delta_ids.items()
],
)
@@ -475,14 +470,9 @@ def _store_state_group_txn(txn: LoggingTransaction) -> int:
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (state_group, room_id, key[0], key[1], state_id)
for key, state_id in current_state_ids.items()
],
)
@@ -589,14 +579,9 @@ def _purge_unreferenced_state_groups(
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (sg, room_id, key[0], key[1], state_id)
for key, state_id in curr_state.items()
],
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 540adb878174..71584f3f744b 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResult:
- verify_key = attr.ib(type=VerifyKey) # the key itself
- valid_until_ts = attr.ib(type=int) # how long we can use this key for
+ verify_key: VerifyKey # the key itself
+ valid_until_ts: int # how long we can use this key for
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e45adfcb5569..1823e1872049 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -696,7 +696,7 @@ def _get_or_create_schema_state(
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
@@ -705,5 +705,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib(type=str)
- absolute_path = attr.ib(type=str)
+ file_name: str
+ absolute_path: str
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 10a46b5e82ed..b1536c1ca491 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class PaginationChunk:
"""Returned by relation pagination APIs.
@@ -35,9 +35,9 @@ class PaginationChunk:
None then there are no previous results.
"""
- chunk = attr.ib(type=List[JsonDict])
- next_batch = attr.ib(type=Optional[Any], default=None)
- prev_batch = attr.ib(type=Optional[Any], default=None)
+ chunk: List[JsonDict]
+ next_batch: Optional[Any] = None
+ prev_batch: Optional[Any] = None
def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
@@ -51,7 +51,7 @@ def to_dict(self) -> Dict[str, Any]:
return d
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class RelationPaginationToken:
"""Pagination token for relation pagination API.
@@ -64,8 +64,8 @@ class RelationPaginationToken:
stream: The stream ordering of the boundary event.
"""
- topological = attr.ib(type=int)
- stream = attr.ib(type=int)
+ topological: int
+ stream: int
@staticmethod
def from_string(string: str) -> "RelationPaginationToken":
@@ -82,7 +82,7 @@ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class AggregationPaginationToken:
"""Pagination token for relation aggregation pagination API.
@@ -94,8 +94,8 @@ class AggregationPaginationToken:
stream: The MAX stream ordering in the boundary group.
"""
- count = attr.ib(type=int)
- stream = attr.ib(type=int)
+ count: int
+ stream: int
@staticmethod
def from_string(string: str) -> "AggregationPaginationToken":
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 50d08094d52c..2a3d47185ae5 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 66 # remember to update the list below when updating
+SCHEMA_VERSION = 67 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -50,6 +50,9 @@
Changes in SCHEMA_VERSION = 66:
- Queries on state_key columns are now disambiguated (ie, the codebase can handle
the `events` table having a `state_key` column).
+
+Changes in SCHEMA_VERSION = 67:
+ - state_events.prev_state is no longer written to.
"""
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b5ba1560d139..df8b2f108877 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -45,7 +45,7 @@
T = TypeVar("T")
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class StateFilter:
"""A filter used when querying for state.
@@ -58,8 +58,8 @@ class StateFilter:
appear in `types`.
"""
- types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
- include_others = attr.ib(default=False, type=bool)
+ types: "frozendict[str, Optional[FrozenSet[str]]]"
+ include_others: bool = False
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4ff3013908a7..3c13859faabd 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -74,8 +74,6 @@ def get_next(self) -> int:
def _load_current_id(
db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
) -> int:
- # debug logging for https://github.com/matrix-org/synapse/issues/7968
- logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor(txn_name="_load_current_id")
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
@@ -86,7 +84,9 @@ def _load_current_id(
(val,) = result
cur.close()
current_id = int(val) if val else step
- return (max if step > 0 else min)(current_id, step)
+ res = (max if step > 0 else min)(current_id, step)
+ logger.info("Initialising stream generator for %s(%s): %i", table, column, res)
+ return res
class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
@@ -762,13 +762,13 @@ async def __aexit__(
return self.inner.__exit__(exc_type, exc, tb)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""
- id_gen = attr.ib(type=MultiWriterIdGenerator)
- multiple_ids = attr.ib(type=Optional[int], default=None)
- stream_ids = attr.ib(type=List[int], factory=list)
+ id_gen: MultiWriterIdGenerator
+ multiple_ids: Optional[int] = None
+ stream_ids: List[int] = attr.Factory(list)
async def __aenter__(self) -> Union[int, List[int]]:
# It's safe to run this in autocommit mode as fetching values from a
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index c08d591f29d3..b52723e2b89c 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -28,14 +28,14 @@
MAX_LIMIT = 1000
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class PaginationConfig:
"""A configuration object which stores pagination parameters."""
- from_token = attr.ib(type=Optional[StreamToken])
- to_token = attr.ib(type=Optional[StreamToken])
- direction = attr.ib(type=str)
- limit = attr.ib(type=Optional[int])
+ from_token: Optional[StreamToken]
+ to_token: Optional[StreamToken]
+ direction: str
+ limit: Optional[int]
@classmethod
async def from_request(
diff --git a/synapse/types.py b/synapse/types.py
index fb72f193432f..f89fb216a656 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -15,13 +15,14 @@
import abc
import re
import string
-from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
+ List,
Mapping,
+ Match,
MutableMapping,
Optional,
Tuple,
@@ -59,9 +60,11 @@
StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T]
-# the type of a JSON-serialisable dict. This could be made stronger, but it will
-# do for now.
+# JSON types. These could be made stronger, but will do for now.
+# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
+# A JSON-serialisable object.
+JsonSerializable = object
# Note that this seems to require inheriting *directly* from Interface in order
@@ -78,7 +81,7 @@ class ISynapseReactor(
"""The interfaces necessary for Synapse to function."""
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class Requester:
"""
Represents the user making a request
@@ -96,13 +99,13 @@ class Requester:
"puppeting" the user.
"""
- user = attr.ib(type="UserID")
- access_token_id = attr.ib(type=Optional[int])
- is_guest = attr.ib(type=bool)
- shadow_banned = attr.ib(type=bool)
- device_id = attr.ib(type=Optional[str])
- app_service = attr.ib(type=Optional["ApplicationService"])
- authenticated_entity = attr.ib(type=str)
+ user: "UserID"
+ access_token_id: Optional[int]
+ is_guest: bool
+ shadow_banned: bool
+ device_id: Optional[str]
+ app_service: Optional["ApplicationService"]
+ authenticated_entity: str
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
@@ -209,7 +212,7 @@ def get_localpart_from_id(string: str) -> str:
DS = TypeVar("DS", bound="DomainSpecificString")
-@attr.s(slots=True, frozen=True, repr=False)
+@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True)
class DomainSpecificString(metaclass=abc.ABCMeta):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -222,11 +225,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
- localpart = attr.ib(type=str)
- domain = attr.ib(type=str)
+ localpart: str
+ domain: str
- # Because this class is a namedtuple of strings and booleans, it is deeply
- # immutable.
+ # Because this is a frozen class, it is deeply immutable.
def __copy__(self):
return self
@@ -380,7 +382,7 @@ def map_username_to_mxid_localpart(
onto different mxids
Returns:
- unicode: string suitable for a mxid localpart
+ string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
username = username.encode("utf-8")
@@ -388,29 +390,23 @@ def map_username_to_mxid_localpart(
# first we sort out upper-case characters
if case_sensitive:
- def f1(m):
+ def f1(m: Match[bytes]) -> bytes:
return b"_" + m.group().lower()
username = UPPER_CASE_PATTERN.sub(f1, username)
else:
username = username.lower()
- # then we sort out non-ascii characters
- def f2(m):
- g = m.group()[0]
- if isinstance(g, str):
- # on python 2, we need to do a ord(). On python 3, the
- # byte itself will do.
- g = ord(g)
- return b"=%02x" % (g,)
+ # then we sort out non-ascii characters by converting to the hex equivalent.
+ def f2(m: Match[bytes]) -> bytes:
+ return b"=%02x" % (m.group()[0],)
username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
# we also do the =-escaping to mxids starting with an underscore.
username = re.sub(b"^_", b"=5f", username)
- # we should now only have ascii bytes left, so can decode back to a
- # unicode.
+ # we should now only have ascii bytes left, so can decode back to a string.
return username.decode("ascii")
@@ -466,14 +462,12 @@ class RoomStreamToken:
attributes, must be hashable.
"""
- topological = attr.ib(
- type=Optional[int],
+ topological: Optional[int] = attr.ib(
validator=attr.validators.optional(attr.validators.instance_of(int)),
)
- stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ stream: int = attr.ib(validator=attr.validators.instance_of(int))
- instance_map = attr.ib(
- type="frozendict[str, int]",
+ instance_map: "frozendict[str, int]" = attr.ib(
factory=frozendict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
@@ -482,7 +476,7 @@ class RoomStreamToken:
),
)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
"""Validates that both `topological` and `instance_map` aren't set."""
if self.instance_map and self.topological:
@@ -598,7 +592,7 @@ async def to_string(self, store: "DataStore") -> str:
return "s%d" % (self.stream,)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class StreamToken:
"""A collection of positions within multiple streams.
@@ -606,20 +600,20 @@ class StreamToken:
must be hashable.
"""
- room_key = attr.ib(
- type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
+ room_key: RoomStreamToken = attr.ib(
+ validator=attr.validators.instance_of(RoomStreamToken)
)
- presence_key = attr.ib(type=int)
- typing_key = attr.ib(type=int)
- receipt_key = attr.ib(type=int)
- account_data_key = attr.ib(type=int)
- push_rules_key = attr.ib(type=int)
- to_device_key = attr.ib(type=int)
- device_list_key = attr.ib(type=int)
- groups_key = attr.ib(type=int)
+ presence_key: int
+ typing_key: int
+ receipt_key: int
+ account_data_key: int
+ push_rules_key: int
+ to_device_key: int
+ device_list_key: int
+ groups_key: int
_SEPARATOR = "_"
- START: "StreamToken"
+ START: ClassVar["StreamToken"]
@classmethod
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
@@ -679,7 +673,7 @@ def copy_and_replace(self, key, new_value) -> "StreamToken":
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedEventPosition:
"""Position of a newly persisted event with instance that persisted it.
@@ -687,8 +681,8 @@ class PersistedEventPosition:
RoomStreamToken.
"""
- instance_name = attr.ib(type=str)
- stream = attr.ib(type=int)
+ instance_name: str
+ stream: int
def persisted_after(self, token: RoomStreamToken) -> bool:
return token.get_stream_pos_for_instance(self.instance_name) < self.stream
@@ -706,16 +700,18 @@ def to_room_stream_token(self) -> RoomStreamToken:
return RoomStreamToken(None, self.stream)
-class ThirdPartyInstanceID(
- namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThirdPartyInstanceID:
+ appservice_id: Optional[str]
+ network_id: Optional[str]
+
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
- # Because this class is a namedtuple of strings, it is deeply immutable.
+ # Because this class is a frozen class, it is deeply immutable.
def __copy__(self):
return self
@@ -723,32 +719,28 @@ def __deepcopy__(self, memo):
return self
@classmethod
- def from_string(cls, s):
+ def from_string(cls, s: str) -> "ThirdPartyInstanceID":
bits = s.split("|", 2)
if len(bits) != 2:
raise SynapseError(400, "Invalid ID %r" % (s,))
return cls(appservice_id=bits[0], network_id=bits[1])
- def to_string(self):
+ def to_string(self) -> str:
return "%s|%s" % (self.appservice_id, self.network_id)
__str__ = to_string
- @classmethod
- def create(cls, appservice_id, network_id):
- return cls(appservice_id=appservice_id, network_id=network_id)
-
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReadReceipt:
"""Information about a read-receipt"""
- room_id = attr.ib()
- receipt_type = attr.ib()
- user_id = attr.ib()
- event_ids = attr.ib()
- data = attr.ib()
+ room_id: str
+ receipt_type: str
+ user_id: str
+ event_ids: List[str]
+ data: JsonDict
def get_verify_key_from_cross_signing_key(key_info):
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 95f23e27b6b1..511f52534b3c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,9 +14,8 @@
import json
import logging
-import re
import typing
-from typing import Any, Callable, Dict, Generator, Optional, Pattern
+from typing import Any, Callable, Dict, Generator, Optional
import attr
from frozendict import frozendict
@@ -32,10 +31,14 @@
if typing.TYPE_CHECKING:
pass
-logger = logging.getLogger(__name__)
-
+# FIXME Mjolnir imports glob_to_regex from this file, but it was moved to
+# matrix_common.
+# As a temporary workaround, we import glob_to_regex here for
+# compatibility with current versions of Mjolnir.
+# See https://github.com/matrix-org/mjolnir/pull/174
+from matrix_common.regex import glob_to_regex # noqa
-_WILDCARD_RUN = re.compile(r"([\?\*]+)")
+logger = logging.getLogger(__name__)
def _reject_invalid_json(val: Any) -> None:
@@ -185,56 +188,3 @@ def log_failure(
if not consumeErrors:
return failure
return None
-
-
-def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern:
- """Converts a glob to a compiled regex object.
-
- Args:
- glob: pattern to match
- word_boundary: If True, the pattern will be allowed to match at word boundaries
- anywhere in the string. Otherwise, the pattern is anchored at the start and
- end of the string.
-
- Returns:
- compiled regex pattern
- """
-
- # Patterns with wildcards must be simplified to avoid performance cliffs
- # - The glob `?**?**?` is equivalent to the glob `???*`
- # - The glob `???*` is equivalent to the regex `.{3,}`
- chunks = []
- for chunk in _WILDCARD_RUN.split(glob):
- # No wildcards? re.escape()
- if not _WILDCARD_RUN.match(chunk):
- chunks.append(re.escape(chunk))
- continue
-
- # Wildcards? Simplify.
- qmarks = chunk.count("?")
- if "*" in chunk:
- chunks.append(".{%d,}" % qmarks)
- else:
- chunks.append(".{%d}" % qmarks)
-
- res = "".join(chunks)
-
- if word_boundary:
- res = re_word_boundary(res)
- else:
- # \A anchors at start of string, \Z at end of string
- res = r"\A" + res + r"\Z"
-
- return re.compile(res, re.IGNORECASE)
-
-
-def re_word_boundary(r: str) -> str:
- """
- Adds word boundary characters to the start and end of an
- expression to require that the match occur as a whole word,
- but do so respecting the fact that strings starting or ending
- with non-word characters will change word boundaries.
- """
- # we can't use \b as it chokes on unicode. however \W seems to be okay
- # as shorthand for [^0-9A-Za-z_].
- return r"(^|\W)%s(\W|$)" % (r,)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 20ce294209ad..3f7299aff7eb 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import collections
import inspect
import itertools
@@ -30,9 +31,11 @@
Iterator,
Optional,
Set,
+ Tuple,
TypeVar,
Union,
cast,
+ overload,
)
import attr
@@ -55,7 +58,26 @@
_T = TypeVar("_T")
-class ObservableDeferred(Generic[_T]):
+class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
+ """Abstract base class defining the consumer interface of ObservableDeferred"""
+
+ __slots__ = ()
+
+ @abc.abstractmethod
+ def observe(self) -> "defer.Deferred[_T]":
+ """Add a new observer for this ObservableDeferred
+
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underlying deferred.
+
+ Note that the returned Deferred doesn't follow the Synapse logcontext rules -
+ you will probably want to `make_deferred_yieldable` it.
+ """
+ ...
+
+
+class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
@@ -234,12 +256,65 @@ def yieldable_gather_results(
).addErrback(unwrapFirstError)
-@attr.s(slots=True)
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
+T3 = TypeVar("T3")
+
+
+@overload
+def gather_results(
+ deferredList: Tuple[()], consumeErrors: bool = ...
+) -> "defer.Deferred[Tuple[()]]":
+ ...
+
+
+@overload
+def gather_results(
+ deferredList: Tuple["defer.Deferred[T1]"],
+ consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1]]":
+ ...
+
+
+@overload
+def gather_results(
+ deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
+ consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2]]":
+ ...
+
+
+@overload
+def gather_results(
+ deferredList: Tuple[
+ "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
+ ],
+ consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2, T3]]":
+ ...
+
+
+def gather_results( # type: ignore[misc]
+ deferredList: Tuple["defer.Deferred[T1]", ...],
+ consumeErrors: bool = False,
+) -> "defer.Deferred[Tuple[T1, ...]]":
+ """Combines a tuple of `Deferred`s into a single `Deferred`.
+
+ Wraps `defer.gatherResults` to provide type annotations that support heterogenous
+ lists of `Deferred`s.
+ """
+ # The `type: ignore[misc]` above suppresses
+ # "Overloaded function implementation cannot produce return type of signature 1/2/3"
+ deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
+ return deferred.addCallback(tuple)
+
+
+@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.
- count = attr.ib(type=int)
+ count: int
# Deferreds for the things blocked from executing.
- deferreds = attr.ib(type=collections.OrderedDict)
+ deferreds: collections.OrderedDict
class Linearizer:
@@ -352,7 +427,7 @@ def _await_lock(self, key: Hashable) -> defer.Deferred:
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
- new_defer = make_deferred_yieldable(defer.Deferred())
+ new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1
def cb(_r: None) -> "defer.Deferred[None]":
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 470f4f91a59b..e325f44da328 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -76,6 +76,7 @@ async def get(self) -> TV:
# Fire off the callable now if this is our first time
if not self._deferred:
+ assert self._callable is not None
self._deferred = run_in_background(self._callable)
# we will never need the callable again, so make sure it can be GCed
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 485ddb189373..d267703df0f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -33,7 +33,7 @@
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache
@@ -41,14 +41,13 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
- known_absent: Keys that were looked up in the dict and were not
- there.
+ known_absent: Keys that were looked up in the dict and were not there.
value: The full or partial dict value
"""
- full = attr.ib(type=bool)
- known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
- value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
+ full: bool
+ known_absent: Set[Any] # should be: Set[DKT]
+ value: Dict[Any, Any] # should be: Dict[DKT, DV]
def __len__(self) -> int:
return len(self.value)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index eb96f7e665e6..3f11a2f9dd5c 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -69,7 +69,6 @@ def _get_size_of(val: Any, *, recurse: bool = True) -> int:
sizer.exclude_refs((), None, "")
return sizer.asizeof(val, limit=100 if recurse else 0)
-
except ImportError:
def _get_size_of(val: Any, *, recurse: bool = True) -> int:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 88ccf443377c..a3eb5f741bfc 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,19 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ Optional,
+ TypeVar,
+)
import attr
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import (
+ active_span,
+ start_active_span,
+ start_active_span_follows_from,
+)
from synapse.util import Clock
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ import opentracing
+
# the type of the key in the cache
KV = TypeVar("KV")
@@ -54,6 +72,20 @@ class ResponseCacheContext(Generic[KV]):
"""
+@attr.s(auto_attribs=True)
+class ResponseCacheEntry:
+ result: AbstractObservableDeferred
+ """The (possibly incomplete) result of the operation.
+
+ Note that we continue to store an ObservableDeferred even after the operation
+ completes (rather than switching to an immediate value), since that makes it
+ easier to cache Failure results.
+ """
+
+ opentracing_span_context: "Optional[opentracing.SpanContext]"
+ """The opentracing span which generated/is generating the result"""
+
+
class ResponseCache(Generic[KV]):
"""
This caches a deferred response. Until the deferred completes it will be
@@ -63,10 +95,7 @@ class ResponseCache(Generic[KV]):
"""
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
- # This is poorly-named: it includes both complete and incomplete results.
- # We keep complete results rather than switching to absolute values because
- # that makes it easier to cache Failure results.
- self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
+ self._result_cache: Dict[KV, ResponseCacheEntry] = {}
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
@@ -75,56 +104,63 @@ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
self._metrics = register_cache("response_cache", name, self, resizable=False)
def size(self) -> int:
- return len(self.pending_result_cache)
+ return len(self._result_cache)
def __len__(self) -> int:
return self.size()
- def get(self, key: KV) -> Optional[defer.Deferred]:
- """Look up the given key.
+ def keys(self) -> Iterable[KV]:
+ """Get the keys currently in the result cache
- Returns a new Deferred (which also doesn't follow the synapse
- logcontext rules). You will probably want to make_deferred_yieldable the result.
+ Returns both incomplete entries, and (if the timeout on this cache is non-zero),
+ complete entries which are still in the cache.
- If there is no entry for the key, returns None.
+ Note that the returned iterator is not safe in the face of concurrent execution:
+ behaviour is undefined if `wrap` is called during iteration.
+ """
+ return self._result_cache.keys()
+
+ def _get(self, key: KV) -> Optional[ResponseCacheEntry]:
+ """Look up the given key.
Args:
- key: key to get/set in the cache
+ key: key to get in the cache
Returns:
- None if there is no entry for this key; otherwise a deferred which
- resolves to the result.
+ The entry for this key, if any; else None.
"""
- result = self.pending_result_cache.get(key)
- if result is not None:
+ entry = self._result_cache.get(key)
+ if entry is not None:
self._metrics.inc_hits()
- return result.observe()
+ return entry
else:
self._metrics.inc_misses()
return None
def _set(
- self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
- ) -> "defer.Deferred[RV]":
+ self,
+ context: ResponseCacheContext[KV],
+ deferred: "defer.Deferred[RV]",
+ opentracing_span_context: "Optional[opentracing.SpanContext]",
+ ) -> ResponseCacheEntry:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
synapse.logging.context.run_in_background).
- Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
- You will probably want to make_deferred_yieldable the result.
-
Args:
context: Information about the cache miss
deferred: The deferred which resolves to the result.
+ opentracing_span_context: An opentracing span wrapping the calculation
Returns:
- A new deferred which resolves to the actual result.
+ The cache entry object.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
key = context.cache_key
- self.pending_result_cache[key] = result
+ entry = ResponseCacheEntry(result, opentracing_span_context)
+ self._result_cache[key] = entry
def on_complete(r: RV) -> RV:
# if this cache has a non-zero timeout, and the callback has not cleared
@@ -132,18 +168,18 @@ def on_complete(r: RV) -> RV:
# its removal later.
if self.timeout_sec and context.should_cache:
self.clock.call_later(
- self.timeout_sec, self.pending_result_cache.pop, key, None
+ self.timeout_sec, self._result_cache.pop, key, None
)
else:
# otherwise, remove the result immediately.
- self.pending_result_cache.pop(key, None)
+ self._result_cache.pop(key, None)
return r
- # make sure we do this *after* adding the entry to pending_result_cache,
+ # make sure we do this *after* adding the entry to result_cache,
# in case the result is already complete (in which case flipping the order would
# leave us with a stuck entry in the cache).
result.addBoth(on_complete)
- return result.observe()
+ return entry
async def wrap(
self,
@@ -189,20 +225,41 @@ async def handle_request(request):
Returns:
The result of the callback (from the cache, or otherwise)
"""
- result = self.get(key)
- if not result:
+ entry = self._get(key)
+ if not entry:
logger.debug(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
context = ResponseCacheContext(cache_key=key)
if cache_context:
kwargs["cache_context"] = context
- d = run_in_background(callback, *args, **kwargs)
- result = self._set(context, d)
- elif not isinstance(result, defer.Deferred) or result.called:
+
+ span_context: Optional[opentracing.SpanContext] = None
+
+ async def cb() -> RV:
+ # NB it is important that we do not `await` before setting span_context!
+ nonlocal span_context
+ with start_active_span(f"ResponseCache[{self._name}].calculate"):
+ span = active_span()
+ if span:
+ span_context = span.context
+ return await callback(*args, **kwargs)
+
+ d = run_in_background(cb)
+ entry = self._set(context, d, span_context)
+ return await make_deferred_yieldable(entry.result.observe())
+
+ result = entry.result.observe()
+ if result.called:
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
logger.info(
"[%s]: using incomplete cached result for [%s]", self._name, key
)
- return await make_deferred_yieldable(result)
+
+ span_context = entry.opentracing_span_context
+ with start_active_span_follows_from(
+ f"ResponseCache[{self._name}].wait",
+ contexts=(span_context,) if span_context else (),
+ ):
+ return await make_deferred_yieldable(result)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index de2adacd70dc..46771a401b50 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -142,6 +142,7 @@ def _writer(self) -> None:
def wait(self) -> "Deferred[None]":
"""Returns a deferred that resolves when finished writing to file"""
+ assert self._finished_deferred is not None
return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self) -> None:
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3aa9ba3c43ac..4b53b6d40b31 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,6 +31,7 @@
from tests import unittest
from tests.test_utils import simple_async_mock
+from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@@ -210,6 +211,102 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), AuthError)
+ @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+ def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+ """
+ Tests that when an application service passes the device_id URL parameter
+ with the ID of a valid device for the user in question,
+ the requester instance tracks that device ID.
+ """
+ masquerading_user_id = b"@doppelganger:matrix.org"
+ masquerading_device_id = b"DOPPELDEVICE"
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+ )
+ app_service.is_interested_in_user = Mock(return_value=True)
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+ self.store.get_user_by_access_token = simple_async_mock(None)
+ # This also needs to just return a truth-y value
+ self.store.get_device = simple_async_mock({"hidden": False})
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.args[b"user_id"] = [masquerading_user_id]
+ request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.assertEquals(
+ requester.user.to_string(), masquerading_user_id.decode("utf8")
+ )
+ self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+
+ @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+ """
+ Tests that when an application service passes the device_id URL parameter
+ with an ID that is not a valid device ID for the user in question,
+ the request fails with the appropriate error code.
+ """
+ masquerading_user_id = b"@doppelganger:matrix.org"
+ masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+ )
+ app_service.is_interested_in_user = Mock(return_value=True)
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+ self.store.get_user_by_access_token = simple_async_mock(None)
+ # This also needs to just return a falsey value
+ self.store.get_device = simple_async_mock(None)
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.args[b"user_id"] = [masquerading_user_id]
+ request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
+ self.assertEquals(failure.value.code, 400)
+ self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+
+ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(
+ user_id="@baldrick:matrix.org",
+ device_id="device",
+ token_owner="@admin:matrix.org",
+ )
+ )
+ self.store.insert_client_ip = simple_async_mock(None)
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_success(self.auth.get_user_by_req(request))
+ self.store.insert_client_ip.assert_called_once()
+
+ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
+ self.auth._track_puppeted_user_ips = True
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(
+ user_id="@baldrick:matrix.org",
+ device_id="device",
+ token_owner="@admin:matrix.org",
+ )
+ )
+ self.store.insert_client_ip = simple_async_mock(None)
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_success(self.auth.get_user_by_req(request))
+ self.assertEquals(self.store.insert_client_ip.call_count, 2)
+
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index f386b5e128bf..ba2a2bfd64ad 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -16,13 +16,13 @@
from twisted.internet import defer
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, Namespace
from tests import unittest
-def _regex(regex, exclusive=True):
- return {"regex": re.compile(regex), "exclusive": exclusive}
+def _regex(regex: str, exclusive: bool = True) -> Namespace:
+ return Namespace(exclusive, None, re.compile(regex))
class ApplicationServiceTestCase(unittest.TestCase):
@@ -33,11 +33,6 @@ def setUp(self):
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
- namespaces={
- ApplicationService.NS_USERS: [],
- ApplicationService.NS_ROOMS: [],
- ApplicationService.NS_ALIASES: [],
- },
)
self.event = Mock(
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 1c920157f506..a72a0103d3b8 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -14,6 +14,7 @@
import nacl.signing
+import signedjson.types
from unpaddedbase64 import decode_base64
from synapse.api.room_versions import RoomVersions
@@ -35,7 +36,12 @@
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
- self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
+ # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
+ # monkeypatched to include new `alg` and `version` attributes. This is captured
+ # by the `signedjson.types.SigningKey` protocol.
+ self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey(
+ SIGNING_KEY_SEED
+ )
self.signing_key.alg = KEY_ALG
self.signing_key.version = KEY_VER
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b457dad6d263..b2376e2db925 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -266,7 +266,8 @@ def test_upload_signatures(self):
)
# expect signing key update edu
- self.assertEqual(len(self.edus), 1)
+ self.assertEqual(len(self.edus), 2)
+ self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
# sign the devices
@@ -491,7 +492,7 @@ def check_signing_key_update_txn(
) -> None:
"""Check that the txn has an EDU with a signing key update."""
edus = txn["edus"]
- self.assertEqual(len(edus), 1)
+ self.assertEqual(len(edus), 2)
def generate_and_upload_device_signing_key(
self, user_id: str, device_id: str
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 663960ff534a..bfa156eebbe5 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -108,6 +108,15 @@ def send_example_state_events_to_room(
"state_key": "",
},
),
+ (
+ EventTypes.Topic,
+ {
+ "content": {
+ "topic": "A really cool room",
+ },
+ "state_key": "",
+ },
+ ),
]
)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index f0723892e416..734ed84d78d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Iterable
from unittest import mock
+from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
@@ -23,6 +25,7 @@
from synapse.api.errors import Codes, SynapseError
from tests import unittest
+from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -161,8 +164,9 @@ def test_claim_one_time_key(self):
def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- fallback_key = {"alg1:k1": "key1"}
- fallback_key2 = {"alg1:k2": "key2"}
+ fallback_key = {"alg1:k1": "fallback_key1"}
+ fallback_key2 = {"alg1:k2": "fallback_key2"}
+ fallback_key3 = {"alg1:k2": "fallback_key3"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
@@ -175,7 +179,7 @@ def test_fallback_key(self):
self.handler.upload_keys_for_user(
local_user,
device_id,
- {"org.matrix.msc2732.fallback_keys": fallback_key},
+ {"fallback_keys": fallback_key},
)
)
@@ -220,7 +224,7 @@ def test_fallback_key(self):
self.handler.upload_keys_for_user(
local_user,
device_id,
- {"org.matrix.msc2732.fallback_keys": fallback_key},
+ {"fallback_keys": fallback_key},
)
)
@@ -234,7 +238,7 @@ def test_fallback_key(self):
self.handler.upload_keys_for_user(
local_user,
device_id,
- {"org.matrix.msc2732.fallback_keys": fallback_key2},
+ {"fallback_keys": fallback_key2},
)
)
@@ -271,6 +275,25 @@ def test_fallback_key(self):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
+ # using the unstable prefix should also set the fallback key
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key3},
+ )
+ )
+
+ res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
+ )
+
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
@@ -745,6 +768,8 @@ def test_query_devices_remote_sync(self):
remote_user_id = "@test:other"
local_user_id = "@test:test"
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
@@ -811,3 +836,94 @@ def test_query_devices_remote_sync(self):
}
},
)
+
+ @parameterized.expand(
+ [
+ # The remote homeserver's response indicates that this user has 0/1/2 devices.
+ ([],),
+ (["device_1"],),
+ (["device_1", "device_2"],),
+ ]
+ )
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ """Test that requests for all of a remote user's devices are cached.
+
+ We do this by asserting that only one call over federation was made, and that
+ the two queries to the local homeserver produce the same response.
+ """
+ local_user_id = "@test:test"
+ remote_user_id = "@test:other"
+ request_body = {"device_keys": {remote_user_id: []}}
+
+ response_devices = [
+ {
+ "device_id": device_id,
+ "keys": {
+ "algorithms": ["dummy"],
+ "device_id": device_id,
+ "keys": {f"dummy:{device_id}": "dummy"},
+ "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+ "unsigned": {},
+ "user_id": "@test:other",
+ },
+ }
+ for device_id in device_ids
+ ]
+
+ response_body = {
+ "devices": response_devices,
+ "user_id": remote_user_id,
+ "stream_id": 12345, # an integer, according to the spec
+ }
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
+ mock_get_rooms = mock.patch.object(
+ self.store,
+ "get_rooms_for_user",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(["some_room_id"]),
+ )
+ mock_request = mock.patch.object(
+ self.hs.get_federation_client(),
+ "query_user_devices",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(response_body),
+ )
+
+ with mock_get_rooms, mock_request as mocked_federation_request:
+ # Make the first query and sanity check it succeeds.
+ response_1 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_1["failures"], {})
+
+ # We should have made a federation request to do so.
+ mocked_federation_request.assert_called_once()
+
+ # Reset the mock so we can prove we don't make a second federation request.
+ mocked_federation_request.reset_mock()
+
+ # Repeat the query.
+ response_2 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_2["failures"], {})
+
+ # We should not have made a second federation request.
+ mocked_federation_request.assert_not_called()
+
+ # The two requests to the local homeserver should be identical.
+ self.assertEqual(response_1, response_2)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e1557566e4bc..496b58172643 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -373,9 +373,7 @@ async def get_event_auth(
destination: str, room_id: str, event_id: str
) -> List[EventBase]:
return [
- event_from_pdu_json(
- ae.get_pdu_json(), room_version=room_version, outlier=True
- )
+ event_from_pdu_json(ae.get_pdu_json(), room_version=room_version)
for ae in auth_events
]
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 8a8d369faca1..5816295d8b97 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -23,6 +23,7 @@
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.test_utils.event_injection import create_event
logger = logging.getLogger(__name__)
@@ -51,6 +52,24 @@ def prepare(self, reactor, clock, hs):
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+ def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
+ # Create a member event we can use as an auth_event
+ memberEvent, memberEventContext = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.requester.user.to_string(),
+ state_key=self.requester.user.to_string(),
+ content={"membership": "join"},
+ )
+ )
+ self.get_success(
+ self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ )
+
+ return memberEvent, memberEventContext
+
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
@@ -156,6 +175,90 @@ def test_duplicated_txn_id_one_call(self):
self.assertEqual(len(events), 2)
self.assertEqual(events[0].event_id, events[1].event_id)
+ def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
+ """When we set allow_no_prev_events=True, should be able to create a
+ event without any prev_events (only auth_events).
+ """
+ # Create a member event we can use as an auth_event
+ memberEvent, _ = self._create_and_persist_member_event()
+
+ # Try to create the event with empty prev_events bit with some auth_events
+ event, _ = self.get_success(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": self.room_id,
+ "sender": self.requester.user.to_string(),
+ "content": {"msgtype": "m.text", "body": random_string(5)},
+ },
+ # Empty prev_events is the key thing we're testing here
+ prev_event_ids=[],
+ # But with some auth_events
+ auth_event_ids=[memberEvent.event_id],
+ # Allow no prev_events!
+ allow_no_prev_events=True,
+ )
+ )
+ self.assertIsNotNone(event)
+
+ def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
+ self,
+ ):
+ """When we set allow_no_prev_events=False, shouldn't be able to create a
+ event without any prev_events even if it has auth_events. Expect an
+ exception to be raised.
+ """
+ # Create a member event we can use as an auth_event
+ memberEvent, _ = self._create_and_persist_member_event()
+
+ # Try to create the event with empty prev_events but with some auth_events
+ self.get_failure(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": self.room_id,
+ "sender": self.requester.user.to_string(),
+ "content": {"msgtype": "m.text", "body": random_string(5)},
+ },
+ # Empty prev_events is the key thing we're testing here
+ prev_event_ids=[],
+ # But with some auth_events
+ auth_event_ids=[memberEvent.event_id],
+ # We expect the test to fail because empty prev_events are not
+ # allowed here!
+ allow_no_prev_events=False,
+ ),
+ AssertionError,
+ )
+
+ def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
+ self,
+ ):
+ """When we set allow_no_prev_events=True, should be able to create a
+ event without any prev_events or auth_events. Expect an exception to be
+ raised.
+ """
+ # Try to create the event with empty prev_events and empty auth_events
+ self.get_failure(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": self.room_id,
+ "sender": self.requester.user.to_string(),
+ "content": {"msgtype": "m.text", "body": random_string(5)},
+ },
+ prev_event_ids=[],
+ # The event should be rejected when there are no auth_events
+ auth_event_ids=[],
+ # Allow no prev_events!
+ allow_no_prev_events=True,
+ ),
+ AssertionError,
+ )
+
class ServerAclValidationTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 08e9730d4dfa..2add72b28a3a 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -22,7 +22,7 @@
import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login
+from synapse.rest.client import devices, login, logout
from synapse.types import JsonDict
from tests import unittest
@@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
+ logout.register_servlets,
]
def setUp(self):
@@ -719,6 +720,31 @@ def custom_auth_no_local_user_fallback_test_body(self):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
+ def test_on_logged_out(self):
+ """Tests that the on_logged_out callback is called when the user logs out."""
+ self.register_user("rin", "password")
+ tok = self.login("rin", "password")
+
+ self.called = False
+
+ async def on_logged_out(user_id, device_id, access_token):
+ self.called = True
+
+ on_logged_out = Mock(side_effect=on_logged_out)
+ self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
+ on_logged_out
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/logout",
+ {},
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200)
+ on_logged_out.assert_called_once()
+ self.assertTrue(self.called)
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e5a6a6c747bf..51b22d299812 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -28,6 +28,7 @@
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.transport.client import TransportLayerClient
from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -134,10 +135,18 @@ def prepare(self, reactor, clock, hs: HomeServer):
self._add_child(self.space, self.room, self.token)
def _add_child(
- self, space_id: str, room_id: str, token: str, order: Optional[str] = None
+ self,
+ space_id: str,
+ room_id: str,
+ token: str,
+ order: Optional[str] = None,
+ via: Optional[List[str]] = None,
) -> None:
"""Add a child room to a space."""
- content: JsonDict = {"via": [self.hs.hostname]}
+ if via is None:
+ via = [self.hs.hostname]
+
+ content: JsonDict = {"via": via}
if order is not None:
content["order"] = order
self.helper.send_state(
@@ -253,6 +262,38 @@ def test_simple_space(self):
)
self._assert_hierarchy(result, expected)
+ def test_large_space(self):
+ """Test a space with a large number of rooms."""
+ rooms = [self.room]
+ # Make at least 51 rooms that are part of the space.
+ for _ in range(55):
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, room, self.token)
+ rooms.append(room)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The spaces result should have the space and the first 50 rooms in it,
+ # along with the links from space -> room for those 50 rooms.
+ expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]]
+ self._assert_rooms(result, expected)
+
+ # The result should have the space and the rooms in it, along with the links
+ # from space -> room.
+ expected = [(self.space, rooms)] + [(room, []) for room in rooms]
+
+ # Make two requests to fully paginate the results.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ result2 = self.get_success(
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, from_token=result["next_batch"]
+ )
+ )
+ # Combine the results.
+ result["rooms"] += result2["rooms"]
+ self._assert_hierarchy(result, expected)
+
def test_visibility(self):
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
@@ -1004,6 +1045,85 @@ async def summarize_remote_room_hierarchy(_self, room, suggested_only):
)
self._assert_hierarchy(result, expected)
+ def test_fed_caching(self):
+ """
+ Federation `/hierarchy` responses should be cached.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_subspace = "#space:" + fed_hostname
+ fed_room = "#room:" + fed_hostname
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, fed_subspace, self.token, via=[fed_hostname])
+
+ federation_requests = 0
+
+ async def get_room_hierarchy(
+ _self: TransportLayerClient,
+ destination: str,
+ room_id: str,
+ suggested_only: bool,
+ ) -> JsonDict:
+ nonlocal federation_requests
+ federation_requests += 1
+
+ return {
+ "room": {
+ "room_id": fed_subspace,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ "children_state": [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": fed_subspace,
+ "state_key": fed_room,
+ "content": {"via": [fed_hostname]},
+ },
+ ],
+ },
+ "children": [
+ {
+ "room_id": fed_room,
+ "world_readable": True,
+ },
+ ],
+ "inaccessible_children": [],
+ }
+
+ expected = [
+ (self.space, [self.room, fed_subspace]),
+ (self.room, ()),
+ (fed_subspace, [fed_room]),
+ (fed_room, ()),
+ ]
+
+ with mock.patch(
+ "synapse.federation.transport.client.TransportLayerClient.get_room_hierarchy",
+ new=get_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # The previous federation response should be reused.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # Expire the response cache
+ self.reactor.advance(5 * 60 + 1)
+
+ # A new federation request should be made.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 2)
+ self._assert_hierarchy(result, expected)
+
class RoomSummaryTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 638186f173f0..07a760e91aed 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -27,6 +26,7 @@
import tests.unittest
import tests.utils
+from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -186,6 +186,97 @@ def test_unknown_room_version(self):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+ def test_ban_wins_race_with_join(self):
+ """Rooms shouldn't appear under "joined" if a join loses a race to a ban.
+
+ A complicated edge case. Imagine the following scenario:
+
+ * you attempt to join a room
+ * racing with that is a ban which comes in over federation, which ends up with
+ an earlier stream_ordering than the join.
+ * you get a sync response with a sync token which is _after_ the ban, but before
+ the join
+ * now your join lands; it is a valid event because its `prev_event`s predate the
+ ban, but will not make it into current_state_events (because bans win over
+ joins in state res, essentially).
+ * When we do a sync from the incremental sync, the only event in the timeline
+ is your join ... and yet you aren't joined.
+
+ The ban coming in over federation isn't crucial for this behaviour; the key
+ requirements are:
+ 1. the homeserver generates a join event with prev_events that precede the ban
+ (so that it passes the "are you banned" test)
+ 2. the join event has a stream_ordering after that of the ban.
+
+ We use monkeypatching to artificially trigger condition (1).
+ """
+ # A local user Alice creates a room.
+ owner = self.register_user("alice", "password")
+ owner_tok = self.login(owner, "password")
+ room_id = self.helper.create_room_as(owner, is_public=True, tok=owner_tok)
+
+ # Do a sync as Alice to get the latest event in the room.
+ alice_sync_result: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(owner), generate_sync_config(owner)
+ )
+ )
+ self.assertEqual(len(alice_sync_result.joined), 1)
+ self.assertEqual(alice_sync_result.joined[0].room_id, room_id)
+ last_room_creation_event_id = (
+ alice_sync_result.joined[0].timeline.events[-1].event_id
+ )
+
+ # Eve, a ne'er-do-well, registers.
+ eve = self.register_user("eve", "password")
+ eve_token = self.login(eve, "password")
+
+ # Alice preemptively bans Eve.
+ self.helper.ban(room_id, owner, eve, tok=owner_tok)
+
+ # Eve syncs.
+ eve_requester = create_requester(eve)
+ eve_sync_config = generate_sync_config(eve)
+ eve_sync_after_ban: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config)
+ )
+
+ # Sanity check this sync result. We shouldn't be joined to the room.
+ self.assertEqual(eve_sync_after_ban.joined, [])
+
+ # Eve tries to join the room. We monkey patch the internal logic which selects
+ # the prev_events used when creating the join event, such that the ban does not
+ # precede the join.
+ mocked_get_prev_events = patch.object(
+ self.hs.get_datastore(),
+ "get_prev_events_for_room",
+ new_callable=MagicMock,
+ return_value=make_awaitable([last_room_creation_event_id]),
+ )
+ with mocked_get_prev_events:
+ self.helper.join(room_id, eve, tok=eve_token)
+
+ # Eve makes a second, incremental sync.
+ eve_incremental_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=eve_sync_after_ban.next_batch,
+ )
+ )
+ # Eve should not see herself as joined to the room.
+ self.assertEqual(eve_incremental_sync_after_join.joined, [])
+
+ # If we did a third initial sync, we should _still_ see eve is not joined to the room.
+ eve_initial_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=None,
+ )
+ )
+ self.assertEqual(eve_initial_sync_after_join.joined, [])
+
_request_key = 0
diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py
new file mode 100644
index 000000000000..ee5cf299f64c
--- /dev/null
+++ b/tests/http/test_webclient.py
@@ -0,0 +1,108 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import Dict
+
+from twisted.web.resource import Resource
+
+from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig
+from synapse.http.site import SynapseSite
+
+from tests.server import make_request
+from tests.unittest import HomeserverTestCase, create_resource_tree, override_config
+
+
+class WebClientTests(HomeserverTestCase):
+ @override_config(
+ {
+ "web_client_location": "https://example.org",
+ }
+ )
+ def test_webclient_resolves_with_client_resource(self):
+ """
+ Tests that both client and webclient resources can be accessed simultaneously.
+
+ This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763.
+ """
+ for resource_name_order_list in [
+ ["webclient", "client"],
+ ["client", "webclient"],
+ ]:
+ # Create a dictionary from path regex -> resource
+ resource_dict: Dict[str, Resource] = {}
+
+ for resource_name in resource_name_order_list:
+ resource_dict.update(
+ SynapseHomeServer._configure_named_resource(self.hs, resource_name)
+ )
+
+ # Create a root resource which ties the above resources together into one
+ root_resource = Resource()
+ create_resource_tree(resource_dict, root_resource)
+
+ # Create a site configured with this resource to make HTTP requests against
+ listener_config = ListenerConfig(
+ port=8008,
+ bind_addresses=["127.0.0.1"],
+ type="http",
+ http_options=HttpListenerConfig(
+ resources=[HttpResourceConfig(names=resource_name_order_list)]
+ ),
+ )
+ test_site = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag=self.hs.config.server.server_name,
+ config=listener_config,
+ resource=root_resource,
+ server_version_string="1",
+ max_request_body_size=1234,
+ reactor=self.reactor,
+ )
+
+ # Attempt to make requests to endpoints on both the webclient and client resources
+ # on test_site.
+ self._request_client_and_webclient_resources(test_site)
+
+ def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None:
+ """Make a request to an endpoint on both the webclient and client-server resources
+ of the given SynapseSite.
+
+ Args:
+ test_site: The SynapseSite object to make requests against.
+ """
+
+ # Ensure that the *webclient* resource is behaving as expected (we get redirected to
+ # the configured web_client_location)
+ channel = make_request(
+ self.reactor,
+ site=test_site,
+ method="GET",
+ path="/_matrix/client",
+ )
+ # Check that we are being redirected to the webclient location URI.
+ self.assertEqual(channel.code, HTTPStatus.FOUND)
+ self.assertEqual(
+ channel.headers.getRawHeaders("Location"), ["https://example.org"]
+ )
+
+ # Ensure that a request to the *client* resource works.
+ channel = make_request(
+ self.reactor,
+ site=test_site,
+ method="GET",
+ path="/_matrix/client/v3/login",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertIn("flows", channel.json_body)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b25a06b4271a..eca6a443af78 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
@@ -166,7 +167,7 @@ def test_push_actions_for_user(self):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
)
self.persist(
@@ -179,7 +180,7 @@ def test_push_actions_for_user(self):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
)
self.persist(
@@ -194,7 +195,7 @@ def test_push_actions_for_user(self):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
+ NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
)
def test_get_rooms_for_user_with_stream_ordering(self):
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 04a869e29549..1b6a4bf4b0b1 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -62,7 +62,11 @@ def test_federation_ack_sent(self):
"federation",
"master",
token=10,
- rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])],
+ rows=[
+ FederationStream.FederationStreamRow(
+ type="x", data={"test": [1, 2, 3]}
+ )
+ ],
)
)
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 4d152c0d66c2..1e3fe9c62c00 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -23,6 +23,7 @@
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -96,7 +97,7 @@ def test_invalid_parameter(self) -> None:
def _register_bg_update(self) -> None:
"Adds a bg update but doesn't start it"
- async def _fake_update(progress, batch_size) -> int:
+ async def _fake_update(progress: JsonDict, batch_size: int) -> int:
await self.clock.sleep(0.2)
return batch_size
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 5188499ef2d6..b70350b6f1d7 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -16,11 +16,14 @@
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
@@ -31,7 +34,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -44,7 +47,7 @@ def prepare(self, reactor, clock, hs: HomeServer):
("/_synapse/admin/v1/federation/destinations/dummy",),
]
)
- def test_requester_is_no_admin(self, url: str):
+ def test_requester_is_no_admin(self, url: str) -> None:
"""
If the user is not a server admin, an error 403 is returned.
"""
@@ -62,7 +65,7 @@ def test_requester_is_no_admin(self, url: str):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_invalid_parameter(self):
+ def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
"""
@@ -95,7 +98,7 @@ def test_invalid_parameter(self):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
channel = self.make_request(
@@ -105,7 +108,7 @@ def test_invalid_parameter(self):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
channel = self.make_request(
@@ -117,7 +120,7 @@ def test_invalid_parameter(self):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_limit(self):
+ def test_limit(self) -> None:
"""
Testing list of destinations with limit
"""
@@ -137,7 +140,7 @@ def test_limit(self):
self.assertEqual(channel.json_body["next_token"], "5")
self._check_fields(channel.json_body["destinations"])
- def test_from(self):
+ def test_from(self) -> None:
"""
Testing list of destinations with a defined starting point (from)
"""
@@ -157,7 +160,7 @@ def test_from(self):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["destinations"])
- def test_limit_and_from(self):
+ def test_limit_and_from(self) -> None:
"""
Testing list of destinations with a defined starting point and limit
"""
@@ -177,7 +180,7 @@ def test_limit_and_from(self):
self.assertEqual(len(channel.json_body["destinations"]), 10)
self._check_fields(channel.json_body["destinations"])
- def test_next_token(self):
+ def test_next_token(self) -> None:
"""
Testing that `next_token` appears at the right place
"""
@@ -238,7 +241,7 @@ def test_next_token(self):
self.assertEqual(len(channel.json_body["destinations"]), 1)
self.assertNotIn("next_token", channel.json_body)
- def test_list_all_destinations(self):
+ def test_list_all_destinations(self) -> None:
"""
List all destinations.
"""
@@ -259,7 +262,7 @@ def test_list_all_destinations(self):
# Check that all fields are available
self._check_fields(channel.json_body["destinations"])
- def test_order_by(self):
+ def test_order_by(self) -> None:
"""
Testing order list with parameter `order_by`
"""
@@ -268,7 +271,7 @@ def _order_test(
expected_destination_list: List[str],
order_by: Optional[str],
dir: Optional[str] = None,
- ):
+ ) -> None:
"""Request the list of destinations in a certain order.
Assert that order is what we expect
@@ -311,15 +314,12 @@ def _order_test(
retry_interval,
last_successful_stream_ordering,
) in dest:
- self.get_success(
- self.store.set_destination_retry_timings(
- destination, failure_ts, retry_last_ts, retry_interval
- )
- )
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(
- destination, last_successful_stream_ordering
- )
+ self._create_destination(
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ last_successful_stream_ordering,
)
# order by default (destination)
@@ -358,13 +358,13 @@ def _order_test(
[dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b"
)
- def test_search_term(self):
+ def test_search_term(self) -> None:
"""Test that searching for a destination works correctly"""
def _search_test(
expected_destination: Optional[str],
search_term: str,
- ):
+ ) -> None:
"""Search for a destination and check that the returned destinationis a match
Args:
@@ -410,11 +410,9 @@ def _search_test(
_search_test(None, "foo")
_search_test(None, "bar")
- def test_get_single_destination(self):
- """
- Get one specific destinations.
- """
- self._create_destinations(5)
+ def test_get_single_destination_with_retry_timings(self) -> None:
+ """Get one specific destination which has retry timings."""
+ self._create_destinations(1)
channel = self.make_request(
"GET",
@@ -429,7 +427,54 @@ def test_get_single_destination(self):
# convert channel.json_body into a List
self._check_fields([channel.json_body])
- def _create_destinations(self, number_destinations: int):
+ def test_get_single_destination_no_retry_timings(self) -> None:
+ """Get one specific destination which has no retry timings."""
+ self._create_destination("sub0.example.com")
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/sub0.example.com",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual("sub0.example.com", channel.json_body["destination"])
+ self.assertEqual(0, channel.json_body["retry_last_ts"])
+ self.assertEqual(0, channel.json_body["retry_interval"])
+ self.assertIsNone(channel.json_body["failure_ts"])
+ self.assertIsNone(channel.json_body["last_successful_stream_ordering"])
+
+ def _create_destination(
+ self,
+ destination: str,
+ failure_ts: Optional[int] = None,
+ retry_last_ts: int = 0,
+ retry_interval: int = 0,
+ last_successful_stream_ordering: Optional[int] = None,
+ ) -> None:
+ """Create one specific destination
+
+ Args:
+ destination: the destination we have successfully sent to
+ failure_ts: when the server started failing (ms since epoch)
+ retry_last_ts: time of last retry attempt in unix epoch ms
+ retry_interval: how long until next retry in ms
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ self.get_success(
+ self.store.set_destination_retry_timings(
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ )
+ if last_successful_stream_ordering is not None:
+ self.get_success(
+ self.store.set_destination_last_successful_stream_ordering(
+ destination, last_successful_stream_ordering
+ )
+ )
+
+ def _create_destinations(self, number_destinations: int) -> None:
"""Create a number of destinations
Args:
@@ -437,12 +482,9 @@ def _create_destinations(self, number_destinations: int):
"""
for i in range(0, number_destinations):
dest = f"sub{i}.example.com"
- self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50))
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(dest, 100)
- )
+ self._create_destination(dest, 50, 50, 50, 100)
- def _check_fields(self, content: List[JsonDict]):
+ def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected destination attributes are present in content
Args:
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 81e578fd26c1..86aff7575c9a 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -360,7 +360,7 @@ def test_invalid_parameter(self) -> None:
channel.code,
msg=channel.json_body,
)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
channel.json_body["error"],
@@ -580,7 +580,9 @@ def _create_media(self) -> str:
return server_and_media_id
- def _access_media(self, server_and_media_id, expect_success=True) -> None:
+ def _access_media(
+ self, server_and_media_id: str, expect_success: bool = True
+ ) -> None:
"""
Try to access a media and check the result
"""
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 350a62dda672..8513b1d2df53 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -14,6 +14,7 @@
import random
import string
from http import HTTPStatus
+from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -42,21 +43,27 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/_synapse/admin/v1/registration_tokens"
- def _new_token(self, **kwargs) -> str:
+ def _new_token(
+ self,
+ token: Optional[str] = None,
+ uses_allowed: Optional[int] = None,
+ pending: int = 0,
+ completed: int = 0,
+ expiry_time: Optional[int] = None,
+ ) -> str:
"""Helper function to create a token."""
- token = kwargs.get(
- "token",
- "".join(random.choices(string.ascii_letters, k=8)),
- )
+ if token is None:
+ token = "".join(random.choices(string.ascii_letters, k=8))
+
self.get_success(
self.store.db_pool.simple_insert(
"registration_tokens",
{
"token": token,
- "uses_allowed": kwargs.get("uses_allowed", None),
- "pending": kwargs.get("pending", 0),
- "completed": kwargs.get("completed", 0),
- "expiry_time": kwargs.get("expiry_time", None),
+ "uses_allowed": uses_allowed,
+ "pending": pending,
+ "completed": completed,
+ "expiry_time": expiry_time,
},
)
)
@@ -216,20 +223,13 @@ def test_create_unable_to_generate_token(self) -> None:
# Create all possible single character tokens
tokens = []
for c in string.ascii_letters + string.digits + "._~-":
- tokens.append(
- {
- "token": c,
- "uses_allowed": None,
- "pending": 0,
- "completed": 0,
- "expiry_time": None,
- }
- )
+ tokens.append((c, None, 0, 0, None))
self.get_success(
self.store.db_pool.simple_insert_many(
"registration_tokens",
- tokens,
- "create_all_registration_tokens",
+ keys=("token", "uses_allowed", "pending", "completed", "expiry_time"),
+ values=tokens,
+ desc="create_all_registration_tokens",
)
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 22f9aa62346a..3495a0366ad3 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -66,7 +66,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
)
self.url = "/_synapse/admin/v1/rooms/%s" % self.room_id
- def test_requester_is_no_admin(self):
+ def test_requester_is_no_admin(self) -> None:
"""
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
@@ -81,7 +81,7 @@ def test_requester_is_no_admin(self):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_room_does_not_exist(self):
+ def test_room_does_not_exist(self) -> None:
"""
Check that unknown rooms/server return 200
"""
@@ -96,7 +96,7 @@ def test_room_does_not_exist(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- def test_room_is_not_valid(self):
+ def test_room_is_not_valid(self) -> None:
"""
Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
"""
@@ -115,7 +115,7 @@ def test_room_is_not_valid(self):
channel.json_body["error"],
)
- def test_new_room_user_does_not_exist(self):
+ def test_new_room_user_does_not_exist(self) -> None:
"""
Tests that the user ID must be from local server but it does not have to exist.
"""
@@ -133,7 +133,7 @@ def test_new_room_user_does_not_exist(self):
self.assertIn("failed_to_kick_users", channel.json_body)
self.assertIn("local_aliases", channel.json_body)
- def test_new_room_user_is_not_local(self):
+ def test_new_room_user_is_not_local(self) -> None:
"""
Check that only local users can create new room to move members.
"""
@@ -151,7 +151,7 @@ def test_new_room_user_is_not_local(self):
channel.json_body["error"],
)
- def test_block_is_not_bool(self):
+ def test_block_is_not_bool(self) -> None:
"""
If parameter `block` is not boolean, return an error
"""
@@ -166,7 +166,7 @@ def test_block_is_not_bool(self):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
- def test_purge_is_not_bool(self):
+ def test_purge_is_not_bool(self) -> None:
"""
If parameter `purge` is not boolean, return an error
"""
@@ -181,7 +181,7 @@ def test_purge_is_not_bool(self):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
- def test_purge_room_and_block(self):
+ def test_purge_room_and_block(self) -> None:
"""Test to purge a room and block it.
Members will not be moved to a new room and will not receive a message.
"""
@@ -212,7 +212,7 @@ def test_purge_room_and_block(self):
self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
- def test_purge_room_and_not_block(self):
+ def test_purge_room_and_not_block(self) -> None:
"""Test to purge a room and do not block it.
Members will not be moved to a new room and will not receive a message.
"""
@@ -243,7 +243,7 @@ def test_purge_room_and_not_block(self):
self._is_blocked(self.room_id, expect=False)
self._has_no_members(self.room_id)
- def test_block_room_and_not_purge(self):
+ def test_block_room_and_not_purge(self) -> None:
"""Test to block a room without purging it.
Members will not be moved to a new room and will not receive a message.
The room will not be purged.
@@ -299,7 +299,7 @@ def test_block_unknown_room(self, purge: bool) -> None:
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
- def test_shutdown_room_consent(self):
+ def test_shutdown_room_consent(self) -> None:
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
force part the user from the old room.
@@ -351,7 +351,7 @@ def test_shutdown_room_consent(self):
self._is_purged(self.room_id)
self._has_no_members(self.room_id)
- def test_shutdown_room_block_peek(self):
+ def test_shutdown_room_block_peek(self) -> None:
"""Test that a world_readable room can no longer be peeked into after
it has been shut down.
Members will be moved to a new room and will receive a message.
@@ -400,7 +400,7 @@ def test_shutdown_room_block_peek(self):
# Assert we can no longer peek into the room
self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
- def _is_blocked(self, room_id, expect=True):
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
d = self.store.is_room_blocked(room_id)
if expect:
@@ -408,17 +408,17 @@ def _is_blocked(self, room_id, expect=True):
else:
self.assertIsNone(self.get_success(d))
- def _has_no_members(self, room_id):
+ def _has_no_members(self, room_id: str) -> None:
"""Assert there is now no longer anyone in the room"""
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room)
- def _is_member(self, room_id, user_id):
+ def _is_member(self, room_id: str, user_id: str) -> None:
"""Test that user is member of the room"""
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertIn(user_id, users_in_room)
- def _is_purged(self, room_id):
+ def _is_purged(self, room_id: str) -> None:
"""Test that the following tables have been purged of all rows related to the room."""
for table in PURGE_TABLES:
count = self.get_success(
@@ -432,7 +432,7 @@ def _is_purged(self, room_id):
self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
- def _assert_peek(self, room_id, expect_code):
+ def _assert_peek(self, room_id: str, expect_code: int) -> None:
"""Assert that the admin user can (or cannot) peek into the room."""
url = "rooms/%s/initialSync" % (room_id,)
@@ -492,7 +492,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
]
)
- def test_requester_is_no_admin(self, method: str, url: str):
+ def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
@@ -507,7 +507,7 @@ def test_requester_is_no_admin(self, method: str, url: str):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_room_does_not_exist(self):
+ def test_room_does_not_exist(self) -> None:
"""
Check that unknown rooms/server return 200
@@ -544,7 +544,7 @@ def test_room_does_not_exist(self):
("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
]
)
- def test_room_is_not_valid(self, method: str, url: str):
+ def test_room_is_not_valid(self, method: str, url: str) -> None:
"""
Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
"""
@@ -562,7 +562,7 @@ def test_room_is_not_valid(self, method: str, url: str):
channel.json_body["error"],
)
- def test_new_room_user_does_not_exist(self):
+ def test_new_room_user_does_not_exist(self) -> None:
"""
Tests that the user ID must be from local server but it does not have to exist.
"""
@@ -580,7 +580,7 @@ def test_new_room_user_does_not_exist(self):
self._test_result(delete_id, self.other_user, expect_new_room=True)
- def test_new_room_user_is_not_local(self):
+ def test_new_room_user_is_not_local(self) -> None:
"""
Check that only local users can create new room to move members.
"""
@@ -598,7 +598,7 @@ def test_new_room_user_is_not_local(self):
channel.json_body["error"],
)
- def test_block_is_not_bool(self):
+ def test_block_is_not_bool(self) -> None:
"""
If parameter `block` is not boolean, return an error
"""
@@ -613,7 +613,7 @@ def test_block_is_not_bool(self):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
- def test_purge_is_not_bool(self):
+ def test_purge_is_not_bool(self) -> None:
"""
If parameter `purge` is not boolean, return an error
"""
@@ -628,7 +628,7 @@ def test_purge_is_not_bool(self):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
- def test_delete_expired_status(self):
+ def test_delete_expired_status(self) -> None:
"""Test that the task status is removed after expiration."""
# first task, do not purge, that we can create a second task
@@ -699,7 +699,7 @@ def test_delete_expired_status(self):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_delete_same_room_twice(self):
+ def test_delete_same_room_twice(self) -> None:
"""Test that the call for delete a room at second time gives an exception."""
body = {"new_room_user_id": self.admin_user}
@@ -743,7 +743,7 @@ def test_delete_same_room_twice(self):
expect_new_room=True,
)
- def test_purge_room_and_block(self):
+ def test_purge_room_and_block(self) -> None:
"""Test to purge a room and block it.
Members will not be moved to a new room and will not receive a message.
"""
@@ -774,7 +774,7 @@ def test_purge_room_and_block(self):
self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
- def test_purge_room_and_not_block(self):
+ def test_purge_room_and_not_block(self) -> None:
"""Test to purge a room and do not block it.
Members will not be moved to a new room and will not receive a message.
"""
@@ -805,7 +805,7 @@ def test_purge_room_and_not_block(self):
self._is_blocked(self.room_id, expect=False)
self._has_no_members(self.room_id)
- def test_block_room_and_not_purge(self):
+ def test_block_room_and_not_purge(self) -> None:
"""Test to block a room without purging it.
Members will not be moved to a new room and will not receive a message.
The room will not be purged.
@@ -838,7 +838,7 @@ def test_block_room_and_not_purge(self):
self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
- def test_shutdown_room_consent(self):
+ def test_shutdown_room_consent(self) -> None:
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
force part the user from the old room.
@@ -899,7 +899,7 @@ def test_shutdown_room_consent(self):
self._is_purged(self.room_id)
self._has_no_members(self.room_id)
- def test_shutdown_room_block_peek(self):
+ def test_shutdown_room_block_peek(self) -> None:
"""Test that a world_readable room can no longer be peeked into after
it has been shut down.
Members will be moved to a new room and will receive a message.
@@ -1089,6 +1089,8 @@ def test_list_rooms(self) -> None:
)
room_ids.append(room_id)
+ room_ids.sort()
+
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
channel = self.make_request(
@@ -1360,6 +1362,12 @@ def _order_test(
room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ # Also create a list sorted by IDs for properties that are equal (and thus sorted by room_id)
+ sorted_by_room_id_asc = [room_id_1, room_id_2, room_id_3]
+ sorted_by_room_id_asc.sort()
+ sorted_by_room_id_desc = sorted_by_room_id_asc.copy()
+ sorted_by_room_id_desc.reverse()
+
# Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
self.helper.send_state(
room_id_1,
@@ -1405,41 +1413,42 @@ def _order_test(
_order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
_order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+ # Note: joined_member counts are sorted in descending order when dir=f
_order_test("joined_members", [room_id_3, room_id_2, room_id_1])
_order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+ # Note: joined_local_member counts are sorted in descending order when dir=f
_order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
_order_test(
"joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
)
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+ # Note: versions are sorted in descending order when dir=f
+ _order_test("version", sorted_by_room_id_asc, reverse=True)
+ _order_test("version", sorted_by_room_id_desc)
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("creator", sorted_by_room_id_asc)
+ _order_test("creator", sorted_by_room_id_desc, reverse=True)
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("encryption", sorted_by_room_id_asc)
+ _order_test("encryption", sorted_by_room_id_desc, reverse=True)
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("federatable", sorted_by_room_id_asc)
+ _order_test("federatable", sorted_by_room_id_desc, reverse=True)
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+ _order_test("public", sorted_by_room_id_asc)
+ _order_test("public", sorted_by_room_id_desc, reverse=True)
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("join_rules", sorted_by_room_id_asc)
+ _order_test("join_rules", sorted_by_room_id_desc, reverse=True)
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("guest_access", sorted_by_room_id_asc)
+ _order_test("guest_access", sorted_by_room_id_desc, reverse=True)
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
+ _order_test("history_visibility", sorted_by_room_id_asc)
+ _order_test("history_visibility", sorted_by_room_id_desc, reverse=True)
+ # Note: state_event counts are sorted in descending order when dir=f
_order_test("state_events", [room_id_3, room_id_2, room_id_1])
_order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 4fedd5fd0851..9711405735db 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -608,7 +608,7 @@ def test_invalid_parameter(self):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid deactivated
channel = self.make_request(
@@ -618,7 +618,7 @@ def test_invalid_parameter(self):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
channel = self.make_request(
@@ -628,7 +628,7 @@ def test_invalid_parameter(self):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
channel = self.make_request(
@@ -638,7 +638,7 @@ def test_invalid_parameter(self):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self):
"""
@@ -1181,6 +1181,7 @@ def prepare(self, reactor, clock, hs):
self.other_user, device_id=None, valid_until_ms=None
)
)
+
self.url_prefix = "/_synapse/admin/v2/users/%s"
self.url_other_user = self.url_prefix % self.other_user
@@ -1188,7 +1189,7 @@ def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
channel = self.make_request(
"GET",
@@ -1216,7 +1217,7 @@ def test_user_does_not_exist(self):
channel = self.make_request(
"GET",
- "/_synapse/admin/v2/users/@unknown_person:test",
+ self.url_prefix % "@unknown_person:test",
access_token=self.admin_user_tok,
)
@@ -1337,7 +1338,7 @@ def test_create_server_admin(self):
"""
Check that a new admin user is created successfully.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user (server admin)
body = {
@@ -1386,7 +1387,7 @@ def test_create_user(self):
"""
Check that a new regular user is created successfully.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -1478,7 +1479,7 @@ def test_create_user_mau_limit_reached_active_admin(self):
)
# Register new user with admin API
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -1515,7 +1516,7 @@ def test_create_user_mau_limit_reached_passive_admin(self):
)
# Register new user with admin API
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -1545,12 +1546,13 @@ def test_create_user_email_notif_for_new_users(self):
Check that a new regular user is created successfully and
got an email pusher.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
"password": "abc123",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ # Note that the given email is not in canonical form.
+ "threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
}
channel = self.make_request(
@@ -1587,7 +1589,7 @@ def test_create_user_email_no_notif_for_new_users(self):
Check that a new regular user is created successfully and
got not an email pusher.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -2084,10 +2086,13 @@ def test_deactivate_user(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
+
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
# the user is deactivated, the threepid will be deleted
# Get user
@@ -2100,11 +2105,13 @@ def test_deactivate_user(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self):
"""
@@ -2176,9 +2183,11 @@ def test_reactivate_user(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"password_config": {"localdb_enabled": False}})
def test_reactivate_user_localdb_disabled(self):
"""
@@ -2208,9 +2217,11 @@ def test_reactivate_user_localdb_disabled(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"password_config": {"enabled": False}})
def test_reactivate_user_password_disabled(self):
"""
@@ -2240,9 +2251,11 @@ def test_reactivate_user_password_disabled(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
@@ -2327,7 +2340,7 @@ def test_accidental_deactivation_prevention(self):
Ensure an account can't accidentally be deactivated by using a str value
for the deactivated body parameter
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -2391,18 +2404,20 @@ def _deactivate_user(self, user_id: str) -> None:
# Deactivate the user.
channel = self.make_request(
"PUT",
- "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ self.url_prefix % urllib.parse.quote(user_id),
access_token=self.admin_user_tok,
content={"deactivated": True},
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
self.assertIsNone(self.get_success(d))
self._is_erased(user_id, True)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
@@ -2415,13 +2430,15 @@ def _check_fields(self, content: JsonDict):
self.assertIn("admin", content)
self.assertIn("deactivated", content)
self.assertIn("shadow_banned", content)
- self.assertIn("password_hash", content)
self.assertIn("creation_ts", content)
self.assertIn("appservice_id", content)
self.assertIn("consent_server_notice_sent", content)
self.assertIn("consent_version", content)
self.assertIn("external_ids", content)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", content)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
@@ -2896,7 +2913,7 @@ def test_invalid_parameter(self, method: str):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
channel = self.make_request(
@@ -2906,7 +2923,7 @@ def test_invalid_parameter(self, method: str):
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
channel = self.make_request(
@@ -3882,3 +3899,93 @@ def test_success(self):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
+
+
+class AccountDataTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs) -> None:
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = f"/_synapse/admin/v1/users/{self.other_user}/accountdata"
+
+ def test_no_auth(self) -> None:
+ """Try to get information of a user without authentication."""
+ channel = self.make_request("GET", self.url, {})
+
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self) -> None:
+ """If the user is not a server admin, an error is returned."""
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self) -> None:
+ """Tests that a lookup for a user that does not exist returns a 404"""
+ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self) -> None:
+ """Tests that a lookup for a user that is not a local returns a 400"""
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/accountdata"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only look up local users", channel.json_body["error"])
+
+ def test_success(self) -> None:
+ """Request account data should succeed for an admin."""
+
+ # add account data
+ self.get_success(
+ self.store.add_account_data_for_user(self.other_user, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self.store.add_account_data_to_room(
+ self.other_user, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"a": 1}, channel.json_body["account_data"]["global"]["m.global"]
+ )
+ self.assertEqual(
+ {"b": 2},
+ channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"],
+ )
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 72bbc87b4a0c..27cb856b0acd 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -85,7 +85,7 @@ def recaptcha(
channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
channel = self.make_request(
"POST",
@@ -104,7 +104,7 @@ def test_fallback_captcha(self):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
- 401,
+ HTTPStatus.UNAUTHORIZED,
{"username": "user", "type": "m.login.password", "password": "bar"},
)
@@ -116,15 +116,17 @@ def test_fallback_captcha(self):
)
# Complete the recaptcha step.
- self.recaptcha(session, 200)
+ self.recaptcha(session, HTTPStatus.OK)
# also complete the dummy auth
- self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}})
+ self.register(
+ HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}}
+ )
# Now we should have fulfilled a complete auth flow, including
# the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
- channel = self.register(200, {"auth": {"session": session}})
+ channel = self.register(HTTPStatus.OK, {"auth": {"session": session}})
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
@@ -137,7 +139,8 @@ def test_complete_operation_unknown_session(self):
# will be used.)
# Returns a 401 as per the spec
channel = self.register(
- 401, {"username": "user", "type": "m.login.password", "password": "bar"}
+ HTTPStatus.UNAUTHORIZED,
+ {"username": "user", "type": "m.login.password", "password": "bar"},
)
# Grab the session
@@ -231,7 +234,9 @@ def test_ui_auth(self):
"""
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
# Grab the session
session = channel.json_body["session"]
@@ -242,7 +247,7 @@ def test_ui_auth(self):
self.delete_device(
self.user_tok,
self.device_id,
- 200,
+ HTTPStatus.OK,
{
"auth": {
"type": "m.login.password",
@@ -260,14 +265,16 @@ def test_grandfathered_identifier(self):
UIA - check that still works.
"""
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
self.user_tok,
self.device_id,
- 200,
+ HTTPStatus.OK,
{
"auth": {
"type": "m.login.password",
@@ -293,7 +300,9 @@ def test_can_change_body(self):
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_devices(401, {"devices": [self.device_id]})
+ channel = self.delete_devices(
+ HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]}
+ )
# Grab the session
session = channel.json_body["session"]
@@ -303,7 +312,7 @@ def test_can_change_body(self):
# Make another request providing the UI auth flow, but try to delete the
# second device.
self.delete_devices(
- 200,
+ HTTPStatus.OK,
{
"devices": ["dev2"],
"auth": {
@@ -324,7 +333,9 @@ def test_cannot_change_uri(self):
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
# Grab the session
session = channel.json_body["session"]
@@ -338,7 +349,7 @@ def test_cannot_change_uri(self):
self.delete_device(
self.user_tok,
"dev2",
- 403,
+ HTTPStatus.FORBIDDEN,
{
"auth": {
"type": "m.login.password",
@@ -361,13 +372,13 @@ def test_can_reuse_session(self):
self.login("test", self.user_pass, "dev3")
# Attempt to delete a device. This works since the user just logged in.
- self.delete_device(self.user_tok, "dev2", 200)
+ self.delete_device(self.user_tok, "dev2", HTTPStatus.OK)
# Move the clock forward past the validation timeout.
self.reactor.advance(6)
# Deleting another devices throws the user into UI auth.
- channel = self.delete_device(self.user_tok, "dev3", 401)
+ channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED)
# Grab the session
session = channel.json_body["session"]
@@ -378,7 +389,7 @@ def test_can_reuse_session(self):
self.delete_device(
self.user_tok,
"dev3",
- 200,
+ HTTPStatus.OK,
{
"auth": {
"type": "m.login.password",
@@ -393,7 +404,7 @@ def test_can_reuse_session(self):
# due to re-using the previous session.
#
# Note that *no auth* information is provided, not even a session iD!
- self.delete_device(self.user_tok, self.device_id, 200)
+ self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK)
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
@@ -413,7 +424,9 @@ def test_ui_auth_via_sso(self):
self.assertEqual(login_resp["user_id"], self.user)
# initiate a UI Auth process by attempting to delete the device
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
# check that SSO is offered
flows = channel.json_body["flows"]
@@ -426,13 +439,13 @@ def test_ui_auth_via_sso(self):
)
# that should serve a confirmation page
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# and now the delete request should succeed.
self.delete_device(
self.user_tok,
self.device_id,
- 200,
+ HTTPStatus.OK,
body={"auth": {"session": session_id}},
)
@@ -445,13 +458,15 @@ def test_does_not_offer_password_for_sso_user(self):
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
- channel = self.delete_device(user_tok, device_id, 401)
+ channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
def test_does_not_offer_sso_for_password_user(self):
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
@@ -463,7 +478,9 @@ def test_offers_both_flows_for_upgraded_user(self):
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
flows = channel.json_body["flows"]
# we have no particular expectations of ordering here
@@ -480,7 +497,9 @@ def test_ui_auth_fails_for_incorrect_sso_user(self):
self.assertEqual(login_resp["user_id"], self.user)
# start a UI Auth flow by attempting to delete a device
- channel = self.delete_device(self.user_tok, self.device_id, 401)
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
flows = channel.json_body["flows"]
self.assertIn({"stages": ["m.login.sso"]}, flows)
@@ -496,7 +515,10 @@ def test_ui_auth_fails_for_incorrect_sso_user(self):
# ... and the delete op should now fail with a 403
self.delete_device(
- self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+ self.user_tok,
+ self.device_id,
+ HTTPStatus.FORBIDDEN,
+ body={"auth": {"session": session_id}},
)
@@ -551,7 +573,9 @@ def test_login_issue_refresh_token(self):
login_without_refresh = self.make_request(
"POST", "/_matrix/client/r0/login", body
)
- self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result)
+ self.assertEqual(
+ login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result
+ )
self.assertNotIn("refresh_token", login_without_refresh.json_body)
login_with_refresh = self.make_request(
@@ -559,7 +583,9 @@ def test_login_issue_refresh_token(self):
"/_matrix/client/r0/login",
{"refresh_token": True, **body},
)
- self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
+ self.assertEqual(
+ login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result
+ )
self.assertIn("refresh_token", login_with_refresh.json_body)
self.assertIn("expires_in_ms", login_with_refresh.json_body)
@@ -577,7 +603,9 @@ def test_register_issue_refresh_token(self):
},
)
self.assertEqual(
- register_without_refresh.code, 200, register_without_refresh.result
+ register_without_refresh.code,
+ HTTPStatus.OK,
+ register_without_refresh.result,
)
self.assertNotIn("refresh_token", register_without_refresh.json_body)
@@ -591,7 +619,9 @@ def test_register_issue_refresh_token(self):
"refresh_token": True,
},
)
- self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
+ self.assertEqual(
+ register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result
+ )
self.assertIn("refresh_token", register_with_refresh.json_body)
self.assertIn("expires_in_ms", register_with_refresh.json_body)
@@ -610,14 +640,14 @@ def test_token_refresh(self):
"/_matrix/client/r0/login",
body,
)
- self.assertEqual(login_response.code, 200, login_response.result)
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
refresh_response = self.make_request(
"POST",
"/_matrix/client/v1/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
- self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
self.assertIn("access_token", refresh_response.json_body)
self.assertIn("refresh_token", refresh_response.json_body)
self.assertIn("expires_in_ms", refresh_response.json_body)
@@ -648,7 +678,7 @@ def test_refreshable_access_token_expiration(self):
"/_matrix/client/r0/login",
body,
)
- self.assertEqual(login_response.code, 200, login_response.result)
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
self.assertApproximates(
login_response.json_body["expires_in_ms"], 60 * 1000, 100
)
@@ -658,7 +688,7 @@ def test_refreshable_access_token_expiration(self):
"/_matrix/client/v1/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
- self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
self.assertApproximates(
refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
)
@@ -705,7 +735,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self)
"/_matrix/client/r0/login",
{"refresh_token": True, **body},
)
- self.assertEqual(login_response1.code, 200, login_response1.result)
+ self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result)
self.assertApproximates(
login_response1.json_body["expires_in_ms"], 60 * 1000, 100
)
@@ -716,7 +746,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self)
"/_matrix/client/r0/login",
body,
)
- self.assertEqual(login_response2.code, 200, login_response2.result)
+ self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result)
nonrefreshable_access_token = login_response2.json_body["access_token"]
# Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
@@ -818,7 +848,7 @@ def test_ultimate_session_expiry(self):
"/_matrix/client/r0/login",
body,
)
- self.assertEqual(login_response.code, 200, login_response.result)
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
refresh_token = login_response.json_body["refresh_token"]
# Advance shy of 2 minutes into the future
@@ -826,7 +856,7 @@ def test_ultimate_session_expiry(self):
# Refresh our session. The refresh token should still be valid right now.
refresh_response = self.use_refresh_token(refresh_token)
- self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
self.assertIn(
"refresh_token",
refresh_response.json_body,
@@ -846,7 +876,9 @@ def test_ultimate_session_expiry(self):
# This should fail because the refresh token's lifetime has also been
# diminished as our session expired.
refresh_response = self.use_refresh_token(refresh_token)
- self.assertEqual(refresh_response.code, 403, refresh_response.result)
+ self.assertEqual(
+ refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
+ )
def test_refresh_token_invalidation(self):
"""Refresh tokens are invalidated after first use of the next token.
@@ -875,7 +907,7 @@ def test_refresh_token_invalidation(self):
"/_matrix/client/r0/login",
body,
)
- self.assertEqual(login_response.code, 200, login_response.result)
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
# This first refresh should work properly
first_refresh_response = self.make_request(
@@ -884,7 +916,7 @@ def test_refresh_token_invalidation(self):
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
- first_refresh_response.code, 200, first_refresh_response.result
+ first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result
)
# This one as well, since the token in the first one was never used
@@ -894,7 +926,7 @@ def test_refresh_token_invalidation(self):
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
- second_refresh_response.code, 200, second_refresh_response.result
+ second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result
)
# This one should not, since the token from the first refresh is not valid anymore
@@ -904,7 +936,9 @@ def test_refresh_token_invalidation(self):
{"refresh_token": first_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
- third_refresh_response.code, 401, third_refresh_response.result
+ third_refresh_response.code,
+ HTTPStatus.UNAUTHORIZED,
+ third_refresh_response.result,
)
# The associated access token should also be invalid
@@ -913,7 +947,9 @@ def test_refresh_token_invalidation(self):
"/_matrix/client/r0/account/whoami",
access_token=first_refresh_response.json_body["access_token"],
)
- self.assertEqual(whoami_response.code, 401, whoami_response.result)
+ self.assertEqual(
+ whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result
+ )
# But all other tokens should work (they will expire after some time)
for access_token in [
@@ -923,7 +959,9 @@ def test_refresh_token_invalidation(self):
whoami_response = self.make_request(
"GET", "/_matrix/client/r0/account/whoami", access_token=access_token
)
- self.assertEqual(whoami_response.code, 200, whoami_response.result)
+ self.assertEqual(
+ whoami_response.code, HTTPStatus.OK, whoami_response.result
+ )
# Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
fourth_refresh_response = self.make_request(
@@ -932,7 +970,9 @@ def test_refresh_token_invalidation(self):
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
- fourth_refresh_response.code, 403, fourth_refresh_response.result
+ fourth_refresh_response.code,
+ HTTPStatus.FORBIDDEN,
+ fourth_refresh_response.result,
)
# But refreshing from the last valid refresh token still works
@@ -942,5 +982,5 @@ def test_refresh_token_invalidation(self):
{"refresh_token": second_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
- fifth_refresh_response.code, 200, fifth_refresh_response.result
+ fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 1b58b73136c5..c9b220e73d1a 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -16,13 +16,17 @@
import itertools
import urllib.parse
from typing import Dict, List, Optional, Tuple
+from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import inject_event
class RelationsTestCase(unittest.HomeserverTestCase):
@@ -90,11 +94,6 @@ def test_send_relation(self):
channel.json_body,
)
- def test_deny_membership(self):
- """Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
- self.assertEquals(400, channel.code, channel.json_body)
-
def test_deny_invalid_event(self):
"""Test that we deny relations on non-existant events"""
channel = self._send_relation(
@@ -456,7 +455,14 @@ def test_aggregation_must_be_annotation(self):
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_bundled_aggregations(self):
- """Test that annotations, references, and threads get correctly bundled."""
+ """
+ Test that annotations, references, and threads get correctly bundled.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+
+ See test_edit for a similar test for edits.
+ """
# Setup by sending a variety of relations.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -484,12 +490,13 @@ def test_bundled_aggregations(self):
self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]
- def assert_bundle(actual):
+ def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
# Ensure the fields are as expected.
self.assertCountEqual(
- actual.keys(),
+ relations_dict.keys(),
(
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
@@ -505,17 +512,20 @@ def assert_bundle(actual):
{"type": "m.reaction", "key": "b", "count": 1},
]
},
- actual[RelationTypes.ANNOTATION],
+ relations_dict[RelationTypes.ANNOTATION],
)
self.assertEquals(
{"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- actual[RelationTypes.REFERENCE],
+ relations_dict[RelationTypes.REFERENCE],
)
self.assertEquals(
2,
- actual[RelationTypes.THREAD].get("count"),
+ relations_dict[RelationTypes.THREAD].get("count"),
+ )
+ self.assertTrue(
+ relations_dict[RelationTypes.THREAD].get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
@@ -532,20 +542,9 @@ def assert_bundle(actual):
"type": "m.room.test",
"user_id": self.user_id,
},
- actual[RelationTypes.THREAD].get("latest_event"),
+ relations_dict[RelationTypes.THREAD].get("latest_event"),
)
- def _find_and_assert_event(events):
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- break
- else:
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
- assert_bundle(event["unsigned"].get("m.relations"))
-
# Request the event directly.
channel = self.make_request(
"GET",
@@ -553,7 +552,7 @@ def _find_and_assert_event(events):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["unsigned"].get("m.relations"))
+ assert_bundle(channel.json_body)
# Request the room messages.
channel = self.make_request(
@@ -562,7 +561,7 @@ def _find_and_assert_event(events):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- _find_and_assert_event(channel.json_body["chunk"])
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
# Request the room context.
channel = self.make_request(
@@ -571,17 +570,14 @@ def _find_and_assert_event(events):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
+ assert_bundle(channel.json_body["event"])
# Request sync.
- # channel = self.make_request("GET", "/sync", access_token=self.user_token)
- # self.assertEquals(200, channel.code, channel.json_body)
- # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- # self.assertTrue(room_timeline["limited"])
- # _find_and_assert_event(room_timeline["events"])
-
- # Note that /relations is tested separately in test_aggregation_get_event_for_thread
- # since it needs different data configured.
+ channel = self.make_request("GET", "/sync", access_token=self.user_token)
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ self._find_event_in_chunk(room_timeline["events"])
def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included
@@ -651,6 +647,118 @@ def test_aggregation_get_event_for_thread(self):
},
)
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_ignore_invalid_room(self):
+ """Test that we ignore invalid relations over federation."""
+ # Create another room and send a message in it.
+ room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+ parent_id = res["event_id"]
+
+ # Disable the validation to pretend this came over federation.
+ with patch(
+ "synapse.handlers.message.EventCreationHandler._validate_event_relation",
+ new=lambda self, event: make_awaitable(None),
+ ):
+ # Generate a various relations from a different room.
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.reaction",
+ sender=self.user_id,
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": parent_id,
+ "key": "A",
+ }
+ },
+ )
+ )
+
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.room.message",
+ sender=self.user_id,
+ content={
+ "body": "foo",
+ "msgtype": "m.text",
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": parent_id,
+ },
+ },
+ )
+ )
+
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.room.message",
+ sender=self.user_id,
+ content={
+ "body": "foo",
+ "msgtype": "m.text",
+ "m.relates_to": {
+ "rel_type": RelationTypes.THREAD,
+ "event_id": parent_id,
+ },
+ },
+ )
+ )
+
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.room.message",
+ sender=self.user_id,
+ content={
+ "body": "foo",
+ "msgtype": "m.text",
+ "new_content": {
+ "body": "new content",
+ "msgtype": "m.text",
+ },
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": parent_id,
+ },
+ },
+ )
+ )
+
+ # They should be ignored when fetching relations.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(channel.json_body["chunk"], [])
+
+ # And when fetching aggregations.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(channel.json_body["chunk"], [])
+
+ # And for bundled aggregations.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{room2}/event/{parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
def test_edit(self):
"""Test that a simple edit works."""
@@ -664,25 +772,58 @@ def test_edit(self):
edit_event_id = channel.json_body["event_id"]
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
-
self.assertEquals(channel.json_body["content"], new_body)
+ assert_bundle(channel.json_body)
- relations_dict = channel.json_body["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ # Request sync, but limit the timeline so it becomes limited (and includes
+ # bundled aggregations).
+ filter = urllib.parse.quote_plus(
+ '{"room": {"timeline": {"limit": 2}}}'.encode()
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
)
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who
@@ -989,6 +1130,16 @@ def test_unknown_relations(self):
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
def _send_relation(
self,
relation_type: str,
@@ -1004,7 +1155,8 @@ def _send_relation(
relation_type: One of `RelationTypes`
event_type: The type of the event to create
key: The aggregation key used for m.annotation relation type.
- content: The content of the created event.
+ content: The content of the created event. Will be modified to configure
+ the m.relates_to key based on the other provided parameters.
access_token: The access token used to send the relation, defaults
to `self.user_token`
parent_id: The event_id this relation relates to. If None, then self.parent_id
@@ -1015,17 +1167,21 @@ def _send_relation(
if not access_token:
access_token = self.user_token
- query = ""
- if key:
- query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
-
original_id = parent_id if parent_id else self.parent_id
+ if content is None:
+ content = {}
+ content["m.relates_to"] = {
+ "event_id": original_id,
+ "rel_type": relation_type,
+ }
+ if key is not None:
+ content["m.relates_to"]["key"] = key
+
channel = self.make_request(
"POST",
- "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
- % (self.room, original_id, relation_type, event_type, query),
- content or {},
+ f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
+ content,
access_token=access_token,
)
return channel
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b58452195a82..fe5b536d9705 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -228,7 +228,7 @@ def get_event(self, event_id, expect_none=False):
self.assertIsNotNone(event)
time_now = self.clock.time_msec()
- serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+ serialized = self.serializer.serialize_event(event, time_now)
return serialized
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
new file mode 100644
index 000000000000..721454c1875f
--- /dev/null
+++ b/tests/rest/client/test_room_batch.py
@@ -0,0 +1,180 @@
+import logging
+from typing import List, Tuple
+from unittest.mock import Mock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.appservice import ApplicationService
+from synapse.rest import admin
+from synapse.rest.client import login, register, room, room_batch
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+def _create_join_state_events_for_batch_send_request(
+ virtual_user_ids: List[str],
+ insert_time: int,
+) -> List[JsonDict]:
+ return [
+ {
+ "type": EventTypes.Member,
+ "sender": virtual_user_id,
+ "origin_server_ts": insert_time,
+ "content": {
+ "membership": "join",
+ "displayname": "display-name-for-%s" % (virtual_user_id,),
+ },
+ "state_key": virtual_user_id,
+ }
+ for virtual_user_id in virtual_user_ids
+ ]
+
+
+def _create_message_events_for_batch_send_request(
+ virtual_user_id: str, insert_time: int, count: int
+) -> List[JsonDict]:
+ return [
+ {
+ "type": EventTypes.Message,
+ "sender": virtual_user_id,
+ "origin_server_ts": insert_time,
+ "content": {
+ "msgtype": "m.text",
+ "body": "Historical %d" % (i),
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ }
+ for i in range(count)
+ ]
+
+
+class RoomBatchTestCase(unittest.HomeserverTestCase):
+ """Test importing batches of historical messages."""
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room_batch.register_servlets,
+ room.register_servlets,
+ register.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ hostname="test",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.clock = clock
+ self.storage = hs.get_storage()
+
+ self.virtual_user_id = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
+
+ def _create_test_room(self) -> Tuple[str, str, str, str]:
+ room_id = self.helper.create_room_as(
+ self.appservice.sender, tok=self.appservice.token
+ )
+
+ res_a = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "A",
+ },
+ tok=self.appservice.token,
+ )
+ event_id_a = res_a["event_id"]
+
+ res_b = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "B",
+ },
+ tok=self.appservice.token,
+ )
+ event_id_b = res_b["event_id"]
+
+ res_c = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "C",
+ },
+ tok=self.appservice.token,
+ )
+ event_id_c = res_c["event_id"]
+
+ return room_id, event_id_a, event_id_b, event_id_c
+
+ @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
+ def test_same_state_groups_for_whole_historical_batch(self):
+ """Make sure that when using the `/batch_send` endpoint to import a
+ bunch of historical messages, it re-uses the same `state_group` across
+ the whole batch. This is an easy optimization to make sure we're getting
+ right because the state for the whole batch is contained in
+ `state_events_at_start` and can be shared across everything.
+ """
+
+ time_before_room = int(self.clock.time_msec())
+ room_id, event_id_a, _, _ = self._create_test_room()
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s"
+ % (room_id, event_id_a),
+ content={
+ "events": _create_message_events_for_batch_send_request(
+ self.virtual_user_id, time_before_room, 3
+ ),
+ "state_events_at_start": _create_join_state_events_for_batch_send_request(
+ [self.virtual_user_id], time_before_room
+ ),
+ },
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Get the historical event IDs that we just imported
+ historical_event_ids = channel.json_body["event_ids"]
+ self.assertEqual(len(historical_event_ids), 3)
+
+ # Fetch the state_groups
+ state_group_map = self.get_success(
+ self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+ )
+
+ # We expect all of the historical events to be using the same state_group
+ # so there should only be a single state_group here!
+ self.assertEqual(
+ len(state_group_map.keys()),
+ 1,
+ "Expected a single state_group to be returned by saw state_groups=%s"
+ % (state_group_map.keys(),),
+ )
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 1af5e5cee504..842438358021 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -196,6 +196,16 @@ def leave(self, room=None, user=None, expect_code=200, tok=None):
expect_code=expect_code,
)
+ def ban(self, room: str, src: str, targ: str, **kwargs: object):
+ """A convenience helper: `change_membership` with `membership` preset to "ban"."""
+ self.change_membership(
+ room=room,
+ src=src,
+ targ=targ,
+ membership=Membership.BAN,
+ **kwargs,
+ )
+
def change_membership(
self,
room: str,
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 8698135a769d..16e904f15b45 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -1,4 +1,5 @@
# Copyright 2018 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/server.py b/tests/server.py
index 40cf5b12c3a0..a0cd14ea45b4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import hashlib
import json
import logging
+import os
+import os.path
+import time
+import uuid
+import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
@@ -27,6 +32,7 @@
Type,
Union,
)
+from unittest.mock import Mock
import attr
from typing_extensions import Deque
@@ -53,11 +59,25 @@
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
+from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.utils import setup_test_homeserver as _sth
+from tests.utils import (
+ LEAVE_DB,
+ POSTGRES_BASE_DB,
+ POSTGRES_HOST,
+ POSTGRES_PASSWORD,
+ POSTGRES_USER,
+ SQLITE_PERSIST_DB,
+ USE_POSTGRES_FOR_TESTS,
+ MockClock,
+ default_config,
+)
logger = logging.getLogger(__name__)
@@ -450,14 +470,11 @@ def _(res):
return d
-def setup_test_homeserver(cleanup_func, *args, **kwargs):
+def _make_test_homeserver_synchronous(server: HomeServer) -> None:
"""
- Set up a synchronous test server, driven by the reactor used by
- the homeserver.
+ Make the given test homeserver's database interactions synchronous.
"""
- server = _sth(cleanup_func, *args, **kwargs)
- # Make the thread pool synchronous.
clock = server.get_clock()
for database in server.get_datastores().databases:
@@ -485,6 +502,7 @@ def runInteraction(interaction, *args, **kwargs):
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
+ # Replace the thread pool with a threadless 'thread' pool
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
@@ -492,8 +510,6 @@ def runInteraction(interaction, *args, **kwargs):
# thread, so we need to disable the dedicated thread behaviour.
server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
- return server
-
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
@@ -673,3 +689,185 @@ def connect_client(
client.makeConnection(FakeTransport(server, reactor))
return client, server
+
+
+class TestHomeServer(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+
+def setup_test_homeserver(
+ cleanup_func,
+ name="test",
+ config=None,
+ reactor=None,
+ homeserver_to_use: Type[HomeServer] = TestHomeServer,
+ **kwargs,
+):
+ """
+ Setup a homeserver suitable for running tests against. Keyword arguments
+ are passed to the Homeserver constructor.
+
+ If no datastore is supplied, one is created and given to the homeserver.
+
+ Args:
+ cleanup_func : The function used to register a cleanup routine for
+ after the test.
+
+ Calling this method directly is deprecated: you should instead derive from
+ HomeserverTestCase.
+ """
+ if reactor is None:
+ from twisted.internet import reactor
+
+ if config is None:
+ config = default_config(name, parse=True)
+
+ config.ldap_enabled = False
+
+ if "clock" not in kwargs:
+ kwargs["clock"] = MockClock()
+
+ if USE_POSTGRES_FOR_TESTS:
+ test_db = "synapse_test_%s" % uuid.uuid4().hex
+
+ database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": test_db,
+ "host": POSTGRES_HOST,
+ "password": POSTGRES_PASSWORD,
+ "user": POSTGRES_USER,
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ if SQLITE_PERSIST_DB:
+ # The current working directory is in _trial_temp, so this gets created within that directory.
+ test_db_location = os.path.abspath("test.db")
+ logger.debug("Will persist db to %s", test_db_location)
+ # Ensure each test gets a clean database.
+ try:
+ os.remove(test_db_location)
+ except FileNotFoundError:
+ pass
+ else:
+ logger.debug("Removed existing DB at %s", test_db_location)
+ else:
+ test_db_location = ":memory:"
+
+ database_config = {
+ "name": "sqlite3",
+ "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
+ }
+
+ if "db_txn_limit" in kwargs:
+ database_config["txn_limit"] = kwargs["db_txn_limit"]
+
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
+
+ # Create the database before we actually try and connect to it, based off
+ # the template database we generate in setupdb()
+ if isinstance(db_engine, PostgresEngine):
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ cur.execute(
+ "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+ )
+ cur.close()
+ db_conn.close()
+
+ hs = homeserver_to_use(
+ name,
+ config=config,
+ version_string="Synapse/tests",
+ reactor=reactor,
+ )
+
+ # Install @cache_in_self attributes
+ for key, val in kwargs.items():
+ setattr(hs, "_" + key, val)
+
+ # Mock TLS
+ hs.tls_server_context_factory = Mock()
+ hs.tls_client_options_factory = Mock()
+
+ hs.setup()
+ if homeserver_to_use == TestHomeServer:
+ hs.setup_background_tasks()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
+
+ # We need to do cleanup on PostgreSQL
+ def cleanup():
+ import psycopg2
+
+ # Close all the db pools
+ database._db_pool.close()
+
+ dropped = False
+
+ # Drop the test database
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for _ in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
+ cur.close()
+ db_conn.close()
+
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
+
+ # bcrypt is far too slow to be doing in unit tests
+ # Need to let the HS build an auth handler and then mess with it
+ # because AuthHandler's constructor requires the HS, so we can't make one
+ # beforehand and pass it in to the HS's constructor (chicken / egg)
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
+
+ # Make the threadpool and database transactions synchronous for testing.
+ _make_test_homeserver_synchronous(hs)
+
+ return hs
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 01af49a16b0f..d697d2bc1e79 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Set
+from typing import Iterable, Optional, Set
from synapse.api.constants import AccountDataTypes
@@ -25,7 +25,7 @@ def prepare(self, hs, reactor, clock):
self.user = "@user:test"
def _update_ignore_list(
- self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+ self, *ignored_user_ids: Iterable[str], ignorer_user_id: Optional[str] = None
) -> None:
"""Update the account data to block the given users."""
if ignorer_user_id is None:
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index d77c001506c6..6156dfac4e58 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Use backported mock for AsyncMock support on Python 3.6.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
-from tests.test_utils import make_awaitable
+from tests.test_utils import make_awaitable, simple_async_mock
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -116,14 +115,14 @@ def prepare(self, reactor, clock, homeserver):
)
# Mock out the AsyncContextManager
- self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
- self._update_ctx_manager.__aenter__ = Mock(
- return_value=make_awaitable(None),
- )
- self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+ class MockCM:
+ __aenter__ = simple_async_mock(return_value=None)
+ __aexit__ = simple_async_mock(return_value=None)
+
+ self._update_ctx_manager = MockCM
# Mock out the `update_handler` callback
- self._on_update = Mock(return_value=self._update_ctx_manager)
+ self._on_update = Mock(return_value=self._update_ctx_manager())
# Define a default batch size value that's not the same as the internal default
# value (100).
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index ddad44bd6cbb..3e4f0579c9cb 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -23,7 +23,8 @@
from synapse.storage.engines import create_engine
from tests import unittest
-from tests.utils import TestHomeServer, default_config
+from tests.server import TestHomeServer
+from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6790aa524291..b547bf8d9978 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -94,7 +94,7 @@ def test_count_devices_by_users(self):
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
- # Add two device updates with a single stream_id
+ # Add two device updates with sequential `stream_id`s
self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
@@ -107,6 +107,164 @@ def test_get_device_updates_by_remote(self):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
+ def test_get_device_updates_by_remote_can_limit_properly(self):
+ """
+ Tests that `get_device_updates_by_remote` returns an appropriate
+ stream_id to resume fetching from (without skipping any results).
+ """
+
+ # Add some device updates with sequential `stream_id`s
+ device_ids = [
+ "device_id1",
+ "device_id2",
+ "device_id3",
+ "device_id4",
+ "device_id5",
+ ]
+ self.get_success(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ )
+
+ # Get device updates meant for this remote
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+ )
+
+ # Check the first three original device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids[:3], device_updates)
+
+ # Get the next batch of device updates
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+
+ # Check the last two original device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids[3:], device_updates)
+
+ # Add some more device updates to ensure it still resumes properly
+ device_ids = ["device_id6", "device_id7"]
+ self.get_success(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ )
+
+ # Get the next batch of device updates
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+
+ # Check the newly-added device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids, device_updates)
+
+ # Check there are no more device updates left.
+ _, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+ self.assertEqual(device_updates, [])
+
+ def test_get_device_updates_by_remote_cross_signing_key_updates(
+ self,
+ ) -> None:
+ """
+ Tests that `get_device_updates_by_remote` limits the length of the return value
+ properly when cross-signing key updates are present.
+ Current behaviour is that the cross-signing key updates will always come in pairs,
+ even if that means leaving an earlier batch one EDU short of the limit.
+ """
+
+ assert self.hs.is_mine_id(
+ "@user_id:test"
+ ), "Test not valid: this MXID should be considered local"
+
+ self.get_success(
+ self.store.set_e2e_cross_signing_key(
+ "@user_id:test",
+ "master",
+ {
+ "keys": {
+ "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ },
+ "signatures": {
+ "@user_id:test": {
+ "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ }
+ },
+ },
+ )
+ )
+ self.get_success(
+ self.store.set_e2e_cross_signing_key(
+ "@user_id:test",
+ "self_signing",
+ {
+ "keys": {
+ "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ },
+ "signatures": {
+ "@user_id:test": {
+ "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ }
+ },
+ },
+ )
+ )
+
+ # Add some device updates with sequential `stream_id`s
+ # Note that the public cross-signing keys occupy the same space as device IDs,
+ # so also notify that those have updated.
+ device_ids = [
+ "device_id1",
+ "device_id2",
+ "fakeMaster",
+ "fakeSelfSigning",
+ ]
+
+ self.get_success(
+ self.store.add_device_change_to_streams(
+ "@user_id:test", device_ids, ["somehost"]
+ )
+ )
+
+ # Get device updates meant for this remote
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+ )
+
+ # Here we expect the device updates for `device_id1` and `device_id2`.
+ # That means we only receive 2 updates this time around.
+ # If we had a higher limit, we would expect to see the pair of
+ # (unstable-prefixed & unprefixed) signing key updates for the device
+ # represented by `fakeMaster` and `fakeSelfSigning`.
+ # Our implementation only sends these two variants together, so we get
+ # a short batch.
+ self.assertEqual(len(device_updates), 2, device_updates)
+
+ # Check the first two devices (device_id1, device_id2) came out.
+ self._check_devices_in_updates(device_ids[:2], device_updates)
+
+ # Get more device updates meant for this remote
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+
+ # The next 2 updates should be a cross-signing key update
+ # (the master key update and the self-signing key update are combined into
+ # one 'signing key update', but the cross-signing key update is emitted
+ # twice, once with an unprefixed type and once again with an unstable-prefixed type)
+ # (This is a temporary arrangement for backwards compatibility!)
+ self.assertEqual(len(device_updates), 2, device_updates)
+ self.assertEqual(
+ device_updates[0][0], "m.signing_key_update", device_updates[0]
+ )
+ self.assertEqual(
+ device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
+ )
+
+ # Check there are no more device updates left.
+ _, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+ self.assertEqual(device_updates, [])
+
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 9b6b42542532..7556171d8ac4 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
+
from tests import unittest
# sample room_key data for use in the tests
-room_key = {
+room_key: RoomKey = {
"first_message_index": 1,
"forwarded_count": 1,
"is_verified": False,
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index c3fcf7e7b405..2bc89512f8ae 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Tuple, Union
+
import attr
from parameterized import parameterized
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersion,
+)
from synapse.events import _EventInternalMetadata
from synapse.util import json_encoder
@@ -506,26 +512,44 @@ def insert_event(txn):
)
self.assertSetEqual(difference, set())
- def test_prune_inbound_federation_queue(self):
- "Test that pruning of inbound federation queues work"
+ @parameterized.expand(
+ [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
+ )
+ def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
+ """Test that pruning of inbound federation queues work"""
room_id = "some_room_id"
+ def prev_event_format(prev_event_id: str) -> Union[Tuple[str, dict], str]:
+ """Account for differences in prev_events format across room versions"""
+ if room_version.event_format == EventFormatVersions.V1:
+ return prev_event_id, {}
+
+ return prev_event_id
+
# Insert a bunch of events that all reference the previous one.
self.get_success(
self.store.db_pool.simple_insert_many(
table="federation_inbound_events_staging",
+ keys=(
+ "origin",
+ "room_id",
+ "received_ts",
+ "event_id",
+ "event_json",
+ "internal_metadata",
+ ),
values=[
- {
- "origin": "some_origin",
- "room_id": room_id,
- "received_ts": 0,
- "event_id": f"$fake_event_id_{i + 1}",
- "event_json": json_encoder.encode(
- {"prev_events": [f"$fake_event_id_{i}"]}
+ (
+ "some_origin",
+ room_id,
+ 0,
+ f"$fake_event_id_{i + 1}",
+ json_encoder.encode(
+ {"prev_events": [prev_event_format(f"$fake_event_id_{i}")]}
),
- "internal_metadata": "{}",
- }
+ "{}",
+ )
for i in range(500)
],
desc="test_prune_inbound_federation_queue",
@@ -535,12 +559,12 @@ def test_prune_inbound_federation_queue(self):
# Calling prune once should return True, i.e. a prune happen. The second
# time it shouldn't.
pruned = self.get_success(
- self.store.prune_staged_events_in_room(room_id, RoomVersions.V6)
+ self.store.prune_staged_events_in_room(room_id, room_version)
)
self.assertTrue(pruned)
pruned = self.get_success(
- self.store.prune_staged_events_in_room(room_id, RoomVersions.V6)
+ self.store.prune_staged_events_in_room(room_id, room_version)
)
self.assertFalse(pruned)
@@ -550,7 +574,7 @@ def test_prune_inbound_federation_queue(self):
self.store.db_pool.simple_select_one_onecol(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
- retcol="COALESCE(COUNT(*), 0)",
+ retcol="COUNT(*)",
desc="test_prune_inbound_federation_queue",
)
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index bb5939ba4a51..738f3ad1dcc7 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -14,6 +14,8 @@
from unittest.mock import Mock
+from synapse.storage.databases.main.event_push_actions import NotifCounts
+
from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@@ -57,11 +59,11 @@ def _assert_counts(noitf_count, highlight_count):
)
self.assertEquals(
counts,
- {
- "notify_count": noitf_count,
- "unread_count": 0, # Unread counts are tested in the sync tests.
- "highlight_count": highlight_count,
- },
+ NotifCounts(
+ notify_count=noitf_count,
+ unread_count=0, # Unread counts are tested in the sync tests.
+ highlight_count=highlight_count,
+ ),
)
def _inject_actions(stream, action):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index fccab733c029..5cfdfe9b852e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,8 +19,8 @@
from synapse.types import UserID, create_requester
from tests import unittest
+from tests.server import TestHomeServer
from tests.test_utils import event_injection
-from tests.utils import TestHomeServer
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3eef1c4c05c8..2b9804aba005 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -17,7 +17,9 @@
from twisted.internet.defer import succeed
from synapse.api.errors import FederationError
+from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext
from synapse.types import UserID, create_requester
from synapse.util import Clock
@@ -276,3 +278,73 @@ def test_cross_signing_keys_retry(self):
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
)
self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
+
+
+class StripUnsignedFromEventsTestCase(unittest.TestCase):
+ def test_strip_unauthorized_unsigned_values(self):
+ event1 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event1:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "content": {"membership": "join"},
+ "auth_events": [],
+ "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
+ }
+ filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
+ # Make sure unauthorized fields are stripped from unsigned
+ self.assertNotIn("more warez", filtered_event.unsigned)
+
+ def test_strip_event_maintains_allowed_fields(self):
+ event2 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event2:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "auth_events": [],
+ "content": {"membership": "join"},
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+
+ filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
+ self.assertIn("age", filtered_event2.unsigned)
+ self.assertEqual(14, filtered_event2.unsigned["age"])
+ self.assertNotIn("more warez", filtered_event2.unsigned)
+ # Invite_room_state is allowed in events of type m.room.member
+ self.assertIn("invite_room_state", filtered_event2.unsigned)
+ self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
+
+ def test_strip_event_removes_fields_based_on_event_type(self):
+ event3 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event3:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.power_levels",
+ "origin": "test.servx",
+ "content": {},
+ "auth_events": [],
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+ filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
+ self.assertIn("age", filtered_event3.unsigned)
+ # Invite_room_state field is only permitted in event type m.room.member
+ self.assertNotIn("invite_room_state", filtered_event3.unsigned)
+ self.assertNotIn("more warez", filtered_event3.unsigned)
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 40b89fb2efa6..46e02f483fef 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.media.v1.preview_url_resource import (
- _calc_og,
+from synapse.rest.media.v1.preview_html import (
+ _get_html_media_encodings,
decode_body,
- get_html_media_encodings,
+ parse_html_to_open_graph,
summarize_paragraphs,
)
@@ -160,7 +160,7 @@ def test_simple(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -176,7 +176,7 @@ def test_comment(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -195,7 +195,7 @@ def test_comment2(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(
og,
@@ -217,7 +217,7 @@ def test_script(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -231,7 +231,7 @@ def test_missing_title(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -246,7 +246,7 @@ def test_h1_as_title(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -261,7 +261,7 @@ def test_missing_title_and_broken_h1(self):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -289,7 +289,7 @@ def test_xml(self):
FooSome text.
""".strip()
tree = decode_body(html, "http://example.com/test.html")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding(self):
@@ -303,7 +303,7 @@ def test_invalid_encoding(self):
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
- og = _calc_og(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self):
@@ -318,7 +318,7 @@ def test_invalid_encoding2(self):