diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index cf1989eff923..000000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,78 +0,0 @@ -version: 2.1 -jobs: - dockerhubuploadrelease: - docker: - - image: docker:git - steps: - - checkout - - docker_prepare - - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - # for release builds, we want to get the amd64 image out asap, so first - # we do an amd64-only build, before following up with a multiarch build. - - docker_build: - tag: -t matrixdotorg/synapse:${CIRCLE_TAG} - platforms: linux/amd64 - - docker_build: - tag: -t matrixdotorg/synapse:${CIRCLE_TAG} - platforms: linux/amd64,linux/arm64 - - dockerhubuploadlatest: - docker: - - image: docker:git - steps: - - checkout - - docker_prepare - - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - # for `latest`, we don't want the arm images to disappear, so don't update the tag - # until all of the platforms are built. - - docker_build: - tag: -t matrixdotorg/synapse:latest - platforms: linux/amd64,linux/arm64 - -workflows: - build: - jobs: - - dockerhubuploadrelease: - filters: - tags: - only: /v[0-9].[0-9]+.[0-9]+.*/ - branches: - ignore: /.*/ - - dockerhubuploadlatest: - filters: - branches: - only: [ master, main ] - -commands: - docker_prepare: - description: Sets up a remote docker server, downloads the buildx cli plugin, and enables multiarch images - parameters: - buildx_version: - type: string - default: "v0.4.1" - steps: - - setup_remote_docker: - # 19.03.13 was the most recent available on circleci at the time of - # writing. - version: 19.03.13 - - run: apk add --no-cache curl - - run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache - - run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx - - run: chmod a+x ~/.docker/cli-plugins/docker-buildx - # install qemu links in /proc/sys/fs/binfmt_misc on the docker instance running the circleci job - - run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes - # create a context named `builder` for the builds - - run: docker context create builder - # create a buildx builder using the new context, and set it as the default - - run: docker buildx create builder --use - - docker_build: - description: Builds and pushed images to dockerhub using buildx - parameters: - platforms: - type: string - default: linux/amd64 - tag: - type: string - steps: - - run: docker buildx build -f docker/Dockerfile --push --platform << parameters.platforms >> --label gitsha1=${CIRCLE_SHA1} << parameters.tag >> --progress=plain . diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 000000000000..af7ed21fce7f --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,72 @@ +# GitHub actions workflow which builds and publishes the docker images. + +name: Build docker images + +on: + push: + tags: ["v*"] + branches: [ master, main ] + workflow_dispatch: + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Set up QEMU + id: qemu + uses: docker/setup-qemu-action@v1 + with: + platforms: arm64 + + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + + - name: Inspect builder + run: docker buildx inspect + + - name: Log in to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Calculate docker image tag + id: set-tag + run: | + case "${GITHUB_REF}" in + refs/heads/master|refs/heads/main) + tag=latest + ;; + refs/tags/*) + tag=${GITHUB_REF#refs/tags/} + ;; + *) + tag=${GITHUB_SHA} + ;; + esac + echo "::set-output name=tag::$tag" + + # for release builds, we want to get the amd64 image out asap, so first + # we do an amd64-only build, before following up with a multiarch build. + - name: Build and push amd64 + uses: docker/build-push-action@v2 + if: "${{ startsWith(github.ref, 'refs/tags/v') }}" + with: + push: true + labels: "gitsha1=${{ github.sha }}" + tags: "matrixdotorg/synapse:${{ steps.set-tag.outputs.tag }}" + file: "docker/Dockerfile" + platforms: linux/amd64 + + - name: Build and push all platforms + uses: docker/build-push-action@v2 + with: + push: true + labels: "gitsha1=${{ github.sha }}" + tags: "matrixdotorg/synapse:${{ steps.set-tag.outputs.tag }}" + file: "docker/Dockerfile" + platforms: linux/amd64,linux/arm64 diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index f292d703ed22..325c1f7d3940 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -3,28 +3,33 @@ name: Build release artifacts on: + # we build on PRs and develop to (hopefully) get early warning + # of things breaking (but only build one set of debs) + pull_request: push: - # we build on develop and release branches to (hopefully) get early warning - # of things breaking - branches: ["develop", "release-*"] + branches: ["develop"] - # we also rebuild on tags, so that we can be sure of picking the artifacts - # from the right tag. + # we do the full build on tags. tags: ["v*"] permissions: contents: write jobs: - # first get the list of distros to build for. get-distros: + name: "Calculate list of debian distros" runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - id: set-distros run: | - echo "::set-output name=distros::$(scripts-dev/build_debian_packages --show-dists-json)" + # if we're running from a tag, get the full list of distros; otherwise just use debian:sid + dists='["debian:sid"]' + if [[ $GITHUB_REF == refs/tags/* ]]; then + dists=$(scripts-dev/build_debian_packages --show-dists-json) + fi + echo "::set-output name=distros::$dists" # map the step outputs to job outputs outputs: distros: ${{ steps.set-distros.outputs.distros }} @@ -66,7 +71,7 @@ jobs: # if it's a tag, create a release and attach the artifacts to it attach-assets: name: "Attach assets to release" - if: startsWith(github.ref, 'refs/tags/') + if: ${{ !failure() && !cancelled() && startsWith(github.ref, 'refs/tags/') }} needs: - build-debs - build-sdist diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf36ee1cdfff..cef4439477b8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -65,14 +65,14 @@ jobs: # Dummy step to gate other tests on without repeating the whole list linting-done: - if: ${{ always() }} # Run this even if prior jobs were skipped + if: ${{ !cancelled() }} # Run this even if prior jobs were skipped needs: [lint, lint-crlf, lint-newsfile, lint-sdist] runs-on: ubuntu-latest steps: - run: "true" trial: - if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail + if: ${{ !cancelled() && !failure() }} # Allow previous steps to be skipped, but not fail needs: linting-done runs-on: ubuntu-latest strategy: @@ -131,7 +131,7 @@ jobs: || true trial-olddeps: - if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail + if: ${{ !cancelled() && !failure() }} # Allow previous steps to be skipped, but not fail needs: linting-done runs-on: ubuntu-latest steps: @@ -156,7 +156,7 @@ jobs: trial-pypy: # Very slow; only run if the branch name includes 'pypy' - if: ${{ contains(github.ref, 'pypy') && !failure() }} + if: ${{ contains(github.ref, 'pypy') && !failure() && !cancelled() }} needs: linting-done runs-on: ubuntu-latest strategy: @@ -185,7 +185,7 @@ jobs: || true sytest: - if: ${{ !failure() }} + if: ${{ !failure() && !cancelled() }} needs: linting-done runs-on: ubuntu-latest container: @@ -245,7 +245,7 @@ jobs: /logs/**/*.log* portdb: - if: ${{ !failure() }} # Allow previous steps to be skipped, but not fail + if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail needs: linting-done runs-on: ubuntu-latest strategy: @@ -286,7 +286,7 @@ jobs: - run: .buildkite/scripts/test_synapse_port_db.sh complement: - if: ${{ !failure() }} + if: ${{ !failure() && !cancelled() }} needs: linting-done runs-on: ubuntu-latest container: @@ -344,3 +344,15 @@ jobs: env: COMPLEMENT_BASE_IMAGE: complement-synapse:latest working-directory: complement + + # a job which marks all the other jobs as complete, thus allowing PRs to be merged. + tests-done: + needs: + - trial + - trial-olddeps + - sytest + - portdb + - complement + runs-on: ubuntu-latest + steps: + - run: "true" \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md index 6e5a72041111..b512d9ff3f4a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,36 @@ +Synapse 1.39.0rc3 (2021-07-28) +============================== + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.38 which caused an exception at startup when SAML authentication was enabled. ([\#10477](https://github.com/matrix-org/synapse/issues/10477)) +- Fix a long-standing bug where Synapse would not inform clients that a device had exhausted its one-time-key pool, potentially causing problems decrypting events. ([\#10485](https://github.com/matrix-org/synapse/issues/10485)) +- Fix reporting old R30 stats as R30v2 stats. Introduced in v1.39.0rc1. ([\#10486](https://github.com/matrix-org/synapse/issues/10486)) + + +Internal Changes +---------------- + +- Fix an error which prevented the Github Actions workflow to build the docker images from running. ([\#10461](https://github.com/matrix-org/synapse/issues/10461)) +- Fix release script to correctly version debian changelog when doing RCs. ([\#10465](https://github.com/matrix-org/synapse/issues/10465)) + + +Synapse 1.39.0rc2 (2021-07-22) +============================== + +Bugfixes +-------- + +- Always include `device_one_time_keys_count` key in `/sync` response to work around a bug in Element Android that broke encryption for new devices. ([\#10457](https://github.com/matrix-org/synapse/issues/10457)) + + +Internal Changes +---------------- + +- Move docker image build to Github Actions. ([\#10416](https://github.com/matrix-org/synapse/issues/10416)) + + Synapse 1.38.1 (2021-07-22) =========================== @@ -7,6 +40,75 @@ Bugfixes - Always include `device_one_time_keys_count` key in `/sync` response to work around a bug in Element Android that broke encryption for new devices. ([\#10457](https://github.com/matrix-org/synapse/issues/10457)) +Synapse 1.39.0rc1 (2021-07-20) +============================== + +The Third-Party Event Rules module interface has been deprecated in favour of the generic module interface introduced in Synapse v1.37.0. Support for the old interface is planned to be removed in September 2021. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information. + +Features +-------- + +- Add the ability to override the account validity feature with a module. ([\#9884](https://github.com/matrix-org/synapse/issues/9884)) +- The spaces summary API now returns any joinable rooms, not only rooms which are world-readable. ([\#10298](https://github.com/matrix-org/synapse/issues/10298), [\#10305](https://github.com/matrix-org/synapse/issues/10305)) +- Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric. ([\#10332](https://github.com/matrix-org/synapse/issues/10332), [\#10427](https://github.com/matrix-org/synapse/issues/10427)) +- Allow providing credentials to `http_proxy`. ([\#10360](https://github.com/matrix-org/synapse/issues/10360)) + + +Bugfixes +-------- + +- Fix error while dropping locks on shutdown. Introduced in v1.38.0. ([\#10433](https://github.com/matrix-org/synapse/issues/10433)) +- Add base starting insertion event when no chunk ID is specified in the historical batch send API. ([\#10250](https://github.com/matrix-org/synapse/issues/10250)) +- Fix historical batch send endpoint (MSC2716) rejecting batches with messages from multiple senders. ([\#10276](https://github.com/matrix-org/synapse/issues/10276)) +- Fix purging rooms that other homeservers are still sending events for. Contributed by @ilmari. ([\#10317](https://github.com/matrix-org/synapse/issues/10317)) +- Fix errors during backfill caused by previously purged redaction events. Contributed by Andreas Rammhold (@andir). ([\#10343](https://github.com/matrix-org/synapse/issues/10343)) +- Fix the user directory becoming broken (and noisy errors being logged) when knocking and room statistics are in use. ([\#10344](https://github.com/matrix-org/synapse/issues/10344)) +- Fix newly added `synapse_federation_server_oldest_inbound_pdu_in_staging` prometheus metric to measure age rather than timestamp. ([\#10355](https://github.com/matrix-org/synapse/issues/10355)) +- Fix PostgreSQL sometimes using table scans for queries against `state_groups_state` table, taking a long time and a large amount of IO. ([\#10359](https://github.com/matrix-org/synapse/issues/10359)) +- Fix `make_room_admin` failing for users that have left a private room. ([\#10367](https://github.com/matrix-org/synapse/issues/10367)) +- Fix a number of logged errors caused by remote servers being down. ([\#10400](https://github.com/matrix-org/synapse/issues/10400), [\#10414](https://github.com/matrix-org/synapse/issues/10414)) +- Responses from `/make_{join,leave,knock}` no longer include signatures, which will turn out to be invalid after events are returned to `/send_{join,leave,knock}`. ([\#10404](https://github.com/matrix-org/synapse/issues/10404)) + + +Improved Documentation +---------------------- + +- Updated installation dependencies for newer macOS versions and ARM Macs. Contributed by Luke Walsh. ([\#9971](https://github.com/matrix-org/synapse/issues/9971)) +- Simplify structure of room admin API. ([\#10313](https://github.com/matrix-org/synapse/issues/10313)) +- Refresh the logcontext dev documentation. ([\#10353](https://github.com/matrix-org/synapse/issues/10353)), ([\#10337](https://github.com/matrix-org/synapse/issues/10337)) +- Add delegation example for caddy in the reverse proxy documentation. Contributed by @moritzdietz. ([\#10368](https://github.com/matrix-org/synapse/issues/10368)) +- Fix and clarify some links in `docs` and `contrib`. ([\#10370](https://github.com/matrix-org/synapse/issues/10370)), ([\#10322](https://github.com/matrix-org/synapse/issues/10322)), ([\#10399](https://github.com/matrix-org/synapse/issues/10399)) +- Make deprecation notice of the spam checker doc more obvious. ([\#10395](https://github.com/matrix-org/synapse/issues/10395)) +- Add instructions on installing Debian packages for release candidates. ([\#10396](https://github.com/matrix-org/synapse/issues/10396)) + + +Deprecations and Removals +------------------------- + +- Remove functionality associated with the unused `room_stats_historical` and `user_stats_historical` tables. Contributed by @xmunoz. ([\#9721](https://github.com/matrix-org/synapse/issues/9721)) +- The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information. ([\#10386](https://github.com/matrix-org/synapse/issues/10386)) + + +Internal Changes +---------------- + +- Convert `room_depth.min_depth` column to a `BIGINT`. ([\#10289](https://github.com/matrix-org/synapse/issues/10289)) +- Add tests to characterise the current behaviour of R30 phone-home metrics. ([\#10315](https://github.com/matrix-org/synapse/issues/10315)) +- Rebuild event context and auth when processing specific results from `ThirdPartyEventRules` modules. ([\#10316](https://github.com/matrix-org/synapse/issues/10316)) +- Minor change to the code that populates `user_daily_visits`. ([\#10324](https://github.com/matrix-org/synapse/issues/10324)) +- Re-enable Sytests that were disabled for the 1.37.1 release. ([\#10345](https://github.com/matrix-org/synapse/issues/10345), [\#10357](https://github.com/matrix-org/synapse/issues/10357)) +- Run `pyupgrade` on the codebase. ([\#10347](https://github.com/matrix-org/synapse/issues/10347), [\#10348](https://github.com/matrix-org/synapse/issues/10348)) +- Switch `application_services_txns.txn_id` database column to `BIGINT`. ([\#10349](https://github.com/matrix-org/synapse/issues/10349)) +- Convert internal type variable syntax to reflect wider ecosystem use. ([\#10350](https://github.com/matrix-org/synapse/issues/10350), [\#10380](https://github.com/matrix-org/synapse/issues/10380), [\#10381](https://github.com/matrix-org/synapse/issues/10381), [\#10382](https://github.com/matrix-org/synapse/issues/10382), [\#10418](https://github.com/matrix-org/synapse/issues/10418)) +- Make the Github Actions workflow configuration more efficient. ([\#10383](https://github.com/matrix-org/synapse/issues/10383)) +- Add type hints to `get_{domain,localpart}_from_id`. ([\#10385](https://github.com/matrix-org/synapse/issues/10385)) +- When building Debian packages for prerelease versions, set the Section accordingly. ([\#10391](https://github.com/matrix-org/synapse/issues/10391)) +- Add type hints and comments to event auth code. ([\#10393](https://github.com/matrix-org/synapse/issues/10393)) +- Stagger sending of presence update to remote servers, reducing CPU spikes caused by starting many connections to remote servers at once. ([\#10398](https://github.com/matrix-org/synapse/issues/10398)) +- Remove unused `events_by_room` code (tech debt). ([\#10421](https://github.com/matrix-org/synapse/issues/10421)) +- Add a github actions job which records success of other jobs. ([\#10430](https://github.com/matrix-org/synapse/issues/10430)) + + Synapse 1.38.0 (2021-07-13) =========================== diff --git a/changelog.d/10395.doc b/changelog.d/10395.doc deleted file mode 100644 index 4bdaea76c5dd..000000000000 --- a/changelog.d/10395.doc +++ /dev/null @@ -1 +0,0 @@ -Make deprecation notice of the spam checker doc more obvious. diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml index d1ecd453db03..26d640c44887 100644 --- a/contrib/docker/docker-compose.yml +++ b/contrib/docker/docker-compose.yml @@ -56,7 +56,7 @@ services: - POSTGRES_USER=synapse - POSTGRES_PASSWORD=changeme # ensure the database gets created correctly - # https://github.com/matrix-org/synapse/blob/master/docs/postgres.md#set-up-database + # https://matrix-org.github.io/synapse/latest/postgres.html#set-up-database - POSTGRES_INITDB_ARGS=--encoding=UTF-8 --lc-collate=C --lc-ctype=C volumes: # You may store the database tables in a local folder.. diff --git a/contrib/grafana/README.md b/contrib/grafana/README.md index 460879339427..0d4e1b59b27f 100644 --- a/contrib/grafana/README.md +++ b/contrib/grafana/README.md @@ -1,6 +1,6 @@ # Using the Synapse Grafana dashboard 0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/ -1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md +1. Have your Prometheus scrape your Synapse. https://matrix-org.github.io/synapse/latest/metrics-howto.html 2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/ -3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus +3. Set up required recording rules. [contrib/prometheus](../prometheus) diff --git a/contrib/prometheus/README.md b/contrib/prometheus/README.md index b3f23bcc80fd..4dbf648df83d 100644 --- a/contrib/prometheus/README.md +++ b/contrib/prometheus/README.md @@ -34,7 +34,7 @@ Add a new job to the main prometheus.yml file: ``` An example of a Prometheus configuration with workers can be found in -[metrics-howto.md](https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md). +[metrics-howto.md](https://matrix-org.github.io/synapse/latest/metrics-howto.html). To use `synapse.rules` add diff --git a/contrib/purge_api/README.md b/contrib/purge_api/README.md index 06b4cdb9f7f4..2f2e5c58cda0 100644 --- a/contrib/purge_api/README.md +++ b/contrib/purge_api/README.md @@ -3,8 +3,9 @@ Purge history API examples # `purge_history.sh` -A bash file, that uses the [purge history API](/docs/admin_api/purge_history_api.rst) to -purge all messages in a list of rooms up to a certain event. You can select a +A bash file, that uses the +[purge history API](https://matrix-org.github.io/synapse/latest/admin_api/purge_history_api.html) +to purge all messages in a list of rooms up to a certain event. You can select a timeframe or a number of messages that you want to keep in the room. Just configure the variables DOMAIN, ADMIN, ROOMS_ARRAY and TIME at the top of @@ -12,5 +13,6 @@ the script. # `purge_remote_media.sh` -A bash file, that uses the [purge history API](/docs/admin_api/purge_history_api.rst) to -purge all old cached remote media. +A bash file, that uses the +[purge history API](https://matrix-org.github.io/synapse/latest/admin_api/purge_history_api.html) +to purge all old cached remote media. diff --git a/contrib/purge_api/purge_history.sh b/contrib/purge_api/purge_history.sh index c45136ff535c..9d5324ea1c47 100644 --- a/contrib/purge_api/purge_history.sh +++ b/contrib/purge_api/purge_history.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash # this script will use the api: -# https://github.com/matrix-org/synapse/blob/master/docs/admin_api/purge_history_api.rst +# https://matrix-org.github.io/synapse/latest/admin_api/purge_history_api.html # # It will purge all messages in a list of rooms up to a cetrain event diff --git a/contrib/systemd-with-workers/README.md b/contrib/systemd-with-workers/README.md index 8d21d532bd40..9b19b042e949 100644 --- a/contrib/systemd-with-workers/README.md +++ b/contrib/systemd-with-workers/README.md @@ -1,2 +1,3 @@ The documentation for using systemd to manage synapse workers is now part of -the main synapse distribution. See [docs/systemd-with-workers](../../docs/systemd-with-workers). +the main synapse distribution. See +[docs/systemd-with-workers](https://matrix-org.github.io/synapse/latest/systemd-with-workers/index.html). diff --git a/debian/changelog b/debian/changelog index 5598fdea31ba..4944e5571413 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,9 +1,21 @@ +matrix-synapse-py3 (1.39.0~rc3) stable; urgency=medium + + * New synapse release 1.39.0~rc3. + + -- Synapse Packaging team Wed, 28 Jul 2021 13:30:58 +0100 + matrix-synapse-py3 (1.38.1) stable; urgency=medium * New synapse release 1.38.1. -- Synapse Packaging team Thu, 22 Jul 2021 15:37:06 +0100 +matrix-synapse-py3 (1.39.0~rc1) stable; urgency=medium + + * New synapse release 1.39.0rc1. + + -- Synapse Packaging team Tue, 20 Jul 2021 14:28:34 +0100 + matrix-synapse-py3 (1.38.0) stable; urgency=medium * New synapse release 1.38.0. diff --git a/docker/build_debian.sh b/docker/build_debian.sh index f426d2b77b20..f572ed9aa0ef 100644 --- a/docker/build_debian.sh +++ b/docker/build_debian.sh @@ -15,6 +15,20 @@ cd /synapse/build dch -M -l "+$DIST" "build for $DIST" dch -M -r "" --force-distribution --distribution "$DIST" +# if this is a prerelease, set the Section accordingly. +# +# When the package is later added to the package repo, reprepro will use the +# Section to determine which "component" it should go into (see +# https://manpages.debian.org/stretch/reprepro/reprepro.1.en.html#GUESSING) + +DEB_VERSION=`dpkg-parsechangelog -SVersion` +case $DEB_VERSION in + *rc*|*a*|*b*|*c*) + sed -ie '/^Section:/c\Section: prerelease' debian/control + ;; +esac + + dpkg-buildpackage -us -uc ls -l .. diff --git a/docs/MSC1711_certificates_FAQ.md b/docs/MSC1711_certificates_FAQ.md index 283f288aaf02..7d71c190abed 100644 --- a/docs/MSC1711_certificates_FAQ.md +++ b/docs/MSC1711_certificates_FAQ.md @@ -132,7 +132,7 @@ your domain, you can simply route all traffic through the reverse proxy by updating the SRV record appropriately (or removing it, if the proxy listens on 8448). -See [reverse_proxy.md](reverse_proxy.md) for information on setting up a +See [the reverse proxy documentation](reverse_proxy.md) for information on setting up a reverse proxy. #### Option 3: add a .well-known file to delegate your matrix traffic @@ -303,7 +303,7 @@ We no longer actively recommend against using a reverse proxy. Many admins will find it easier to direct federation traffic to a reverse proxy and manage their own TLS certificates, and this is a supported configuration. -See [reverse_proxy.md](reverse_proxy.md) for information on setting up a +See [the reverse proxy documentation](reverse_proxy.md) for information on setting up a reverse proxy. ### Do I still need to give my TLS certificates to Synapse if I am using a reverse proxy? diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index b033fc03efd0..61bed1e0d5d8 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -47,7 +47,7 @@ The API returns a JSON body like the following: ## List all media uploaded by a user Listing all media that has been uploaded by a local user can be achieved through -the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user) +the use of the [List media of a user](user_admin_api.md#list-media-of-a-user) Admin API. # Quarantine media @@ -257,7 +257,7 @@ URL Parameters * `server_name`: string - The name of your local server (e.g `matrix.org`). * `before_ts`: string representing a positive integer - Unix timestamp in ms. Files that were last used before this timestamp will be deleted. It is the timestamp of -last access and not the timestamp creation. +last access and not the timestamp creation. * `size_gt`: Optional - string representing a positive integer - Size of the media in bytes. Files that are larger will be deleted. Defaults to `0`. * `keep_profiles`: Optional - string representing a boolean - Switch to also delete files diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index bb7828a52529..48777dd23106 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -1,13 +1,9 @@ # Contents - [List Room API](#list-room-api) - * [Parameters](#parameters) - * [Usage](#usage) - [Room Details API](#room-details-api) - [Room Members API](#room-members-api) - [Room State API](#room-state-api) - [Delete Room API](#delete-room-api) - * [Parameters](#parameters-1) - * [Response](#response) * [Undoing room shutdowns](#undoing-room-shutdowns) - [Make Room Admin API](#make-room-admin-api) - [Forward Extremities Admin API](#forward-extremities-admin-api) @@ -19,7 +15,7 @@ The List Room admin API allows server admins to get a list of rooms on their server. There are various parameters available that allow for filtering and sorting the returned list. This API supports pagination. -## Parameters +**Parameters** The following query parameters are available: @@ -46,6 +42,8 @@ The following query parameters are available: * `search_term` - Filter rooms by their room name. Search term can be contained in any part of the room name. Defaults to no filtering. +**Response** + The following fields are possible in the JSON response body: * `rooms` - An array of objects, each containing information about a room. @@ -79,17 +77,15 @@ The following fields are possible in the JSON response body: Use `prev_batch` for the `from` value in the next request to get the "previous page" of results. -## Usage +The API is: A standard request with no filtering: ``` GET /_synapse/admin/v1/rooms - -{} ``` -Response: +A response body like the following is returned: ```jsonc { @@ -137,11 +133,9 @@ Filtering by room name: ``` GET /_synapse/admin/v1/rooms?search_term=TWIM - -{} ``` -Response: +A response body like the following is returned: ```json { @@ -172,11 +166,9 @@ Paginating through a list of rooms: ``` GET /_synapse/admin/v1/rooms?order_by=size - -{} ``` -Response: +A response body like the following is returned: ```jsonc { @@ -228,11 +220,9 @@ parameter to the value of `next_token`. ``` GET /_synapse/admin/v1/rooms?order_by=size&from=100 - -{} ``` -Response: +A response body like the following is returned: ```jsonc { @@ -304,17 +294,13 @@ The following fields are possible in the JSON response body: * `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. * `state_events` - Total number of state_events of a room. Complexity of the room. -## Usage - -A standard request: +The API is: ``` GET /_synapse/admin/v1/rooms/ - -{} ``` -Response: +A response body like the following is returned: ```json { @@ -347,17 +333,13 @@ The response includes the following fields: * `members` - A list of all the members that are present in the room, represented by their ids. * `total` - Total number of members in the room. -## Usage - -A standard request: +The API is: ``` GET /_synapse/admin/v1/rooms//members - -{} ``` -Response: +A response body like the following is returned: ```json { @@ -378,17 +360,13 @@ The response includes the following fields: * `state` - The current state of the room at the time of request. -## Usage - -A standard request: +The API is: ``` GET /_synapse/admin/v1/rooms//state - -{} ``` -Response: +A response body like the following is returned: ```json { @@ -432,6 +410,7 @@ DELETE /_synapse/admin/v1/rooms/ ``` with a body of: + ```json { "new_room_user_id": "@someuser:example.com", @@ -461,7 +440,7 @@ A response body like the following is returned: } ``` -## Parameters +**Parameters** The following parameters should be set in the URL: @@ -491,7 +470,7 @@ The following JSON body parameters are available: The JSON body must not be empty. The body must be at least `{}`. -## Response +**Response** The following fields are returned in the JSON response body: @@ -548,10 +527,10 @@ By default the server admin (the caller) is granted power, but another user can optionally be specified, e.g.: ``` - POST /_synapse/admin/v1/rooms//make_room_admin - { - "user_id": "@foo:example.com" - } +POST /_synapse/admin/v1/rooms//make_room_admin +{ + "user_id": "@foo:example.com" +} ``` # Forward Extremities Admin API @@ -565,7 +544,7 @@ extremities accumulate in a room, performance can become degraded. For details, To check the status of forward extremities for a room: ``` - GET /_synapse/admin/v1/rooms//forward_extremities +GET /_synapse/admin/v1/rooms//forward_extremities ``` A response as follows will be returned: @@ -581,7 +560,7 @@ A response as follows will be returned: "received_ts": 1611263016761 } ] -} +} ``` ## Deleting forward extremities @@ -594,7 +573,7 @@ If a room has lots of forward extremities, the extra can be deleted as follows: ``` - DELETE /_synapse/admin/v1/rooms//forward_extremities +DELETE /_synapse/admin/v1/rooms//forward_extremities ``` A response as follows will be returned, indicating the amount of forward extremities diff --git a/docs/admin_api/server_notices.md b/docs/admin_api/server_notices.md index 858b052b84c7..323138491a9c 100644 --- a/docs/admin_api/server_notices.md +++ b/docs/admin_api/server_notices.md @@ -45,4 +45,4 @@ Once the notice has been sent, the API will return the following response: ``` Note that server notices must be enabled in `homeserver.yaml` before this API -can be used. See [server_notices.md](../server_notices.md) for more information. +can be used. See [the server notices documentation](../server_notices.md) for more information. diff --git a/docs/consent_tracking.md b/docs/consent_tracking.md index 3f997e590308..911a1f95dbb4 100644 --- a/docs/consent_tracking.md +++ b/docs/consent_tracking.md @@ -152,7 +152,7 @@ version of the policy. To do so: * ensure that the consent resource is configured, as in the previous section - * ensure that server notices are configured, as in [server_notices.md](server_notices.md). + * ensure that server notices are configured, as in [the server notice documentation](server_notices.md). * Add `server_notice_content` under `user_consent` in `homeserver.yaml`. For example: diff --git a/docs/delegate.md b/docs/delegate.md index 208ddb627745..05cb6350472b 100644 --- a/docs/delegate.md +++ b/docs/delegate.md @@ -74,7 +74,7 @@ We no longer actively recommend against using a reverse proxy. Many admins will find it easier to direct federation traffic to a reverse proxy and manage their own TLS certificates, and this is a supported configuration. -See [reverse_proxy.md](reverse_proxy.md) for information on setting up a +See [the reverse proxy documentation](reverse_proxy.md) for information on setting up a reverse proxy. ### Do I still need to give my TLS certificates to Synapse if I am using a reverse proxy? diff --git a/docs/federate.md b/docs/federate.md index 89c2b196385e..5107f995be98 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -14,7 +14,7 @@ you set the `server_name` to match your machine's public DNS hostname. For this default configuration to work, you will need to listen for TLS connections on port 8448. The preferred way to do that is by using a -reverse proxy: see [reverse_proxy.md](reverse_proxy.md) for instructions +reverse proxy: see [the reverse proxy documentation](reverse_proxy.md) for instructions on how to correctly set one up. In some cases you might not want to run Synapse on the machine that has @@ -23,7 +23,7 @@ traffic to use a different port than 8448. For example, you might want to have your user names look like `@user:example.com`, but you want to run Synapse on `synapse.example.com` on port 443. This can be done using delegation, which allows an admin to control where federation traffic should -be sent. See [delegate.md](delegate.md) for instructions on how to set this up. +be sent. See [the delegation documentation](delegate.md) for instructions on how to set this up. Once federation has been configured, you should be able to join a room over federation. A good place to start is `#synapse:matrix.org` - a room for @@ -44,8 +44,8 @@ a complicated dance which requires connections in both directions). Another common problem is that people on other servers can't join rooms that you invite them to. This can be caused by an incorrectly-configured reverse -proxy: see [reverse_proxy.md](reverse_proxy.md) for instructions on how to correctly -configure a reverse proxy. +proxy: see [the reverse proxy documentation](reverse_proxy.md) for instructions on how +to correctly configure a reverse proxy. ### Known issues diff --git a/docs/log_contexts.md b/docs/log_contexts.md index fe30ca27916b..d49dce8830a1 100644 --- a/docs/log_contexts.md +++ b/docs/log_contexts.md @@ -14,12 +14,16 @@ The `synapse.logging.context` module provides a facilities for managing the current log context (as well as providing the `LoggingContextFilter` class). -Deferreds make the whole thing complicated, so this document describes +Asynchronous functions make the whole thing complicated, so this document describes how it all works, and how to write code which follows the rules. -##Logcontexts without Deferreds +In this document, "awaitable" refers to any object which can be `await`ed. In the context of +Synapse, that normally means either a coroutine or a Twisted +[`Deferred`](https://twistedmatrix.com/documents/current/api/twisted.internet.defer.Deferred.html). -In the absence of any Deferred voodoo, things are simple enough. As with +## Logcontexts without asynchronous code + +In the absence of any asynchronous voodoo, things are simple enough. As with any code of this nature, the rule is that our function should leave things as it found them: @@ -55,126 +59,109 @@ def do_request_handling(): logger.debug("phew") ``` -## Using logcontexts with Deferreds +## Using logcontexts with awaitables -Deferreds --- and in particular, `defer.inlineCallbacks` --- break the -linear flow of code so that there is no longer a single entry point -where we should set the logcontext and a single exit point where we -should remove it. +Awaitables break the linear flow of code so that there is no longer a single entry point +where we should set the logcontext and a single exit point where we should remove it. Consider the example above, where `do_request_handling` needs to do some -blocking operation, and returns a deferred: +blocking operation, and returns an awaitable: ```python -@defer.inlineCallbacks -def handle_request(request_id): +async def handle_request(request_id): with context.LoggingContext() as request_context: request_context.request = request_id - yield do_request_handling() + await do_request_handling() logger.debug("finished") ``` In the above flow: - The logcontext is set -- `do_request_handling` is called, and returns a deferred -- `handle_request` yields the deferred -- The `inlineCallbacks` wrapper of `handle_request` returns a deferred +- `do_request_handling` is called, and returns an awaitable +- `handle_request` awaits the awaitable +- Execution of `handle_request` is suspended So we have stopped processing the request (and will probably go on to start processing the next), without clearing the logcontext. To circumvent this problem, synapse code assumes that, wherever you have -a deferred, you will want to yield on it. To that end, whereever -functions return a deferred, we adopt the following conventions: +an awaitable, you will want to `await` it. To that end, whereever +functions return awaitables, we adopt the following conventions: -**Rules for functions returning deferreds:** +**Rules for functions returning awaitables:** -> - If the deferred is already complete, the function returns with the +> - If the awaitable is already complete, the function returns with the > same logcontext it started with. -> - If the deferred is incomplete, the function clears the logcontext -> before returning; when the deferred completes, it restores the +> - If the awaitable is incomplete, the function clears the logcontext +> before returning; when the awaitable completes, it restores the > logcontext before running any callbacks. That sounds complicated, but actually it means a lot of code (including the example above) "just works". There are two cases: -- If `do_request_handling` returns a completed deferred, then the +- If `do_request_handling` returns a completed awaitable, then the logcontext will still be in place. In this case, execution will - continue immediately after the `yield`; the "finished" line will + continue immediately after the `await`; the "finished" line will be logged against the right context, and the `with` block restores the original context before we return to the caller. -- If the returned deferred is incomplete, `do_request_handling` clears +- If the returned awaitable is incomplete, `do_request_handling` clears the logcontext before returning. The logcontext is therefore clear - when `handle_request` yields the deferred. At that point, the - `inlineCallbacks` wrapper adds a callback to the deferred, and - returns another (incomplete) deferred to the caller, and it is safe - to begin processing the next request. - - Once `do_request_handling`'s deferred completes, it will reinstate - the logcontext, before running the callback added by the - `inlineCallbacks` wrapper. That callback runs the second half of - `handle_request`, so again the "finished" line will be logged - against the right context, and the `with` block restores the - original context. + when `handle_request` `await`s the awaitable. + + Once `do_request_handling`'s awaitable completes, it will reinstate + the logcontext, before running the second half of `handle_request`, + so again the "finished" line will be logged against the right context, + and the `with` block restores the original context. As an aside, it's worth noting that `handle_request` follows our rules --though that only matters if the caller has its own logcontext which it +- though that only matters if the caller has its own logcontext which it cares about. The following sections describe pitfalls and helpful patterns when implementing these rules. -Always yield your deferreds ---------------------------- +Always await your awaitables +---------------------------- -Whenever you get a deferred back from a function, you should `yield` on -it as soon as possible. (Returning it directly to your caller is ok too, -if you're not doing `inlineCallbacks`.) Do not pass go; do not do any -logging; do not call any other functions. +Whenever you get an awaitable back from a function, you should `await` on +it as soon as possible. Do not pass go; do not do any logging; do not +call any other functions. ```python -@defer.inlineCallbacks -def fun(): +async def fun(): logger.debug("starting") - yield do_some_stuff() # just like this + await do_some_stuff() # just like this - d = more_stuff() - result = yield d # also fine, of course + coro = more_stuff() + result = await coro # also fine, of course return result - -def nonInlineCallbacksFun(): - logger.debug("just a wrapper really") - return do_some_stuff() # this is ok too - the caller will yield on - # it anyway. ``` Provided this pattern is followed all the way back up to the callchain to where the logcontext was set, this will make things work out ok: provided `do_some_stuff` and `more_stuff` follow the rules above, then -so will `fun` (as wrapped by `inlineCallbacks`) and -`nonInlineCallbacksFun`. +so will `fun`. -It's all too easy to forget to `yield`: for instance if we forgot that -`do_some_stuff` returned a deferred, we might plough on regardless. This +It's all too easy to forget to `await`: for instance if we forgot that +`do_some_stuff` returned an awaitable, we might plough on regardless. This leads to a mess; it will probably work itself out eventually, but not before a load of stuff has been logged against the wrong context. (Normally, other things will break, more obviously, if you forget to -`yield`, so this tends not to be a major problem in practice.) +`await`, so this tends not to be a major problem in practice.) Of course sometimes you need to do something a bit fancier with your -Deferreds - not all code follows the linear A-then-B-then-C pattern. +awaitable - not all code follows the linear A-then-B-then-C pattern. Notes on implementing more complex patterns are in later sections. -## Where you create a new Deferred, make it follow the rules +## Where you create a new awaitable, make it follow the rules -Most of the time, a Deferred comes from another synapse function. -Sometimes, though, we need to make up a new Deferred, or we get a -Deferred back from external code. We need to make it follow our rules. +Most of the time, an awaitable comes from another synapse function. +Sometimes, though, we need to make up a new awaitable, or we get an awaitable +back from external code. We need to make it follow our rules. -The easy way to do it is with a combination of `defer.inlineCallbacks`, -and `context.PreserveLoggingContext`. Suppose we want to implement +The easy way to do it is by using `context.make_deferred_yieldable`. Suppose we want to implement `sleep`, which returns a deferred which will run its callbacks after a given number of seconds. That might look like: @@ -186,25 +173,12 @@ def get_sleep_deferred(seconds): return d ``` -That doesn't follow the rules, but we can fix it by wrapping it with -`PreserveLoggingContext` and `yield` ing on it: +That doesn't follow the rules, but we can fix it by calling it through +`context.make_deferred_yieldable`: ```python -@defer.inlineCallbacks -def sleep(seconds): - with PreserveLoggingContext(): - yield get_sleep_deferred(seconds) -``` - -This technique works equally for external functions which return -deferreds, or deferreds we have made ourselves. - -You can also use `context.make_deferred_yieldable`, which just does the -boilerplate for you, so the above could be written: - -```python -def sleep(seconds): - return context.make_deferred_yieldable(get_sleep_deferred(seconds)) +async def sleep(seconds): + return await context.make_deferred_yieldable(get_sleep_deferred(seconds)) ``` ## Fire-and-forget @@ -213,20 +187,18 @@ Sometimes you want to fire off a chain of execution, but not wait for its result. That might look a bit like this: ```python -@defer.inlineCallbacks -def do_request_handling(): - yield foreground_operation() +async def do_request_handling(): + await foreground_operation() # *don't* do this background_operation() logger.debug("Request handling complete") -@defer.inlineCallbacks -def background_operation(): - yield first_background_step() +async def background_operation(): + await first_background_step() logger.debug("Completed first step") - yield second_background_step() + await second_background_step() logger.debug("Completed second step") ``` @@ -235,13 +207,13 @@ The above code does a couple of steps in the background after against the `request_context` logcontext, which may or may not be desirable. There are two big problems with the above, however. The first problem is that, if `background_operation` returns an incomplete -Deferred, it will expect its caller to `yield` immediately, so will have +awaitable, it will expect its caller to `await` immediately, so will have cleared the logcontext. In this example, that means that 'Request handling complete' will be logged without any context. The second problem, which is potentially even worse, is that when the -Deferred returned by `background_operation` completes, it will restore -the original logcontext. There is nothing waiting on that Deferred, so +awaitable returned by `background_operation` completes, it will restore +the original logcontext. There is nothing waiting on that awaitable, so the logcontext will leak into the reactor and possibly get attached to some arbitrary future operation. @@ -254,9 +226,8 @@ deferred completes will be the empty logcontext), and will restore the current logcontext before continuing the foreground process: ```python -@defer.inlineCallbacks -def do_request_handling(): - yield foreground_operation() +async def do_request_handling(): + await foreground_operation() # start background_operation off in the empty logcontext, to # avoid leaking the current context into the reactor. @@ -274,16 +245,15 @@ Obviously that option means that the operations done in The second option is to use `context.run_in_background`, which wraps a function so that it doesn't reset the logcontext even when it returns -an incomplete deferred, and adds a callback to the returned deferred to +an incomplete awaitable, and adds a callback to the returned awaitable to reset the logcontext. In other words, it turns a function that follows -the Synapse rules about logcontexts and Deferreds into one which behaves +the Synapse rules about logcontexts and awaitables into one which behaves more like an external function --- the opposite operation to that described in the previous section. It can be used like this: ```python -@defer.inlineCallbacks -def do_request_handling(): - yield foreground_operation() +async def do_request_handling(): + await foreground_operation() context.run_in_background(background_operation) @@ -294,152 +264,53 @@ def do_request_handling(): ## Passing synapse deferreds into third-party functions A typical example of this is where we want to collect together two or -more deferred via `defer.gatherResults`: +more awaitables via `defer.gatherResults`: ```python -d1 = operation1() -d2 = operation2() -d3 = defer.gatherResults([d1, d2]) +a1 = operation1() +a2 = operation2() +a3 = defer.gatherResults([a1, a2]) ``` This is really a variation of the fire-and-forget problem above, in that -we are firing off `d1` and `d2` without yielding on them. The difference +we are firing off `a1` and `a2` without awaiting on them. The difference is that we now have third-party code attached to their callbacks. Anyway either technique given in the [Fire-and-forget](#fire-and-forget) section will work. -Of course, the new Deferred returned by `gatherResults` needs to be +Of course, the new awaitable returned by `gather` needs to be wrapped in order to make it follow the logcontext rules before we can -yield it, as described in [Where you create a new Deferred, make it +yield it, as described in [Where you create a new awaitable, make it follow the -rules](#where-you-create-a-new-deferred-make-it-follow-the-rules). +rules](#where-you-create-a-new-awaitable-make-it-follow-the-rules). So, option one: reset the logcontext before starting the operations to be gathered: ```python -@defer.inlineCallbacks -def do_request_handling(): +async def do_request_handling(): with PreserveLoggingContext(): - d1 = operation1() - d2 = operation2() - result = yield defer.gatherResults([d1, d2]) + a1 = operation1() + a2 = operation2() + result = await defer.gatherResults([a1, a2]) ``` In this case particularly, though, option two, of using -`context.preserve_fn` almost certainly makes more sense, so that +`context.run_in_background` almost certainly makes more sense, so that `operation1` and `operation2` are both logged against the original logcontext. This looks like: ```python -@defer.inlineCallbacks -def do_request_handling(): - d1 = context.preserve_fn(operation1)() - d2 = context.preserve_fn(operation2)() +async def do_request_handling(): + a1 = context.run_in_background(operation1) + a2 = context.run_in_background(operation2) - with PreserveLoggingContext(): - result = yield defer.gatherResults([d1, d2]) + result = await make_deferred_yieldable(defer.gatherResults([a1, a2])) ``` -## Was all this really necessary? - -The conventions used work fine for a linear flow where everything -happens in series via `defer.inlineCallbacks` and `yield`, but are -certainly tricky to follow for any more exotic flows. It's hard not to -wonder if we could have done something else. - -We're not going to rewrite Synapse now, so the following is entirely of -academic interest, but I'd like to record some thoughts on an -alternative approach. - -I briefly prototyped some code following an alternative set of rules. I -think it would work, but I certainly didn't get as far as thinking how -it would interact with concepts as complicated as the cache descriptors. - -My alternative rules were: - -- functions always preserve the logcontext of their caller, whether or - not they are returning a Deferred. -- Deferreds returned by synapse functions run their callbacks in the - same context as the function was orignally called in. - -The main point of this scheme is that everywhere that sets the -logcontext is responsible for clearing it before returning control to -the reactor. - -So, for example, if you were the function which started a -`with LoggingContext` block, you wouldn't `yield` within it --- instead -you'd start off the background process, and then leave the `with` block -to wait for it: - -```python -def handle_request(request_id): - with context.LoggingContext() as request_context: - request_context.request = request_id - d = do_request_handling() - - def cb(r): - logger.debug("finished") - - d.addCallback(cb) - return d -``` - -(in general, mixing `with LoggingContext` blocks and -`defer.inlineCallbacks` in the same function leads to slighly -counter-intuitive code, under this scheme). - -Because we leave the original `with` block as soon as the Deferred is -returned (as opposed to waiting for it to be resolved, as we do today), -the logcontext is cleared before control passes back to the reactor; so -if there is some code within `do_request_handling` which needs to wait -for a Deferred to complete, there is no need for it to worry about -clearing the logcontext before doing so: - -```python -def handle_request(): - r = do_some_stuff() - r.addCallback(do_some_more_stuff) - return r -``` - ---- and provided `do_some_stuff` follows the rules of returning a -Deferred which runs its callbacks in the original logcontext, all is -happy. - -The business of a Deferred which runs its callbacks in the original -logcontext isn't hard to achieve --- we have it today, in the shape of -`context._PreservingContextDeferred`: - -```python -def do_some_stuff(): - deferred = do_some_io() - pcd = _PreservingContextDeferred(LoggingContext.current_context()) - deferred.chainDeferred(pcd) - return pcd -``` - -It turns out that, thanks to the way that Deferreds chain together, we -automatically get the property of a context-preserving deferred with -`defer.inlineCallbacks`, provided the final Defered the function -`yields` on has that property. So we can just write: - -```python -@defer.inlineCallbacks -def handle_request(): - yield do_some_stuff() - yield do_some_more_stuff() -``` - -To conclude: I think this scheme would have worked equally well, with -less danger of messing it up, and probably made some more esoteric code -easier to write. But again --- changing the conventions of the entire -Synapse codebase is not a sensible option for the marginal improvement -offered. - -## A note on garbage-collection of Deferred chains +## A note on garbage-collection of awaitable chains -It turns out that our logcontext rules do not play nicely with Deferred +It turns out that our logcontext rules do not play nicely with awaitable chains which get orphaned and garbage-collected. Imagine we have some code that looks like this: @@ -451,13 +322,12 @@ def on_something_interesting(): for d in listener_queue: d.callback("foo") -@defer.inlineCallbacks -def await_something_interesting(): - new_deferred = defer.Deferred() - listener_queue.append(new_deferred) +async def await_something_interesting(): + new_awaitable = defer.Deferred() + listener_queue.append(new_awaitable) with PreserveLoggingContext(): - yield new_deferred + await new_awaitable ``` Obviously, the idea here is that we have a bunch of things which are @@ -476,18 +346,19 @@ def reset_listener_queue(): listener_queue.clear() ``` -So, both ends of the deferred chain have now dropped their references, -and the deferred chain is now orphaned, and will be garbage-collected at -some point. Note that `await_something_interesting` is a generator -function, and when Python garbage-collects generator functions, it gives -them a chance to clean up by making the `yield` raise a `GeneratorExit` +So, both ends of the awaitable chain have now dropped their references, +and the awaitable chain is now orphaned, and will be garbage-collected at +some point. Note that `await_something_interesting` is a coroutine, +which Python implements as a generator function. When Python +garbage-collects generator functions, it gives them a chance to +clean up by making the `async` (or `yield`) raise a `GeneratorExit` exception. In our case, that means that the `__exit__` handler of `PreserveLoggingContext` will carefully restore the request context, but there is now nothing waiting for its return, so the request context is never cleared. -To reiterate, this problem only arises when *both* ends of a deferred -chain are dropped. Dropping the the reference to a deferred you're -supposed to be calling is probably bad practice, so this doesn't +To reiterate, this problem only arises when *both* ends of a awaitable +chain are dropped. Dropping the the reference to an awaitable you're +supposed to be awaiting is bad practice, so this doesn't actually happen too much. Unfortunately, when it does happen, it will lead to leaked logcontexts which are incredibly hard to track down. diff --git a/docs/modules.md b/docs/modules.md index bec1c06d15f4..9a430390a4cb 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following API method: ```python -def ModuleApi.register_web_resource(path: str, resource: IResource) +def ModuleApi.register_web_resource(path: str, resource: IResource) -> None ``` The path is the full absolute path to register the resource at. For example, if you @@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c and is under no obligation to implement all callbacks from the categories it registers callbacks for. +Modules can register callbacks using one of the module API's `register_[...]_callbacks` +methods. The callback functions are passed to these methods as keyword arguments, with +the callback name as the argument name and the function as its value. This is demonstrated +in the example below. A `register_[...]_callbacks` method exists for each module type +documented in this section. + #### Spam checker callbacks -To register one of the callbacks described in this section, a module needs to use the -module API's `register_spam_checker_callbacks` method. The callback functions are passed -to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the -argument name and the function as its value. This is demonstrated in the example below. +Spam checker callbacks allow module developers to implement spam mitigation actions for +Synapse instances. Spam checker callbacks can be registered using the module API's +`register_spam_checker_callbacks` method. The available spam checker callbacks are: @@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (i.e. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). ```python async def user_may_create_room(user: str) -> bool @@ -181,13 +186,103 @@ The arguments passed to this callback are: ```python async def check_media_file_for_spam( file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper", - file_info: "synapse.rest.media.v1._base.FileInfo" + file_info: "synapse.rest.media.v1._base.FileInfo", ) -> bool ``` Called when storing a local or remote file. The module must return a boolean indicating whether the given file can be stored in the homeserver's media store. +#### Account validity callbacks + +Account validity callbacks allow module developers to add extra steps to verify the +validity on an account, i.e. see if a user can be granted access to their account on the +Synapse instance. Account validity callbacks can be registered using the module API's +`register_account_validity_callbacks` method. + +The available account validity callbacks are: + +```python +async def is_user_expired(user: str) -> Optional[bool] +``` + +Called when processing any authenticated request (except for logout requests). The module +can return a `bool` to indicate whether the user has expired and should be locked out of +their account, or `None` if the module wasn't able to figure it out. The user is +represented by their Matrix user ID (e.g. `@alice:example.com`). + +If the module returns `True`, the current request will be denied with the error code +`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't +invalidate the user's access token. + +```python +async def on_user_registration(user: str) -> None +``` + +Called after successfully registering a user, in case the module needs to perform extra +operations to keep track of them. (e.g. add them to a database table). The user is +represented by their Matrix user ID. + +#### Third party rules callbacks + +Third party rules callbacks allow module developers to add extra checks to verify the +validity of incoming events. Third party event rules callbacks can be registered using +the module API's `register_third_party_rules_callbacks` method. + +The available third party rules callbacks are: + +```python +async def check_event_allowed( + event: "synapse.events.EventBase", + state_events: "synapse.types.StateMap", +) -> Tuple[bool, Optional[dict]] +``` + +** +This callback is very experimental and can and will break without notice. Module developers +are encouraged to implement `check_event_for_spam` from the spam checker category instead. +** + +Called when processing any incoming event, with the event and a `StateMap` +representing the current state of the room the event is being sent into. A `StateMap` is +a dictionary that maps tuples containing an event type and a state key to the +corresponding state event. For example retrieving the room's `m.room.create` event from +the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`. +The module must return a boolean indicating whether the event can be allowed. + +Note that this callback function processes incoming events coming via federation +traffic (on top of client traffic). This means denying an event might cause the local +copy of the room's history to diverge from that of remote servers. This may cause +federation issues in the room. It is strongly recommended to only deny events using this +callback function if the sender is a local user, or in a private federation in which all +servers are using the same module, with the same configuration. + +If the boolean returned by the module is `True`, it may also tell Synapse to replace the +event with new data by returning the new event's data as a dictionary. In order to do +that, it is recommended the module calls `event.get_dict()` to get the current event as a +dictionary, and modify the returned dictionary accordingly. + +Note that replacing the event only works for events sent by local users, not for events +received over federation. + +```python +async def on_create_room( + requester: "synapse.types.Requester", + request_content: dict, + is_requester_admin: bool, +) -> None +``` + +Called when processing a room creation request, with the `Requester` object for the user +performing the request, a dictionary representing the room creation request's JSON body +(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom) +for a list of possible parameters), and a boolean indicating whether the user performing +the request is a server admin. + +Modules can modify the `request_content` (by e.g. adding events to its `initial_state`), +or deny the room's creation by raising a `module_api.errors.SynapseError`. + + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/replication.md b/docs/replication.md index ed8823315726..e82df0de8a30 100644 --- a/docs/replication.md +++ b/docs/replication.md @@ -28,7 +28,7 @@ minimal. ### The Replication Protocol -See [tcp_replication.md](tcp_replication.md) +See [the TCP replication documentation](tcp_replication.md). ### The Slaved DataStore diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 01db466f96e5..76bb45aff2e1 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -21,7 +21,7 @@ port 8448. Where these are different, we refer to the 'client port' and the 'federation port'. See [the Matrix specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names) for more details of the algorithm used for federation connections, and -[delegate.md](delegate.md) for instructions on setting up delegation. +[Delegation](delegate.md) for instructions on setting up delegation. **NOTE**: Your reverse proxy must not `canonicalise` or `normalise` the requested URI in any way (for example, by decoding `%xx` escapes). @@ -98,6 +98,33 @@ example.com:8448 { reverse_proxy http://localhost:8008 } ``` +[Delegation](delegate.md) example: +``` +(matrix-well-known-header) { + # Headers + header Access-Control-Allow-Origin "*" + header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS" + header Access-Control-Allow-Headers "Origin, X-Requested-With, Content-Type, Accept, Authorization" + header Content-Type "application/json" +} + +example.com { + handle /.well-known/matrix/server { + import matrix-well-known-header + respond `{"m.server":"matrix.example.com:443"}` + } + + handle /.well-known/matrix/client { + import matrix-well-known-header + respond `{"m.homeserver":{"base_url":"https://matrix.example.com"},"m.identity_server":{"base_url":"https://identity.example.com"}}` + } +} + +matrix.example.com { + reverse_proxy /_matrix/* http://localhost:8008 + reverse_proxy /_synapse/client/* http://localhost:8008 +} +``` ### Apache diff --git a/docs/room_and_user_statistics.md b/docs/room_and_user_statistics.md index e1facb38d414..cc38c890bb59 100644 --- a/docs/room_and_user_statistics.md +++ b/docs/room_and_user_statistics.md @@ -1,9 +1,9 @@ Room and User Statistics ======================== -Synapse maintains room and user statistics (as well as a cache of room state), -in various tables. These can be used for administrative purposes but are also -used when generating the public room directory. +Synapse maintains room and user statistics in various tables. These can be used +for administrative purposes but are also used when generating the public room +directory. # Synapse Developer Documentation @@ -15,48 +15,8 @@ used when generating the public room directory. * **subject**: Something we are tracking stats about – currently a room or user. * **current row**: An entry for a subject in the appropriate current statistics table. Each subject can have only one. -* **historical row**: An entry for a subject in the appropriate historical - statistics table. Each subject can have any number of these. ### Overview -Stats are maintained as time series. There are two kinds of column: - -* absolute columns – where the value is correct for the time given by `end_ts` - in the stats row. (Imagine a line graph for these values) - * They can also be thought of as 'gauges' in Prometheus, if you are familiar. -* per-slice columns – where the value corresponds to how many of the occurrences - occurred within the time slice given by `(end_ts − bucket_size)
end_ts` - or `start_ts
end_ts`. (Imagine a histogram for these values) - -Stats are maintained in two tables (for each type): current and historical. - -Current stats correspond to the present values. Each subject can only have one -entry. - -Historical stats correspond to values in the past. Subjects may have multiple -entries. - -## Concepts around the management of stats - -### Current rows - -Current rows contain the most up-to-date statistics for a room. -They only contain absolute columns - -### Historical rows - -Historical rows can always be considered to be valid for the time slice and -end time specified. - -* historical rows will not exist for every time slice – they will be omitted - if there were no changes. In this case, the following assumptions can be - made to interpolate/recreate missing rows: - - absolute fields have the same values as in the preceding row - - per-slice fields are zero (`0`) -* historical rows will not be retained forever – rows older than a configurable - time will be purged. - -#### Purge - -The purging of historical rows is not yet implemented. +Stats correspond to the present values. Current rows contain the most up-to-date +statistics for a room. Each subject can only have one entry. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 054770f71ff6..853c2f68990f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1310,91 +1310,6 @@ account_threepid_delegates: #auto_join_rooms_for_guests: false -## Account Validity ## - -# Optional account validity configuration. This allows for accounts to be denied -# any request after a given period. -# -# Once this feature is enabled, Synapse will look for registered users without an -# expiration date at startup and will add one to every account it found using the -# current settings at that time. -# This means that, if a validity period is set, and Synapse is restarted (it will -# then derive an expiration date from the current validity period), and some time -# after that the validity period changes and Synapse is restarted, the users' -# expiration dates won't be updated unless their account is manually renewed. This -# date will be randomly selected within a range [now + period - d ; now + period], -# where d is equal to 10% of the validity period. -# -account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - - ## Metrics ### # Enable collection and rendering of performance metrics @@ -2653,11 +2568,6 @@ stats: # #enabled: false - # The size of each timeslice in the room_stats_historical and - # user_stats_historical tables, as a time period. Defaults to "1d". - # - #bucket_size: 1h - # Server Notices room configuration # @@ -2744,19 +2654,6 @@ stats: # action: allow -# Server admins can define a Python module that implements extra rules for -# allowing or denying incoming events. In order to work, this module needs to -# override the methods defined in synapse/events/third_party_rules.py. -# -# This feature is designed to be used in closed federations only, where each -# participating server enforces the same rules. -# -#third_party_event_rules: -# module: "my_custom_project.SuperRulesSet" -# config: -# example_option: 'things' - - ## Opentracing ## # These settings enable opentracing, which implements distributed tracing. diff --git a/docs/server_notices.md b/docs/server_notices.md index 950a6608e9bb..339d10a0ab3f 100644 --- a/docs/server_notices.md +++ b/docs/server_notices.md @@ -3,8 +3,8 @@ 'Server Notices' are a new feature introduced in Synapse 0.30. They provide a channel whereby server administrators can send messages to users on the server. -They are used as part of communication of the server polices(see -[consent_tracking.md](consent_tracking.md)), however the intention is that +They are used as part of communication of the server polices (see +[Consent Tracking](consent_tracking.md)), however the intention is that they may also find a use for features such as "Message of the day". This is a feature specific to Synapse, but it uses standard Matrix diff --git a/docs/setup/installation.md b/docs/setup/installation.md index d041d083339b..8540a7b0c10d 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -166,13 +166,16 @@ sudo dnf groupinstall "Development Tools" Installing prerequisites on macOS: +You may need to install the latest Xcode developer tools: ```sh xcode-select --install -sudo easy_install pip -sudo pip install virtualenv -brew install pkg-config libffi ``` +On ARM-based Macs you may need to explicitly install libjpeg which is a pillow dependency. You can use Homebrew (https://brew.sh): +```sh + brew install jpeg + ``` + On macOS Catalina (10.15) you may need to explicitly install OpenSSL via brew and inform `pip` about it so that `psycopg2` builds: @@ -268,9 +271,8 @@ For more details, see ##### Matrix.org packages -Matrix.org provides Debian/Ubuntu packages of the latest stable version of -Synapse via . They are available for Debian -9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them: +Matrix.org provides Debian/Ubuntu packages of Synapse via +. To install the latest release: ```sh sudo apt install -y lsb-release wget apt-transport-https @@ -281,12 +283,16 @@ sudo apt update sudo apt install matrix-synapse-py3 ``` -**Note**: if you followed a previous version of these instructions which -recommended using `apt-key add` to add an old key from -`https://matrix.org/packages/debian/`, you should note that this key has been -revoked. You should remove the old key with `sudo apt-key remove -C35EB17E1EAE708E6603A9B3AD0592FE47F0DF61`, and follow the above instructions to -update your configuration. +Packages are also published for release candidates. To enable the prerelease +channel, add `prerelease` to the `sources.list` line. For example: + +```sh +sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg +echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main prerelease" | + sudo tee /etc/apt/sources.list.d/matrix-org.list +sudo apt update +sudo apt install matrix-synapse-py3 +``` The fingerprint of the repository signing key (as shown by `gpg /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is @@ -409,7 +415,7 @@ instead. Advantages include: - allowing the DB to be run on separate hardware For information on how to install and use PostgreSQL in Synapse, please see -[docs/postgres.md](../postgres.md) +[Using Postgres](../postgres.md) SQLite is only acceptable for testing purposes. SQLite should not be used in a production server. Synapse will perform poorly when using @@ -424,7 +430,7 @@ over HTTPS. The recommended way to do so is to set up a reverse proxy on port `8448`. You can find documentation on doing so in -[docs/reverse_proxy.md](../reverse_proxy.md). +[the reverse proxy documentation](../reverse_proxy.md). Alternatively, you can configure Synapse to expose an HTTPS port. To do so, you will need to edit `homeserver.yaml`, as follows: @@ -451,7 +457,7 @@ so, you will need to edit `homeserver.yaml`, as follows: `cert.pem`). For a more detailed guide to configuring your server for federation, see -[federate.md](../federate.md). +[Federation](../federate.md). ### Client Well-Known URI @@ -563,9 +569,7 @@ on your server even if `enable_registration` is `false`. ### Setting up a TURN server For reliable VoIP calls to be routed via this homeserver, you MUST configure -a TURN server. See -[docs/turn-howto.md](../turn-howto.md) -for details. +a TURN server. See [TURN setup](../turn-howto.md) for details. ### URL previews diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md index a7de2de88aca..b160d9352840 100644 --- a/docs/systemd-with-workers/README.md +++ b/docs/systemd-with-workers/README.md @@ -14,10 +14,12 @@ contains an example configuration for the `federation_reader` worker. ## Synapse configuration files -See [workers.md](../workers.md) for information on how to set up the -configuration files and reverse-proxy correctly. You can find an example worker -config in the [workers](https://github.com/matrix-org/synapse/tree/develop/docs/systemd-with-workers/workers/) -folder. +See [the worker documentation](../workers.md) for information on how to set up the +configuration files and reverse-proxy correctly. +Below is a sample `federation_reader` worker configuration file. +```yaml +{{#include workers/federation_reader.yaml}} +``` Systemd manages daemonization itself, so ensure that none of the configuration files set either `daemonize` or `worker_daemonize`. @@ -72,12 +74,12 @@ systemctl restart matrix-synapse.target **Optional:** If further hardening is desired, the file `override-hardened.conf` may be copied from -`contrib/systemd/override-hardened.conf` in this repository to the location +[contrib/systemd/override-hardened.conf](https://github.com/matrix-org/synapse/tree/develop/contrib/systemd/) +in this repository to the location `/etc/systemd/system/matrix-synapse.service.d/override-hardened.conf` (the directory may have to be created). It enables certain sandboxing features in systemd to further secure the synapse service. You may read the comments to -understand what the override file is doing. The same file will need to be copied -to +understand what the override file is doing. The same file will need to be copied to `/etc/systemd/system/matrix-synapse-worker@.service.d/override-hardened-worker.conf` (this directory may also have to be created) in order to apply the same hardening options to any worker processes. diff --git a/docs/upgrade.md b/docs/upgrade.md index db0450f563a2..c8f4a2c17172 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,19 @@ process, for example: ``` +# Upgrading to v1.39.0 + +## Deprecation of the current third-party rules module interface + +The current third-party rules module interface is deprecated in favour of the new generic +modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer +to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface) +to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules) +to update their configuration once the modules they are using have been updated. + +We plan to remove support for the current third-party rules interface in September 2021. + + # Upgrading to v1.38.0 ## Re-indexing of `events` table on Postgres databases diff --git a/docs/usage/configuration/logging_sample_config.md b/docs/usage/configuration/logging_sample_config.md index 4c4bc6fc1639..a673d487b85e 100644 --- a/docs/usage/configuration/logging_sample_config.md +++ b/docs/usage/configuration/logging_sample_config.md @@ -11,4 +11,4 @@ a fresh config using Synapse by following the instructions in ```yaml {{#include ../../sample_log_config.yaml}} -``__` \ No newline at end of file +``` \ No newline at end of file diff --git a/docs/workers.md b/docs/workers.md index 797758ee84b3..d8672324c301 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -73,7 +73,7 @@ https://hub.docker.com/r/matrixdotorg/synapse/. To make effective use of the workers, you will need to configure an HTTP reverse-proxy such as nginx or haproxy, which will direct incoming requests to the correct worker, or to the main synapse instance. See -[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse +[the reverse proxy documentation](reverse_proxy.md) for information on setting up a reverse proxy. When using workers, each worker process has its own configuration file which @@ -170,8 +170,8 @@ Finally, you need to start your worker processes. This can be done with either `synctl` or your distribution's preferred service manager such as `systemd`. We recommend the use of `systemd` where available: for information on setting up `systemd` to start synapse workers, see -[systemd-with-workers](systemd-with-workers). To use `synctl`, see -[synctl_workers.md](synctl_workers.md). +[Systemd with Workers](systemd-with-workers). To use `synctl`, see +[Using synctl with Workers](synctl_workers.md). ## Available worker applications diff --git a/mypy.ini b/mypy.ini index 72ce932d7302..8717ae738e31 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,6 +83,7 @@ files = synapse/util/stringutils.py, synapse/visibility.py, tests/replication, + tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, tests/rest/client/v1/test_login.py, diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 5bfaa4ad2f62..cff433af2af2 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -139,6 +139,11 @@ def run(): # Switch to the release branch. parsed_new_version = version.parse(new_version) + + # We assume for debian changelogs that we only do RCs or full releases. + assert not parsed_new_version.is_devrelease + assert not parsed_new_version.is_postrelease + release_branch_name = ( f"release-v{parsed_new_version.major}.{parsed_new_version.minor}" ) @@ -190,12 +195,21 @@ def run(): # Generate changelogs subprocess.run("python3 -m towncrier", shell=True) - # Generate debian changelogs if its not an RC. - if not rc: - subprocess.run( - f'dch -M -v {new_version} "New synapse release {new_version}."', shell=True - ) - subprocess.run('dch -M -r -D stable ""', shell=True) + # Generate debian changelogs + if parsed_new_version.pre is not None: + # If this is an RC then we need to coerce the version string to match + # Debian norms, e.g. 1.39.0rc2 gets converted to 1.39.0~rc2. + base_ver = parsed_new_version.base_version + pre_type, pre_num = parsed_new_version.pre + debian_version = f"{base_ver}~{pre_type}{pre_num}" + else: + debian_version = new_version + + subprocess.run( + f'dch -M -v {debian_version} "New synapse release {debian_version}."', + shell=True, + ) + subprocess.run('dch -M -r -D stable ""', shell=True) # Show the user the changes and ask if they want to edit the change log. repo.git.add("-u") diff --git a/synapse/__init__.py b/synapse/__init__.py index 7ea5a790db95..c9a445c8fe47 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ except ImportError: pass -__version__ = "1.38.1" +__version__ = "1.39.0rc3" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 307f5f9a9463..05699714eea7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -62,16 +62,14 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() + self._account_validity_handler = hs.get_account_validity_handler() - self.token_cache = LruCache( + self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" - ) # type: LruCache[str, Tuple[str, bool]] + ) self._auth_blocking = AuthBlocking(self.hs) - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users @@ -187,12 +185,17 @@ async def get_user_by_req( shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity_enabled and not allow_expired: - if await self.store.is_account_expired( - user_info.user_id, self.clock.time_msec() + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + user_info.user_id ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired raise AuthError( - 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, ) device_id = user_info.device_id @@ -240,6 +243,37 @@ async def get_user_by_req( except KeyError: raise MissingClientTokenError() + async def validate_appservice_can_control_user_id( + self, app_service: ApplicationService, user_id: str + ): + """Validates that the app service is allowed to control + the given user. + + Args: + app_service: The app service that controls the user + user_id: The author MXID that the app service is controlling + + Raises: + AuthError: If the application service is not allowed to control the user + (user namespace regex does not match, wrong homeserver, etc) + or if the user has not been registered yet. + """ + + # It's ok if the app service is trying to use the sender from their registration + if app_service.sender == user_id: + pass + # Check to make sure the app service is allowed to control the user + elif not app_service.is_interested_in_user(user_id): + raise AuthError( + 403, + "Application service cannot masquerade as this user (%s)." % user_id, + ) + # Check to make sure the user is already registered on the homeserver + elif not (await self.store.get_user_by_id(user_id)): + raise AuthError( + 403, "Application service has not registered this user (%s)" % user_id + ) + async def _get_appservice_user_id( self, request: Request ) -> Tuple[Optional[str], Optional[ApplicationService]]: @@ -261,13 +295,11 @@ async def _get_appservice_user_id( return app_service.sender, app_service user_id = request.args[b"user_id"][0].decode("utf8") + await self.validate_appservice_can_control_user_id(app_service, user_id) + if app_service.sender == user_id: return app_service.sender, app_service - if not app_service.is_interested_in_user(user_id): - raise AuthError(403, "Application service cannot masquerade as this user.") - if not (await self.store.get_user_by_id(user_id)): - raise AuthError(403, "Application service has not registered this user") return user_id, app_service async def get_user_by_access_token( diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8363c2bb0f5f..8c7ad2a40762 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -127,6 +127,14 @@ class ToDeviceEventTypes: RoomKeyRequest = "m.room_key_request" +class DeviceKeyAlgorithms: + """Spec'd algorithms for the generation of per-device keys""" + + ED25519 = "ed25519" + CURVE25519 = "curve25519" + SIGNED_CURVE25519 = "signed_curve25519" + + class EduTypes: Presence = "m.presence" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 4cb8bbaf701e..054ab14ab6e0 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -118,7 +118,7 @@ def __init__(self, location: bytes, http_code: int = http.FOUND): super().__init__(code=http_code, msg=msg) self.location = location - self.cookies = [] # type: List[bytes] + self.cookies: List[bytes] = [] class SynapseError(CodeMessageException): @@ -160,7 +160,7 @@ def __init__( ): super().__init__(code, msg, errcode) if additional_fields is None: - self._additional_fields = {} # type: Dict + self._additional_fields: Dict = {} else: self._additional_fields = dict(additional_fields) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index ce49a0ad58d1..ad1ff6a9df6c 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -289,7 +289,7 @@ def check(self, event): room_id = None ev_type = "m.presence" contains_url = False - labels = [] # type: List[str] + labels: List[str] = [] else: sender = event.get("sender", None) if not sender: diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index b9a10283f44a..3e3d09bbd244 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -46,9 +46,7 @@ def __init__( # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions = ( - OrderedDict() - ) # type: OrderedDict[Hashable, Tuple[float, int, float]] + self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() async def can_do_action( self, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index f6c1c97b40ca..a20abc5a655f 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -195,7 +195,7 @@ class RoomVersions: ) -KNOWN_ROOM_VERSIONS = { +KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { v.identifier: v for v in ( RoomVersions.V1, @@ -209,4 +209,4 @@ class RoomVersions: RoomVersions.V7, ) # Note that we do not include MSC2043 here unless it is enabled in the config. -} # type: Dict[str, RoomVersion] +} diff --git a/synapse/app/_base.py b/synapse/app/_base.py index b30571fe495b..50a02f51f54a 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -38,6 +38,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -368,6 +369,7 @@ def run_sighup(*args, **kwargs): module(config=config, api=module_api) load_legacy_spam_checkers(hs) + load_legacy_third_party_event_rules(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 5b041fcaade2..c3d49925189e 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -270,7 +270,7 @@ def _listen_http(self, listener_config: ListenerConfig): site_tag = port # We always include a health resource. - resources = {"/health": HealthResource()} # type: Dict[str, IResource] + resources: Dict[str, IResource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -395,10 +395,8 @@ def start_listening(self): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7af56ac13640..920b34d97bf4 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -305,10 +305,8 @@ def start_listening(self): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 8f86cecb7698..86ad7337a995 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -71,6 +71,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): # General statistics # + store = hs.get_datastore() + stats["homeserver"] = hs.config.server_name stats["server_context"] = hs.config.server_context stats["timestamp"] = now @@ -79,34 +81,38 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): stats["python_version"] = "{}.{}.{}".format( version.major, version.minor, version.micro ) - stats["total_users"] = await hs.get_datastore().count_all_users() + stats["total_users"] = await store.count_all_users() - total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() + total_nonbridged_users = await store.count_nonbridged_users() stats["total_nonbridged_users"] = total_nonbridged_users - daily_user_type_results = await hs.get_datastore().count_daily_user_type() + daily_user_type_results = await store.count_daily_user_type() for name, count in daily_user_type_results.items(): stats["daily_user_type_" + name] = count - room_count = await hs.get_datastore().get_room_count() + room_count = await store.get_room_count() stats["total_room_count"] = room_count - stats["daily_active_users"] = await hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() - daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms() + stats["daily_active_users"] = await store.count_daily_users() + stats["monthly_active_users"] = await store.count_monthly_users() + daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms() stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms - stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages() - daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages() + stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages() + daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages() stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages - stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_active_rooms"] = await store.count_daily_active_rooms() + stats["daily_messages"] = await store.count_daily_messages() + daily_sent_messages = await store.count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - r30_results = await hs.get_datastore().count_r30_users() + r30_results = await store.count_r30_users() for name, count in r30_results.items(): stats["r30_users_" + name] = count + r30v2_results = await store.count_r30v2_users() + for name, count in r30v2_results.items(): + stats["r30v2_users_" + name] = count + stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size @@ -115,8 +121,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): # # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version + stats["database_engine"] = store.db_pool.engine.module.__name__ + stats["database_server_version"] = store.db_pool.engine.server_version # # Logging configuration diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 61152b2c46cf..935f24263c98 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -88,9 +88,9 @@ def __init__(self, hs): super().__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache( + self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS - ) # type: ResponseCache[Tuple[str, str]] + ) async def query_user(self, service, user_id): if service.url is None: diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 957de7f3a613..6be4eafe5582 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -18,6 +18,21 @@ class AccountValidityConfig(Config): section = "account_validity" def read_config(self, config, **kwargs): + """Parses the old account validity config. The config format looks like this: + + account_validity: + enabled: true + period: 6w + renew_at: 1w + renew_email_subject: "Renew your %(app)s account" + template_dir: "res/templates" + account_renewed_html_path: "account_renewed.html" + invalid_token_html_path: "invalid_token.html" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ account_validity_config = config.get("account_validity") or {} self.account_validity_enabled = account_validity_config.get("enabled", False) self.account_validity_renew_by_email_enabled = ( @@ -75,90 +90,3 @@ def read_config(self, config, **kwargs): ], account_validity_template_dir, ) - - def generate_config_section(self, **kwargs): - return """\ - ## Account Validity ## - - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - """ diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 746fc3cc02f6..1ebea88db2a2 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -57,14 +57,14 @@ def load_appservices(hostname, config_files): return [] # Dicts of value -> filename - seen_as_tokens = {} # type: Dict[str, str] - seen_ids = {} # type: Dict[str, str] + seen_as_tokens: Dict[str, str] = {} + seen_ids: Dict[str, str] = {} appservices = [] for config_file in config_files: try: - with open(config_file, "r") as f: + with open(config_file) as f: appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 7789b4032323..8d5f38b5d934 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -25,7 +25,7 @@ _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" # Map from canonicalised cache name to cache. -_CACHES = {} # type: Dict[str, Callable[[float], None]] +_CACHES: Dict[str, Callable[[float], None]] = {} # a lock on the contents of _CACHES _CACHES_LOCK = threading.Lock() @@ -157,7 +157,7 @@ def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size( config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) ) - self.cache_factors = {} # type: Dict[str, float] + self.cache_factors: Dict[str, float] = {} cache_config = config.get("caches") or {} self.global_factor = cache_config.get( diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 5564d7d097d4..bcecbfec0322 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -134,9 +134,9 @@ def read_config(self, config, **kwargs): # trusted_third_party_id_servers does not contain a scheme whereas # account_threepid_delegate_email is expected to. Presume https - self.account_threepid_delegate_email = ( + self.account_threepid_delegate_email: Optional[str] = ( "https://" + first_trusted_identity_server - ) # type: Optional[str] + ) self.using_identity_server_from_trusted_list = True else: raise ConfigError( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 7fb1f7021f5e..e25ccba9ace0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -25,10 +25,10 @@ def read_config(self, config: JsonDict, **kwargs): experimental = config.get("experimental_features") or {} # MSC2858 (multiple SSO identity providers) - self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool + self.msc2858_enabled: bool = experimental.get("msc2858_enabled", False) # MSC3026 (busy presence state) - self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool + self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) # MSC2716 (backfill existing history) - self.msc2716_enabled = experimental.get("msc2716_enabled", False) # type: bool + self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) diff --git a/synapse/config/federation.py b/synapse/config/federation.py index cdd7a1ef054e..7d64993e22fb 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -22,7 +22,7 @@ class FederationConfig(Config): def read_config(self, config, **kwargs): # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None # type: Optional[dict] + self.federation_domain_whitelist: Optional[dict] = None federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 942e2672a961..ba89d11cf016 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -460,7 +460,7 @@ def _parse_oidc_config_dict( ) from e client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") - client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] + client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None if client_secret_jwt_key_config is not None: keyfile = client_secret_jwt_key_config.get("key_file") if keyfile: diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index fd90b7977292..0f5b2b397754 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): - self.password_providers = [] # type: List[Any] + self.password_providers: List[Any] = [] providers = [] # We want to be backwards compatible with the old `ldap_config` diff --git a/synapse/config/repository.py b/synapse/config/repository.py index a7a82742ac13..0dfb3a227a3b 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -62,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): Dictionary mapping from media type string to list of ThumbnailRequirement tuples. """ - requirements = {} # type: Dict[str, List] + requirements: Dict[str, List] = {} for size in thumbnail_sizes: width = size["width"] height = size["height"] @@ -141,7 +141,7 @@ def read_config(self, config, **kwargs): # # We don't create the storage providers here as not all workers need # them to be started. - self.media_storage_providers = [] # type: List[tuple] + self.media_storage_providers: List[tuple] = [] for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to diff --git a/synapse/config/server.py b/synapse/config/server.py index 6bff715230ca..b9e0c0b30093 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -505,7 +505,7 @@ def read_config(self, config, **kwargs): " greater than 'allowed_lifetime_max'" ) - self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] + self.retention_purge_jobs: List[Dict[str, Optional[int]]] = [] for purge_job_config in retention_config.get("purge_jobs", []): interval_config = purge_job_config.get("interval") @@ -688,23 +688,21 @@ class LimitRemoteRoomsConfig: # not included in the sample configuration file on purpose as it's a temporary # hack, so that some users can trial the new defaults without impacting every # user on the homeserver. - users_new_default_push_rules = ( + users_new_default_push_rules: list = ( config.get("users_new_default_push_rules") or [] - ) # type: list + ) if not isinstance(users_new_default_push_rules, list): raise ConfigError("'users_new_default_push_rules' must be a list") # Turn the list into a set to improve lookup speed. - self.users_new_default_push_rules = set( - users_new_default_push_rules - ) # type: set + self.users_new_default_push_rules: set = set(users_new_default_push_rules) # Whitelist of domain names that given next_link parameters must have - next_link_domain_whitelist = config.get( + next_link_domain_whitelist: Optional[List[str]] = config.get( "next_link_domain_whitelist" - ) # type: Optional[List[str]] + ) - self.next_link_domain_whitelist = None # type: Optional[Set[str]] + self.next_link_domain_whitelist: Optional[Set[str]] = None if next_link_domain_whitelist is not None: if not isinstance(next_link_domain_whitelist, list): raise ConfigError("'next_link_domain_whitelist' must be a list") diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index cb7716c83701..a233a9ce0388 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -34,7 +34,7 @@ class SpamCheckerConfig(Config): section = "spamchecker" def read_config(self, config, **kwargs): - self.spam_checkers = [] # type: List[Tuple[Any, Dict]] + self.spam_checkers: List[Tuple[Any, Dict]] = [] spam_checkers = config.get("spam_checker") or [] if isinstance(spam_checkers, dict): diff --git a/synapse/config/sso.py b/synapse/config/sso.py index e4346e02aa7c..d0f04cf8e6b2 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,7 +39,7 @@ class SSOConfig(Config): section = "sso" def read_config(self, config, **kwargs): - sso_config = config.get("sso") or {} # type: Dict[str, Any] + sso_config: Dict[str, Any] = config.get("sso") or {} # The sso-specific template_dir self.sso_template_dir = sso_config.get("template_dir") diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 78f61fe9da29..6f253e00c0da 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -38,13 +38,9 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True - self.stats_bucket_size = 86400 * 1000 stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) - self.stats_bucket_size = self.parse_duration( - stats_config.get("bucket_size", "1d") - ) if not self.stats_enabled: logger.warning(ROOM_STATS_DISABLED_WARN) @@ -59,9 +55,4 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # correctly. # #enabled: false - - # The size of each timeslice in the room_stats_historical and - # user_stats_historical tables, as a time period. Defaults to "1d". - # - #bucket_size: 1h """ diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index f502ff539e27..a3fae0242082 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -28,18 +28,3 @@ def read_config(self, config, **kwargs): self.third_party_event_rules = load_module( provider, ("third_party_event_rules",) ) - - def generate_config_section(self, **kwargs): - return """\ - # Server admins can define a Python module that implements extra rules for - # allowing or denying incoming events. In order to work, this module needs to - # override the methods defined in synapse/events/third_party_rules.py. - # - # This feature is designed to be used in closed federations only, where each - # participating server enforces the same rules. - # - #third_party_event_rules: - # module: "my_custom_project.SuperRulesSet" - # config: - # example_option: 'things' - """ diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 9a16a8fbae96..5679f05e4270 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -66,10 +66,8 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs): if self.federation_client_minimum_tls_version == "1.3": if getattr(SSL, "OP_NO_TLSv1_3", None) is None: raise ConfigError( - ( - "federation_client_minimum_tls_version cannot be 1.3, " - "your OpenSSL does not support it" - ) + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" ) # Whitelist of domains to not verify certificates for @@ -80,7 +78,7 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs): fed_whitelist_entries = [] # Support globs (*) in whitelist values - self.federation_certificate_verification_whitelist = [] # type: List[Pattern] + self.federation_certificate_verification_whitelist: List[Pattern] = [] for entry in fed_whitelist_entries: try: entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) @@ -132,8 +130,8 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - self.tls_certificate = None # type: Optional[crypto.X509] - self.tls_private_key = None # type: Optional[crypto.PKey] + self.tls_certificate: Optional[crypto.X509] = None + self.tls_private_key: Optional[crypto.PKey] = None def is_disk_cert_valid(self, allow_self_signed=True): """ diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e5a4685ed49f..9e9b1c1c8630 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -170,11 +170,13 @@ def __init__( ) self._key_fetchers = key_fetchers - self._server_queue = BatchingQueue( + self._server_queue: BatchingQueue[ + _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] + ] = BatchingQueue( "keyring_server", clock=hs.get_clock(), process_batch_callback=self._inner_fetch_key_requests, - ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] + ) async def verify_json_for_server( self, @@ -330,7 +332,7 @@ async def _inner_fetch_key_requests( # First we need to deduplicate requests for the same key. We do this by # taking the *maximum* requested `minimum_valid_until_ts` for each pair # of server name/key ID. - server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]] + server_to_key_to_ts: Dict[str, Dict[str, int]] = {} for request in requests: by_server = server_to_key_to_ts.setdefault(request.server_name, {}) for key_id in request.key_ids: @@ -355,7 +357,7 @@ async def _inner_fetch_key_requests( # We now convert the returned list of results into a map from server # name to key ID to FetchKeyResult, to return. - to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]] + to_return: Dict[str, Dict[str, FetchKeyResult]] = {} for (request, results) in zip(deduped_requests, results_per_request): to_return_by_server = to_return.setdefault(request.server_name, {}) for key_id, key_result in results.items(): @@ -455,7 +457,7 @@ async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]): ) res = await self.store.get_server_verify_keys(key_ids_to_fetch) - keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + keys: Dict[str, Dict[str, FetchKeyResult]] = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys @@ -603,7 +605,7 @@ async def get_key(key_server: TrustedKeyServer) -> Dict: ).addErrback(unwrapFirstError) ) - union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {} for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) @@ -656,8 +658,8 @@ async def get_server_verify_key_v2_indirect( except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] - added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] + keys: Dict[str, Dict[str, FetchKeyResult]] = {} + added_keys: List[Tuple[str, str, FetchKeyResult]] = [] time_now_ms = self.clock.time_msec() @@ -805,7 +807,7 @@ async def get_server_verify_key_v2_direct( Raises: KeyLookupError if there was a problem making the lookup """ - keys = {} # type: Dict[str, FetchKeyResult] + keys: Dict[str, FetchKeyResult] = {} for requested_key_id in key_ids: # we may have found this key as a side-effect of asking for another. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 89bcf8151589..137dff251370 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -48,6 +48,9 @@ def check( room_version_obj: the version of the room event: the event being checked. auth_events: the existing room state. + do_sig_check: True if it should be verified that the sending server + signed the event. + do_size_check: True if the size of the event fields should be verified. Raises: AuthError if the checks fail @@ -528,7 +531,7 @@ def _check_power_levels( user_level = get_user_power_level(event.user_id, auth_events) # Check other levels: - levels_to_check = [ + levels_to_check: List[Tuple[str, Optional[str]]] = [ ("users_default", None), ("events_default", None), ("state_default", None), @@ -536,7 +539,7 @@ def _check_power_levels( ("redact", None), ("kick", None), ("invite", None), - ] # type: List[Tuple[str, Optional[str]]] + ] old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): @@ -566,12 +569,12 @@ def _check_power_levels( new_loc = new_loc.get(dir, {}) if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) # type: Optional[int] + old_level: Optional[int] = int(old_loc[level_to_check]) else: old_level = None if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) # type: Optional[int] + new_level: Optional[int] = int(new_loc[level_to_check]) else: new_level = None diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 6286ad999a85..0298af4c02d7 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -105,28 +105,28 @@ def __init__(self, internal_metadata_dict: JsonDict): self._dict = dict(internal_metadata_dict) # the stream ordering of this event. None, until it has been persisted. - self.stream_ordering = None # type: Optional[int] + self.stream_ordering: Optional[int] = None # whether this event is an outlier (ie, whether we have the state at that point # in the DAG) self.outlier = False - out_of_band_membership = DictProperty("out_of_band_membership") # type: bool - send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str - recheck_redaction = DictProperty("recheck_redaction") # type: bool - soft_failed = DictProperty("soft_failed") # type: bool - proactively_send = DictProperty("proactively_send") # type: bool - redacted = DictProperty("redacted") # type: bool - txn_id = DictProperty("txn_id") # type: str - token_id = DictProperty("token_id") # type: int - historical = DictProperty("historical") # type: bool + out_of_band_membership: bool = DictProperty("out_of_band_membership") + send_on_behalf_of: str = DictProperty("send_on_behalf_of") + recheck_redaction: bool = DictProperty("recheck_redaction") + soft_failed: bool = DictProperty("soft_failed") + proactively_send: bool = DictProperty("proactively_send") + redacted: bool = DictProperty("redacted") + txn_id: str = DictProperty("txn_id") + token_id: int = DictProperty("token_id") + historical: bool = DictProperty("historical") # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before = DictProperty("before") # type: RoomStreamToken - after = DictProperty("after") # type: RoomStreamToken - order = DictProperty("order") # type: Tuple[int, int] + before: RoomStreamToken = DictProperty("before") + after: RoomStreamToken = DictProperty("after") + order: Tuple[int, int] = DictProperty("order") def get_dict(self) -> JsonDict: return dict(self._dict) @@ -291,6 +291,20 @@ def get_pdu_json(self, time_now=None) -> JsonDict: return pdu_json + def get_templated_pdu_json(self) -> JsonDict: + """ + Return a JSON object suitable for a templated event, as used in the + make_{join,leave,knock} workflow. + """ + # By using _dict directly we don't pull in signatures/unsigned. + template_json = dict(self._dict) + # The hashes (similar to the signature) need to be recalculated by the + # joining/leaving/knocking server after (potentially) modifying the + # event. + template_json.pop("hashes") + + return template_json + def __set__(self, instance, value): raise AttributeError("Unrecognized attribute %s" % (instance,)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 26e39508596e..87e2bb123b6d 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -132,12 +132,12 @@ async def build( format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: # The types of auth/prev events changes between event versions. - auth_events = await self._store.add_event_hashes( - auth_event_ids - ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] - prev_events = await self._store.add_event_hashes( - prev_event_ids - ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + auth_events: Union[ + List[str], List[Tuple[str, Dict[str, str]]] + ] = await self._store.add_event_hashes(auth_event_ids) + prev_events: Union[ + List[str], List[Tuple[str, Dict[str, str]]] + ] = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_event_ids prev_events = prev_event_ids @@ -156,7 +156,7 @@ async def build( # the db) depth = min(depth, MAX_DEPTH) - event_dict = { + event_dict: Dict[str, Any] = { "auth_events": auth_events, "prev_events": prev_events, "type": self.type, @@ -166,7 +166,7 @@ async def build( "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } # type: Dict[str, Any] + } if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index efec16c226a2..57f1d53fa863 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): """Wrapper that loads spam checkers configured using the old configuration, and registers the spam checker hooks they implement. """ - spam_checkers = [] # type: List[Any] + spam_checkers: List[Any] = [] api = hs.get_module_api() for module, config in hs.config.spam_checkers: # Older spam checkers don't accept the `api` argument, so we @@ -239,7 +239,7 @@ async def check_event_for_spam( will be used as the error message returned to the user. """ for callback in self._check_event_for_spam_callbacks: - res = await callback(event) # type: Union[bool, str] + res: Union[bool, str] = await callback(event) if res: return res diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index f7944fd8344f..7a6eb3e51667 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -11,16 +11,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from typing import TYPE_CHECKING, Union - +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Requester, StateMap +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + + +CHECK_EVENT_ALLOWED_CALLBACK = Callable[ + [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] +] +ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] +CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ + [str, str, StateMap[EventBase]], Awaitable[bool] +] +CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ + [str, StateMap[EventBase], str], Awaitable[bool] +] + + +def load_legacy_third_party_event_rules(hs: "HomeServer"): + """Wrapper that loads a third party event rules module configured using the old + configuration, and registers the hooks they implement. + """ + if hs.config.third_party_event_rules is None: + return + + module, config = hs.config.third_party_event_rules + + api = hs.get_module_api() + third_party_rules = module(config=config, module_api=api) + + # The known hooks. If a module implements a method which name appears in this set, + # we'll want to register it. + third_party_event_rules_methods = { + "check_event_allowed", + "on_create_room", + "check_threepid_can_be_invited", + "check_visibility_can_be_modified", + } + + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We return a separate wrapper for these methods because, in order to wrap them + # correctly, we need to await its result. Therefore it doesn't make a lot of + # sense to make it go through the run() wrapper. + if f.__name__ == "check_event_allowed": + + # We need to wrap check_event_allowed because its old form would return either + # a boolean or a dict, but now we want to return the dict separately from the + # boolean. + async def wrap_check_event_allowed( + event: EventBase, + state_events: StateMap[EventBase], + ) -> Tuple[bool, Optional[dict]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(event, state_events) + if isinstance(res, dict): + return True, res + else: + return res, None + + return wrap_check_event_allowed + + if f.__name__ == "on_create_room": + + # We need to wrap on_create_room because its old form would return a boolean + # if the room creation is denied, but now we just want it to raise an + # exception. + async def wrap_on_create_room( + requester: Requester, config: dict, is_requester_admin: bool + ) -> None: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(requester, config, is_requester_admin) + if res is False: + raise SynapseError( + 403, + "Room creation forbidden with these parameters", + ) + + return wrap_on_create_room + + def run(*args, **kwargs): + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None + + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(third_party_rules, hook, None)) + for hook in third_party_event_rules_methods + } + + api.register_third_party_rules_callbacks(**hooks) + class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra @@ -35,36 +143,65 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - module = None - config = None - if hs.config.third_party_event_rules: - module, config = hs.config.third_party_event_rules + self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] + self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] + self._check_threepid_can_be_invited_callbacks: List[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = [] + self._check_visibility_can_be_modified_callbacks: List[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = [] + + def register_third_party_rules_callbacks( + self, + check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, + check_threepid_can_be_invited: Optional[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = None, + check_visibility_can_be_modified: Optional[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = None, + ): + """Register callbacks from modules for each hook.""" + if check_event_allowed is not None: + self._check_event_allowed_callbacks.append(check_event_allowed) + + if on_create_room is not None: + self._on_create_room_callbacks.append(on_create_room) + + if check_threepid_can_be_invited is not None: + self._check_threepid_can_be_invited_callbacks.append( + check_threepid_can_be_invited, + ) - if module is not None: - self.third_party_rules = module( - config=config, - module_api=hs.get_module_api(), + if check_visibility_can_be_modified is not None: + self._check_visibility_can_be_modified_callbacks.append( + check_visibility_can_be_modified, ) async def check_event_allowed( self, event: EventBase, context: EventContext - ) -> Union[bool, dict]: + ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. The module can return: * True: the event is allowed. * False: the event is not allowed, and should be rejected with M_FORBIDDEN. - * a dict: replacement event data. + + If the event is allowed, the module can also return a dictionary to use as a + replacement for the event. Args: event: The event to be checked. context: The context of the event. Returns: - The result from the ThirdPartyRules module, as above + The result from the ThirdPartyRules module, as above. """ - if self.third_party_rules is None: - return True + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_event_allowed_callbacks) == 0: + return True, None prev_state_ids = await context.get_prev_state_ids() @@ -77,29 +214,46 @@ async def check_event_allowed( # the hashes and signatures. event.freeze() - return await self.third_party_rules.check_event_allowed(event, state_events) + for callback in self._check_event_allowed_callbacks: + try: + res, replacement_data = await callback(event, state_events) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue + + # Return if the event shouldn't be allowed or if the module came up with a + # replacement dict for the event. + if res is False: + return res, None + elif isinstance(replacement_data, dict): + return True, replacement_data + + return True, None async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ) -> bool: - """Intercept requests to create room to allow, deny or update the - request config. + ) -> None: + """Intercept requests to create room to maybe deny it (via an exception) or + update the request config. Args: requester config: The creation config from the client. is_requester_admin: If the requester is an admin - - Returns: - Whether room creation is allowed or denied. """ - - if self.third_party_rules is None: - return True - - return await self.third_party_rules.on_create_room( - requester, config, is_requester_admin - ) + for callback in self._on_create_room_callbacks: + try: + await callback(requester, config, is_requester_admin) + except Exception as e: + # Don't silence the errors raised by this callback since we expect it to + # raise an exception to deny the creation of the room; instead make sure + # it's a SynapseError we can send to clients. + if not isinstance(e, SynapseError): + e = SynapseError( + 403, "Room creation forbidden with these parameters" + ) + + raise e async def check_threepid_can_be_invited( self, medium: str, address: str, room_id: str @@ -114,15 +268,20 @@ async def check_threepid_can_be_invited( Returns: True if the 3PID can be invited, False if not. """ - - if self.third_party_rules is None: + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_threepid_can_be_invited_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) + for callback in self._check_threepid_can_be_invited_callbacks: + try: + if await callback(medium, address, state_events) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def check_visibility_can_be_modified( self, room_id: str, new_visibility: str @@ -137,18 +296,20 @@ async def check_visibility_can_be_modified( Returns: True if the room's visibility can be modified, False if not. """ - if self.third_party_rules is None: - return True - - check_func = getattr( - self.third_party_rules, "check_visibility_can_be_modified", None - ) - if not check_func or not callable(check_func): + # Bail out early without hitting the store if we don't have any callback + if len(self._check_visibility_can_be_modified_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await check_func(room_id, state_events, new_visibility) + for callback in self._check_visibility_can_be_modified_callbacks: + try: + if await callback(room_id, state_events, new_visibility) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index ed09c6af1f43..c767d30627a6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -86,7 +86,7 @@ class FederationClient(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] + self.pdu_destination_tried: Dict[str, Dict[str, int]] = {} self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -94,13 +94,13 @@ def __init__(self, hs: "HomeServer"): self.hostname = hs.hostname self.signing_key = hs.signing_key - self._get_pdu_cache = ExpiringCache( + self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache( cache_name="get_pdu_cache", clock=self._clock, max_len=1000, expiry_ms=120 * 1000, reset_expiry_on_get=False, - ) # type: ExpiringCache[str, EventBase] + ) def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -293,10 +293,10 @@ async def get_pdu( transaction_data, ) - pdu_list = [ + pdu_list: List[EventBase] = [ event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] - ] # type: List[EventBase] + ] if pdu_list and pdu_list[0]: pdu = pdu_list[0] diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ac0f2ccfb3ec..29619aeeb86e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -122,12 +122,12 @@ def __init__(self, hs: "HomeServer"): # origins that we are currently processing a transaction from. # a dict from origin to txn id. - self._active_transactions = {} # type: Dict[str, str] + self._active_transactions: Dict[str, str] = {} # We cache results for transaction with the same ID - self._transaction_resp_cache = ResponseCache( + self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "fed_txn_handler", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) self.transaction_actions = TransactionActions(self.store) @@ -135,12 +135,12 @@ def __init__(self, hs: "HomeServer"): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache( - hs.get_clock(), "state_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, Optional[str]]] - self._state_ids_resp_cache = ResponseCache( + self._state_resp_cache: ResponseCache[ + Tuple[str, Optional[str]] + ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) + self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) self._federation_metrics_domains = ( hs.config.federation.federation_metrics_domains @@ -337,7 +337,7 @@ async def _handle_pdus_in_txn( origin_host, _ = parse_server_name(origin) - pdus_by_room = {} # type: Dict[str, List[EventBase]] + pdus_by_room: Dict[str, List[EventBase]] = {} newest_pdu_ts = 0 @@ -516,9 +516,9 @@ async def _on_context_state_request_compute( self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: if event_id: - pdus = await self.handler.get_state_for_pdu( + pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu( room_id, event_id - ) # type: Iterable[EventBase] + ) else: pdus = (await self.state.get_current_state(room_id)).values() @@ -562,8 +562,7 @@ async def on_make_join_request( raise IncompatibleRoomVersionError(room_version=room_version) pdu = await self.handler.on_make_join_request(origin, room_id, user_id) - time_now = self._clock.time_msec() - return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + return {"event": pdu.get_templated_pdu_json(), "room_version": room_version} async def on_invite_request( self, origin: str, content: JsonDict, room_version_id: str @@ -611,8 +610,7 @@ async def on_make_leave_request( room_version = await self.store.get_room_version_id(room_id) - time_now = self._clock.time_msec() - return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + return {"event": pdu.get_templated_pdu_json(), "room_version": room_version} async def on_send_leave_request( self, origin: str, content: JsonDict, room_id: str @@ -659,9 +657,8 @@ async def on_make_knock_request( ) pdu = await self.handler.on_make_knock_request(origin, room_id, user_id) - time_now = self._clock.time_msec() return { - "event": pdu.get_pdu_json(time_now), + "event": pdu.get_templated_pdu_json(), "room_version": room_version.identifier, } @@ -791,7 +788,7 @@ async def on_claim_client_keys( log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) results = await self.store.claim_e2e_one_time_keys(query) - json_result = {} # type: Dict[str, Dict[str, dict]] + json_result: Dict[str, Dict[str, dict]] = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): @@ -1119,17 +1116,13 @@ def __init__(self, hs: "HomeServer"): self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) - self.edu_handlers = ( - {} - ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] - self.query_handlers = ( - {} - ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] + self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} # Map from type to instance names that we should route EDU handling to. # We randomly choose one instance from the list to route to for each new # EDU received. - self._edu_type_to_instance = {} # type: Dict[str, List[str]] + self._edu_type_to_instance: Dict[str, List[str]] = {} def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 65d76ea97415..1fbf325fdc9a 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -71,34 +71,32 @@ def __init__(self, hs: "HomeServer"): # We may have multiple federation sender instances, so we need to track # their positions separately. self._sender_instances = hs.config.worker.federation_shard_config.instances - self._sender_positions = {} # type: Dict[str, int] + self._sender_positions: Dict[str, int] = {} # Pending presence map user_id -> UserPresenceState - self.presence_map = {} # type: Dict[str, UserPresenceState] + self.presence_map: Dict[str, UserPresenceState] = {} # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) - self.presence_destinations = ( - SortedDict() - ) # type: SortedDict[int, Tuple[str, Iterable[str]]] + self.presence_destinations: SortedDict[ + int, Tuple[str, Iterable[str]] + ] = SortedDict() # (destination, key) -> EDU - self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] + self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {} # stream position -> (destination, key) - self.keyed_edu_changed = ( - SortedDict() - ) # type: SortedDict[int, Tuple[str, tuple]] + self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict() - self.edus = SortedDict() # type: SortedDict[int, Edu] + self.edus: SortedDict[int, Edu] = SortedDict() # stream ID for the next entry into keyed_edu_changed/edus. self.pos = 1 # map from stream ID to the time that stream entry was generated, so that we # can clear out entries after a while - self.pos_time = SortedDict() # type: SortedDict[int, int] + self.pos_time: SortedDict[int, int] = SortedDict() # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner @@ -291,7 +289,7 @@ async def get_replication_rows( # list of tuple(int, BaseFederationRow), where the first is the position # of the federation stream. - rows = [] # type: List[Tuple[int, BaseFederationRow]] + rows: List[Tuple[int, BaseFederationRow]] = [] # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) @@ -445,11 +443,11 @@ def add_to_buffer(self, buff): buff.edus.setdefault(self.edu.destination, []).append(self.edu) -_rowtypes = ( +_rowtypes: Tuple[Type[BaseFederationRow], ...] = ( PresenceDestinationsRow, KeyedEduRow, EduRow, -) # type: Tuple[Type[BaseFederationRow], ...] +) TypeToRow = {Row.TypeId: Row for Row in _rowtypes} diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index deb40f461096..d980e0d9866a 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,9 +14,12 @@ import abc import logging +from collections import OrderedDict from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple +import attr from prometheus_client import Counter +from typing_extensions import Literal from twisted.internet import defer @@ -33,8 +36,12 @@ event_processing_loop_room_count, events_processed_counter, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.util import Clock from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -137,6 +144,84 @@ async def get_replication_rows( raise NotImplementedError() +@attr.s +class _PresenceQueue: + """A queue of destinations that need to be woken up due to new presence + updates. + + Staggers waking up of per destination queues to ensure that we don't attempt + to start TLS connections with many hosts all at once, leading to pinned CPU. + """ + + # The maximum duration in seconds between queuing up a destination and it + # being woken up. + _MAX_TIME_IN_QUEUE = 30.0 + + # The maximum duration in seconds between waking up consecutive destination + # queues. + _MAX_DELAY = 0.1 + + sender: "FederationSender" = attr.ib() + clock: Clock = attr.ib() + queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict) + processing: bool = attr.ib(default=False) + + def add_to_queue(self, destination: str) -> None: + """Add a destination to the queue to be woken up.""" + + self.queue[destination] = None + + if not self.processing: + self._handle() + + @wrap_as_background_process("_PresenceQueue.handle") + async def _handle(self) -> None: + """Background process to drain the queue.""" + + if not self.queue: + return + + assert not self.processing + self.processing = True + + try: + # We start with a delay that should drain the queue quickly enough that + # we process all destinations in the queue in _MAX_TIME_IN_QUEUE + # seconds. + # + # We also add an upper bound to the delay, to gracefully handle the + # case where the queue only has a few entries in it. + current_sleep_seconds = min( + self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue) + ) + + while self.queue: + destination, _ = self.queue.popitem(last=False) + + queue = self.sender._get_per_destination_queue(destination) + + if not queue._new_data_to_send: + # The per destination queue has already been woken up. + continue + + queue.attempt_new_transaction() + + await self.clock.sleep(current_sleep_seconds) + + if not self.queue: + break + + # More destinations may have been added to the queue, so we may + # need to reduce the delay to ensure everything gets processed + # within _MAX_TIME_IN_QUEUE seconds. + current_sleep_seconds = min( + current_sleep_seconds, self._MAX_TIME_IN_QUEUE / len(self.queue) + ) + + finally: + self.processing = False + + class FederationSender(AbstractFederationSender): def __init__(self, hs: "HomeServer"): self.hs = hs @@ -148,14 +233,14 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id - self._presence_router = None # type: Optional[PresenceRouter] + self._presence_router: Optional["PresenceRouter"] = None self._transaction_manager = TransactionManager(hs) self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] + self._per_destination_queues: Dict[str, PerDestinationQueue] = {} LaterGauge( "synapse_federation_transaction_queue_pending_destinations", @@ -192,9 +277,7 @@ def __init__(self, hs: "HomeServer"): # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room = ( - {} - ) # type: Dict[str, Set[PerDestinationQueue]] + self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {} self._rr_txn_interval_per_room_ms = ( 1000.0 / hs.config.federation_rr_transactions_per_room_per_second @@ -210,6 +293,8 @@ def __init__(self, hs: "HomeServer"): self._external_cache = hs.get_external_cache() + self._presence_queue = _PresenceQueue(self, self.clock) + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -265,7 +350,7 @@ async def handle_event(event: EventBase) -> None: if not event.internal_metadata.should_proactively_send(): return - destinations = None # type: Optional[Set[str]] + destinations: Optional[Set[str]] = None if not event.prev_event_ids(): # If there are no prev event IDs then the state is empty # and so no remote servers in the room @@ -331,7 +416,7 @@ async def handle_room_events(events: Iterable[EventBase]) -> None: for event in events: await handle_event(event) - events_by_room = {} # type: Dict[str, List[EventBase]] + events_by_room: Dict[str, List[EventBase]] = {} for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -519,7 +604,12 @@ def send_presence_to_destinations( self._instance_name, destination ): continue - self._get_per_destination_queue(destination).send_presence(states) + + self._get_per_destination_queue(destination).send_presence( + states, start_loop=False + ) + + self._presence_queue.add_to_queue(destination) def build_and_send_edu( self, @@ -628,7 +718,7 @@ async def _wake_destinations_needing_catchup(self) -> None: In order to reduce load spikes, adds a delay between each destination. """ - last_processed = None # type: Optional[str] + last_processed: Optional[str] = None while True: destinations_to_wake = ( diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 3a2efd56eea9..c11d1f6d3152 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -105,34 +105,34 @@ def __init__( # catch-up at startup. # New events will only be sent once this is finished, at which point # _catching_up is flipped to False. - self._catching_up = True # type: bool + self._catching_up: bool = True # The stream_ordering of the most recent PDU that was discarded due to # being in catch-up mode. - self._catchup_last_skipped = 0 # type: int + self._catchup_last_skipped: int = 0 # Cache of the last successfully-transmitted stream ordering for this # destination (we are the only updater so this is safe) - self._last_successful_stream_ordering = None # type: Optional[int] + self._last_successful_stream_ordering: Optional[int] = None # a queue of pending PDUs - self._pending_pdus = [] # type: List[EventBase] + self._pending_pdus: List[EventBase] = [] # XXX this is never actually used: see # https://github.com/matrix-org/synapse/issues/7549 - self._pending_edus = [] # type: List[Edu] + self._pending_edus: List[Edu] = [] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] + self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {} # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: Dict[str, UserPresenceState] + self._pending_presence: Dict[str, UserPresenceState] = {} # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] + self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {} self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -171,14 +171,24 @@ def send_pdu(self, pdu: EventBase) -> None: self.attempt_new_transaction() - def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if necessary. + def send_presence( + self, states: Iterable[UserPresenceState], start_loop: bool = True + ) -> None: + """Add presence updates to the queue. + + Args: + states: Presence updates to send + start_loop: Whether to start the transmission loop if not already + running. Args: states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) - self.attempt_new_transaction() + self._new_data_to_send = True + + if start_loop: + self.attempt_new_transaction() def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet @@ -243,7 +253,7 @@ def attempt_new_transaction(self) -> None: ) async def _transaction_transmission_loop(self) -> None: - pending_pdus = [] # type: List[EventBase] + pending_pdus: List[EventBase] = [] try: self.transmission_loop_running = True diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index c9e7c5746171..98b1bf77fdea 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -395,9 +395,9 @@ async def get_public_rooms( # this uses MSC2197 (Search Filtering over Federation) path = _create_v1_path("/publicRooms") - data = { + data: Dict[str, Any] = { "include_all_networks": "true" if include_all_networks else "false" - } # type: Dict[str, Any] + } if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -423,9 +423,9 @@ async def get_public_rooms( else: path = _create_v1_path("/publicRooms") - args = { + args: Dict[str, Any] = { "include_all_networks": "true" if include_all_networks else "false" - } # type: Dict[str, Any] + } if third_party_instance_id: args["third_party_instance_id"] = (third_party_instance_id,) if limit: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index d37d9565fc00..2974d4d0cc1e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1013,7 +1013,7 @@ async def on_POST( if not self.allow_access: raise FederationDeniedError(origin) - limit = int(content.get("limit", 100)) # type: Optional[int] + limit: Optional[int] = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -1095,7 +1095,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1110,7 +1112,9 @@ async def on_POST( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1131,7 +1135,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1152,7 +1158,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1174,7 +1182,9 @@ async def on_POST( group_id: str, room_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1192,7 +1202,9 @@ async def on_DELETE( group_id: str, room_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1220,7 +1232,9 @@ async def on_POST( room_id: str, config_key: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1243,7 +1257,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1264,7 +1280,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1288,7 +1306,9 @@ async def on_POST( group_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1354,7 +1374,9 @@ async def on_POST( group_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1487,7 +1509,9 @@ async def on_POST( category_id: str, room_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1523,7 +1547,9 @@ async def on_DELETE( category_id: str, room_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1549,7 +1575,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1571,7 +1599,9 @@ async def on_GET( group_id: str, category_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1589,7 +1619,9 @@ async def on_POST( group_id: str, category_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1618,7 +1650,9 @@ async def on_DELETE( group_id: str, category_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1644,7 +1678,9 @@ async def on_GET( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1666,7 +1702,9 @@ async def on_GET( group_id: str, role_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1682,7 +1720,9 @@ async def on_POST( group_id: str, role_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1713,7 +1753,9 @@ async def on_DELETE( group_id: str, role_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1750,7 +1792,9 @@ async def on_POST( role_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1784,7 +1828,9 @@ async def on_DELETE( role_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1825,7 +1871,9 @@ async def on_PUT( query: Dict[bytes, List[bytes]], group_id: str, ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args(query, "requester_user_id") + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1943,7 +1991,7 @@ async def on_GET( return 200, complexity -FEDERATION_SERVLET_CLASSES = ( +FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, FederationStateV1Servlet, @@ -1971,15 +2019,13 @@ async def on_GET( FederationSpaceSummaryServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) -OPENID_SERVLET_CLASSES = ( - OpenIdUserInfo, -) # type: Tuple[Type[BaseFederationServlet], ...] +OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,) -ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] +ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,) -GROUP_SERVER_SERVLET_CLASSES = ( +GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationGroupsProfileServlet, FederationGroupsSummaryServlet, FederationGroupsRoomsServlet, @@ -1998,19 +2044,19 @@ async def on_GET( FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsConfigServlet, FederationGroupsSettingJoinPolicyServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) -GROUP_LOCAL_SERVLET_CLASSES = ( +GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationGroupsLocalInviteServlet, FederationGroupsRemoveLocalUserServlet, FederationGroupsBulkPublicisedServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) -GROUP_ATTESTATION_SERVLET_CLASSES = ( +GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationGroupsRenewAttestaionServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) DEFAULT_SERVLET_GROUPS = ( diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index a06d060ebf4a..3dc55ab861a5 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -707,9 +707,9 @@ async def _add_user( See accept_invite, join_group. """ if not self.hs.is_mine_id(user_id): - local_attestation = self.attestations.create_attestation( - group_id, user_id - ) # type: Optional[JsonDict] + local_attestation: Optional[ + JsonDict + ] = self.attestations.create_attestation(group_id, user_id) remote_attestation = content["attestation"] @@ -868,9 +868,9 @@ async def create_group( remote_attestation, user_id=requester_user_id, group_id=group_id ) - local_attestation = self.attestations.create_attestation( - group_id, requester_user_id - ) # type: Optional[JsonDict] + local_attestation: Optional[ + JsonDict + ] = self.attestations.create_attestation(group_id, requester_user_id) else: local_attestation = None remote_attestation = None diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d800e1691280..6a05a6530599 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -15,8 +15,6 @@ import logging from typing import TYPE_CHECKING, Optional -import synapse.state -import synapse.storage import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter @@ -38,10 +36,10 @@ class BaseHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() # type: synapse.storage.DataStore + self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler + self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs @@ -55,12 +53,12 @@ def __init__(self, hs: "HomeServer"): # Check whether ratelimiting room admin message redaction is enabled # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: - self.admin_redaction_ratelimiter = Ratelimiter( + self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, - ) # type: Optional[Ratelimiter] + ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d752cf34f0cc..078accd634f1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -15,9 +15,11 @@ import email.mime.multipart import email.utils import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from synapse.api.errors import StoreError, SynapseError +from twisted.web.http import Request + +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils @@ -27,6 +29,15 @@ logger = logging.getLogger(__name__) +# Types for callbacks to be registered via the module api +IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] +ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] +# Temporary hooks to allow for a transition from `/_matrix/client` endpoints +# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`. +ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] +ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]] +ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable] + class AccountValidityHandler: def __init__(self, hs: "HomeServer"): @@ -70,6 +81,99 @@ def __init__(self, hs: "HomeServer"): if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] + self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] + self._on_legacy_send_mail_callback: Optional[ + ON_LEGACY_SEND_MAIL_CALLBACK + ] = None + self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None + + # The legacy admin requests callback isn't a protected attribute because we need + # to access it from the admin servlet, which is outside of this handler. + self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None + + def register_account_validity_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ): + """Register callbacks from module for each hook.""" + if is_user_expired is not None: + self._is_user_expired_callbacks.append(is_user_expired) + + if on_user_registration is not None: + self._on_user_registration_callbacks.append(on_user_registration) + + # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and + # an admin one). As part of moving the feature into a module, we need to change + # the path from /_matrix/client/unstable/account_validity/... to + # /_synapse/client/account_validity, because: + # + # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix + # * the way we register servlets means that modules can't register resources + # under /_matrix/client + # + # We need to allow for a transition period between the old and new endpoints + # in order to allow for clients to update (and for emails to be processed). + # + # Once the email-account-validity module is loaded, it will take control of account + # validity by moving the rows from our `account_validity` table into its own table. + # + # Therefore, we need to allow modules (in practice just the one implementing the + # email-based account validity) to temporarily hook into the legacy endpoints so we + # can route the traffic coming into the old endpoints into the module, which is + # why we have the following three temporary hooks. + if on_legacy_send_mail is not None: + if self._on_legacy_send_mail_callback is not None: + raise RuntimeError("Tried to register on_legacy_send_mail twice") + + self._on_legacy_send_mail_callback = on_legacy_send_mail + + if on_legacy_renew is not None: + if self._on_legacy_renew_callback is not None: + raise RuntimeError("Tried to register on_legacy_renew twice") + + self._on_legacy_renew_callback = on_legacy_renew + + if on_legacy_admin_request is not None: + if self.on_legacy_admin_request_callback is not None: + raise RuntimeError("Tried to register on_legacy_admin_request twice") + + self.on_legacy_admin_request_callback = on_legacy_admin_request + + async def is_user_expired(self, user_id: str) -> bool: + """Checks if a user has expired against third-party modules. + + Args: + user_id: The user to check the expiry of. + + Returns: + Whether the user has expired. + """ + for callback in self._is_user_expired_callbacks: + expired = await callback(user_id) + if expired is not None: + return expired + + if self._account_validity_enabled: + # If no module could determine whether the user has expired and the legacy + # configuration is enabled, fall back to it. + return await self.store.is_account_expired(user_id, self.clock.time_msec()) + + return False + + async def on_user_registration(self, user_id: str): + """Tell third-party modules about a user's registration. + + Args: + user_id: The ID of the newly registered user. + """ + for callback in self._on_user_registration_callbacks: + await callback(user_id) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time @@ -95,6 +199,17 @@ async def send_renewal_email_to_user(self, user_id: str) -> None: Raises: SynapseError if the user is not set to renew. """ + # If a module supports sending a renewal email from here, do that, otherwise do + # the legacy dance. + if self._on_legacy_send_mail_callback is not None: + await self._on_legacy_send_mail_callback(user_id) + return + + if not self._account_validity_renew_by_email_enabled: + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) # If this user isn't set to be expired, raise an error. @@ -209,6 +324,10 @@ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: token is considered stale. A token is stale if the 'token_used_ts_ms' db column is non-null. + This method exists to support handling the legacy account validity /renew + endpoint. If a module implements the on_legacy_renew callback, then this process + is delegated to the module instead. + Args: renewal_token: Token sent with the renewal request. Returns: @@ -218,6 +337,11 @@ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: * An int representing the user's expiry timestamp as milliseconds since the epoch, or 0 if the token was invalid. """ + # If a module supports triggering a renew from here, do that, otherwise do the + # legacy dance. + if self._on_legacy_renew_callback is not None: + return await self._on_legacy_renew_callback(renewal_token) + try: ( user_id, diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d75a8b15c347..bfa7f2c545c6 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -139,7 +139,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> to_key = RoomStreamToken(None, stream_ordering) # Events that we've processed in this room - written_events = set() # type: Set[str] + written_events: Set[str] = set() # We need to track gaps in the events stream so that we can then # write out the state at those events. We do this by keeping track @@ -152,7 +152,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> # The reverse mapping to above, i.e. map from unseen event to events # that have the unseen event in their prev_events, i.e. the unseen # events "children". - unseen_to_child_events = {} # type: Dict[str, Set[str]] + unseen_to_child_events: Dict[str, Set[str]] = {} # We fetch events in the room the user could see by fetching *all* # events that we have and then filtering, this isn't the most diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 862638cc4f15..21a17cd2e834 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -96,7 +96,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken): self.current_max, limit ) - events_by_room = {} # type: Dict[str, List[EventBase]] + events_by_room: Dict[str, List[EventBase]] = {} for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -275,7 +275,7 @@ async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] ) -> List[JsonDict]: - events = [] # type: List[JsonDict] + events: List[JsonDict] = [] presence_source = self.event_sources.sources["presence"] from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" @@ -375,7 +375,7 @@ async def get_3pe_protocols( self, only_protocol: Optional[str] = None ) -> Dict[str, JsonDict]: services = self.store.get_app_services() - protocols = {} # type: Dict[str, List[JsonDict]] + protocols: Dict[str, List[JsonDict]] = {} # Collect up all the individual protocol responses out of the ASes for s in services: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e2ac595a6214..22a855224188 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -191,7 +191,7 @@ class AuthHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] + self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): @@ -296,7 +296,7 @@ def __init__(self, hs: "HomeServer"): # A mapping of user ID to extra attributes to include in the login # response. - self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {} async def validate_user_via_ui_auth( self, @@ -500,7 +500,7 @@ async def check_ui_auth( all the stages in any of the permitted flows. """ - sid = None # type: Optional[str] + sid: Optional[str] = None authdict = clientdict.pop("auth", {}) if "session" in authdict: sid = authdict["session"] @@ -588,9 +588,9 @@ async def check_ui_auth( ) # check auth type currently being presented - errordict = {} # type: Dict[str, Any] + errordict: Dict[str, Any] = {} if "type" in authdict: - login_type = authdict["type"] # type: str + login_type: str = authdict["type"] try: result = await self._check_auth_dict(authdict, clientip) if result: @@ -766,7 +766,7 @@ def _auth_dict_for_flows( LoginType.TERMS: self._get_params_terms, } - params = {} # type: Dict[str, Any] + params: Dict[str, Any] = {} for f in public_flows: for stage in f: @@ -1530,9 +1530,9 @@ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> s except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - user_id_to_verify = await self.get_session_data( + user_id_to_verify: str = await self.get_session_data( session_id, UIAuthSessionDataConstants.REQUEST_USER_ID - ) # type: str + ) idps = await self.hs.get_sso_handler().get_identity_providers_for_user( user_id_to_verify diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 7346ccfe9396..0325f86e20a1 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -40,7 +40,7 @@ def __init__(self, error, error_description=None): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error @@ -171,7 +171,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} # type: Dict[str, List[Optional[str]]] + attributes: Dict[str, List[Optional[str]]] = {} for child in root[0]: if child.tag.endswith("user"): user = child.text diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 95bdc5902a2b..46ee83440741 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -452,7 +452,7 @@ async def notify_device_update( user_id ) - hosts = set() # type: Set[str] + hosts: Set[str] = set() if self.hs.is_mine_id(user_id): hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) @@ -613,20 +613,20 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. - self._pending_updates = ( - {} - ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]] + self._pending_updates: Dict[ + str, List[Tuple[str, str, Iterable[str], JsonDict]] + ] = {} # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious # resyncs. - self._seen_updates = ExpiringCache( + self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache( cache_name="device_update_edu", clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, - ) # type: ExpiringCache[str, Set[str]] + ) # Attempt to resync out of sync device lists every 30s. self._resync_retry_in_progress = False @@ -755,7 +755,7 @@ async def _need_to_do_resync( """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ - seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str] + seen_updates: Set[str] = self._seen_updates.get(user_id, set()) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 580b941595cf..679b47f0815f 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -203,7 +203,7 @@ async def send_device_message( log_kv({"number_of_to_device_messages": len(messages)}) set_tag("sender", sender_user_id) local_messages = {} - remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): # Ratelimit local cross-user key requests by the sending device. if ( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4064a2b85913..d487fee627f2 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -22,6 +22,7 @@ CodeMessageException, Codes, NotFoundError, + RequestSendFailed, ShadowBanError, StoreError, SynapseError, @@ -236,9 +237,9 @@ async def _delete_association(self, room_alias: RoomAlias) -> str: async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result = await self.get_association_from_room_alias( - room_alias - ) # type: Optional[RoomAliasMapping] + result: Optional[ + RoomAliasMapping + ] = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -252,12 +253,14 @@ async def get_association(self, room_alias: RoomAlias) -> JsonDict: retry_on_dns_fail=False, ignore_backoff=True, ) + except RequestSendFailed: + raise SynapseError(502, "Failed to fetch alias") except CodeMessageException as e: logging.warning("Error retrieving alias") if e.code == 404: fed_result = None else: - raise + raise SynapseError(502, "Failed to fetch alias") if fed_result and "room_id" in fed_result and "servers" in fed_result: room_id = fed_result["room_id"] diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 3972849d4d16..d92370859fb6 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -115,9 +115,9 @@ async def query_devices( the number of in-flight queries at a time. """ with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query = query_body.get( + device_keys_query: Dict[str, Iterable[str]] = query_body.get( "device_keys", {} - ) # type: Dict[str, Iterable[str]] + ) # separate users by domain. # make a map from domain to user_id to device_ids @@ -136,7 +136,7 @@ async def query_devices( # First get local devices. # A map of destination -> failure response. - failures = {} # type: Dict[str, JsonDict] + failures: Dict[str, JsonDict] = {} results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -151,11 +151,9 @@ async def query_devices( # Now attempt to get any remote devices from our local cache. # A map of destination -> user ID -> device IDs. - remote_queries_not_in_cache = ( - {} - ) # type: Dict[str, Dict[str, Iterable[str]]] + remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {} if remote_queries: - query_list = [] # type: List[Tuple[str, Optional[str]]] + query_list: List[Tuple[str, Optional[str]]] = [] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend( @@ -362,9 +360,9 @@ async def query_local_devices( A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] # type: List[Tuple[str, Optional[str]]] + local_query: List[Tuple[str, Optional[str]]] = [] - result_dict = {} # type: Dict[str, Dict[str, dict]] + result_dict: Dict[str, Dict[str, dict]] = {} for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -402,9 +400,9 @@ async def on_federation_query_client_keys( self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: """Handle a device key query from a federated server""" - device_keys_query = query_body.get( + device_keys_query: Dict[str, Optional[List[str]]] = query_body.get( "device_keys", {} - ) # type: Dict[str, Optional[List[str]]] + ) res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -421,8 +419,8 @@ async def on_federation_query_client_keys( async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int ) -> JsonDict: - local_query = [] # type: List[Tuple[str, str, str]] - remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] + local_query: List[Tuple[str, str, str]] = [] + remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids @@ -439,8 +437,8 @@ async def claim_one_time_keys( results = await self.store.claim_e2e_one_time_keys(local_query) # A map of user ID -> device ID -> key ID -> key. - json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] - failures = {} # type: Dict[str, JsonDict] + json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + failures: Dict[str, JsonDict] = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): @@ -768,8 +766,8 @@ async def _process_self_signatures( Raises: SynapseError: if the input is malformed """ - signature_list = [] # type: List[SignatureListItem] - failures = {} # type: Dict[str, Dict[str, JsonDict]] + signature_list: List["SignatureListItem"] = [] + failures: Dict[str, Dict[str, JsonDict]] = {} if not signatures: return signature_list, failures @@ -930,8 +928,8 @@ async def _process_other_signatures( Raises: SynapseError: if the input is malformed """ - signature_list = [] # type: List[SignatureListItem] - failures = {} # type: Dict[str, Dict[str, JsonDict]] + signature_list: List["SignatureListItem"] = [] + failures: Dict[str, Dict[str, JsonDict]] = {} if not signatures: return signature_list, failures @@ -1300,7 +1298,7 @@ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] + self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {} async def incoming_signing_key_update( self, origin: str, edu_content: JsonDict @@ -1349,7 +1347,7 @@ async def _handle_signing_key_updates(self, user_id: str) -> None: # This can happen since we batch updates return - device_ids = [] # type: List[str] + device_ids: List[str] = [] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f134f1e234b6..4b3f037072d9 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -93,7 +93,7 @@ async def get_stream( # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add = [] # type: List[JsonDict] + to_add: List[JsonDict] = [] for event in events: if not isinstance(event, EventBase): continue @@ -103,9 +103,9 @@ async def get_stream( # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = await self.store.get_users_in_room( + users: Iterable[str] = await self.store.get_users_in_room( event.room_id - ) # type: Iterable[str] + ) else: users = [event.state_key] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 991ec9919a95..57287199091d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -181,7 +181,7 @@ def __init__(self, hs: "HomeServer"): # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. - self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] + self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") self._room_backfill = Linearizer("room_backfill") @@ -368,7 +368,7 @@ async def on_receive_pdu( ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list(ours.values()) # type: List[StateMap[str]] + state_maps: List[StateMap[str]] = list(ours.values()) # we don't need this any more, let's delete it. del ours @@ -735,7 +735,7 @@ async def _get_state_after_missing_prev_event( # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - (await self.store.get_events(missing_desired_events, allow_rejected=True)) + await self.store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. @@ -845,7 +845,7 @@ async def _process_received_pdu( # exact key to expect. Otherwise check it matches any key we # have for that device. - current_keys = [] # type: Container[str] + current_keys: Container[str] = [] if device: keys = device.get("keys", {}).get("keys", {}) @@ -1185,7 +1185,7 @@ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: if e_type == EventTypes.Member and event.membership == Membership.JOIN ] - joined_domains = {} # type: Dict[str, int] + joined_domains: Dict[str, int] = {} for u, d in joined_users: try: dom = get_domain_from_id(u) @@ -1314,7 +1314,7 @@ async def _get_events_and_persist( room_version = await self.store.get_room_version(room_id) - event_map = {} # type: Dict[str, EventBase] + event_map: Dict[str, EventBase] = {} async def get_event(event_id: str): with nested_logging_context(event_id): @@ -1414,12 +1414,15 @@ async def send_invite(self, target_host: str, event: EventBase) -> EventBase: Invites must be signed by the invitee's server before distribution. """ - pdu = await self.federation_client.send_invite( - destination=target_host, - room_id=event.room_id, - event_id=event.event_id, - pdu=event, - ) + try: + pdu = await self.federation_client.send_invite( + destination=target_host, + room_id=event.room_id, + event_id=event.event_id, + pdu=event, + ) + except RequestSendFailed: + raise SynapseError(502, f"Can't connect to server {target_host}") return pdu @@ -1593,7 +1596,7 @@ async def do_knock( # Ask the remote server to create a valid knock event for us. Once received, # we sign the event - params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] + params: Dict[str, Iterable[str]] = {"ver": supported_room_versions} origin, event, event_format_version = await self._make_and_verify_event( target_hosts, room_id, knockee, Membership.KNOCK, content, params=params ) @@ -1931,7 +1934,7 @@ async def on_make_knock_request( builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2023,7 +2026,7 @@ async def on_send_membership_event( # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2450,14 +2453,14 @@ async def _check_for_soft_fail( state_sets_d = await self.state_store.get_state_groups( event.room_id, extrem_ids ) - state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] + state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets.append(state) current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = { + current_state_ids: StateMap[str] = { k: e.event_id for k, e in current_states.items() - } # type: StateMap[str] + } else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2814,7 +2817,7 @@ async def _update_context_for_auth_events( """ # exclude the state key of the new event from the current_state in the context. if event.is_state(): - event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] + event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) else: event_key = None state_updates = { @@ -3031,9 +3034,13 @@ async def exchange_third_party_invite( await member_handler.send_membership_event(None, event, context) else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} - await self.federation_client.forward_third_party_invite( - destinations, room_id, event_dict - ) + + try: + await self.federation_client.forward_third_party_invite( + destinations, room_id, event_dict + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to forward third party invite") async def on_exchange_third_party_invite_request( self, event_dict: JsonDict @@ -3149,7 +3156,7 @@ async def _check_signature(self, event: EventBase, context: EventContext) -> Non logger.debug("Checking auth on event %r", event.content) - last_exception = None # type: Optional[Exception] + last_exception: Optional[Exception] = None # for each public key in the 3pid invite event for public_key_object in event_auth.get_public_keys(invite_event): diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 157f2ff2189b..1a6c5c64a278 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -214,7 +214,7 @@ async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: async def bulk_get_publicised_groups( self, user_ids: Iterable[str], proxy: bool = True ) -> JsonDict: - destinations = {} # type: Dict[str, Set[str]] + destinations: Dict[str, Set[str]] = {} local_users = set() for user_id in user_ids: @@ -227,7 +227,7 @@ async def bulk_get_publicised_groups( raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] # type: List[str] + failed_results: List[str] = [] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 33d16fbf9c36..0961dec5ab5c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -302,7 +302,7 @@ async def try_unbind_threepid_with_id_server( ) url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") + url_bytes = b"/_matrix/identity/api/v1/3pid/unbind" content = { "mxid": mxid, @@ -695,7 +695,7 @@ async def _lookup_3pid_v1( return data["mxid"] except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - except IOError as e: + except OSError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) return None diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 76242865ae28..5d4964076084 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,9 +46,17 @@ def __init__(self, hs: "HomeServer"): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = ResponseCache( - hs.get_clock(), "initial_sync_cache" - ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] + self.snapshot_cache: ResponseCache[ + Tuple[ + str, + Optional[StreamToken], + Optional[StreamToken], + str, + Optional[int], + bool, + bool, + ] + ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 66e40a915d04..8a0024ce8485 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,7 +81,7 @@ def __init__(self, hs: "HomeServer"): # The scheduled call to self._expire_event. None if no call is currently # scheduled. - self._scheduled_expiry = None # type: Optional[IDelayedCall] + self._scheduled_expiry: Optional[IDelayedCall] = None if not hs.config.worker_app: run_as_background_process( @@ -196,9 +196,7 @@ async def get_state_events( room_state_events = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) - room_state = room_state_events[ - event.event_id - ] # type: Mapping[Any, EventBase] + room_state: Mapping[Any, EventBase] = room_state_events[event.event_id] else: raise AuthError( 403, @@ -421,9 +419,9 @@ def __init__(self, hs: "HomeServer"): self.action_generator = hs.get_action_generator() self.spam_checker = hs.get_spam_checker() - self.third_party_event_rules = ( + self.third_party_event_rules: "ThirdPartyEventRules" = ( self.hs.get_third_party_event_rules() - ) # type: ThirdPartyEventRules + ) self._block_events_without_consent_error = ( self.config.block_events_without_consent_error @@ -440,7 +438,7 @@ def __init__(self, hs: "HomeServer"): # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {} # The number of forward extremeities before a dummy event is sent. self._dummy_events_threshold = hs.config.dummy_events_threshold @@ -465,9 +463,7 @@ def __init__(self, hs: "HomeServer"): # Stores the state groups we've recently added to the joined hosts # external cache. Note that the timeout must be significantly less than # the TTL on the external cache. - self._external_cache_joined_hosts_updates = ( - None - ) # type: Optional[ExpiringCache] + self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None if self._external_cache.is_enabled(): self._external_cache_joined_hosts_updates = ExpiringCache( "_external_cache_joined_hosts_updates", @@ -518,6 +514,9 @@ async def create_event( outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -772,6 +771,7 @@ async def create_and_send_nonmember_event( txn_id: Optional[str] = None, ignore_shadow_ban: bool = False, outlier: bool = False, + historical: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, int]: """ @@ -799,6 +799,9 @@ async def create_and_send_nonmember_event( outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -847,6 +850,7 @@ async def create_and_send_nonmember_event( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, outlier=outlier, + historical=historical, depth=depth, ) @@ -945,10 +949,10 @@ async def create_new_client_event( if requester: context.app_service = requester.app_service - third_party_result = await self.third_party_event_rules.check_event_allowed( + res, new_content = await self.third_party_event_rules.check_event_allowed( event, context ) - if not third_party_result: + if res is False: logger.info( "Event %s forbidden by third-party rules", event, @@ -956,11 +960,11 @@ async def create_new_client_event( raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - elif isinstance(third_party_result, dict): + elif new_content is not None: # the third-party rules want to replace the event. We'll need to build a new # event. event, context = await self._rebuild_event_after_third_party_rules( - third_party_result, event + new_content, event ) self.validator.validate_new(event, self.config) @@ -1291,7 +1295,7 @@ async def persist_and_notify_client_event( # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = [] # type: List[str] + original_alt_aliases: List[str] = [] original_event_id = event.unsigned.get("replaces_state") if original_event_id: @@ -1594,11 +1598,13 @@ async def _rebuild_event_after_third_party_rules( for k, v in original_event.internal_metadata.get_dict().items(): setattr(builder.internal_metadata, k, v) - # the event type hasn't changed, so there's no point in re-calculating the - # auth events. + # modules can send new state events, so we re-calculate the auth events just in + # case. + prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + event = await builder.build( - prev_event_ids=original_event.prev_event_ids(), - auth_event_ids=original_event.auth_event_ids(), + prev_event_ids=prev_event_ids, + auth_event_ids=None, ) # we rebuild the event context, to be on the safe side. If nothing else, diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index ee6e41c0e4d9..eca8f1604001 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -72,26 +72,26 @@ (b"oidc_session_no_samesite", b"HttpOnly"), ] + #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. -Token = TypedDict( - "Token", - { - "access_token": str, - "token_type": str, - "id_token": Optional[str], - "refresh_token": Optional[str], - "expires_in": int, - "scope": Optional[str], - }, -) +class Token(TypedDict): + access_token: str + token_type: str + id_token: Optional[str] + refresh_token: Optional[str] + expires_in: int + scope: Optional[str] + #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but #: there is no real point of doing this in our case. JWK = Dict[str, str] + #: A JWK Set, as per RFC7517 sec 5. -JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class JWKS(TypedDict): + keys: List[JWK] class OidcHandler: @@ -105,9 +105,9 @@ def __init__(self, hs: "HomeServer"): assert provider_confs self._token_generator = OidcSessionTokenGenerator(hs) - self._providers = { + self._providers: Dict[str, "OidcProvider"] = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } # type: Dict[str, OidcProvider] + } async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. @@ -178,7 +178,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None: # are two. for cookie_name, _ in _SESSION_COOKIES: - session = request.getCookie(cookie_name) # type: Optional[bytes] + session: Optional[bytes] = request.getCookie(cookie_name) if session is not None: break else: @@ -255,7 +255,7 @@ def __init__(self, error, error_description=None): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error @@ -277,7 +277,7 @@ def __init__( self._token_generator = token_generator self._config = provider - self._callback_url = hs.config.oidc_callback_url # type: str + self._callback_url: str = hs.config.oidc_callback_url # Calculate the prefix for OIDC callback paths based on the public_baseurl. # We'll insert this into the Path= parameter of any session cookies we set. @@ -290,7 +290,7 @@ def __init__( self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method - client_secret = None # type: Union[None, str, JwtClientSecret] + client_secret: Optional[Union[str, JwtClientSecret]] = None if provider.client_secret: client_secret = provider.client_secret elif provider.client_secret_jwt_key: @@ -305,7 +305,7 @@ def __init__( provider.client_id, client_secret, provider.client_auth_method, - ) # type: ClientAuth + ) self._client_auth_method = provider.client_auth_method # cache of metadata for the identity provider (endpoint uris, mostly). This is @@ -324,7 +324,7 @@ def __init__( self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() - self._server_name = hs.config.server_name # type: str + self._server_name: str = hs.config.server_name # identifier for the external_ids table self.idp_id = provider.idp_id @@ -639,7 +639,7 @@ async def _exchange_code(self, code: str) -> Token: ) logger.warning(description) # Body was still valid JSON. Might be useful to log it for debugging. - logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + logger.warning("Code exchange response: %r", resp) raise OidcError("server_error", description) return resp @@ -1217,10 +1217,12 @@ class OidcSessionData: ui_auth_session_id = attr.ib(type=str) -UserAttributeDict = TypedDict( - "UserAttributeDict", - {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, -) +class UserAttributeDict(TypedDict): + localpart: Optional[str] + display_name: Optional[str] + emails: List[str] + + C = TypeVar("C") @@ -1381,7 +1383,7 @@ def render_template_field(template: Optional[Template]) -> Optional[str]: if display_name == "": display_name = None - emails = [] # type: List[str] + emails: List[str] = [] email = render_template_field(self._config.email_template) if email: emails.append(email) @@ -1391,7 +1393,7 @@ def render_template_field(template: Optional[Template]) -> Optional[str]: ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: - extras = {} # type: Dict[str, str] + extras: Dict[str, str] = {} for key, template in self._config.extra_attributes.items(): try: extras[key] = template.render(user=userinfo).strip() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1e1186c29e7d..1dbafd253d37 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -81,9 +81,9 @@ def __init__(self, hs: "HomeServer"): self._server_name = hs.hostname self.pagination_lock = ReadWriteLock() - self._purges_in_progress_by_room = set() # type: Set[str] + self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus - self._purges_by_id = {} # type: Dict[str, PurgeStatus] + self._purges_by_id: Dict[str, PurgeStatus] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 44ed7a071231..016c5df2ca75 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -378,14 +378,14 @@ def __init__(self, hs: "HomeServer"): # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. - self._user_to_num_current_syncs = {} # type: Dict[str, int] + self._user_to_num_current_syncs: Dict[str, int] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() # user_id -> last_sync_ms. Lists the users that have stopped syncing but # we haven't notified the presence writer of that yet - self.users_going_offline = {} # type: Dict[str, int] + self.users_going_offline: Dict[str, int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -650,7 +650,7 @@ def __init__(self, hs: "HomeServer"): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() # type: Set[str] + self.unpersisted_users_changes: Set[str] = set() hs.get_reactor().addSystemEventTrigger( "before", @@ -664,7 +664,7 @@ def __init__(self, hs: "HomeServer"): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} # type: Dict[str, int] + self.user_to_num_current_syncs: Dict[str, int] = {} # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -674,8 +674,8 @@ def __init__(self, hs: "HomeServer"): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]] - self.external_process_last_updated_ms = {} # type: Dict[str, int] + self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -1581,9 +1581,7 @@ async def get_new_events( # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users = ( - set() - ) # type: Union[Set[str], FrozenSet[str]] + interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() if from_key: # First get all users that have had a presence update @@ -1950,8 +1948,8 @@ async def get_interested_parties( A 2-tuple of `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] - users_to_states = {} # type: Dict[str, List[UserPresenceState]] + room_ids_to_states: Dict[str, List[UserPresenceState]] = {} + users_to_states: Dict[str, List[UserPresenceState]] = {} for state in states: room_ids = await store.get_rooms_for_user(state.user_id) for room_id in room_ids: @@ -2063,12 +2061,12 @@ def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler): # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]] + self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] self._next_id = 1 # Map from instance name to current token - self._current_tokens = {} # type: Dict[str, int] + self._current_tokens: Dict[str, int] = {} if self._queue_presence_updates: self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) @@ -2168,7 +2166,7 @@ async def get_replication_rows( # handle the case where `from_token` stream ID has already been dropped. start_idx = max(from_token + 1 - self._next_id, -len(self._queue)) - to_send = [] # type: List[Tuple[int, Tuple[str, str]]] + to_send: List[Tuple[int, Tuple[str, str]]] = [] limited = False new_id = upto_token for _, stream_id, destinations, user_ids in self._queue[start_idx:]: @@ -2216,7 +2214,7 @@ async def process_replication_rows( if not self._federation: return - hosts_to_users = {} # type: Dict[str, Set[str]] + hosts_to_users: Dict[str, Set[str]] = {} for row in rows: hosts_to_users.setdefault(row.destination, set()).add(row.user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 05b4a97b590b..20a033d0ba5f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -197,7 +197,7 @@ async def set_displayname( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - displayname_to_set = new_displayname # type: Optional[str] + displayname_to_set: Optional[str] = new_displayname if new_displayname == "": displayname_to_set = None @@ -286,7 +286,7 @@ async def set_avatar_url( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - avatar_url_to_set = new_avatar_url # type: Optional[str] + avatar_url_to_set: Optional[str] = new_avatar_url if new_avatar_url == "": avatar_url_to_set = None diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index f782d9db3205..283483fc2c71 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -30,6 +30,8 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server_name self.store = hs.get_datastore() + self.event_auth_handler = hs.get_event_auth_handler() + self.hs = hs # We only need to poke the federation sender explicitly if its on the @@ -59,6 +61,19 @@ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] for room_id, room_values in content.items(): + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + is_in_room = await self.event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring receipt from %s as we're not in the room", + origin, + ) + continue + for receipt_type, users in room_values.items(): for user_id, user_values in users.items(): if get_domain_from_id(user_id) != origin: @@ -83,8 +98,8 @@ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier.""" - min_batch_id = None # type: Optional[int] - max_batch_id = None # type: Optional[int] + min_batch_id: Optional[int] = None + max_batch_id: Optional[int] = None for receipt in receipts: res = await self.store.insert_receipt( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 26ef0161796a..8cf614136ebb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -55,15 +55,12 @@ ["guest", "auth_provider"], ) -LoginDict = TypedDict( - "LoginDict", - { - "device_id": str, - "access_token": str, - "valid_until_ms": Optional[int], - "refresh_token": Optional[str], - }, -) + +class LoginDict(TypedDict): + device_id: str + access_token: str + valid_until_ms: Optional[int] + refresh_token: Optional[str] class RegistrationHandler(BaseHandler): @@ -77,6 +74,7 @@ def __init__(self, hs: "HomeServer"): self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() + self._account_validity_handler = hs.get_account_validity_handler() self._server_notices_mxid = hs.config.server_notices_mxid self._server_name = hs.hostname @@ -700,6 +698,10 @@ async def register_with_store( shadow_banned=shadow_banned, ) + # Only call the account validity module(s) on the main process, to avoid + # repeating e.g. database writes on all of the workers. + await self._account_validity_handler.on_user_registration(user_id) + async def register_device( self, user_id: str, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 579b1b93c5fa..370561e54935 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -87,7 +87,7 @@ def __init__(self, hs: "HomeServer"): self.config = hs.config # Room state based off defined presets - self._presets_dict = { + self._presets_dict: Dict[str, Dict[str, Any]] = { RoomCreationPreset.PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": HistoryVisibility.SHARED, @@ -109,7 +109,7 @@ def __init__(self, hs: "HomeServer"): "guest_can_join": False, "power_level_content_override": {}, }, - } # type: Dict[str, Dict[str, Any]] + } # Modify presets to selectively enable encryption by default per homeserver config for preset_name, preset_config in self._presets_dict.items(): @@ -127,9 +127,9 @@ def __init__(self, hs: "HomeServer"): # If a user tries to update the same room multiple times in quick # succession, only process the first attempt and return its result to # subsequent requests - self._upgrade_response_cache = ResponseCache( + self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS - ) # type: ResponseCache[Tuple[str, str]] + ) self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -377,10 +377,10 @@ async def clone_existing_room( if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") - creation_content = { + creation_content: JsonDict = { "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, - } # type: JsonDict + } # Check if old room was non-federatable @@ -618,15 +618,11 @@ async def create_room( else: is_requester_admin = await self.auth.is_server_admin(requester.user) - # Check whether the third party rules allows/changes the room create - # request. - event_allowed = await self.third_party_event_rules.on_create_room( + # Let the third party rules modify the room creation config if needed, or abort + # the room creation entirely with an exception. + await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) - if not event_allowed: - raise SynapseError( - 403, "You are not permitted to create rooms", Codes.FORBIDDEN - ) if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id @@ -936,7 +932,7 @@ async def send(etype: str, content: JsonDict, **kwargs) -> int: etype=EventTypes.PowerLevels, content=pl_content ) else: - power_level_content = { + power_level_content: JsonDict = { "users": {creator_id: 100}, "users_default": 0, "events": { @@ -955,7 +951,7 @@ async def send(etype: str, content: JsonDict, **kwargs) -> int: "kick": 50, "redact": 50, "invite": 50, - } # type: JsonDict + } if config["original_invitees_have_ops"]: for invitee in invite_list: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5e3ef7ce3a72..fae2c098e32e 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -20,7 +20,12 @@ from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.api.errors import Codes, HttpResponseException +from synapse.api.errors import ( + Codes, + HttpResponseException, + RequestSendFailed, + SynapseError, +) from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.descriptors import cached from synapse.util.caches.response_cache import ResponseCache @@ -42,12 +47,12 @@ class RoomListHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search - self.response_cache = ResponseCache( - hs.get_clock(), "room_list" - ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] - self.remote_response_cache = ResponseCache( - hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 - ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] + self.response_cache: ResponseCache[ + Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] + ] = ResponseCache(hs.get_clock(), "room_list") + self.remote_response_cache: ResponseCache[ + Tuple[str, Optional[int], Optional[str], bool, Optional[str]] + ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000) async def get_local_public_room_list( self, @@ -134,10 +139,10 @@ async def _get_public_room_list( if since_token: batch_token = RoomListNextBatch.from_token(since_token) - bounds = ( + bounds: Optional[Tuple[int, str]] = ( batch_token.last_joined_members, batch_token.last_room_id, - ) # type: Optional[Tuple[int, str]] + ) forwards = batch_token.direction_is_forward has_batch_token = True else: @@ -177,7 +182,7 @@ def build_room_entry(room): results = [build_room_entry(r) for r in results] - response = {} # type: JsonDict + response: JsonDict = {} num_results = len(results) if limit is not None: more_to_come = num_results == probing_limit @@ -378,7 +383,11 @@ async def get_remote_public_room_list( ): logger.debug("Falling back to locally-filtered /publicRooms") else: - raise # Not an error that should trigger a fallback. + # Not an error that should trigger a fallback. + raise SynapseError(502, "Failed to fetch room list") + except RequestSendFailed: + # Not an error that should trigger a fallback. + raise SynapseError(502, "Failed to fetch room list") # if we reach this point, then we fall back to the situation where # we currently don't support searching across federation, so we have @@ -417,14 +426,17 @@ async def _get_remote_list_cached( repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - return await repl_layer.get_public_rooms( - server_name, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + try: + return await repl_layer.get_public_rooms( + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to fetch room list") key = ( server_name, diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 80ba65b9e019..e6e71e972964 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -83,7 +83,7 @@ def __init__(self, hs: "HomeServer"): self.unstable_idp_brand = None # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] + self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {} self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) @@ -372,7 +372,7 @@ def expire_sessions(self): DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) + "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),) ) @@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str: return username -MXID_MAPPER_MAP = { +MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = { "hexencode": map_username_to_mxid_localpart, "dotreplace": dot_replace_for_mxid, -} # type: Dict[str, Callable[[str], str]] +} @attr.s diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 4e718d3f633b..8226d6f5a1a8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -192,7 +192,7 @@ async def search( # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] # type: List[str] + historical_room_ids: List[str] = [] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -216,9 +216,9 @@ async def search( rank_map = {} # event_id -> rank of event allowed_events = [] # Holds result of grouping by room, if applicable - room_groups = {} # type: Dict[str, JsonDict] + room_groups: Dict[str, JsonDict] = {} # Holds result of grouping by sender, if applicable - sender_group = {} # type: Dict[str, JsonDict] + sender_group: Dict[str, JsonDict] = {} # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -262,7 +262,7 @@ async def search( s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] # type: List[EventBase] + room_events: List[EventBase] = [] i = 0 pagination_token = batch_token diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index b585057ec3a5..5f7d4602bd8d 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -24,6 +24,7 @@ EventContentFields, EventTypes, HistoryVisibility, + JoinRules, Membership, RoomTypes, ) @@ -89,14 +90,14 @@ async def get_space_summary( room_queue = deque((_RoomQueueEntry(room_id, ()),)) # rooms we have already processed - processed_rooms = set() # type: Set[str] + processed_rooms: Set[str] = set() # events we have already processed. We don't necessarily have their event ids, # so instead we key on (room id, state key) - processed_events = set() # type: Set[Tuple[str, str]] + processed_events: Set[Tuple[str, str]] = set() - rooms_result = [] # type: List[JsonDict] - events_result = [] # type: List[JsonDict] + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] while room_queue and len(rooms_result) < MAX_ROOMS: queue_entry = room_queue.popleft() @@ -150,14 +151,21 @@ async def get_space_summary( # The room should only be included in the summary if: # a. the user is in the room; # b. the room is world readable; or - # c. the user is in a space that has been granted access to - # the room. + # c. the user could join the room, e.g. the join rules + # are set to public or the user is in a space that + # has been granted access to the room. # # Note that we know the user is not in the root room (which is # why the remote call was made in the first place), but the user # could be in one of the children rooms and we just didn't know # about the link. - include_room = room.get("world_readable") is True + + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + include_room = ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ) # Check if the user is a member of any of the allowed spaces # from the response. @@ -264,10 +272,10 @@ async def federation_space_summary( # the set of rooms that we should not walk further. Initialise it with the # excluded-rooms list; we will add other rooms as we process them so that # we do not loop. - processed_rooms = set(exclude_rooms) # type: Set[str] + processed_rooms: Set[str] = set(exclude_rooms) - rooms_result = [] # type: List[JsonDict] - events_result = [] # type: List[JsonDict] + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] while room_queue and len(rooms_result) < MAX_ROOMS: room_id = room_queue.popleft() @@ -345,7 +353,7 @@ async def _summarize_local_room( max_children = MAX_ROOMS_PER_SPACE now = self._clock.time_msec() - events_result = [] # type: List[JsonDict] + events_result: List[JsonDict] = [] for edge_event in itertools.islice(child_events, max_children): events_result.append( await self._event_serializer.serialize_event( @@ -420,9 +428,8 @@ async def _is_room_accessible( It should be included if: - * The requester is joined or invited to the room. - * The requester can join without an invite (per MSC3083). - * The origin server has any user that is joined or invited to the room. + * The requester is joined or can join the room (per MSC3173). + * The origin server has any user that is joined or can join the room. * The history visibility is set to world readable. Args: @@ -441,13 +448,39 @@ async def _is_room_accessible( # If there's no state for the room, it isn't known. if not state_ids: + # The user might have a pending invite for the room. + if requester and await self._store.get_invite_for_local_user_in_room( + requester, room_id + ): + return True + logger.info("room %s is unknown, omitting from summary", room_id) return False room_version = await self._store.get_room_version(room_id) - # if we have an authenticated requesting user, first check if they are able to view - # stripped state in the room. + # Include the room if it has join rules of public or knock. + join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) + if join_rules_event_id: + join_rules_event = await self._store.get_event(join_rules_event_id) + join_rule = join_rules_event.content.get("join_rule") + if join_rule == JoinRules.PUBLIC or ( + room_version.msc2403_knocking and join_rule == JoinRules.KNOCK + ): + return True + + # Include the room if it is peekable. + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + # Otherwise we need to check information specific to the user or server. + + # If we have an authenticated requesting user, check if they are a member + # of the room (or can join the room). if requester: member_event_id = state_ids.get((EventTypes.Member, requester), None) @@ -470,9 +503,11 @@ async def _is_room_accessible( return True # If this is a request over federation, check if the host is in the room or - # is in one of the spaces specified via the join rules. + # has a user who could join the room. elif origin: - if await self._event_auth_handler.check_host_in_room(room_id, origin): + if await self._event_auth_handler.check_host_in_room( + room_id, origin + ) or await self._store.is_host_invited(room_id, origin): return True # Alternately, if the host has a user in any of the spaces specified @@ -490,18 +525,10 @@ async def _is_room_accessible( ): return True - # otherwise, check if the room is peekable - hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None) - if hist_vis_event_id: - hist_vis_ev = await self._store.get_event(hist_vis_event_id) - hist_vis = hist_vis_ev.content.get("history_visibility") - if hist_vis == HistoryVisibility.WORLD_READABLE: - return True - logger.info( - "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary", + "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", room_id, - requester, + requester or origin, ) return False @@ -535,6 +562,7 @@ async def _build_room_entry(self, room_id: str) -> JsonDict: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], + "join_rules": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0b297e54c496..1b855a685c4a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -202,10 +202,10 @@ def __init__(self, hs: "HomeServer"): self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) # a map from session id to session data - self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {} # map from idp_id to SsoIdentityProvider - self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._identity_providers: Dict[str, SsoIdentityProvider] = {} self._consent_at_registration = hs.config.consent.user_consent_at_registration @@ -296,7 +296,7 @@ async def handle_redirect_request( ) # if the client chose an IdP, use that - idp = None # type: Optional[SsoIdentityProvider] + idp: Optional[SsoIdentityProvider] = None if idp_id: idp = self._identity_providers.get(idp_id) if not idp: @@ -669,9 +669,9 @@ async def complete_sso_ui_auth_request( remote_user_id, ) - user_id_to_verify = await self._auth_handler.get_session_data( + user_id_to_verify: str = await self._auth_handler.get_session_data( ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID - ) # type: str + ) if not user_id: logger.warning( @@ -793,7 +793,7 @@ async def handle_submit_username_request( session.use_display_name = use_display_name emails_from_idp = set(session.emails) - filtered_emails = set() # type: Set[str] + filtered_emails: Set[str] = set() # we iterate through the list rather than just building a set conjunction, so # that we can log attempts to use unknown addresses diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 4e45d1da573c..3fd89af2a4e2 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -45,12 +45,11 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.stats_bucket_size = hs.config.stats_bucket_size self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None # type: Optional[int] + self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -106,20 +105,6 @@ async def _unsafe_process(self) -> None: room_deltas = {} user_deltas = {} - # Then count deltas for total_events and total_event_bytes. - ( - room_count, - user_count, - ) = await self.store.get_changes_room_total_events_and_bytes( - self.pos, max_pos - ) - - for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, Counter()).update(fields) - - for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, Counter()).update(fields) - logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -146,10 +131,10 @@ async def _handle_deltas( mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + room_to_stats_deltas: Dict[str, CounterType[str]] = {} + user_to_stats_deltas: Dict[str, CounterType[str]] = {} - room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] + room_to_state_updates: Dict[str, Dict[str, Any]] = {} for delta in deltas: typ = delta["type"] @@ -179,14 +164,12 @@ async def _handle_deltas( ) continue - event_content = {} # type: JsonDict + event_content: JsonDict = {} - sender = None if event_id is not None: event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} - sender = event.sender # All the values in this dict are deltas (RELATIVE changes) room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter()) @@ -244,12 +227,6 @@ async def _handle_deltas( room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: room_stats_delta["invited_members"] += 1 - - if sender and self.is_mine_id(sender): - user_to_stats_deltas.setdefault(sender, Counter())[ - "invites_sent" - ] += 1 - elif membership == Membership.LEAVE: room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: @@ -279,10 +256,6 @@ async def _handle_deltas( room_state["is_federatable"] = ( event_content.get("m.federate", True) is True ) - if sender and self.is_mine_id(sender): - user_to_stats_deltas.setdefault(sender, Counter())[ - "rooms_created" - ] += 1 elif typ == EventTypes.JoinRules: room_state["join_rules"] = event_content.get("join_rule") elif typ == EventTypes.RoomHistoryVisibility: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b9a036105959..f30bfcc93cf2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -278,12 +278,14 @@ def __init__(self, hs: "HomeServer"): self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) - self.lazy_loaded_members_cache = ExpiringCache( + self.lazy_loaded_members_cache: ExpiringCache[ + Tuple[str, Optional[str]], LruCache[str, str] + ] = ExpiringCache( "lazy_loaded_members_cache", self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, - ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]] + ) async def wait_for_sync_for_user( self, @@ -440,7 +442,7 @@ async def ephemeral_by_room( ) now_token = now_token.copy_and_replace("typing_key", typing_key) - ephemeral_by_room = {} # type: JsonDict + ephemeral_by_room: JsonDict = {} for event in typing: # we want to exclude the room_id from the event, but modifying the @@ -502,7 +504,7 @@ async def _load_filtered_recents( # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline - current_state_ids = frozenset() # type: FrozenSet[str] + current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): current_state_ids_map = await self.store.get_current_state_ids( room_id @@ -783,9 +785,9 @@ async def compute_summary( def get_lazy_loaded_members_cache( self, cache_key: Tuple[str, Optional[str]] ) -> LruCache[str, str]: - cache = self.lazy_loaded_members_cache.get( + cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get( cache_key - ) # type: Optional[LruCache[str, str]] + ) if cache is None: logger.debug("creating LruCache for %r", cache_key) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) @@ -984,7 +986,7 @@ async def compute_state_delta( if t[0] == EventTypes.Member: cache.set(t[1], event_id) - state = {} # type: Dict[str, EventBase] + state: Dict[str, EventBase] = {} if state_ids: state = await self.store.get_events(list(state_ids.values())) @@ -1088,9 +1090,13 @@ async def generate_sync_result( logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts = {} # type: JsonDict - unused_fallback_key_types = [] # type: List[str] + one_time_key_counts: JsonDict = {} + unused_fallback_key_types: List[str] = [] if device_id: + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) @@ -1437,7 +1443,7 @@ async def _generate_sync_entry_for_rooms( ) if block_all_room_ephemeral: - ephemeral_by_room = {} # type: Dict[str, List[JsonDict]] + ephemeral_by_room: Dict[str, List[JsonDict]] = {} else: now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, @@ -1468,7 +1474,7 @@ async def _generate_sync_entry_for_rooms( # If there is ignored users account data and it matches the proper type, # then use it. - ignored_users = frozenset() # type: FrozenSet[str] + ignored_users: FrozenSet[str] = frozenset() if ignored_account_data: ignored_users_data = ignored_account_data.get("ignored_users", {}) if isinstance(ignored_users_data, dict): @@ -1586,7 +1592,7 @@ async def _get_rooms_changed( user_id, since_token.room_key, now_token.room_key ) - mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] + mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -1599,7 +1605,7 @@ async def _get_rooms_changed( logger.debug( "Membership changes in %s: [%s]", room_id, - ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)), + ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events), ) non_joins = [e for e in events if e.membership != Membership.JOIN] @@ -1722,7 +1728,7 @@ async def _get_rooms_changed( # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events = [leave_event] # type: Optional[List[EventBase]] + batch_events: Optional[List[EventBase]] = [leave_event] else: batch_events = None @@ -1971,7 +1977,7 @@ async def _generate_room_entry( room_id, batch, sync_config, since_token, now_token, full_state=full_state ) - summary = {} # type: Optional[JsonDict] + summary: Optional[JsonDict] = {} # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form @@ -1995,7 +2001,7 @@ async def _generate_room_entry( ) if room_builder.rtype == "joined": - unread_notifications = {} # type: Dict[str, int] + unread_notifications: Dict[str, int] = {} room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e22393adc48d..0cb651a40009 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,11 +68,11 @@ def __init__(self, hs: "HomeServer"): ) # map room IDs to serial numbers - self._room_serials = {} # type: Dict[str, int] + self._room_serials: Dict[str, int] = {} # map room IDs to sets of users currently typing - self._room_typing = {} # type: Dict[str, Set[str]] + self._room_typing: Dict[str, Set[str]] = {} - self._member_last_federation_poke = {} # type: Dict[RoomMember, int] + self._member_last_federation_poke: Dict[RoomMember, int] = {} self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 @@ -208,6 +208,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.notifier = hs.get_notifier() + self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -216,7 +217,7 @@ def __init__(self, hs: "HomeServer"): hs.get_distributor().observe("user_left_room", self.user_left_room) # clock time we expect to stop - self._member_typing_until = {} # type: Dict[RoomMember, int] + self._member_typing_until: Dict[RoomMember, int] = {} # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( @@ -326,6 +327,19 @@ async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + is_in_room = await self.event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring typing update from %s as we're not in the room", + origin, + ) + return + member = RoomMember(user_id=user_id, room_id=room_id) # Check that the string is a valid user id @@ -391,9 +405,9 @@ async def get_all_typing_updates( if last_id == current_id: return [], current_id, False - changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id - ) # type: Optional[Iterable[str]] + changed_rooms: Optional[ + Iterable[str] + ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) if changed_rooms is None: changed_rooms = self._room_serials diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index dacc4f3076e7..6edb1da50a98 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -52,7 +52,7 @@ def __init__(self, hs: "HomeServer"): self.search_all_users = hs.config.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream - self.pos = None # type: Optional[int] + self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index ed4671b7deee..578fc48ef454 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes: return hostname # no Host header, use the address/port that the request arrived on - host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] + host: Union[address.IPv4Address, address.IPv6Address] = request.getHost() hostname = host.host.encode("ascii") diff --git a/synapse/http/client.py b/synapse/http/client.py index 1ca6624fd5da..2ac76b15c2d2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -160,7 +160,7 @@ def __init__( def resolveHostName( self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 ) -> IResolutionReceiver: - addresses = [] # type: List[IAddress] + addresses: List[IAddress] = [] def _callback() -> None: has_bad_ip = False @@ -333,9 +333,9 @@ def __init__( if self._ip_blacklist: # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. - self.reactor = BlacklistingReactorWrapper( + self.reactor: ISynapseReactor = BlacklistingReactorWrapper( hs.get_reactor(), self._ip_whitelist, self._ip_blacklist - ) # type: ISynapseReactor + ) else: self.reactor = hs.get_reactor() @@ -349,14 +349,14 @@ def __init__( pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 - self.agent = ProxyAgent( + self.agent: IAgent = ProxyAgent( self.reactor, hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, use_proxy=use_proxy, - ) # type: IAgent + ) if self._ip_blacklist: # If we have an IP blacklist, we then install the blacklisting Agent @@ -411,7 +411,7 @@ async def request( cooperator=self._cooperator, ) - request_deferred = treq.request( + request_deferred: defer.Deferred = treq.request( method, uri, agent=self.agent, @@ -421,7 +421,7 @@ async def request( # response bodies. unbuffered=True, **self._extra_treq_args, - ) # type: defer.Deferred + ) # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. @@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which immediately errors upon receiving data.""" - transport = None # type: Optional[ITCPTransport] + transport: Optional[ITCPTransport] = None def __init__(self, deferred: defer.Deferred): self.deferred = deferred @@ -798,7 +798,7 @@ def connectionLost(self, reason: Failure = connectionDone) -> None: class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" - transport = None # type: Optional[ITCPTransport] + transport: Optional[ITCPTransport] = None def __init__( self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int] diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 20d39a4ea64f..43f2140429b5 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -70,10 +70,8 @@ logger = logging.getLogger(__name__) -_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]] -_had_valid_well_known_cache = TTLCache( - "had-valid-well-known" -) # type: TTLCache[bytes, bool] +_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known") +_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known") @attr.s(slots=True, frozen=True) @@ -130,9 +128,10 @@ async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult: # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = await self._fetch_well_known( - server_name - ) # type: Optional[bytes], float + result: Optional[bytes] + cache_period: float + + result, cache_period = await self._fetch_well_known(server_name) except _FetchWellKnownFailure as e: if prev_result and e.temporary: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b8849c0150ae..2efa15bf0470 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -43,6 +43,7 @@ from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IReactorTime from twisted.internet.task import _EPSILON, Cooperator +from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse @@ -105,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): the parsed data. """ - CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore + CONTENT_TYPE: str = abc.abstractproperty() # type: ignore """The expected content type of the response, e.g. `application/json`. If the content type doesn't match we fail the request. """ @@ -262,6 +263,15 @@ async def _handle_response( request.uri.decode("ascii"), ) raise RequestSendFailed(e, can_retry=True) from e + except ResponseFailed as e: + logger.warning( + "{%s} [%s] Failed to read response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( "{%s} [%s] Error reading response %s %s: %s", @@ -317,11 +327,11 @@ def __init__(self, hs, tls_client_options_factory): # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. - self.reactor = BlacklistingReactorWrapper( + self.reactor: ISynapseReactor = BlacklistingReactorWrapper( hs.get_reactor(), hs.config.federation_ip_range_whitelist, hs.config.federation_ip_range_blacklist, - ) # type: ISynapseReactor + ) user_agent = hs.version_string if hs.config.user_agent_suffix: @@ -494,7 +504,7 @@ async def _send_request( ) # Inject the span into the headers - headers_dict = {} # type: Dict[bytes, List[bytes]] + headers_dict: Dict[bytes, List[bytes]] = {} opentracing.inject_header_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -523,9 +533,9 @@ async def _send_request( destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) - producer = QuieterFileBodyProducer( + producer: Optional[IBodyProducer] = QuieterFileBodyProducer( BytesIO(data), cooperator=self._cooperator - ) # type: Optional[IBodyProducer] + ) else: producer = None auth_headers = self.build_auth_headers( @@ -1137,6 +1147,24 @@ async def get_file( msg, ) raise SynapseError(502, msg, Codes.TOO_LARGE) + except defer.TimeoutError as e: + logger.warning( + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e + except ResponseFailed as e: + logger.warning( + "{%s} [%s] Failed to read response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7dfae8b786b9..f7193e60bdbc 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -117,7 +117,8 @@ def __init__( https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - # Parse credentials from https proxy connection string if present + # Parse credentials from http and https proxy connection string if present + self.http_proxy_creds, http_proxy = parse_username_password(http_proxy) self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) self.http_proxy_endpoint = _http_proxy_endpoint( @@ -171,7 +172,7 @@ def request(self, method, uri, headers=None, bodyProducer=None): """ uri = uri.strip() if not _VALID_URI.match(uri): - raise ValueError("Invalid URI {!r}".format(uri)) + raise ValueError(f"Invalid URI {uri!r}") parsed_uri = URI.fromBytes(uri) pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) @@ -189,6 +190,15 @@ def request(self, method, uri, headers=None, bodyProducer=None): and self.http_proxy_endpoint and not should_skip_proxy ): + # Determine whether we need to set Proxy-Authorization headers + if self.http_proxy_creds: + # Set a Proxy-Authorization header + if headers is None: + headers = Headers() + headers.addRawHeader( + b"Proxy-Authorization", + self.http_proxy_creds.as_proxy_authorization_value(), + ) # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: pool_key = ("http-proxy", self.http_proxy_endpoint) diff --git a/synapse/http/server.py b/synapse/http/server.py index efbc6d5b2541..b79fa722e931 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: if f.check(SynapseError): # mypy doesn't understand that f.check asserts the type. - exc = f.value # type: SynapseError # type: ignore + exc: SynapseError = f.value # type: ignore error_code = exc.code error_dict = exc.error_dict() @@ -132,7 +132,7 @@ def return_html_error( """ if f.check(CodeMessageException): # mypy doesn't understand that f.check asserts the type. - cme = f.value # type: CodeMessageException # type: ignore + cme: CodeMessageException = f.value # type: ignore code = cme.code msg = cme.msg @@ -404,7 +404,7 @@ def _get_handler_for_request( key word arguments to pass to the callback """ # At this point the path must be bytes. - request_path_bytes = request.path # type: bytes # type: ignore + request_path_bytes: bytes = request.path # type: ignore request_path = request_path_bytes.decode("ascii") # Treat HEAD requests as GET requests. request_method = request.method @@ -557,7 +557,7 @@ def __init__( request: Request, iterator: Iterator[bytes], ): - self._request = request # type: Optional[Request] + self._request: Optional[Request] = request self._iterator = iterator self._paused = False diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6ba2ce1e53a2..04560fb58977 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -205,7 +205,7 @@ def parse_string( parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore return parse_string_from_args( args, name, diff --git a/synapse/http/site.py b/synapse/http/site.py index 40754b7bea7e..190084e8aa68 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -64,16 +64,16 @@ class SynapseRequest(Request): def __init__(self, channel, *args, max_request_body_size=1024, **kw): Request.__init__(self, channel, *args, **kw) self._max_request_body_size = max_request_body_size - self.site = channel.site # type: SynapseSite + self.site: SynapseSite = channel.site self._channel = channel # this is used by the tests self.start_time = 0.0 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. - self._requester = None # type: Optional[Union[Requester, str]] + self._requester: Optional[Union[Requester, str]] = None # we can't yet create the logcontext, as we don't know the method. - self.logcontext = None # type: Optional[LoggingContext] + self.logcontext: Optional[LoggingContext] = None global _next_request_seq self.request_seq = _next_request_seq @@ -152,7 +152,7 @@ def get_redacted_uri(self) -> str: Returns: The redacted URI as a string. """ - uri = self.uri # type: Union[bytes, str] + uri: Union[bytes, str] = self.uri if isinstance(uri, bytes): uri = uri.decode("ascii", errors="replace") return redact_uri(uri) @@ -167,7 +167,7 @@ def get_method(self) -> str: Returns: The request method as a string. """ - method = self.method # type: Union[bytes, str] + method: Union[bytes, str] = self.method if isinstance(method, bytes): return self.method.decode("ascii") return method @@ -384,7 +384,7 @@ def _finished_processing(self): # authenticated (e.g. and admin is puppetting a user) then we log both. requester, authenticated_entity = self.get_authenticated_entity() if authenticated_entity: - requester = "{}.{}".format(authenticated_entity, requester) + requester = f"{authenticated_entity}.{requester}" self.site.access_logger.log( log_level, @@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest): """ # the client IP and ssl flag, as extracted from the headers. - _forwarded_for = None # type: Optional[_XForwardedForAddress] - _forwarded_https = False # type: bool + _forwarded_for: "Optional[_XForwardedForAddress]" = None + _forwarded_https: bool = False def requestReceived(self, command, path, version): # this method is called by the Channel once the full request has been diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index c515690b38f0..8202d0494d72 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -110,9 +110,9 @@ def __init__( self.port = port self.maximum_buffer = maximum_buffer - self._buffer = deque() # type: Deque[logging.LogRecord] - self._connection_waiter = None # type: Optional[Deferred] - self._producer = None # type: Optional[LogProducer] + self._buffer: Deque[logging.LogRecord] = deque() + self._connection_waiter: Optional[Deferred] = None + self._producer: Optional[LogProducer] = None # Connect without DNS lookups if it's a direct IP. if _reactor is None: @@ -123,9 +123,9 @@ def __init__( try: ip = ip_address(self.host) if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( + endpoint: IStreamClientEndpoint = TCP4ClientEndpoint( _reactor, self.host, self.port - ) # type: IStreamClientEndpoint + ) elif isinstance(ip, IPv6Address): endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) else: @@ -165,7 +165,7 @@ def fail(failure: Failure) -> None: def writer(result: Protocol) -> None: # Force recognising transport as a Connection and not the more # generic ITransport. - transport = result.transport # type: Connection # type: ignore + transport: Connection = result.transport # type: ignore # We have a connection. If we already have a producer, and its # transport is the same, just trigger a resumeProducing. @@ -188,7 +188,7 @@ def writer(result: Protocol) -> None: self._producer.resumeProducing() self._connection_waiter = None - deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred + deferred: Deferred = self._service.whenConnected(failAfterFailures=1) deferred.addCallbacks(writer, fail) self._connection_waiter = deferred diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index c7a971a9d60b..b9933a152847 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -63,7 +63,7 @@ def parse_drain_configs( DrainType.CONSOLE_JSON, DrainType.FILE_JSON, ): - formatter = "json" # type: Optional[str] + formatter: Optional[str] = "json" elif logging_type in ( DrainType.CONSOLE_JSON_TERSE, DrainType.NETWORK_JSON_TERSE, diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 7fc11a9ac2f8..18ac50780285 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -113,13 +113,13 @@ def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None: self.reset() else: # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now - self.ru_utime = copy_from.ru_utime # type: float - self.ru_stime = copy_from.ru_stime # type: float - self.db_txn_count = copy_from.db_txn_count # type: int + self.ru_utime: float = copy_from.ru_utime + self.ru_stime: float = copy_from.ru_stime + self.db_txn_count: int = copy_from.db_txn_count - self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float - self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float - self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int + self.db_txn_duration_sec: float = copy_from.db_txn_duration_sec + self.db_sched_duration_sec: float = copy_from.db_sched_duration_sec + self.evt_db_fetch_count: int = copy_from.evt_db_fetch_count def copy(self) -> "ContextResourceUsage": return ContextResourceUsage(copy_from=self) @@ -289,12 +289,12 @@ def __init__( # The thread resource usage when the logcontext became active. None # if the context is not currently active. - self.usage_start = None # type: Optional[resource._RUsage] + self.usage_start: Optional[resource._RUsage] = None self.main_thread = get_thread_id() self.request = None self.tag = "" - self.scope = None # type: Optional[_LogContextScope] + self.scope: Optional["_LogContextScope"] = None # keep track of whether we have hit the __exit__ block for this context # (suggesting that the the thing that created the context thinks it should diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 140ed711e373..ecd51f1b4a26 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -251,7 +251,7 @@ def report_span(self, span): except Exception: logger.exception("Failed to report span") - RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]] + RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter except ImportError: RustReporter = None @@ -286,7 +286,7 @@ class SynapseBaggage: # Block everything by default # A regex which matches the server_names to expose traces for. # None means 'block everything'. -_homeserver_whitelist = None # type: Optional[Pattern[str]] +_homeserver_whitelist: Optional[Pattern[str]] = None # Util methods @@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"): config = JaegerConfig( config=hs.config.jaeger_config, - service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + service_name=f"{hs.config.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), ) @@ -662,7 +662,7 @@ def inject_header_dict( span = opentracing.tracer.active_span - carrier = {} # type: Dict[str, str] + carrier: Dict[str, str] = {} opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -704,7 +704,7 @@ def get_active_span_text_map(destination=None): if destination and not whitelisted_homeserver(destination): return {} - carrier = {} # type: Dict[str, str] + carrier: Dict[str, str] = {} opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier ) @@ -718,7 +718,7 @@ def active_span_context_as_string(): Returns: The active span context encoded as a string. """ - carrier = {} # type: Dict[str, str] + carrier: Dict[str, str] = {} if opentracing: opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index fef28466694a..f237b8a2369e 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -46,7 +46,7 @@ METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] +all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {} HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -130,7 +130,7 @@ def __init__(self, name, desc, labels, sub_metrics): ) # Counts number of in flight blocks for a given set of label values - self._registrations = {} # type: Dict + self._registrations: Dict = {} # Protects access to _registrations self._lock = threading.Lock() @@ -248,7 +248,7 @@ def __init__( # We initially set this to None. We won't report metrics until # this has been initialised after a successful data update - self._metric = None # type: Optional[GaugeHistogramMetricFamily] + self._metric: Optional[GaugeHistogramMetricFamily] = None registry.register(self) diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 8002be56e0e8..bb9bcb5592ed 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -34,7 +34,7 @@ from synapse.util import caches -CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" INF = float("inf") @@ -55,8 +55,8 @@ def floatToGoString(d): # Go switches to exponents sooner than Python. # We only need to care about positive values for le/quantile. if d > 0 and dot > 6: - mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.") - return "{0}e+0{1}".format(mantissa, dot - 1) + mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") + return f"{mantissa}e+0{dot - 1}" return s @@ -65,7 +65,7 @@ def sample_line(line, name): labelstr = "{{{0}}}".format( ",".join( [ - '{0}="{1}"'.format( + '{}="{}"'.format( k, v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), ) @@ -78,10 +78,8 @@ def sample_line(line, name): timestamp = "" if line.timestamp is not None: # Convert to milliseconds. - timestamp = " {0:d}".format(int(float(line.timestamp) * 1000)) - return "{0}{1} {2}{3}\n".format( - name, labelstr, floatToGoString(line.value), timestamp - ) + timestamp = f" {int(float(line.timestamp) * 1000):d}" + return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) def generate_latest(registry, emit_help=False): @@ -118,14 +116,14 @@ def generate_latest(registry, emit_help=False): # Output in the old format for compatibility. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mname, mtype)) + output.append(f"# TYPE {mname} {mtype}\n") - om_samples = {} # type: Dict[str, List[str]] + om_samples: Dict[str, List[str]] = {} for s in metric.samples: for suffix in ["_created", "_gsum", "_gcount"]: if s.name == metric.name + suffix: @@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False): for suffix, lines in sorted(om_samples.items()): if emit_help: output.append( - "# HELP {0}{1} {2}\n".format( + "# HELP {}{} {}\n".format( metric.name, suffix, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix)) + output.append(f"# TYPE {metric.name}{suffix} gauge\n") output.extend(lines) # Get rid of the weird colon things while we're at it @@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False): # Also output in the new format, if it's different. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mnewname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) + output.append(f"# TYPE {mnewname} {mtype}\n") for s in metric.samples: # Get rid of the OpenMetrics specific samples (we should already have diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index de96ca08212a..3a14260752ed 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -93,7 +93,7 @@ # map from description to a counter, so that we can name our logcontexts # incrementally. (It actually duplicates _background_process_start_count, but # it's much simpler to do so than to try to combine them.) -_background_process_counts = {} # type: Dict[str, int] +_background_process_counts: Dict[str, int] = {} # Set of all running background processes that became active active since the # last time metrics were scraped (i.e. background processes that performed some @@ -103,7 +103,7 @@ # background processes stacking up behind a lock or linearizer, where we then # only need to iterate over and update metrics for the process that have # actually been active and can ignore the idle ones. -_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess] +_background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set() # A lock that covers the above set and dict _bg_metrics_lock = threading.Lock() @@ -137,8 +137,7 @@ def collect(self): _background_process_db_txn_duration, _background_process_db_sched_duration, ): - for r in m.collect(): - yield r + yield from m.collect() REGISTRY.register(_Collector()) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 721c45abac70..1259fc2d906a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -12,18 +12,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import email.utils import logging -from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +import jinja2 from twisted.internet import defer from twisted.web.resource import IResource from synapse.events import EventBase from synapse.http.client import SimpleHttpClient +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util import Clock +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -33,7 +57,20 @@ are loaded into Synapse. """ -__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"] +__all__ = [ + "errors", + "make_deferred_yieldable", + "parse_json_object_from_request", + "respond_with_html", + "run_in_background", + "cached", + "UserID", + "DatabasePool", + "LoggingTransaction", + "DirectServeHtmlResource", + "DirectServeJsonResource", + "ModuleApi", +] logger = logging.getLogger(__name__) @@ -52,12 +89,28 @@ def __init__(self, hs: "HomeServer", auth_handler): self._server_name = hs.hostname self._presence_stream = hs.get_event_sources().sources["presence"] self._state = hs.get_state_handler() + self._clock: Clock = hs.get_clock() + self._send_email_handler = hs.get_send_email_handler() + + try: + app_name = self._hs.config.email_app_name + + self._from_string = self._hs.config.email_notif_from % {"app": app_name} + except (KeyError, TypeError): + # If substitution failed (which can happen if the string contains + # placeholders other than just "app", or if the type of the placeholder is + # not a string), fall back to the bare strings. + self._from_string = self._hs.config.email_notif_from + + self._raw_from = email.utils.parseaddr(self._from_string)[1] # We expose these as properties below in order to attach a helpful docstring. - self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient + self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) self._spam_checker = hs.get_spam_checker() + self._account_validity_handler = hs.get_account_validity_handler() + self._third_party_event_rules = hs.get_third_party_event_rules() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -67,6 +120,16 @@ def register_spam_checker_callbacks(self): """Registers callbacks for spam checking capabilities.""" return self._spam_checker.register_callbacks + @property + def register_account_validity_callbacks(self): + """Registers callbacks for account validity capabilities.""" + return self._account_validity_handler.register_account_validity_callbacks + + @property + def register_third_party_rules_callbacks(self): + """Registers callbacks for third party event rules capabilities.""" + return self._third_party_event_rules.register_third_party_rules_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. @@ -101,22 +164,56 @@ def public_room_list_manager(self): """ return self._public_room_list_manager - def get_user_by_req(self, req, allow_guest=False): + @property + def public_baseurl(self) -> str: + """The configured public base URL for this homeserver.""" + return self._hs.config.public_baseurl + + @property + def email_app_name(self) -> str: + """The application name configured in the homeserver's configuration.""" + return self._hs.config.email.email_app_name + + async def get_user_by_req( + self, + req: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + ) -> Requester: """Check the access_token provided for a request Args: - req (twisted.web.server.Request): Incoming HTTP request - allow_guest (bool): True if guest users should be allowed. If this + req: Incoming HTTP request + allow_guest: True if guest users should be allowed. If this is False, and the access token is for a guest user, an AuthError will be thrown + allow_expired: True if expired users should be allowed. If this + is False, and the access token is for an expired user, an + AuthError will be thrown + Returns: - twisted.internet.defer.Deferred[synapse.types.Requester]: - the requester for this request + The requester for this request + Raises: - synapse.api.errors.AuthError: if no user by that token exists, + InvalidClientCredentialsError: if no user by that token exists, or the token is invalid. """ - return self._auth.get_user_by_req(req, allow_guest) + return await self._auth.get_user_by_req( + req, + allow_guest, + allow_expired=allow_expired, + ) + + async def is_user_admin(self, user_id: str) -> bool: + """Checks if a user is a server admin. + + Args: + user_id: The Matrix ID of the user to check. + + Returns: + True if the user is a server admin, False otherwise. + """ + return await self._store.is_server_admin(UserID.from_string(user_id)) def get_qualified_user_id(self, username): """Qualify a user id, if necessary @@ -134,6 +231,32 @@ def get_qualified_user_id(self, username): return username return UserID(username, self._hs.hostname).to_string() + async def get_profile_for_user(self, localpart: str) -> ProfileInfo: + """Look up the profile info for the user with the given localpart. + + Args: + localpart: The localpart to look up profile information for. + + Returns: + The profile information (i.e. display name and avatar URL). + """ + return await self._store.get_profileinfo(localpart) + + async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: + """Look up the threepids (email addresses and phone numbers) associated with the + given Matrix user ID. + + Args: + user_id: The Matrix user ID to look up threepids for. + + Returns: + A list of threepids, each threepid being represented by a dictionary + containing a "medium" key which value is "email" for email addresses and + "msisdn" for phone numbers, and an "address" key which value is the + threepid's address. + """ + return await self._store.user_get_threepids(user_id) + def check_user_exists(self, user_id): """Check if user exists. @@ -464,6 +587,88 @@ async def send_local_online_presence_to(self, users: Iterable[str]) -> None: presence_events, destination ) + def looping_background_call( + self, + f: Callable, + msec: float, + *args, + desc: Optional[str] = None, + **kwargs, + ): + """Wraps a function as a background process and calls it repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f: The function to call repeatedly. f can be either synchronous or + asynchronous, and must follow Synapse's logcontext rules. + More info about logcontexts is available at + https://matrix-org.github.io/synapse/latest/log_contexts.html + msec: How long to wait between calls in milliseconds. + *args: Positional arguments to pass to function. + desc: The background task's description. Default to the function's name. + **kwargs: Key arguments to pass to function. + """ + if desc is None: + desc = f.__name__ + + if self._hs.config.run_background_tasks: + self._clock.looping_call( + run_as_background_process, + msec, + desc, + f, + *args, + **kwargs, + ) + else: + logger.warning( + "Not running looping call %s as the configuration forbids it", + f, + ) + + async def send_mail( + self, + recipient: str, + subject: str, + html: str, + text: str, + ): + """Send an email on behalf of the homeserver. + + Args: + recipient: The email address for the recipient. + subject: The email's subject. + html: The email's HTML content. + text: The email's text content. + """ + await self._send_email_handler.send_email( + email_address=recipient, + subject=subject, + app_name=self.email_app_name, + html=html, + text=text, + ) + + def read_templates( + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Read and load the content of the template files at the given location. + By default, Synapse will look for these templates in its configured template + directory, but another directory to search in can be provided. + + Args: + filenames: The name of the template files to look for. + custom_template_directory: An additional directory to look for the files in. + + Returns: + A list containing the loaded templates, with the orders matching the one of + the filenames parameter. + """ + return self._hs.config.read_templates(filenames, custom_template_directory) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 02bbb0be39fc..98ea911a8195 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -14,5 +14,9 @@ """Exception types which are exposed as part of the stable module API""" -from synapse.api.errors import RedirectException, SynapseError # noqa: F401 +from synapse.api.errors import ( # noqa: F401 + InvalidClientCredentialsError, + RedirectException, + SynapseError, +) from synapse.config._base import ConfigError # noqa: F401 diff --git a/synapse/notifier.py b/synapse/notifier.py index 3c3cc476318b..c5fbebc17d75 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -203,21 +203,21 @@ class Notifier: UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs: "synapse.server.HomeServer"): - self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] - self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] + self.user_to_user_stream: Dict[str, _NotifierUserStream] = {} + self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] + self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication - self.replication_callbacks = [] # type: List[Callable[[], None]] + self.replication_callbacks: List[Callable[[], None]] = [] # Called when remote servers have come back online after having been # down. - self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] + self.remote_server_up_callbacks: List[Callable[[str], None]] = [] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -237,7 +237,7 @@ def __init__(self, hs: "synapse.server.HomeServer"): # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_user_streams = set() # type: Set[_NotifierUserStream] + all_user_streams: Set[_NotifierUserStream] = set() for streams in list(self.room_to_user_streams.values()): all_user_streams |= streams @@ -329,8 +329,8 @@ def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken pending = self.pending_new_room_events self.pending_new_room_events = [] - users = set() # type: Set[UserID] - rooms = set() # type: Set[str] + users: Set[UserID] = set() + rooms: Set[str] = set() for entry in pending: if entry.event_pos.persisted_after(max_room_stream_token): @@ -580,7 +580,7 @@ async def check_for_updates( if after_token == before_token: return EventStreamResult([], (from_token, from_token)) - events = [] # type: List[EventBase] + events: List[EventBase] = [] end_token = from_token for name, source in self.event_sources.sources.items(): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 669ea462e29e..c337e530d3cf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -194,7 +194,7 @@ async def action_for_event_by_user( count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event, context) - actions_by_user = {} # type: Dict[str, List[Union[dict, str]]] + actions_by_user: Dict[str, List[Union[dict, str]]] = {} room_members = await self.store.get_joined_users_from_context(event, context) @@ -207,7 +207,7 @@ async def action_for_event_by_user( event, len(room_members), sender_power_level, power_levels ) - condition_cache = {} # type: Dict[str, bool] + condition_cache: Dict[str, bool] = {} # If the event is not a state event check if any users ignore the sender. if not event.is_state(): diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 2ee0ccd58aa9..1fc9716a3422 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -26,10 +26,10 @@ def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, l # We're going to be mutating this a lot, so do a deep copy ruleslist = copy.deepcopy(ruleslist) - rules = { + rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = { "global": {}, "device": {}, - } # type: Dict[str, Dict[str, List[Dict[str, Any]]]] + } rules["global"] = _add_empty_priority_class_arrays(rules["global"]) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 99a18874d1cf..e08e125cb8a5 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,8 +66,8 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer self.store = self.hs.get_datastore() self.email = pusher_config.pushkey - self.timed_call = None # type: Optional[IDelayedCall] - self.throttle_params = {} # type: Dict[str, ThrottleParams] + self.timed_call: Optional[IDelayedCall] = None + self.throttle_params: Dict[str, ThrottleParams] = {} self._inited = False self._is_processing = False @@ -168,7 +168,7 @@ async def _unsafe_process(self) -> None: ) ) - soonest_due_at = None # type: Optional[int] + soonest_due_at: Optional[int] = None if not unprocessed: await self.save_last_stream_ordering_and_success(self.max_stream_ordering) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 06bf5f8ada84..36aabd84220b 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -71,7 +71,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusher_config.failing_since - self.timed_call = None # type: Optional[IDelayedCall] + self.timed_call: Optional[IDelayedCall] = None self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._pusherpool = hs.get_pusherpool() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5f9ea5003a82..7be5fe1e9ba1 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -110,7 +110,7 @@ def __init__( self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() self.app_name = app_name - self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig + self.email_subjects: EmailSubjectConfig = hs.config.email_subjects logger.info("Created Mailer for app_name %s" % app_name) @@ -230,7 +230,7 @@ async def send_notification_mail( [pa["event_id"] for pa in push_actions] ) - notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]] + notifs_by_room: Dict[str, List[Dict[str, Any]]] = {} for pa in push_actions: notifs_by_room.setdefault(pa["room_id"], []).append(pa) @@ -356,13 +356,13 @@ async def _get_room_vars( room_name = await calculate_room_name(self.store, room_state_ids, user_id) - room_vars = { + room_vars: Dict[str, Any] = { "title": room_name, "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], "invite": is_invite, "link": self._make_room_link(room_id), - } # type: Dict[str, Any] + } if not is_invite: for n in notifs: @@ -460,9 +460,9 @@ async def _get_message_vars( type_state_key = ("m.room.member", event.sender) sender_state_event_id = room_state_ids.get(type_state_key) if sender_state_event_id: - sender_state_event = await self.store.get_event( + sender_state_event: Optional[EventBase] = await self.store.get_event( sender_state_event_id - ) # type: Optional[EventBase] + ) else: # Attempt to check the historical state for the room. historical_state = await self.state_store.get_state_for_event( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 412941393fe3..0510c1cbd50f 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -199,7 +199,7 @@ def name_from_member_event(member_event: EventBase) -> str: def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]: - ret = {} # type: Dict[str, Dict[str, str]] + ret: Dict[str, Dict[str, str]] = {} for k, v in state.items(): ret.setdefault(k[0], {})[k[1]] = v return ret diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 98b90a4f516e..7a8dc63976e2 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -195,9 +195,9 @@ def _get_value(self, dotted_key: str) -> Optional[str]: # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache( +regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( 50000, "regex_push_cache" -) # type: LruCache[Tuple[str, bool, bool], Pattern] +) def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index c51938b8cf48..021275437c5a 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -31,13 +31,13 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.pusher_types = { + self.pusher_types: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] = { "http": HttpPusher - } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] + } logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: - self.mailers = {} # type: Dict[str, Mailer] + self.mailers: Dict[str, Mailer] = {} self._notif_template_html = hs.config.email_notif_template_html self._notif_template_text = hs.config.email_notif_template_text diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 579fcdf47250..85621f33ef1e 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,10 +62,6 @@ def __init__(self, hs: "HomeServer"): self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) - # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() @@ -87,7 +83,9 @@ def __init__(self, hs: "HomeServer"): self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() # map from user id to app_id:pushkey to pusher - self.pushers = {} # type: Dict[str, Dict[str, Pusher]] + self.pushers: Dict[str, Dict[str, Pusher]] = {} + + self._account_validity_handler = hs.get_account_validity_handler() def start(self) -> None: """Starts the pushers off in a background process.""" @@ -238,12 +236,9 @@ async def _on_new_notifications(self, max_token: RoomStreamToken) -> None: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): @@ -268,12 +263,9 @@ async def on_new_receipts( for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 271c17c22615..cdcbdd772b14 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -115,7 +115,7 @@ "cache_memory": ["pympler"], } -ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] +ALL_OPTIONAL_REQUIREMENTS: Set[str] = set() for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): # Exclude systemd as it's a system-based requirement. @@ -193,7 +193,7 @@ def check_requirements(for_feature=None): if not for_feature: # Check the optional dependencies are up to date. We allow them to not be # installed. - OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str] + OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) for dependency in OPTS: try: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index f13a7c23b4a6..25589b0042a0 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -85,17 +85,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): is received. """ - NAME = abc.abstractproperty() # type: str # type: ignore - PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore + NAME: str = abc.abstractproperty() # type: ignore + PATH_ARGS: Tuple[str, ...] = abc.abstractproperty() # type: ignore METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True def __init__(self, hs: "HomeServer"): if self.CACHE: - self.response_cache = ResponseCache( + self.response_cache: ResponseCache[str] = ResponseCache( hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 - ) # type: ResponseCache[str] + ) # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. @@ -232,7 +232,7 @@ async def send_request(*, instance_name="master", **kwargs): # have a good idea that the request has either succeeded or failed on # the master, and so whether we should clean up or not. while True: - headers = {} # type: Dict[bytes, List[bytes]] + headers: Dict[bytes, List[bytes]] = {} # Add an authorization header, if configured. if replication_secret: headers[b"Authorization"] = [b"Bearer " + replication_secret] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index faa99387a745..e460dd85cd83 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -27,7 +27,9 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = MultiWriterIdGenerator( + self._cache_id_gen: Optional[ + MultiWriterIdGenerator + ] = MultiWriterIdGenerator( db_conn, database, stream_name="caches", @@ -41,7 +43,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): ], sequence_name="cache_invalidation_stream_seq", writers=[], - ) # type: Optional[MultiWriterIdGenerator] + ) else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 13ed87adc4ab..436d39c3203f 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -23,9 +23,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = LruCache( + self.client_ip_last_seen: LruCache[tuple, int] = LruCache( cache_name="client_ip_last_seen", max_size=50000 - ) # type: LruCache[tuple, int] + ) async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 62d780917500..9d4859798bd3 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -121,13 +121,13 @@ def __init__(self, hs: "HomeServer"): self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() - self.send_handler = None # type: Optional[FederationSenderHandler] + self.send_handler: Optional[FederationSenderHandler] = None if hs.should_send_federation(): self.send_handler = FederationSenderHandler(hs) # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]] + self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -173,7 +173,7 @@ async def on_rdata( if entities: self.notifier.on_new_event("to_device_key", token, users=entities) elif stream_name == DeviceListsStream.NAME: - all_room_ids = set() # type: Set[str] + all_room_ids: Set[str] = set() for row in rows: if row.entity.startswith("@"): room_ids = await self.store.get_rooms_for_user(row.entity) @@ -201,7 +201,7 @@ async def on_rdata( if row.data.rejected: continue - extra_users = () # type: Tuple[UserID, ...] + extra_users: Tuple[UserID, ...] = () if row.data.type == EventTypes.Member and row.data.state_key: extra_users = (UserID.from_string(row.data.state_key),) @@ -348,7 +348,7 @@ def __init__(self, hs: "HomeServer"): # Stores the latest position in the federation stream we've gotten up # to. This is always set before we use it. - self.federation_position = None # type: Optional[int] + self.federation_position: Optional[int] = None self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 505d450e1991..1311b013dae7 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta): A full command line on the wire is constructed from `NAME + " " + to_line()` """ - NAME = None # type: str + NAME: str @classmethod @abc.abstractmethod @@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand): NAME = "REMOTE_SERVER_UP" -_COMMANDS = ( +_COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, PositionCommand, @@ -393,7 +393,7 @@ class RemoteServerUpCommand(_SimpleCommand): UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, -) # type: Tuple[Type[Command], ...] +) # Map of command name to command type. COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2ad7a200bb6a..eae4515363f2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -105,12 +105,12 @@ def __init__(self, hs: "HomeServer"): hs.get_instance_name() in hs.config.worker.writers.presence ) - self._streams = { + self._streams: Dict[str, Stream] = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() - } # type: Dict[str, Stream] + } # List of streams that this instance is the source of - self._streams_to_replicate = [] # type: List[Stream] + self._streams_to_replicate: List[Stream] = [] for stream in self._streams.values(): if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: @@ -180,14 +180,14 @@ def __init__(self, hs: "HomeServer"): # Map of stream name to batched updates. See RdataCommand for info on # how batching works. - self._pending_batches = {} # type: Dict[str, List[Any]] + self._pending_batches: Dict[str, List[Any]] = {} # The factory used to create connections. - self._factory = None # type: Optional[ReconnectingClientFactory] + self._factory: Optional[ReconnectingClientFactory] = None # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) - self._connections = [] # type: List[IReplicationConnection] + self._connections: List[IReplicationConnection] = [] LaterGauge( "synapse_replication_tcp_resource_total_connections", @@ -200,7 +200,7 @@ def __init__(self, hs: "HomeServer"): # them in order in a separate background process. # the streams which are currently being processed by _unsafe_process_queue - self._processing_streams = set() # type: Set[str] + self._processing_streams: Set[str] = set() # for each stream, a queue of commands that are awaiting processing, and the # connection that they arrived on. @@ -210,7 +210,7 @@ def __init__(self, hs: "HomeServer"): # For each connection, the incoming stream names that have received a POSITION # from that connection. - self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] + self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} LaterGauge( "synapse_replication_tcp_command_queue", diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 6e3705364f47..8c80153ab6e0 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -102,7 +102,7 @@ # A list of all connected protocols. This allows us to send metrics about the # connections. -connected_connections = [] # type: List[BaseReplicationStreamProtocol] +connected_connections: "List[BaseReplicationStreamProtocol]" = [] logger = logging.getLogger(__name__) @@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The transport is going to be an ITCPTransport, but that doesn't have the # (un)registerProducer methods, those are only on the implementation. - transport = None # type: Connection + transport: Connection delimiter = b"\n" # Valid commands we expect to receive - VALID_INBOUND_COMMANDS = [] # type: Collection[str] + VALID_INBOUND_COMMANDS: Collection[str] = [] # Valid commands we can send - VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] + VALID_OUTBOUND_COMMANDS: Collection[str] = [] max_line_buffer = 10000 @@ -165,7 +165,7 @@ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 # When we requested the connection be closed - self.time_we_closed = None # type: Optional[int] + self.time_we_closed: Optional[int] = None self.received_ping = False # Have we received a ping from the other side @@ -175,10 +175,10 @@ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.conn_id = random_string(5) # To dedupe in case of name clashes. # List of pending commands to send once we've established the connection - self.pending_commands = [] # type: List[Command] + self.pending_commands: List[Command] = [] # The LoopingCall for sending pings. - self._send_ping_loop = None # type: Optional[task.LoopingCall] + self._send_ping_loop: Optional[task.LoopingCall] = None # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 6a2c2655e4bd..8c0df627c8d0 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]): it. """ - constant = attr.ib() # type: V + constant: V = attr.ib() def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant @@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): commands. """ - synapse_handler = None # type: ReplicationCommandHandler - synapse_stream_name = None # type: str - synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol + synapse_handler: "ReplicationCommandHandler" + synapse_stream_name: str + synapse_outbound_redis_connection: txredisapi.RedisProtocol def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b03824925a91..3716c41bea7b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -85,9 +85,9 @@ class Stream: time it was called. """ - NAME = None # type: str # The name of the stream + NAME: str # The name of the stream # The type of the row. Used by the default impl of parse_row. - ROW_TYPE = None # type: Any + ROW_TYPE: Any = None @classmethod def parse_row(cls, row: StreamRow): @@ -283,9 +283,7 @@ def __init__(self, hs: "HomeServer"): assert isinstance(presence_handler, PresenceHandler) - update_function = ( - presence_handler.get_all_presence_updates - ) # type: UpdateFunction + update_function: UpdateFunction = presence_handler.get_all_presence_updates else: # Query presence writer process update_function = make_http_update_function(hs, self.NAME) @@ -334,9 +332,9 @@ def __init__(self, hs: "HomeServer"): if writer_instance == hs.get_instance_name(): # On the writer, query the typing handler typing_writer_handler = hs.get_typing_writer_handler() - update_function = ( - typing_writer_handler.get_all_typing_updates - ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + update_function: Callable[ + [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] + ] = typing_writer_handler.get_all_typing_updates current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index e7e87bac9241..a030e9299edc 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -65,7 +65,7 @@ class BaseEventsStreamRow: """ # Unique string that ids the type. Must be overridden in sub classes. - TypeId = None # type: str + TypeId: str @classmethod def from_data(cls, data): @@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id = attr.ib() # str, optional -_EventRows = ( +_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( EventsStreamEventRow, EventsStreamCurrentStateRow, -) # type: Tuple[Type[BaseEventsStreamRow], ...] +) TypeToRow = {Row.TypeId: Row for Row in _EventRows} @@ -157,9 +157,9 @@ async def _update_function( # now we fetch up to that many rows from the events table - event_rows = await self._store.get_all_new_forward_event_rows( + event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count - ) # type: List[Tuple] + ) # we rely on get_all_new_forward_event_rows strictly honouring the limit, so # that we know it is safe to just take upper_limit = event_rows[-1][0]. @@ -172,7 +172,7 @@ async def _update_function( if len(event_rows) == target_row_count: limited = True - upper_limit = event_rows[-1][0] # type: int + upper_limit: int = event_rows[-1][0] else: limited = False upper_limit = current_token @@ -191,30 +191,30 @@ async def _update_function( # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit - ) # type: List[Tuple] + ) # we now need to turn the raw database rows returned into tuples suitable # for the replication protocol (basically, we add an identifier to # distinguish the row type). At the same time, we can limit the event_rows # to the max stream_id from state_rows. - event_updates = ( + event_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamEventRow.TypeId, rest)) for (stream_id, *rest) in event_rows if stream_id <= upper_limit - ) # type: Iterable[Tuple[int, Tuple]] + ) - state_updates = ( + state_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) for (stream_id, *rest) in state_rows - ) # type: Iterable[Tuple[int, Tuple]] + ) - ex_outliers_updates = ( + ex_outliers_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamEventRow.TypeId, rest)) for (stream_id, *rest) in ex_outliers_rows - ) # type: Iterable[Tuple[int, Tuple]] + ) # we need to return a sorted list, so merge them together. updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 096a85d36306..c445af9bd986 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -51,9 +51,9 @@ def __init__(self, hs: "HomeServer"): current_token = current_token_without_instance( federation_sender.get_current_token ) - update_function = ( - federation_sender.get_replication_rows - ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + update_function: Callable[ + [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] + ] = federation_sender.get_replication_rows elif hs.should_send_federation(): # federation sender: Query master process diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f0cddd2d2c7f..40ee33646cb2 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -402,9 +402,9 @@ async def on_POST( # Get the room ID from the identifier. try: - remote_room_hosts = [ + remote_room_hosts: Optional[List[str]] = [ x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] + ] except Exception: remote_room_hosts = None room_id, remote_room_hosts = await self.resolve_room_id( @@ -462,6 +462,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() + self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -500,7 +501,13 @@ async def on_POST( admin_user_id = None for admin_user in reversed(admin_users): - if room_state.get((EventTypes.Member, admin_user)): + ( + current_membership_type, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + admin_user, room_id + ) + if current_membership_type == "join": admin_user_id = admin_user break @@ -652,9 +659,7 @@ async def on_GET( filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 7d75564758bf..589e47fa473c 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth_handler = hs.get_auth_handler() self.reactor = hs.get_reactor() - self.nonces = {} # type: Dict[str, int] + self.nonces: Dict[str, int] = {} self.hs = hs def _clear_old_nonces(self): @@ -560,16 +560,24 @@ def __init__(self, hs: "HomeServer"): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - body = parse_json_object_from_request(request) + if self.account_activity_handler.on_legacy_admin_request_callback: + expiration_ts = await ( + self.account_activity_handler.on_legacy_admin_request_callback(request) + ) + else: + body = parse_json_object_from_request(request) - if "user_id" not in body: - raise SynapseError(400, "Missing property 'user_id' in the request body") + if "user_id" not in body: + raise SynapseError( + 400, + "Missing property 'user_id' in the request body", + ) - expiration_ts = await self.account_activity_handler.renew_account_for_user( - body["user_id"], - body.get("expiration_ts"), - not body.get("enable_renewal_emails", True), - ) + expiration_ts = await self.account_activity_handler.renew_account_for_user( + body["user_id"], + body.get("expiration_ts"), + not body.get("enable_renewal_emails", True), + ) res = {"expiration_ts": expiration_ts} return 200, res diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cbcb60fe314f..11567bf32cef 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,19 +44,14 @@ logger = logging.getLogger(__name__) -LoginResponse = TypedDict( - "LoginResponse", - { - "user_id": str, - "access_token": str, - "home_server": str, - "expires_in_ms": Optional[int], - "refresh_token": Optional[str], - "device_id": str, - "well_known": Optional[Dict[str, Any]], - }, - total=False, -) +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] class LoginRestServlet(RestServlet): @@ -121,7 +116,7 @@ def on_GET(self, request: SynapseRequest): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = { + sso_flow: JsonDict = { "type": LoginRestServlet.SSO_TYPE, "identity_providers": [ _get_auth_flow_dict_for_idp( @@ -129,7 +124,7 @@ def on_GET(self, request: SynapseRequest): ) for idp in self._sso_handler.get_identity_providers().values() ], - } # type: JsonDict + } if self._msc2858_enabled: # backwards-compatibility support for clients which don't @@ -150,9 +145,7 @@ def on_GET(self, request: SynapseRequest): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend( - ({"type": t} for t in self.auth_handler.get_supported_login_types()) - ) + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) @@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp( use_unstable_brands: whether we should use brand identifiers suitable for the unstable API """ - e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: @@ -561,7 +554,7 @@ async def on_GET( finish_request(request) return - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) sso_url = await self._sso_handler.handle_redirect_request( request, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 92ebe838fd84..31a1193cd3c0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -29,6 +29,7 @@ SynapseError, ) from synapse.api.filtering import Filter +from synapse.appservice import ApplicationService from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, @@ -47,11 +48,13 @@ from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + Requester, RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID, + create_requester, ) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -309,7 +312,7 @@ def __init__(self, hs): self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -349,6 +352,54 @@ async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int: return depth + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=False) @@ -414,7 +465,9 @@ async def on_POST(self, request, room_id): if event_dict["type"] == EventTypes.Member: membership = event_dict["content"].get("membership", None) event_id, _ = await self.room_member_handler.update_membership( - requester, + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), target=UserID.from_string(event_dict["state_key"]), room_id=room_id, action=membership, @@ -434,7 +487,9 @@ async def on_POST(self, request, room_id): event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), event_dict, outlier=True, prev_event_ids=[fake_prev_event_id], @@ -449,37 +504,73 @@ async def on_POST(self, request, room_id): events_to_create = body["events"] - # If provided, connect the chunk to the last insertion point - # The chunk ID passed in comes from the chunk_id in the - # "insertion" event from the previous chunk. + prev_event_ids = prev_events_from_query + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None if chunk_id_from_query: - last_event_in_chunk = events_to_create[-1] - last_event_in_chunk["content"][ - EventContentFields.MSC2716_CHUNK_ID - ] = chunk_id_from_query + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] - # Add an "insertion" event to the start of each chunk (next to the oldest + # Connect this current chunk to the insertion event from the previous chunk + last_event_in_chunk["content"][ + EventContentFields.MSC2716_CHUNK_ID + ] = chunk_id_to_connect_to + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time # event in the chunk) so the next chunk can be connected to this one. - next_chunk_id = random_string(64) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": requester.user.to_string(), - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, # Since the insertion event is put at the start of the chunk, - # where the oldest event is, copy the origin_server_ts from + # where the oldest-in-time event is, copy the origin_server_ts from # the first event we're inserting - "origin_server_ts": events_to_create[0]["origin_server_ts"], - } + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) # Prepend the insertion event to the start of the chunk events_to_create = [insertion_event] + events_to_create - inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query) - event_ids = [] - prev_event_ids = prev_events_from_query events_to_persist = [] for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) @@ -498,7 +589,9 @@ async def on_POST(self, request, room_id): } event, context = await self.event_creation_handler.create_event( - requester, + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), event_dict, prev_event_ids=event_dict.get("prev_events"), auth_event_ids=auth_event_ids, @@ -528,15 +621,23 @@ async def on_POST(self, request, room_id): # where topological_ordering is just depth. for (event, context) in reversed(events_to_persist): ev = await self.event_creation_handler.handle_new_client_event( - requester=requester, + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), event=event, context=context, ) + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + return 200, { "state_events": auth_event_ids, "events": event_ids, - "next_chunk_id": next_chunk_id, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], } def on_GET(self, request, room_id): @@ -682,7 +783,7 @@ async def on_POST(self, request): server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) - limit = int(content.get("limit", 100)) # type: Optional[int] + limit: Optional[int] = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -828,9 +929,7 @@ async def on_GET(self, request, room_id): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -943,9 +1042,7 @@ async def on_GET(self, request, room_id, event_id): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2d1ad3d3fbf0..3ebe40186153 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -14,7 +14,7 @@ import logging -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet @@ -92,11 +92,6 @@ def __init__(self, hs): ) async def on_POST(self, request): - if not self.account_validity_renew_by_email_enabled: - raise AuthError( - 403, "Account renewal via email is disabled on this server." - ) - requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index f8dcee603ce8..d537d811d8a1 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -59,7 +59,7 @@ async def _put(self, request, message_type, txn_id): requester, message_type, content["messages"] ) - response = (200, {}) # type: Tuple[int, dict] + response: Tuple[int, dict] = (200, {}) return response diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index e52570cd8ea9..4282e2b228b1 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -117,7 +117,7 @@ async def _async_render_GET(self, request): has_consented = False public_version = username == "" if not public_version: - args = request.args # type: Dict[bytes, List[bytes]] + args: Dict[bytes, List[bytes]] = request.args userhmac_bytes = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac_bytes) @@ -154,7 +154,7 @@ async def _async_render_POST(self, request): """ version = parse_string(request, "v", required=True) username = parse_string(request, "u", required=True) - args = request.args # type: Dict[bytes, List[bytes]] + args: Dict[bytes, List[bytes]] = request.args userhmac = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index d56a1ae48279..63a40b18528f 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -97,7 +97,7 @@ def __init__(self, hs): async def _async_render_GET(self, request): if len(request.postpath) == 1: (server,) = request.postpath - query = {server.decode("ascii"): {}} # type: dict + query: dict = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") @@ -141,7 +141,7 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False): time_now_ms = self.clock.time_msec() # Note that the value is unused. - cache_misses = {} # type: Dict[str, Dict[str, int]] + cache_misses: Dict[str, Dict[str, int]] = {} for (server_name, key_id, _), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py index d20186bbd095..3dd16d4bb542 100644 --- a/synapse/rest/media/v1/__init__.py +++ b/synapse/rest/media/v1/__init__.py @@ -17,7 +17,7 @@ # check for JPEG support. try: PIL.Image._getdecoder("rgb", "jpeg", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder jpeg not available"): raise Exception( "FATAL: jpeg codec not supported. Install pillow correctly! " @@ -32,7 +32,7 @@ # check for PNG support. try: PIL.Image._getdecoder("rgb", "zip", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder zip not available"): raise Exception( "FATAL: zip codec not supported. Install pillow correctly! " diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 0fb4cd81f1a2..90364ebcf70d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -49,7 +49,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # The type on postpath seems incorrect in Twisted 21.2.0. - postpath = request.postpath # type: List[bytes] # type: ignore + postpath: List[bytes] = request.postpath # type: ignore assert postpath # This allows users to append e.g. /test.png to the URL. Useful for diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 21c43c340c3b..4f702f890c1c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -78,16 +78,16 @@ def __init__(self, hs: "HomeServer"): Thumbnailer.set_limits(self.max_image_pixels) - self.primary_base_path = hs.config.media_store_path # type: str - self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths + self.primary_base_path: str = hs.config.media_store_path + self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] - self.recently_accessed_locals = set() # type: Set[str] + self.recently_accessed_remotes: Set[Tuple[str, str]] = set() + self.recently_accessed_locals: Set[str] = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -711,7 +711,7 @@ async def _generate_thumbnails( # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. - thumbnails = {} # type: Dict[Tuple[int, int, str], str] + thumbnails: Dict[Tuple[int, int, str], str] = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index c7fd97c46c9d..56cdc1b4edd3 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -191,7 +191,7 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: for provider in self.storage_providers: for path in paths: - res = await provider.fetch(path, file_info) # type: Any + res: Any = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -233,7 +233,7 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: os.makedirs(dirname) for provider in self.storage_providers: - res = await provider.fetch(path, file_info) # type: Any + res: Any = await provider.fetch(path, file_info) if res: with res: consumer = BackgroundFileConsumer( diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0adfb1a70f73..8e7fead3a286 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -169,12 +169,12 @@ def __init__( # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata - self._cache = ExpiringCache( + self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=ONE_HOUR, - ) # type: ExpiringCache[str, ObservableDeferred] + ) if self._worker_run_media_background_jobs: self._cleaner_loop = self.clock.looping_call( @@ -460,7 +460,7 @@ async def _download_url(self, url: str, user: str) -> Dict[str, Any]: file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url # type: Optional[str] + url_to_download: Optional[str] = url oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # "og:video:height" : "720", # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - og = {} # type: Dict[str, Optional[str]] + og: Dict[str, Optional[str]] = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): if "content" in tag.attrib: # if we've got more than 50 tags, someone is taking the piss diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 62dc4aae2d69..146adca8f1bf 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -61,11 +61,11 @@ async def _async_render_POST(self, request: SynapseRequest) -> None: errcode=Codes.TOO_LARGE, ) - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore upload_name_bytes = parse_bytes_from_args(args, "filename") if upload_name_bytes: try: - upload_name = upload_name_bytes.decode("utf8") # type: Optional[str] + upload_name: Optional[str] = upload_name_bytes.decode("utf8") except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 @@ -89,7 +89,7 @@ async def _async_render_POST(self, request: SynapseRequest) -> None: # TODO(markjh): parse content-dispostion try: - content = request.content # type: IO # type: ignore + content: IO = request.content # type: ignore content_uri = await self.media_repo.create_content( media_type, upload_name, content, content_length, requester.user ) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 9b002cc15ebe..ab24ec0a8e68 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -118,9 +118,9 @@ async def _async_render_POST(self, request: SynapseRequest): use_display_name = parse_boolean(request, "use_display_name", default=False) try: - emails_to_use = [ + emails_to_use: List[str] = [ val.decode("utf-8") for val in request.args.get(b"use_email", []) - ] # type: List[str] + ] except ValueError: raise SynapseError(400, "Query parameter use_email must be utf-8") except SynapseError as e: diff --git a/synapse/server.py b/synapse/server.py index 2c27d2a7e8bb..095dba9ad038 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -247,15 +247,15 @@ def __init__( # the key we use to sign events and requests self.signing_key = config.key.signing_key[0] self.config = config - self._listening_services = [] # type: List[twisted.internet.tcp.Port] - self.start_time = None # type: Optional[int] + self._listening_services: List[twisted.internet.tcp.Port] = [] + self.start_time: Optional[int] = None self._instance_id = random_string(5) self._instance_name = config.worker.instance_name self.version_string = version_string - self.datastores = None # type: Optional[Databases] + self.datastores: Optional[Databases] = None self._module_web_resources: Dict[str, IResource] = {} self._module_web_resources_consumed = False diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e65f6f88fe74..4e0f814035b3 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -34,7 +34,7 @@ def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastore() - self._users_in_progress = set() # type: Set[str] + self._users_in_progress: Set[str] = set() self._current_consent_version = hs.config.user_consent_version self._server_notice_content = hs.config.user_consent_server_notice_content diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index e4b0bc5c7235..073b0d754fd6 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -205,7 +205,7 @@ async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str # The user has yet to join the server notices room pass - referenced_events = [] # type: List[str] + referenced_events: List[str] = [] if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index c875b15b326d..cdf0973d057e 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -32,10 +32,12 @@ class ServerNoticesSender(WorkerServerNoticesSender): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._server_notices = ( + self._server_notices: Iterable[ + Union[ConsentServerNotices, ResourceLimitsServerNotices] + ] = ( ConsentServerNotices(hs), ResourceLimitsServerNotices(hs), - ) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]] + ) async def on_user_syncing(self, user_id: str) -> None: """Called when the user performs a sync operation. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1770f620e59..6223daf5228d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -309,9 +309,9 @@ async def compute_event_context( if old_state: # if we're given the state before the event, then we use that - state_ids_before_event = { + state_ids_before_event: StateMap[str] = { (s.type, s.state_key): s.event_id for s in old_state - } # type: StateMap[str] + } state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -513,23 +513,25 @@ def __init__(self, hs): self.resolve_linearizer = Linearizer(name="state_resolve_lock") # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = ExpiringCache( + self._state_cache: ExpiringCache[ + FrozenSet[int], _StateCacheEntry + ] = ExpiringCache( cache_name="state_cache", clock=self.clock, max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, - ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry] + ) # # stuff for tracking time spent on state-res by room # # tracks the amount of work done on state res per room - self._state_res_metrics = defaultdict( + self._state_res_metrics: DefaultDict[str, _StateResMetrics] = defaultdict( _StateResMetrics - ) # type: DefaultDict[str, _StateResMetrics] + ) self.clock.looping_call(self._report_metrics, 120 * 1000) @@ -700,9 +702,9 @@ def _report_biggest( items = self._state_res_metrics.items() # log the N biggest rooms - biggest = heapq.nlargest( + biggest: List[Tuple[str, _StateResMetrics]] = heapq.nlargest( n_to_log, items, key=lambda i: extract_key(i[1]) - ) # type: List[Tuple[str, _StateResMetrics]] + ) metrics_logger.debug( "%i biggest rooms for state-res by %s: %s", len(biggest), @@ -754,7 +756,7 @@ def _make_state_cache_entry( # failing that, look for the closest match. prev_group = None - delta_ids = None # type: Optional[StateMap[str]] + delta_ids: Optional[StateMap[str]] = None for old_group, old_state in state_groups_ids.items(): n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 318e9988130f..267193cedf73 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -159,7 +159,7 @@ def _seperate( """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} # type: MutableStateMap[Set[str]] + conflicted_state: MutableStateMap[Set[str]] = {} for state_set in state_set_iterator: for key, value in state_set.items(): diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 008644cd9862..e66e6571c8d9 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -276,7 +276,7 @@ async def _get_auth_chain_difference( # event IDs if they appear in the `event_map`. This is the intersection of # the event's auth chain with the events in the `event_map` *plus* their # auth event IDs. - events_to_auth_chain = {} # type: Dict[str, Set[str]] + events_to_auth_chain: Dict[str, Set[str]] = {} for event in event_map.values(): chain = {event.event_id} events_to_auth_chain[event.event_id] = chain @@ -301,17 +301,17 @@ async def _get_auth_chain_difference( # ((type, state_key)->event_id) mappings; and (b) we have stripped out # unpersisted events and replaced them with the persisted events in # their auth chain. - state_sets_ids = [] # type: List[Set[str]] + state_sets_ids: List[Set[str]] = [] # For each state set, the unpersisted event IDs reachable (by their auth # chain) from the events in that set. - unpersisted_set_ids = [] # type: List[Set[str]] + unpersisted_set_ids: List[Set[str]] = [] for state_set in state_sets: - set_ids = set() # type: Set[str] + set_ids: Set[str] = set() state_sets_ids.append(set_ids) - unpersisted_ids = set() # type: Set[str] + unpersisted_ids: Set[str] = set() unpersisted_set_ids.append(unpersisted_ids) for event_id in state_set.values(): @@ -334,7 +334,7 @@ async def _get_auth_chain_difference( union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - difference_from_event_map = union - intersection # type: Collection[str] + difference_from_event_map: Collection[str] = union - intersection else: difference_from_event_map = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] @@ -458,7 +458,7 @@ async def _reverse_topological_power_sort( The sorted list """ - graph = {} # type: Dict[str, Set[str]] + graph: Dict[str, Set[str]] = {} for idx, event_id in enumerate(event_ids, start=1): await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff @@ -657,7 +657,7 @@ async def _get_mainline_depth_for_event( """ room_id = event.room_id - tmp_event = event # type: Optional[EventBase] + tmp_event: Optional[EventBase] = event # We do an iterative search, replacing `event with the power level in its # auth events (if any) @@ -767,7 +767,7 @@ def lexicographical_topological_sort( # outgoing edges, c.f. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm outdegree_map = graph - reverse_graph = {} # type: Dict[str, Set[str]] + reverse_graph: Dict[str, Set[str]] = {} # Lists of nodes with zero out degree. Is actually a tuple of # `(key(node), node)` so that sorting does the right thing diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 142787fdfd1c..82b31d24f1fa 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -92,14 +92,12 @@ def __init__(self, hs: "HomeServer", database: "DatabasePool"): self.db_pool = database # if a background update is currently running, its name. - self._current_background_update = None # type: Optional[str] - - self._background_update_performance = ( - {} - ) # type: Dict[str, BackgroundUpdatePerformance] - self._background_update_handlers = ( - {} - ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]] + self._current_background_update: Optional[str] = None + + self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} + self._background_update_handlers: Dict[ + str, Callable[[JsonDict, int], Awaitable[int]] + ] = {} self._all_done = False def start_doing_background_updates(self) -> None: @@ -411,7 +409,7 @@ def create_index_sqlite(conn: Connection) -> None: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner = create_index_psql # type: Optional[Callable[[Connection], None]] + runner: Optional[Callable[[Connection], None]] = create_index_psql elif psql_only: runner = None else: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 33c42cf95a88..ccf9ac51ef8c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -670,8 +670,8 @@ async def runInteraction( Returns: The result of func """ - after_callbacks = [] # type: List[_CallbackListEntry] - exception_callbacks = [] # type: List[_CallbackListEntry] + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] if not current_context(): logger.warning("Starting db txn '%s' from sentinel context", desc) @@ -907,7 +907,7 @@ def simple_insert_many_txn( # The sort is to ensure that we don't rely on dictionary iteration # order. keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i) ) for k in keys: @@ -1090,7 +1090,7 @@ def _getwhere(key): return False # We didn't find any existing rows, so insert a new one - allvalues = {} # type: Dict[str, Any] + allvalues: Dict[str, Any] = {} allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) @@ -1121,7 +1121,7 @@ def simple_upsert_txn_native_upsert( values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting """ - allvalues = {} # type: Dict[str, Any] + allvalues: Dict[str, Any] = {} allvalues.update(keyvalues) allvalues.update(insertion_values or {}) @@ -1257,7 +1257,7 @@ def simple_upsert_many_txn_native_upsert( value_values: A list of each row's value column values. Ignored if value_names is empty. """ - allnames = [] # type: List[str] + allnames: List[str] = [] allnames.extend(key_names) allnames.extend(value_names) @@ -1566,7 +1566,7 @@ async def simple_select_many_batch( """ keyvalues = keyvalues or {} - results = [] # type: List[Dict[str, Any]] + results: List[Dict[str, Any]] = [] if not iterable: return results @@ -1978,7 +1978,7 @@ def simple_select_list_paginate_txn( raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else "" - arg_list = [] # type: List[Any] + arg_list: List[Any] = [] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) arg_list += list(filters.values()) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 9f182c2a890f..e2d1b758bddc 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -48,9 +48,7 @@ def _make_exclusive_regex( ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_pattern = re.compile( - exclusive_user_regex - ) # type: Optional[Pattern] + exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex) else: # We handle this case specially otherwise the constructed regex # will always match diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 50e7ddd7355b..c55508867d07 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -203,9 +203,7 @@ def delete_messages_for_device_txn(txn): "delete_messages_for_device", delete_messages_for_device_txn ) - log_kv( - {"message": "deleted {} messages for device".format(count), "count": count} - ) + log_kv({"message": f"deleted {count} messages for device", "count": count}) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 0e3dd4e9cac9..1edc96042bbe 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -21,6 +21,7 @@ from twisted.enterprise.adbapi import Connection +from synapse.api.constants import DeviceKeyAlgorithms from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, make_in_list_sql_clause @@ -247,7 +248,7 @@ def _get_e2e_device_keys_txn( txn.execute(sql, query_params) - result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] + result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {} for (user_id, device_id, display_name, key_json) in txn: if include_deleted_devices: deleted_devices.remove((user_id, device_id)) @@ -381,9 +382,15 @@ def _count_e2e_one_time_keys(txn): " GROUP BY algorithm" ) txn.execute(sql, (user_id, device_id)) - result = {} + + # Initially set the key count to 0. This ensures that the client will always + # receive *some count*, even if it's 0. + result = {DeviceKeyAlgorithms.SIGNED_CURVE25519: 0} + + # Override entries with the count of any keys we pulled from the database for algorithm, key_count in txn: result[algorithm] = key_count + return result return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c4474df9758c..d39368c20ed6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -62,9 +62,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) # Cache of event ID to list of auth event IDs and their depths. - self._event_auth_cache = LruCache( + self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache( 500000, "_event_auth_cache", size_callback=len - ) # type: LruCache[str, List[Tuple[str, int]]] + ) self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) @@ -137,10 +137,10 @@ def _get_auth_chain_ids_using_cover_index_txn( initial_events = set(event_ids) # All the events that we've found that are reachable from the events. - seen_events = set() # type: Set[str] + seen_events: Set[str] = set() # A map from chain ID to max sequence number of the given events. - event_chains = {} # type: Dict[int, int] + event_chains: Dict[int, int] = {} sql = """ SELECT event_id, chain_id, sequence_number @@ -182,7 +182,7 @@ def _get_auth_chain_ids_using_cover_index_txn( """ # A map from chain ID to max sequence number *reachable* from any event ID. - chains = {} # type: Dict[int, int] + chains: Dict[int, int] = {} # Add all linked chains reachable from initial set of chains. for batch in batch_iter(event_chains, 1000): @@ -353,14 +353,14 @@ def _get_auth_chain_difference_using_cover_index_txn( initial_events = set(state_sets[0]).union(*state_sets[1:]) # Map from event_id -> (chain ID, seq no) - chain_info = {} # type: Dict[str, Tuple[int, int]] + chain_info: Dict[str, Tuple[int, int]] = {} # Map from chain ID -> seq no -> event Id - chain_to_event = {} # type: Dict[int, Dict[int, str]] + chain_to_event: Dict[int, Dict[int, str]] = {} # All the chains that we've found that are reachable from the state # sets. - seen_chains = set() # type: Set[int] + seen_chains: Set[int] = set() sql = """ SELECT event_id, chain_id, sequence_number @@ -392,9 +392,9 @@ def _get_auth_chain_difference_using_cover_index_txn( # Corresponds to `state_sets`, except as a map from chain ID to max # sequence number reachable from the state set. - set_to_chain = [] # type: List[Dict[int, int]] + set_to_chain: List[Dict[int, int]] = [] for state_set in state_sets: - chains = {} # type: Dict[int, int] + chains: Dict[int, int] = {} set_to_chain.append(chains) for event_id in state_set: @@ -446,7 +446,7 @@ def _get_auth_chain_difference_using_cover_index_txn( # Mapping from chain ID to the range of sequence numbers that should be # pulled from the database. - chain_to_gap = {} # type: Dict[int, Tuple[int, int]] + chain_to_gap: Dict[int, Tuple[int, int]] = {} for chain_id in seen_chains: min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) @@ -555,7 +555,7 @@ def _get_auth_chain_difference_txn( } # The sorted list of events whose auth chains we should walk. - search = [] # type: List[Tuple[int, str]] + search: List[Tuple[int, str]] = [] # We need to get the depth of the initial events for sorting purposes. sql = """ @@ -578,7 +578,7 @@ def _get_auth_chain_difference_txn( search.sort() # Map from event to its auth events - event_to_auth_events = {} # type: Dict[str, Set[str]] + event_to_auth_events: Dict[str, Set[str]] = {} base_sql = """ SELECT a.event_id, auth_id, depth @@ -1230,7 +1230,9 @@ def _get_stats_for_federation_staging_txn(txn): "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging" ) - (age,) = txn.fetchone() + (received_ts,) = txn.fetchone() + + age = self._clock.time_msec() - received_ts return count, age diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index d1237c65cc85..55caa6bbe78c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -759,7 +759,7 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] + summaries: Dict[Tuple[str, str], _EventPushSummary] = {} for row in txn: summaries[(row[0], row[1])] = _EventPushSummary( unread_count=row[2], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 897fa06639cb..a396a201d4fc 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -109,10 +109,8 @@ def __init__( # Ideally we'd move these ID gens here, unfortunately some other ID # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = ( - self.store._backfill_id_gen - ) # type: MultiWriterIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator + self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen # This should only exist on instances that are configured to write assert ( @@ -221,7 +219,7 @@ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[st Returns: Filtered event ids """ - results = [] # type: List[str] + results: List[str] = [] def _get_events_which_are_prevs_txn(txn, batch): sql = """ @@ -508,7 +506,7 @@ def _add_chain_cover_index( """ # Map from event ID to chain ID/sequence number. - chain_map = {} # type: Dict[str, Tuple[int, int]] + chain_map: Dict[str, Tuple[int, int]] = {} # Set of event IDs to calculate chain ID/seq numbers for. events_to_calc_chain_id_for = set(event_to_room_id) @@ -817,8 +815,8 @@ def _allocate_chain_ids( # new chain if the sequence number has already been allocated. # - existing_chains = set() # type: Set[int] - tree = [] # type: List[Tuple[str, Optional[str]]] + existing_chains: Set[int] = set() + tree: List[Tuple[str, Optional[str]]] = [] # We need to do this in a topologically sorted order as we want to # generate chain IDs/sequence numbers of an event's auth events before @@ -848,7 +846,7 @@ def _allocate_chain_ids( ) txn.execute(sql % (clause,), args) - chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn} # Allocate the new events chain ID/sequence numbers. # @@ -858,8 +856,8 @@ def _allocate_chain_ids( # number of new chain IDs in one call, replacing all temporary # objects with real allocated chain IDs. - unallocated_chain_ids = set() # type: Set[object] - new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + unallocated_chain_ids: Set[object] = set() + new_chain_tuples: Dict[str, Tuple[Any, int]] = {} for event_id, auth_event_id in tree: # If we reference an auth_event_id we fetch the allocated chain ID, # either from the existing `chain_map` or the newly generated @@ -870,7 +868,7 @@ def _allocate_chain_ids( if not existing_chain_id: existing_chain_id = chain_map[auth_event_id] - new_chain_tuple = None # type: Optional[Tuple[Any, int]] + new_chain_tuple: Optional[Tuple[Any, int]] = None if existing_chain_id: # We found a chain ID/sequence number candidate, check its # not already taken. @@ -897,9 +895,9 @@ def _allocate_chain_ids( ) # Map from potentially temporary chain ID to real chain ID - chain_id_to_allocated_map = dict( + chain_id_to_allocated_map: Dict[Any, int] = dict( zip(unallocated_chain_ids, newly_allocated_chain_ids) - ) # type: Dict[Any, int] + ) chain_id_to_allocated_map.update((c, c) for c in existing_chains) return { @@ -1175,9 +1173,9 @@ def _filter_events_and_contexts_for_duplicates( Returns: list[(EventBase, EventContext)]: filtered list """ - new_events_and_contexts = ( - OrderedDict() - ) # type: OrderedDict[str, Tuple[EventBase, EventContext]] + new_events_and_contexts: OrderedDict[ + str, Tuple[EventBase, EventContext] + ] = OrderedDict() for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) if prev_event_context: @@ -1205,7 +1203,7 @@ def _update_room_depths_txn( we are persisting backfilled (bool): True if the events were backfilled """ - depth_updates = {} # type: Dict[str, int] + depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) @@ -1580,11 +1578,11 @@ def _store_redaction(self, txn, event): # invalidate the cache for the redacted event txn.call_after(self.store._invalidate_get_event_cache, event.redacts) - self.db_pool.simple_insert_txn( + self.db_pool.simple_upsert_txn( txn, table="redactions", + keyvalues={"event_id": event.event_id}, values={ - "event_id": event.event_id, "redacts": event.redacts, "received_ts": self._clock.time_msec(), }, @@ -1885,7 +1883,7 @@ def _set_push_actions_for_event_and_users_txn( ), ) - room_to_event_ids = {} # type: Dict[str, List[str]] + room_to_event_ids: Dict[str, List[str]] = {} for e, _ in events_and_contexts: room_to_event_ids.setdefault(e.room_id, []).append(e.event_id) @@ -2012,10 +2010,6 @@ def _update_backward_extremeties(self, txn, events): Forward extremities are handled when we first start persisting the events. """ - events_by_room = {} # type: Dict[str, List[EventBase]] - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 29f33bac55bc..6fcb2b8353aa 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -960,9 +960,9 @@ def _calculate_chain_cover_txn( event_to_types = {row[0]: (row[1], row[2]) for row in rows} # Calculate the new last position we've processed up to. - new_last_depth = rows[-1][3] if rows else last_depth # type: int - new_last_stream = rows[-1][4] if rows else last_stream # type: int - new_last_room_id = rows[-1][5] if rows else "" # type: str + new_last_depth: int = rows[-1][3] if rows else last_depth + new_last_stream: int = rows[-1][4] if rows else last_stream + new_last_room_id: str = rows[-1][5] if rows else "" # Map from room_id to last depth/stream_ordering processed for the room, # excluding the last room (which we're likely still processing). We also @@ -989,7 +989,7 @@ def _calculate_chain_cover_txn( retcols=("event_id", "auth_id"), ) - event_to_auth_chain = {} # type: Dict[str, List[str]] + event_to_auth_chain: Dict[str, List[str]] = {} for row in auth_events: event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 403a5ddaba3e..3c86adab5650 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1365,10 +1365,10 @@ def get_deltas_for_stream_id_txn(txn, stream_id): # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows = await self.db_pool.runInteraction( + rows: List[Tuple] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, - ) # type: List[Tuple] + ) # if we've got fewer rows than the limit, we're good if len(rows) < target_row_count: @@ -1469,7 +1469,7 @@ async def get_already_persisted_events( """ mapping = {} - txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str] + txn_id_to_event: Dict[Tuple[str, int, str], str] = {} for event in events: token_id = getattr(event.internal_metadata, "token_id", None) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 66ad363bfb29..e70d3649ff5a 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -27,8 +27,11 @@ _DEFAULT_CATEGORY_ID = "" _DEFAULT_ROLE_ID = "" + # A room in a group. -_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool}) +class _RoomInGroup(TypedDict): + room_id: str + is_public: bool class GroupServerWorkerStore(SQLBaseStore): @@ -92,6 +95,7 @@ async def get_rooms_in_group( "is_public": False # Whether this is a public room or not } """ + # TODO: Pagination def _get_rooms_in_group_txn(txn): diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 774861074c8a..3d1dff660bd9 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -78,7 +78,11 @@ async def _on_shutdown(self) -> None: """Called when the server is shutting down""" logger.info("Dropping held locks due to shutdown") - for (lock_name, lock_key), token in self._live_tokens.items(): + # We need to take a copy of the tokens dict as dropping the locks will + # cause the dictionary to change. + tokens = dict(self._live_tokens) + + for (lock_name, lock_key), token in tokens.items(): await self._drop_lock(lock_name, lock_key, token) logger.info("Dropped locks due to shutdown") diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index c3f551d37756..dc0bbc56ac8b 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -316,11 +316,140 @@ def _count_r30_users(txn): return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + async def count_r30v2_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as users that: + - Appear more than once in the past 60 days + - Have more than 30 days between the most and least recent appearances that + occurred in the past 60 days. + + (This is the second version of this metric, hence R30'v2') + + Returns: + A mapping from client type to the number of 30-day retained users for that client. + + The dict keys are: + - "all" (a combined number of users across any and all clients) + - "android" (Element Android) + - "ios" (Element iOS) + - "electron" (Element Desktop) + - "web" (any web application -- it's not possible to distinguish Element Web here) + """ + + def _count_r30v2_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs + one_day_from_now_in_secs = now + 86400 + + # This is the 'per-platform' count. + sql = """ + SELECT + client_type, + count(client_type) + FROM + ( + SELECT + user_id, + CASE + WHEN + LOWER(user_agent) LIKE '%%riot%%' OR + LOWER(user_agent) LIKE '%%element%%' + THEN CASE + WHEN + LOWER(user_agent) LIKE '%%electron%%' + THEN 'electron' + WHEN + LOWER(user_agent) LIKE '%%android%%' + THEN 'android' + WHEN + LOWER(user_agent) LIKE '%%ios%%' + THEN 'ios' + ELSE 'unknown' + END + WHEN + LOWER(user_agent) LIKE '%%mozilla%%' OR + LOWER(user_agent) LIKE '%%gecko%%' + THEN 'web' + ELSE 'unknown' + END as client_type + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id, + client_type + HAVING + max(timestamp) - min(timestamp) > ? + ) AS temp + GROUP BY + client_type + ; + """ + + # We initialise all the client types to zero, so we get an explicit + # zero if they don't appear in the query results + results = {"ios": 0, "android": 0, "web": 0, "electron": 0} + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + + for row in txn: + if row[0] == "unknown": + continue + results[row[0]] = row[1] + + # This is the 'all users' count. + sql = """ + SELECT COUNT(*) FROM ( + SELECT + 1 + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id + HAVING + max(timestamp) - min(timestamp) > ? + ) AS r30_users + """ + + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + row = txn.fetchone() + if row is None: + results["all"] = 0 + else: + results["all"] = row[0] + + return results + + return await self.db_pool.runInteraction( + "count_r30v2_users", _count_r30v2_users + ) + def _get_start_of_day(self): """ Returns millisecond unixtime for start of UTC day. """ - now = time.gmtime() + now = time.gmtime(self._clock.time()) today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 @@ -352,7 +481,7 @@ def _generate_user_daily_visits(txn): ) udv ON u.user_id = udv.user_id AND u.device_id=udv.device_id INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? + WHERE ? <= last_seen AND last_seen < ? AND udv.timestamp IS NULL AND users.is_guest=0 AND users.appservice_id IS NULL GROUP BY u.user_id, u.device_id diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 7fb7780d0ffd..664c65dac5a6 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -115,7 +115,7 @@ def _purge_history_txn( logger.info("[purge] looking for events to delete") should_delete_expr = "state_key IS NULL" - should_delete_params = () # type: Tuple[Any, ...] + should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" @@ -215,6 +215,7 @@ def _purge_history_txn( "event_relations", "event_search", "rejections", + "redactions", ): logger.info("[purge] removing events from %s", table) @@ -392,7 +393,6 @@ def _purge_room_txn(self, txn, room_id: str) -> List[int]: "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index db5217633796..a7fb8cd84889 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -79,9 +79,9 @@ def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) # type: Union[StreamIdGenerator, SlavedIdTracker] + self._push_rules_stream_id_gen: Union[ + StreamIdGenerator, SlavedIdTracker + ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e31c5864acef..6ad1a0cf7fbb 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1744,7 +1744,7 @@ def f(txn): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] # type: List[Union[str, int]] + values: List[Union[str, int]] = [v for _, v in items] # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where # clause and values before we handle that. This seems to be only used in the "set password" handler. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 9f0d64a32542..6ddafe543412 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -25,6 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore +from synapse.storage.types import Cursor from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -1022,10 +1023,22 @@ def get_rooms_for_retention_period_in_range_txn(txn): ) -class RoomBackgroundUpdateStore(SQLBaseStore): +class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" + POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" + REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" + + +_REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( + "DROP TRIGGER populate_min_depth2_trigger ON room_depth", + "DROP FUNCTION populate_min_depth2()", + "ALTER TABLE room_depth DROP COLUMN min_depth", + "ALTER TABLE room_depth RENAME COLUMN min_depth2 TO min_depth", +) + +class RoomBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -1037,15 +1050,25 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) self.db_pool.updates.register_background_update_handler( - self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, + _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, self._remove_tombstoned_rooms_from_directory, ) self.db_pool.updates.register_background_update_handler( - self.ADD_ROOMS_ROOM_VERSION_COLUMN, + _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN, self._background_add_rooms_room_version_column, ) + # BG updates to change the type of room_depth.min_depth + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2, + self._background_populate_room_depth_min_depth2, + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH, + self._background_replace_room_depth_min_depth, + ) + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. @@ -1164,7 +1187,9 @@ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): new_last_room_id = room_id self.db_pool.updates._background_update_progress_txn( - txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} + txn, + _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN, + {"room_id": new_last_room_id}, ) return False @@ -1176,7 +1201,7 @@ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): if end: await self.db_pool.updates._end_background_update( - self.ADD_ROOMS_ROOM_VERSION_COLUMN + _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN ) return batch_size @@ -1215,7 +1240,7 @@ def _get_rooms(txn): if not rooms: await self.db_pool.updates._end_background_update( - self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE + _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE ) return 0 @@ -1224,7 +1249,7 @@ def _get_rooms(txn): await self.set_room_is_public(room_id, False) await self.db_pool.updates._background_update_progress( - self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} + _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} ) return len(rooms) @@ -1268,6 +1293,71 @@ async def has_auth_chain_index(self, room_id: str) -> bool: return max_ordering is None + async def _background_populate_room_depth_min_depth2( + self, progress: JsonDict, batch_size: int + ) -> int: + """Populate room_depth.min_depth2 + + This is to deal with the fact that min_depth was initially created as a + 32-bit integer field. + """ + + def process(txn: Cursor) -> int: + last_room = progress.get("last_room", "") + txn.execute( + """ + UPDATE room_depth SET min_depth2=min_depth + WHERE room_id IN ( + SELECT room_id FROM room_depth WHERE room_id > ? + ORDER BY room_id LIMIT ? + ) + RETURNING room_id; + """, + (last_room, batch_size), + ) + row_count = txn.rowcount + if row_count == 0: + return 0 + last_room = max(row[0] for row in txn) + logger.info("populated room_depth up to %s", last_room) + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2, + {"last_room": last_room}, + ) + return row_count + + result = await self.db_pool.runInteraction( + "_background_populate_min_depth2", process + ) + + if result != 0: + return result + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2 + ) + return 0 + + async def _background_replace_room_depth_min_depth( + self, progress: JsonDict, batch_size: int + ) -> int: + """Drop the old 'min_depth' column and rename 'min_depth2' into its place.""" + + def process(txn: Cursor) -> None: + for sql in _REPLACE_ROOM_DEPTH_SQL_COMMANDS: + logger.info("completing room_depth migration: %s", sql) + txn.execute(sql) + + await self.db_pool.runInteraction("_background_replace_room_depth", process) + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH, + ) + + return 0 + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2796354a1f58..68f1b40ea693 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -649,7 +649,7 @@ async def _get_joined_users_from_context( event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) - users_in_room.update((row for row in event_to_memberships.values() if row)) + users_in_room.update(row for row in event_to_memberships.values() if row) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -703,13 +703,22 @@ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: + return await self._check_host_room_membership(room_id, host, Membership.JOIN) + + @cached(max_entries=10000) + async def is_host_invited(self, room_id: str, host: str) -> bool: + return await self._check_host_room_membership(room_id, host, Membership.INVITE) + + async def _check_host_room_membership( + self, room_id: str, host: str, membership: str + ) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ SELECT state_key FROM current_state_events AS c INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' + WHERE m.membership = ? AND type = 'm.room.member' AND c.room_id = ? AND state_key LIKE ? @@ -722,7 +731,7 @@ async def is_host_joined(self, room_id: str, host: str) -> bool: like_clause = "%:" + host rows = await self.db_pool.execute( - "is_host_joined", None, sql, room_id, like_clause + "is_host_joined", None, sql, membership, room_id, like_clause ) if not rows: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 82a18335091e..59d67c255b6d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -26,7 +26,6 @@ from synapse.api.errors import StoreError from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state_deltas import StateDeltasStore -from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -49,14 +48,6 @@ "user": ("joined_rooms",), } -# these fields are per-timeslice and so should be reset to 0 upon a new slice -# You can draw these stats on a histogram. -# Example: number of events sent locally during a time slice -PER_SLICE_FIELDS = { - "room": ("total_events", "total_event_bytes"), - "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), -} - TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} # these are the tables (& ID columns) which contain our actual subjects @@ -106,7 +97,6 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.server_name = hs.hostname self.clock = self.hs.get_clock() self.stats_enabled = hs.config.stats_enabled - self.stats_bucket_size = hs.config.stats_bucket_size self.stats_delta_processing_lock = DeferredLock() @@ -122,22 +112,6 @@ def __init__(self, database: DatabasePool, db_conn, hs): self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") self.db_pool.updates.register_noop_background_update("populate_stats_prepare") - def quantise_stats_time(self, ts): - """ - Quantises a timestamp to be a multiple of the bucket size. - - Args: - ts (int): the timestamp to quantise, in milliseconds since the Unix - Epoch - - Returns: - int: a timestamp which - - is divisible by the bucket size; - - is no later than `ts`; and - - is the largest such timestamp. - """ - return (ts // self.stats_bucket_size) * self.stats_bucket_size - async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. @@ -288,56 +262,6 @@ async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: desc="update_room_state", ) - async def get_statistics_for_subject( - self, stats_type: str, stats_id: str, start: str, size: int = 100 - ) -> List[dict]: - """ - Get statistics for a given subject. - - Args: - stats_type: The type of subject - stats_id: The ID of the subject (e.g. room_id or user_id) - start: Pagination start. Number of entries, not timestamp. - size: How many entries to return. - - Returns: - A list of dicts, where the dict has the keys of - ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". - """ - return await self.db_pool.runInteraction( - "get_statistics_for_subject", - self._get_statistics_for_subject_txn, - stats_type, - stats_id, - start, - size, - ) - - def _get_statistics_for_subject_txn( - self, txn, stats_type, stats_id, start, size=100 - ): - """ - Transaction-bound version of L{get_statistics_for_subject}. - """ - - table, id_col = TYPE_TO_TABLE[stats_type] - selected_columns = list( - ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] - ) - - slice_list = self.db_pool.simple_select_list_paginate_txn( - txn, - table + "_historical", - "end_ts", - start, - size, - retcols=selected_columns + ["bucket_size", "end_ts"], - keyvalues={id_col: stats_id}, - order_direction="DESC", - ) - - return slice_list - @cached() async def get_earliest_token_for_stats( self, stats_type: str, id: str @@ -451,14 +375,10 @@ def _update_stats_delta_txn( table, id_col = TYPE_TO_TABLE[stats_type] - quantised_ts = self.quantise_stats_time(int(ts)) - end_ts = quantised_ts + self.stats_bucket_size - # Lets be paranoid and check that all the given field names are known abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] - slice_field_names = PER_SLICE_FIELDS[stats_type] for field in chain(fields.keys(), absolute_field_overrides.keys()): - if field not in abs_field_names and field not in slice_field_names: + if field not in abs_field_names: # guard against potential SQL injection dodginess raise ValueError( "%s is not a recognised field" @@ -491,20 +411,6 @@ def _update_stats_delta_txn( additive_relatives=deltas_of_absolute_fields, ) - per_slice_additive_relatives = { - key: fields.get(key, 0) for key in slice_field_names - } - self._upsert_copy_from_table_with_additive_relatives_txn( - txn=txn, - into_table=table + "_historical", - keyvalues={id_col: stats_id}, - extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, - extra_dst_keyvalues={"end_ts": end_ts}, - additive_relatives=per_slice_additive_relatives, - src_table=table + "_current", - copy_columns=abs_field_names, - ) - def _upsert_with_additive_relatives_txn( self, txn, table, keyvalues, absolutes, additive_relatives ): @@ -528,7 +434,7 @@ def _upsert_with_additive_relatives_txn( ] relative_updates = [ - "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" + "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)" % {"table": table, "field": field} for field in additive_relatives.keys() ] @@ -568,205 +474,13 @@ def _upsert_with_additive_relatives_txn( self.db_pool.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): - current_row[key] += val + if current_row[key] is None: + current_row[key] = val + else: + current_row[key] += val current_row.update(absolutes) self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) - def _upsert_copy_from_table_with_additive_relatives_txn( - self, - txn, - into_table, - keyvalues, - extra_dst_keyvalues, - extra_dst_insvalues, - additive_relatives, - src_table, - copy_columns, - ): - """Updates the historic stats table with latest updates. - - This involves copying "absolute" fields from the `_current` table, and - adding relative fields to any existing values. - - Args: - txn: Transaction - into_table (str): The destination table to UPSERT the row into - keyvalues (dict[str, any]): Row-identifying key values - extra_dst_keyvalues (dict[str, any]): Additional keyvalues - for `into_table`. - extra_dst_insvalues (dict[str, any]): Additional values to insert - on new row creation for `into_table`. - additive_relatives (dict[str, any]): Fields that will be added onto - if existing row present. (Must be disjoint from copy_columns.) - src_table (str): The source table to copy from - copy_columns (iterable[str]): The list of columns to copy - """ - if self.database_engine.can_native_upsert: - ins_columns = chain( - keyvalues, - copy_columns, - additive_relatives, - extra_dst_keyvalues, - extra_dst_insvalues, - ) - sel_exprs = chain( - keyvalues, - copy_columns, - ( - "?" - for _ in chain( - additive_relatives, extra_dst_keyvalues, extra_dst_insvalues - ) - ), - ) - keyvalues_where = ("%s = ?" % f for f in keyvalues) - - sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) - sets_ar = ( - "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) - for f in additive_relatives - ) - - sql = """ - INSERT INTO %(into_table)s (%(ins_columns)s) - SELECT %(sel_exprs)s - FROM %(src_table)s - WHERE %(keyvalues_where)s - ON CONFLICT (%(keyvalues)s) - DO UPDATE SET %(sets)s - """ % { - "into_table": into_table, - "ins_columns": ", ".join(ins_columns), - "sel_exprs": ", ".join(sel_exprs), - "keyvalues_where": " AND ".join(keyvalues_where), - "src_table": src_table, - "keyvalues": ", ".join( - chain(keyvalues.keys(), extra_dst_keyvalues.keys()) - ), - "sets": ", ".join(chain(sets_cc, sets_ar)), - } - - qargs = list( - chain( - additive_relatives.values(), - extra_dst_keyvalues.values(), - extra_dst_insvalues.values(), - keyvalues.values(), - ) - ) - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, into_table) - src_row = self.db_pool.simple_select_one_txn( - txn, src_table, keyvalues, copy_columns - ) - all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.db_pool.simple_select_one_txn( - txn, - into_table, - keyvalues=all_dest_keyvalues, - retcols=list(chain(additive_relatives.keys(), copy_columns)), - allow_none=True, - ) - - if dest_current_row is None: - merged_dict = { - **keyvalues, - **extra_dst_keyvalues, - **extra_dst_insvalues, - **src_row, - **additive_relatives, - } - self.db_pool.simple_insert_txn(txn, into_table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - src_row[key] = dest_current_row[key] + val - self.db_pool.simple_update_txn( - txn, into_table, all_dest_keyvalues, src_row - ) - - async def get_changes_room_total_events_and_bytes( - self, min_pos: int, max_pos: int - ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: - """Fetches the counts of events in the given range of stream IDs. - - Args: - min_pos - max_pos - - Returns: - Mapping of room ID to field changes. - """ - - return await self.db_pool.runInteraction( - "stats_incremental_total_events_and_bytes", - self.get_changes_room_total_events_and_bytes_txn, - min_pos, - max_pos, - ) - - def get_changes_room_total_events_and_bytes_txn( - self, txn, low_pos: int, high_pos: int - ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: - """Gets the total_events and total_event_bytes counts for rooms and - senders, in a range of stream_orderings (including backfilled events). - - Args: - txn - low_pos: Low stream ordering - high_pos: High stream ordering - - Returns: - The room and user deltas for total_events/total_event_bytes in the - format of `stats_id` -> fields - """ - - if low_pos >= high_pos: - # nothing to do here. - return {}, {} - - if isinstance(self.database_engine, PostgresEngine): - new_bytes_expression = "OCTET_LENGTH(json)" - else: - new_bytes_expression = "LENGTH(CAST(json AS BLOB))" - - sql = """ - SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.room_id - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - room_deltas = { - room_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for room_id, new_events, new_bytes in txn - } - - sql = """ - SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.sender - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - user_deltas = { - user_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for user_id, new_events, new_bytes in txn - if self.hs.is_mine_id(user_id) - } - - return room_deltas, user_deltas - async def _calculate_and_set_initial_state_for_room( self, room_id: str ) -> Tuple[dict, dict, int]: @@ -893,6 +607,7 @@ def _fetch_current_state_stats(txn): "invited_members": membership_counts.get(Membership.INVITE, 0), "left_members": membership_counts.get(Membership.LEAVE, 0), "banned_members": membership_counts.get(Membership.BAN, 0), + "knocked_members": membership_counts.get(Membership.KNOCK, 0), "local_users_in_room": len(local_users_in_room), }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 7581c7d3ff92..959f13de479f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1085,9 +1085,7 @@ def _paginate_room_events_txn( # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and # then filtering the results. if from_token.topological is not None: - from_bound = ( - from_token.as_historical_tuple() - ) # type: Tuple[Optional[int], int] + from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() elif direction == "b": from_bound = ( None, @@ -1099,7 +1097,7 @@ def _paginate_room_events_txn( from_token.stream, ) - to_bound = None # type: Optional[Tuple[Optional[int], int]] + to_bound: Optional[Tuple[Optional[int], int]] = None if to_token: if to_token.topological is not None: to_bound = to_token.as_historical_tuple() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 1d62c6140f00..f93ff0a54536 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -42,7 +42,7 @@ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict] "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]] + tags_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_tags = tags_by_room.setdefault(row["room_id"], {}) room_tags[row["tag"]] = db_to_json(row["content"]) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 22c05cdde7df..38bfdf5dad3a 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -224,12 +224,12 @@ def _set_ui_auth_session_data_txn( self, txn: LoggingTransaction, session_id: str, key: str, value: Any ): # Get the current value. - result = self.db_pool.simple_select_one_txn( + result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) # type: Dict[str, Any] # type: ignore + ) # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 051095fea915..a39877f0d5b2 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -307,7 +307,7 @@ async def persist_events( matched the transcation ID; the existing event is returned in such a case. """ - partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] + partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) @@ -384,7 +384,7 @@ async def _persist_event_batch( A dictionary of event ID to event ID we didn't persist as we already had another event persisted with the same TXN ID. """ - replaced_events = {} # type: Dict[str, str] + replaced_events: Dict[str, str] = {} if not events_and_contexts: return replaced_events @@ -440,16 +440,14 @@ async def _persist_event_batch( # Set of remote users which were in rooms the server has left. We # should check if we still share any rooms and if not we mark their # device lists as stale. - potentially_left_users = set() # type: Set[str] + potentially_left_users: Set[str] = set() if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. # We do this by working out what the new extremities are and then # calculating the state from that. - events_by_room = ( - {} - ) # type: Dict[str, List[Tuple[EventBase, EventContext]]] + events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {} for event, context in chunk: events_by_room.setdefault(event.room_id, []).append( (event, context) @@ -622,9 +620,9 @@ async def _calculate_new_extremities( ) # Remove any events which are prev_events of any existing events. - existing_prevs = await self.persist_events_store._get_events_which_are_prevs( - result - ) # type: Collection[str] + existing_prevs: Collection[ + str + ] = await self.persist_events_store._get_events_which_are_prevs(result) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 683e5e3b90b4..61392b9639c3 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -256,7 +256,7 @@ def _setup_new_database( for database in databases ) - directory_entries = [] # type: List[_DirectoryListing] + directory_entries: List[_DirectoryListing] = [] for directory in directories: directory_entries.extend( _DirectoryListing(file_name, os.path.join(directory, file_name)) @@ -424,10 +424,10 @@ def _upgrade_existing_database( directories.append(os.path.join(schema_path, database, "delta", str(v))) # Used to check if we have any duplicate file names - file_name_counter = Counter() # type: CounterType[str] + file_name_counter: CounterType[str] = Counter() # Now find which directories have anything of interest. - directory_entries = [] # type: List[_DirectoryListing] + directory_entries: List[_DirectoryListing] = [] for directory in directories: logger.debug("Looking for schema deltas in %s", directory) try: @@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]: def executescript(txn: Cursor, schema_path: str) -> None: - with open(schema_path, "r") as f: + with open(schema_path) as f: execute_statements_from_stream(txn, f) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 0a53b73ccc4e..36340a652aac 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 60 +SCHEMA_VERSION = 61 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -21,6 +21,10 @@ See `README.md `_ for more information on how this works. + +Changes in SCHEMA_VERSION = 61: + - The `user_stats_historical` and `room_stats_historical` tables are not written and + are not read (previously, they were written but not read). """ diff --git a/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres b/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres new file mode 100644 index 000000000000..c8aec78e6034 --- /dev/null +++ b/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres @@ -0,0 +1,23 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we use bigint elsewhere in the database for appservice txn ids (notably, +-- application_services_state.last_txn), and generally we use bigints everywhere else +-- we have monotonic counters, so let's bring this one in line. +-- +-- assuming there aren't thousands of rows for decommisioned/non-functional ASes, this +-- table should be pretty small, so safe to do a synchronous ALTER TABLE. + +ALTER TABLE application_services_txns ALTER COLUMN txn_id SET DATA TYPE BIGINT; diff --git a/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql b/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql new file mode 100644 index 000000000000..35ca7a40c026 --- /dev/null +++ b/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this index is redundant; there is another UNIQUE index on this table. +DROP INDEX IF EXISTS room_depth_room; + diff --git a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py new file mode 100644 index 000000000000..f8d7db9f2ef3 --- /dev/null +++ b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py @@ -0,0 +1,70 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This migration handles the process of changing the type of `room_depth.min_depth` to +a BIGINT. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + # this only applies to postgres - sqlite does not distinguish between big and + # little ints. + return + + # First add a new column to contain the bigger min_depth + cur.execute("ALTER TABLE room_depth ADD COLUMN min_depth2 BIGINT") + + # Create a trigger which will keep it populated. + cur.execute( + """ + CREATE OR REPLACE FUNCTION populate_min_depth2() RETURNS trigger AS $BODY$ + BEGIN + new.min_depth2 := new.min_depth; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql + """ + ) + + cur.execute( + """ + CREATE TRIGGER populate_min_depth2_trigger BEFORE INSERT OR UPDATE ON room_depth + FOR EACH ROW + EXECUTE PROCEDURE populate_min_depth2() + """ + ) + + # Start a bg process to populate it for old rooms + cur.execute( + """ + INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6103, 'populate_room_depth_min_depth2', '{}') + """ + ) + + # and another to switch them over once it completes. + cur.execute( + """ + INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6103, 'replace_room_depth_min_depth', '{}', 'populate_room_depth2') + """ + ) + + +def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres b/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres new file mode 100644 index 000000000000..35a153da7b92 --- /dev/null +++ b/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres @@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- By default the postgres statistics collector massively underestimates the +-- number of distinct state groups are in the `state_groups_state`, which can +-- cause postgres to use table scans for queries for multiple state groups. +-- +-- To work around this we can manually tell postgres the number of distinct state +-- groups there are by setting `n_distinct` (a negative value here is the number +-- of distinct values divided by the number of rows, so -0.02 means on average +-- there are 50 rows per distinct value). We don't need a particularly +-- accurate number here, as a) we just want it to always use index scans and b) +-- our estimate is going to be better than the one made by the statistics +-- collector. + +ALTER TABLE state_groups_state ALTER COLUMN state_group SET (n_distinct = -0.02); + +-- Ideally we'd do an `ANALYZE state_groups_state (state_group)` here so that +-- the above gets picked up immediately, but that can take a bit of time so we +-- rely on the autovacuum eventually getting run and doing that in the +-- background for us. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c9dce726cb0e..f8fbba9d381e 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -91,7 +91,7 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": Returns: The new state filter. """ - type_dict = {} # type: Dict[str, Optional[Set[str]]] + type_dict: Dict[str, Optional[Set[str]]] = {} for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -194,7 +194,7 @@ def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """ where_clause = "" - where_args = [] # type: List[str] + where_args: List[str] = [] if self.is_full(): return where_clause, where_args diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f1e62f9e855d..c768fdea56df 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ def __init__( # insertion ordering will ensure its in the correct ordering. # # The key and values are the same, but we never look at the values. - self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] + self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def get_next(self): """ @@ -236,15 +236,15 @@ def __init__( # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = {} # type: Dict[str, int] + self._current_positions: Dict[str, int] = {} # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). - self._unfinished_ids = set() # type: Set[int] + self._unfinished_ids: Set[int] = set() # Set of local IDs that we've processed that are larger than the current # position, due to there being smaller unpersisted IDs. - self._finished_ids = set() # type: Set[int] + self._finished_ids: Set[int] = set() # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances @@ -265,7 +265,7 @@ def __init__( self._persisted_upto_position = ( min(self._current_positions.values()) if self._current_positions else 1 ) - self._known_persisted_positions = [] # type: List[int] + self._known_persisted_positions: List[int] = [] self._sequence_gen = PostgresSequenceGenerator(sequence_name) @@ -465,7 +465,7 @@ def _mark_id_as_finished(self, next_id: int): self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None # type: Optional[int] + new_cur: Optional[int] = None if self._unfinished_ids: # If there are unfinished IDs then the new position will be the diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 30b6b8e0caea..bb33e04fb10a 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -208,10 +208,10 @@ def __init__(self, get_first_callback: GetFirstCallbackType): get_next_id_txn; should return the curreent maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. - self._callback = get_first_callback # type: Optional[GetFirstCallbackType] + self._callback: Optional[GetFirstCallbackType] = get_first_callback # The current max value, or None if we haven't looked in the DB yet. - self._current_max_id = None # type: Optional[int] + self._current_max_id: Optional[int] = None self._lock = threading.Lock() def get_next_id_txn(self, txn: Cursor) -> int: @@ -274,7 +274,7 @@ def build_sequence_generator( `check_consistency` details. """ if isinstance(database_engine, PostgresEngine): - seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator + seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name) else: seq = LocalSequenceGenerator(get_first_callback) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 20fceaa93528..99b0aac2fb12 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -32,9 +32,9 @@ class EventSources: } def __init__(self, hs): - self.sources = { + self.sources: Dict[str, Any] = { name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() - } # type: Dict[str, Any] + } self.store = hs.get_datastore() def get_current_token(self) -> StreamToken: diff --git a/synapse/types.py b/synapse/types.py index 8d2fa00f7140..429bb013d2cf 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -182,14 +182,14 @@ def create_requester( ) -def get_domain_from_id(string): +def get_domain_from_id(string: str) -> str: idx = string.find(":") if idx == -1: raise SynapseError(400, "Invalid ID: %r" % (string,)) return string[idx + 1 :] -def get_localpart_from_id(string): +def get_localpart_from_id(string: str) -> str: idx = string.find(":") if idx == -1: raise SynapseError(400, "Invalid ID: %r" % (string,)) @@ -210,7 +210,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): 'domain' : The domain part of the name """ - SIGIL = abc.abstractproperty() # type: str # type: ignore + SIGIL: str = abc.abstractproperty() # type: ignore localpart = attr.ib(type=str) domain = attr.ib(type=str) @@ -304,7 +304,7 @@ class GroupID(DomainSpecificString): @classmethod def from_string(cls: Type[DS], s: str) -> DS: - group_id = super().from_string(s) # type: DS # type: ignore + group_id: DS = super().from_string(s) # type: ignore if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) @@ -577,10 +577,10 @@ async def to_string(self, store: "DataStore") -> str: entries = [] for name, pos in self.instance_map.items(): instance_id = await store.get_id_for_instance(name) - entries.append("{}.{}".format(instance_id, pos)) + entries.append(f"{instance_id}.{pos}") encoded_map = "~".join(entries) - return "m{}~{}".format(self.stream, encoded_map) + return f"m{self.stream}~{encoded_map}" else: return "s%d" % (self.stream,) @@ -600,7 +600,7 @@ class StreamToken: groups_key = attr.ib(type=int) _SEPARATOR = "_" - START = None # type: StreamToken + START: "StreamToken" @classmethod async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 061102c3c894..014db1355b13 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -257,7 +257,7 @@ def __init__( max_count: The maximum number of concurrent accesses """ if name is None: - self.name = id(self) # type: Union[str, int] + self.name: Union[str, int] = id(self) else: self.name = name @@ -269,7 +269,7 @@ def __init__( self.max_count = max_count # key_to_defer is a map from the key to a _LinearizerEntry. - self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] + self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {} def is_queued(self, key: Hashable) -> bool: """Checks whether there is a process queued up waiting""" @@ -409,10 +409,10 @@ class ReadWriteLock: def __init__(self): # Latest readers queued - self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]] + self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} # Latest writer queued - self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] + self.key_to_current_writer: Dict[str, defer.Deferred] = {} async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 8fd5bfb69b75..274cea7eb709 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -93,11 +93,11 @@ def __init__( self._clock = clock # The set of keys currently being processed. - self._processing_keys = set() # type: Set[Hashable] + self._processing_keys: Set[Hashable] = set() # The currently pending batch of values by key, with a Deferred to call # with the result of the corresponding `_process_batch_callback` call. - self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]] + self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {} # The function to call with batches of values. self._process_batch_callback = process_batch_callback @@ -108,9 +108,7 @@ def __init__( number_of_keys.labels(self._name).set_function(lambda: len(self._next_values)) - self._number_in_flight_metric = number_in_flight.labels( - self._name - ) # type: Gauge + self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name) async def add_to_queue(self, value: V, key: Hashable = ()) -> R: """Adds the value to the queue with the given key, returning the result diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index ca36f07c205b..9012034b7aa8 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -29,8 +29,8 @@ TRACK_MEMORY_USAGE = False -caches_by_name = {} # type: Dict[str, Sized] -collectors_by_name = {} # type: Dict[str, CacheMetric] +caches_by_name: Dict[str, Sized] = {} +collectors_by_name: Dict[str, "CacheMetric"] = {} cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index a301c9e89bf9..891bee0b33ae 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -63,9 +63,9 @@ def __init__(self, f: Callable[[], Awaitable[TV]]): f: The underlying function. Only one call to this function will be alive at once (per instance of CachedCall) """ - self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] - self._deferred = None # type: Optional[Deferred] - self._result = None # type: Union[None, Failure, TV] + self._callable: Optional[Callable[[], Awaitable[TV]]] = f + self._deferred: Optional[Deferred] = None + self._result: Union[None, Failure, TV] = None async def get(self) -> TV: """Kick off the call if necessary, and return the result""" diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1044139119c5..8c6fafc6770a 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -80,25 +80,25 @@ def __init__( cache_type = TreeCache if tree else dict # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache = ( - cache_type() - ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]] + self._pending_deferred_cache: Union[ + TreeCache, "MutableMapping[KT, CacheEntry]" + ] = cache_type() def metrics_cb(): cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) # cache is used for completed results and maps to the result itself, rather than # a Deferred. - self.cache = LruCache( + self.cache: LruCache[KT, VT] = LruCache( max_size=max_entries, cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d) or 1) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, - ) # type: LruCache[KT, VT] + ) - self.thread = None # type: Optional[threading.Thread] + self.thread: Optional[threading.Thread] = None @property def max_entries(self): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index d77e8edeeadf..1e8e6b1d01c6 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -46,17 +46,17 @@ class _CachedFunction(Generic[F]): - invalidate = None # type: Any - invalidate_all = None # type: Any - prefill = None # type: Any - cache = None # type: Any - num_args = None # type: Any + invalidate: Any = None + invalidate_all: Any = None + prefill: Any = None + cache: Any = None + num_args: Any = None - __name__ = None # type: str + __name__: str # Note: This function signature is actually fiddled with by the synapse mypy # plugin to a) make it a bound method, and b) remove any `cache_context` arg. - __call__ = None # type: F + __call__: F class _CacheDescriptorBase: @@ -115,8 +115,8 @@ def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): class _LruCachedFunction(Generic[F]): - cache = None # type: LruCache[CacheKey, Any] - __call__ = None # type: F + cache: LruCache[CacheKey, Any] + __call__: F def lru_cache( @@ -180,10 +180,10 @@ def __init__( self.max_entries = max_entries def __get__(self, obj, owner): - cache = LruCache( + cache: LruCache[CacheKey, Any] = LruCache( cache_name=self.orig.__name__, max_size=self.max_entries, - ) # type: LruCache[CacheKey, Any] + ) get_cache_key = self.cache_key_builder sentinel = LruCacheDescriptor._Sentinel.sentinel @@ -271,12 +271,12 @@ def __init__( self.iterable = iterable def __get__(self, obj, owner): - cache = DeferredCache( + cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, - ) # type: DeferredCache[CacheKey, Any] + ) get_cache_key = self.cache_key_builder @@ -359,7 +359,7 @@ def __init__(self, orig, cached_method_name, list_name, num_args=None): def __get__(self, obj, objtype=None): cached_method = getattr(obj, self.cached_method_name) - cache = cached_method.cache # type: DeferredCache[CacheKey, Any] + cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) @@ -472,15 +472,15 @@ class _CacheContext: Cache = Union[DeferredCache, LruCache] - _cache_context_objects = ( - WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] + _cache_context_objects: """WeakValueDictionary[ + Tuple["_CacheContext.Cache", CacheKey], "_CacheContext" + ]""" = WeakValueDictionary() def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key - def invalidate(self): # type: () -> None + def invalidate(self) -> None: """Invalidates the cache entry referred to by the context.""" self._cache.invalidate(self._cache_key) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 56d94d96ce0b..3f852edd7fcf 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]): """ def __init__(self, name: str, max_entries: int = 1000): - self.cache = LruCache( + self.cache: LruCache[KT, DictionaryEntry] = LruCache( max_size=max_entries, cache_name=name, size_callback=len - ) # type: LruCache[KT, DictionaryEntry] + ) self.name = name self.sequence = 0 - self.thread = None # type: Optional[threading.Thread] + self.thread: Optional[threading.Thread] = None def check_thread(self) -> None: expected_thread = self.thread diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index ac47a31cd7ef..bde16b8577d3 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -SENTINEL = object() # type: Any +SENTINEL: Any = object() T = TypeVar("T") @@ -71,7 +71,7 @@ def __init__( self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry] + self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict() self.iterable = iterable diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4b9d0433fffc..5c65d187b6da 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -90,8 +90,7 @@ def enumerate_leaves(node, depth): yield node else: for n in node.values(): - for m in enumerate_leaves(n, depth - 1): - yield m + yield from enumerate_leaves(n, depth - 1) P = TypeVar("P") @@ -226,7 +225,7 @@ def __init__( # footprint down. Storing `None` is free as its a singleton, while empty # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive # thing and used sets). - self.callbacks = None # type: Optional[List[Callable[[], None]]] + self.callbacks: Optional[List[Callable[[], None]]] = None self.add_callbacks(callbacks) @@ -362,15 +361,15 @@ def __init__( # register_cache might call our "set_cache_factor" callback; there's nothing to # do yet when we get resized. - self._on_resize = None # type: Optional[Callable[[],None]] + self._on_resize: Optional[Callable[[], None]] = None if cache_name is not None: - metrics = register_cache( + metrics: Optional[CacheMetric] = register_cache( "lru_cache", cache_name, self, collect_callback=metrics_collection_callback, - ) # type: Optional[CacheMetric] + ) else: metrics = None diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 34c662c4dbd7..ed7204336f7f 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -66,7 +66,7 @@ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): # This is poorly-named: it includes both complete and incomplete results. # We keep complete results rather than switching to absolute values because # that makes it easier to cache Failure results. - self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred] + self.pending_result_cache: Dict[KV, ObservableDeferred] = {} self.clock = clock self.timeout_sec = timeout_ms / 1000.0 diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e81e468899ea..3a41a8baa603 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -45,10 +45,10 @@ def __init__( ): self._original_max_size = max_size self._max_size = math.floor(max_size) - self._entity_to_key = {} # type: Dict[EntityType, int] + self._entity_to_key: Dict[EntityType, int] = {} # map from stream id to the a set of entities which changed at that stream id. - self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] + self._cache: SortedDict[int, Set[EntityType]] = SortedDict() # the earliest stream_pos for which we can reliably answer # get_all_entities_changed. In other words, one less than the earliest @@ -155,7 +155,7 @@ def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType] if stream_pos < self._earliest_known_stream_pos: return None - changed_entities = [] # type: List[EntityType] + changed_entities: List[EntityType] = [] for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): changed_entities.extend(self._cache[k]) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index a6df81ebffdc..4138931e7bc1 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d): """ if isinstance(d, TreeCacheNode): for value_d in d.values(): - for value in iterate_tree_cache_entry(value_d): - yield value + yield from iterate_tree_cache_entry(value_d) else: yield d diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index c276107d5690..46afe3f934ab 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -SENTINEL = object() # type: Any +SENTINEL: Any = object() T = TypeVar("T") KT = TypeVar("KT") @@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data = {} # type: Dict[KT, _CacheEntry] + self._data: Dict[KT, _CacheEntry] = {} # the _CacheEntries, sorted by expiry time - self._expiry_list = SortedList() # type: SortedList[_CacheEntry] + self._expiry_list: SortedList[_CacheEntry] = SortedList() self._timer = timer diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index 31b24dd18882..d8532411c2c5 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # If pidfile already exists, we should read pid from there; to overwrite it, if # locking will fail, because locking attempt somehow purges the file contents. if os.path.isfile(pid_file): - with open(pid_file, "r") as pid_fh: + with open(pid_file) as pid_fh: old_pid = pid_fh.read() # Create a lockfile so that only one instance of this daemon is running at any time. try: lock_fh = open(pid_file, "w") - except IOError: + except OSError: print("Unable to create the pidfile.") sys.exit(1) @@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # Try to get an exclusive lock on the file. This will fail if another process # has the file locked. fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB) - except IOError: + except OSError: print("Unable to lock on the pidfile.") # We need to overwrite the pidfile if we got here. # @@ -113,7 +113,7 @@ def excepthook(type_, value, traceback): try: lock_fh.write("%s" % (os.getpid())) lock_fh.flush() - except IOError: + except OSError: logger.error("Unable to write pid to the pidfile.") print("Unable to write pid to the pidfile.") sys.exit(1) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 886afa9d1931..8ac3eab2f54f 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -68,7 +68,7 @@ def sorted_topologically( # This is implemented by Kahn's algorithm. degree_map = {node: 0 for node in nodes} - reverse_graph = {} # type: Dict[T, Set[T]] + reverse_graph: Dict[T, Set[T]] = {} for node, edges in graph.items(): if node not in degree_map: diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index f6ebfd7e7d41..d1f76e3dc54f 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: caveat in the macaroon, or if the caveat was not found in the macaroon. """ prefix = key + " = " - result = None # type: Optional[str] + result: Optional[str] = None for caveat in macaroon.caveats: if not caveat.caveat_id.startswith(prefix): continue diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 45353d41c55e..1b82dca81b08 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -124,7 +124,7 @@ def __init__(self, clock, name: str): assert isinstance(curr_context, LoggingContext) parent_context = curr_context self._logging_context = LoggingContext(str(curr_context), parent_context) - self.start = None # type: Optional[int] + self.start: Optional[int] = None def __enter__(self) -> "Measure": if self.start is not None: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index eed0291cae88..99f01e325cf6 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -41,7 +41,7 @@ def new_inline_callbacks(f): @functools.wraps(f) def wrapped(*args, **kwargs): start_context = current_context() - changes = [] # type: List[str] + changes: List[str] = [] orig = orig_inline_callbacks(_check_yield_points(f, changes)) try: @@ -131,7 +131,7 @@ def check_yield_points_inner(*args, **kwargs): gen = f(*args, **kwargs) last_yield_line_no = gen.gi_frame.f_lineno - result = None # type: Any + result: Any = None while True: expected_context = current_context() diff --git a/synapse/visibility.py b/synapse/visibility.py index 490fb26e8114..17532059e9f8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -90,13 +90,13 @@ async def filter_events_for_client( AccountDataTypes.IGNORED_USER_LIST, user_id ) - ignore_list = frozenset() # type: FrozenSet[str] + ignore_list: FrozenSet[str] = frozenset() if ignore_dict_content: ignored_users_dict = ignore_dict_content.get("ignored_users", {}) if isinstance(ignored_users_dict, dict): ignore_list = frozenset(ignored_users_dict.keys()) - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -353,7 +353,7 @@ def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool ) if not check_history_visibility_only: - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. diff --git a/sytest-blacklist b/sytest-blacklist index 566ef96711ea..de9986357b9a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -41,8 +41,3 @@ We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility Local users can peek by room alias Peeked rooms only turn up in the sync for the device who peeked them - - -# Blacklisted due to changes made in #10272 -Outbound federation will ignore a missing event with bad JSON for room version 6 -Federation rejects inbound events where the prev_events cannot be found diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py new file mode 100644 index 000000000000..5527e278dbc3 --- /dev/null +++ b/tests/app/test_phone_stats_home.py @@ -0,0 +1,395 @@ +import synapse +from synapse.app.phone_stats_home import start_phone_stats_home +from synapse.rest.client.v1 import login, room + +from tests import unittest +from tests.unittest import HomeserverTestCase + +FIVE_MINUTES_IN_SECONDS = 300 +ONE_DAY_IN_SECONDS = 86400 + + +class PhoneHomeTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + # Override the retention time for the user_ips table because otherwise it + # gets pruned too aggressively for our R30 test. + @unittest.override_config({"user_ips_max_age": "365d"}) + def test_r30_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the R30 results do not count that user. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Advance 30 days (+ 1 second, because strict inequality causes issues if we are + # bang on 30 days later). + self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait some time for _update_client_ips_batch to get + # called and update the user_ips table. + self.reactor.advance(2 * 60 * 60) + + # *Now* the user is counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance 29 days. The user has now not posted for 29 days. + self.reactor.advance(29 * ONE_DAY_IN_SECONDS) + + # The user is still counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance another day. The user has now not posted for 30 days. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # The user is now no longer counted in R30. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + def test_r30_minimum_usage_using_default_config(self): + """ + Tests the minimum amount of interaction necessary for the R30 metric + to consider a user 'retained'. + + N.B. This test does not override the `user_ips_max_age` config setting, + which defaults to 28 days. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the R30 results do not count that user. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Advance 30 days (+ 1 second, because strict inequality causes issues if we are + # bang on 30 days later). + self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait some time for _update_client_ips_batch to get + # called and update the user_ips table. + self.reactor.advance(2 * 60 * 60) + + # *Now* the user is counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance 27 days. The user has now not posted for 27 days. + self.reactor.advance(27 * ONE_DAY_IN_SECONDS) + + # The user is still counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance another day. The user has now not posted for 28 days. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # The user is now no longer counted in R30. + # (This is because the user_ips table has been pruned, which by default + # only preserves the last 28 days of entries.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + def test_r30_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30 statistic, even if they post every day + during that time! + """ + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "I'm still here", tok=access_token) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "Still here!", tok=access_token) + + # *Now* the user appears in R30. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + +class PhoneHomeR30V2TestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def _advance_to(self, desired_time_secs: float): + now = self.hs.get_clock().time() + assert now < desired_time_secs + self.reactor.advance(desired_time_secs - now) + + def make_homeserver(self, reactor, clock): + hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) + + # We don't want our tests to actually report statistics, so check + # that it's not enabled + assert not hs.config.report_stats + + # This starts the needed data collection that we rely on to calculate + # R30v2 metrics. + start_phone_stats_home(hs) + return hs + + def test_r30v2_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30v2 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + first_post_at = self.hs.get_clock().time() + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the R30 results do not count that user. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance 31 days. + # (R30v2 includes users with **more** than 30 days between the two visits, + # and user_daily_visits records the timestamp as the start of the day.) + self.reactor.advance(31 * ONE_DAY_IN_SECONDS) + # Also advance 5 minutes to let another user_daily_visits update occur + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait a few minutes for the user_daily_visits table to + # be updated by a background process. + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user is counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance to JUST under 60 days after the user's first post + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5) + + # Check the user is still counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance into the next day. The user's first activity is now more than 60 days old. + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5) + + # Check the user is now no longer counted in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30v2 statistic, even if they post every day + during that time! + """ + + # set a custom user-agent to impersonate Element/Android. + headers = ( + ( + "User-Agent", + "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS) + self.helper.send( + room_id, "I'm still here", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # advance yet another day with more activity + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send( + room_id, "Still here!", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user appears in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_returning_dormant_users_not_counted(self): + """ + Tests that dormant users (users inactive for a long time) do not + contribute to R30v2 when they return for just a single day. + This is a key difference between R30 and R30v2. + """ + + # set a custom user-agent to impersonate Element/iOS. + headers = ( + ( + "User-Agent", + "Riot/1.4 (iPhone; iOS 13; Scale/4.00)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # the user goes inactive for 2 months + self.reactor.advance(60 * ONE_DAY_IN_SECONDS) + + # the user returns for one day, perhaps just to check out a new feature + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check that the user does not contribute to R30v2, even though it's been + # more than 30 days since registration. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Check that this is a situation where old R30 differs: + # old R30 DOES count this as 'retained'. + r30_results = self.get_success(store.count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "ios": 1}) + + # Now we want to check that the user will still be able to appear in + # R30v2 as long as the user performs some other activity between + # 30 and 60 days later. + self.reactor.advance(32 * ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # (give time for tables to update) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Check the user now satisfies the requirements to appear in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 59.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS) + + # Check the user still appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 60.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # Check the user no longer appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ebe2c0516570..903c69127d13 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -43,7 +43,7 @@ def test_load_fails_if_server_name_missing(self): def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) @@ -120,7 +120,7 @@ def generate_config(self): def generate_config_and_remove_lines_containing(self, needle): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: contents = f.readlines() contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 875b0d0a114b..3f41e9995039 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -152,7 +152,7 @@ def test_receiving_all_presence(self): ) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_one_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "boop") @@ -274,7 +274,7 @@ def test_send_local_online_presence_to_with_module(self): presence_updates, _ = sync_presence(self, self.other_user_id) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "I'm online!") @@ -285,6 +285,10 @@ def test_send_local_online_presence_to_with_module(self): presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) self.assertEqual(len(presence_updates), 3) + # We stagger sending of presence, so we need to wait a bit for them to + # get sent out. + self.reactor.advance(60) + # Test that sending to a remote user works remote_user_id = "@far_away_person:island" @@ -301,6 +305,10 @@ def test_send_local_online_presence_to_with_module(self): self.module_api.send_local_online_presence_to([remote_user_id]) ) + # We stagger sending of presence, so we need to wait a bit for them to + # get sent out. + self.reactor.advance(60) + # Check that the expected presence updates were sent # We explicitly compare using sets as we expect that calling # module_api.send_local_online_presence_to will create a presence @@ -320,7 +328,7 @@ def test_send_local_online_presence_to_with_module(self): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index e0a24824cc8f..39e7b1ab254e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -47,12 +47,16 @@ def test_reupload_one_time_keys(self): "alg2:k3": {"key": "key3"}, } + # Note that "signed_curve25519" is always returned in key count responses. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} + ) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" @@ -61,7 +65,9 @@ def test_reupload_one_time_keys(self): local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} + ) def test_change_one_time_keys(self): """attempts to change one-time-keys should be rejected""" @@ -79,7 +85,9 @@ def test_change_one_time_keys(self): local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} + ) # Error when changing string key self.get_failure( @@ -89,7 +97,7 @@ def test_change_one_time_keys(self): SynapseError, ) - # Error when replacing dict key with strin + # Error when replacing dict key with string self.get_failure( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} @@ -131,7 +139,9 @@ def test_claim_one_time_key(self): local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} + ) res2 = self.get_success( self.handler.claim_one_time_keys( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index cdb41101b3c4..2928c4f48cca 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -103,7 +103,7 @@ def test_set_my_name(self): ) self.assertIsNone( - (self.get_success(self.store.get_profile_displayname(self.frank.localpart))) + self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) def test_set_my_name_if_disabled(self): diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 9771d3fb3b48..3f73ad7f9478 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,8 +14,18 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock -from synapse.api.constants import EventContentFields, RoomTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -117,7 +127,7 @@ def _add_child(self, space_id: str, room_id: str, token: str) -> None: """Add a child room to a space.""" self.helper.send_state( space_id, - event_type="m.space.child", + event_type=EventTypes.SpaceChild, body={"via": [self.hs.hostname]}, tok=token, state_key=room_id, @@ -155,26 +165,379 @@ def test_visibility(self): # The user cannot see the space. self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - # Joining the room causes it to be visible. - self.helper.join(self.space, user2, tok=token2) + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - # The result should only have the space, but includes the link to the room. - self._assert_rooms(result, [self.space]) + self._assert_rooms(result, [self.space, self.room]) self._assert_events(result, [(self.space, self.room)]) - def test_world_readable(self): - """A world-readable room is visible to everyone.""" + # Make it not world-readable again and confirm it results in an error. self.helper.send_state( self.space, - event_type="m.room.history_visibility", - body={"history_visibility": "world_readable"}, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, [self.space, self.room]) + self._assert_events(result, [(self.space, self.room)]) + + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, ) + self._add_child(self.space, room_id, self.token) + return room_id + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") - # The space should be visible, as well as the link to the room. + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.MSC3083_RESTRICTED, + room_version=RoomVersions.MSC3083.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.MSC3083_RESTRICTED, + room_version=RoomVersions.MSC3083.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space]) - self._assert_events(result, [(self.space, self.room)]) + + self._assert_rooms( + result, + [ + self.space, + self.room, + public_room, + knock_room, + invited_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, public_room), + (self.space, knock_room), + (self.space, not_invited_room), + (self.space, invited_room), + (self.space, restricted_room), + (self.space, restricted_accessible_room), + (self.space, world_readable_room), + (self.space, joined_room), + ], + ) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + self._assert_rooms(result, [self.space, self.room, subspace, subroom]) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, room2), + (self.space, subspace), + (subspace, subroom), + (subspace, self.room), + (subspace, room2), + ], + ) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + # Return some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + rooms = [ + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + { + "room_id": subroom, + "world_readable": True, + }, + ] + event_content = {"via": [fed_hostname]} + events = [ + { + "room_id": subspace, + "state_key": subroom, + "content": event_content, + }, + ] + return rooms, events + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms(result, [self.space, self.room, subspace, subroom]) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, subspace), + (subspace, subroom), + ], + ) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + event = make_event_from_dict( + { + "room_id": invited_room, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": "@remote:" + fed_hostname, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + # Note that these entries are brief, but should contain enough info. + rooms = [ + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [], + }, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [self.room], + }, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ] + + # Place each room in the sub-space. + event_content = {"via": [fed_hostname]} + events = [ + { + "room_id": subspace, + "state_key": room["room_id"], + "content": event_content, + } + for room in rooms + ] + + # Also include the subspace. + rooms.insert( + 0, + { + "room_id": subspace, + "world_readable": True, + }, + ) + return rooms, events + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms( + result, + [ + self.space, + self.room, + subspace, + public_room, + knock_room, + invited_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, subspace), + (subspace, public_room), + (subspace, knock_room), + (subspace, not_invited_room), + (subspace, invited_room), + (subspace, restricted_room), + (subspace, restricted_accessible_room), + (subspace, world_readable_room), + (subspace, joined_room), + ], + ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index c9d4fd93368f..e4059acda356 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -88,16 +88,12 @@ async def get_all_room_state(self): def _get_current_stats(self, stats_type, stat_id): table, id_col = stats.TYPE_TO_TABLE[stats_type] - cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( - stats.PER_SLICE_FIELDS[stats_type] - ) - - end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) + cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) return self.get_success( self.store.db_pool.simple_select_one( - table + "_historical", - {id_col: stat_id, end_ts: end_ts}, + table + "_current", + {id_col: stat_id}, cols, allow_none=True, ) @@ -156,115 +152,6 @@ def test_initial_room(self): self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") - def test_initial_earliest_token(self): - """ - Ingestion via notify_new_event will ignore tokens that the background - update have already processed. - """ - - self.reactor.advance(86401) - - self.hs.config.stats_enabled = False - self.handler.stats_enabled = False - - u1 = self.register_user("u1", "pass") - u1_token = self.login("u1", "pass") - - u2 = self.register_user("u2", "pass") - u2_token = self.login("u2", "pass") - - u3 = self.register_user("u3", "pass") - u3_token = self.login("u3", "pass") - - room_1 = self.helper.create_room_as(u1, tok=u1_token) - self.helper.send_state( - room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token - ) - - # Begin the ingestion by creating the temp tables. This will also store - # the position that the deltas should begin at, once they take over. - self.hs.config.stats_enabled = True - self.handler.stats_enabled = True - self.store.db_pool.updates._all_done = False - self.get_success( - self.store.db_pool.simple_update_one( - table="stats_incremental_position", - keyvalues={}, - updatevalues={"stream_id": 0}, - ) - ) - - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "populate_stats_prepare", "progress_json": "{}"}, - ) - ) - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - # Now, before the table is actually ingested, add some more events. - self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) - self.helper.join(room=room_1, user=u2, tok=u2_token) - - # orig_delta_processor = self.store. - - # Now do the initial ingestion. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_stats_cleanup", - "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", - }, - ) - ) - - self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - self.reactor.advance(86401) - - # Now add some more events, triggering ingestion. Because of the stream - # position being set to before the events sent in the middle, a simpler - # implementation would reprocess those events, and say there were four - # users, not three. - self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) - self.helper.join(room=room_1, user=u3, tok=u3_token) - - # self.handler.notify_new_event() - - # We need to let the delta processor advance
 - self.reactor.advance(10 * 60) - - # Get the slices! There should be two -- day 1, and day 2. - r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) - - self.assertEqual(len(r), 2) - - # The oldest has 2 joined members - self.assertEqual(r[-1]["joined_members"], 2) - - # The newest has 3 - self.assertEqual(r[0]["joined_members"], 3) - def test_create_user(self): """ When we create a user, it should have statistics already ready. @@ -296,22 +183,6 @@ def test_create_room(self): self.assertIsNotNone(r1stats) self.assertIsNotNone(r2stats) - # contains the default things you'd expect in a fresh room - self.assertEqual( - r1stats["total_events"], - EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM, - "Wrong number of total_events in new room's stats!" - " You may need to update this if more state events are added to" - " the room creation process.", - ) - self.assertEqual( - r2stats["total_events"], - EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, - "Wrong number of total_events in new room's stats!" - " You may need to update this if more state events are added to" - " the room creation process.", - ) - self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM ) @@ -327,24 +198,6 @@ def test_create_room(self): self.assertEqual(r2stats["invited_members"], 0) self.assertEqual(r2stats["banned_members"], 0) - def test_send_message_increments_total_events(self): - """ - When we send a message, it increments total_events. - """ - - self._perform_background_initial_update() - - u1 = self.register_user("u1", "pass") - u1token = self.login("u1", "pass") - r1 = self.helper.create_room_as(u1, tok=u1token) - r1stats_ante = self._get_current_stats("room", r1) - - self.helper.send(r1, "hiss", tok=u1token) - - r1stats_post = self._get_current_stats("room", r1) - - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) - def test_updating_profile_information_does_not_increase_joined_members_count(self): """ Check that the joined_members count does not increase when a user changes their @@ -378,7 +231,7 @@ def test_updating_profile_information_does_not_increase_joined_members_count(sel def test_send_state_event_nonoverwriting(self): """ - When we send a non-overwriting state event, it increments total_events AND current_state_events + When we send a non-overwriting state event, it increments current_state_events """ self._perform_background_initial_update() @@ -399,44 +252,14 @@ def test_send_state_event_nonoverwriting(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, ) - def test_send_state_event_overwriting(self): - """ - When we send an overwriting state event, it increments total_events ONLY - """ - - self._perform_background_initial_update() - - u1 = self.register_user("u1", "pass") - u1token = self.login("u1", "pass") - r1 = self.helper.create_room_as(u1, tok=u1token) - - self.helper.send_state( - r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" - ) - - r1stats_ante = self._get_current_stats("room", r1) - - self.helper.send_state( - r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby" - ) - - r1stats_post = self._get_current_stats("room", r1) - - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) - self.assertEqual( - r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], - 0, - ) - def test_join_first_time(self): """ - When a user joins a room for the first time, total_events, current_state_events and + When a user joins a room for the first time, current_state_events and joined_members should increase by exactly 1. """ @@ -455,7 +278,6 @@ def test_join_first_time(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, @@ -466,7 +288,7 @@ def test_join_first_time(self): def test_join_after_leave(self): """ - When a user joins a room after being previously left, total_events and + When a user joins a room after being previously left, joined_members should increase by exactly 1. current_state_events should not increase. left_members should decrease by exactly 1. @@ -490,7 +312,6 @@ def test_join_after_leave(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -504,7 +325,7 @@ def test_join_after_leave(self): def test_invited(self): """ - When a user invites another user, current_state_events, total_events and + When a user invites another user, current_state_events and invited_members should increase by exactly 1. """ @@ -522,7 +343,6 @@ def test_invited(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, @@ -533,7 +353,7 @@ def test_invited(self): def test_join_after_invite(self): """ - When a user joins a room after being invited, total_events and + When a user joins a room after being invited and joined_members should increase by exactly 1. current_state_events should not increase. invited_members should decrease by exactly 1. @@ -556,7 +376,6 @@ def test_join_after_invite(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -570,7 +389,7 @@ def test_join_after_invite(self): def test_left(self): """ - When a user leaves a room after joining, total_events and + When a user leaves a room after joining and left_members should increase by exactly 1. current_state_events should not increase. joined_members should decrease by exactly 1. @@ -593,7 +412,6 @@ def test_left(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -607,7 +425,7 @@ def test_left(self): def test_banned(self): """ - When a user is banned from a room after joining, total_events and + When a user is banned from a room after joining and left_members should increase by exactly 1. current_state_events should not increase. banned_members should decrease by exactly 1. @@ -630,7 +448,6 @@ def test_banned(self): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f58afbc244a8..fa3cff598ed6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -38,6 +38,9 @@ # Test room id ROOM_ID = "a-room" +# Room we're not in +OTHER_ROOM_ID = "another-room" + def _expect_edu_transaction(edu_type, content, origin="test"): return { @@ -115,6 +118,11 @@ async def check_user_in_room(room_id, user_id): hs.get_auth().check_user_in_room = check_user_in_room + async def check_host_in_room(room_id, server_name): + return room_id == ROOM_ID + + hs.get_event_auth_handler().check_host_in_room = check_host_in_room + def get_joined_hosts_for_room(room_id): return {member.domain for member in self.room_members} @@ -244,6 +252,35 @@ def test_started_typing_remote_recv(self): ], ) + def test_started_typing_remote_recv_not_in_room(self): + self.room_members = [U_APPLE, U_ONION] + + self.assertEquals(self.event_source.get_current_key(), 0) + + channel = self.make_request( + "PUT", + "/_matrix/federation/v1/send/1000000", + _make_edu_transaction_json( + "m.typing", + content={ + "room_id": OTHER_ROOM_ID, + "user_id": U_ONION.to_string(), + "typing": True, + }, + ), + federation_auth_origin=b"farm", + ) + self.assertEqual(channel.code, 200) + + self.on_new_event.assert_not_called() + + self.assertEquals(self.event_source.get_current_key(), 0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0) + ) + self.assertEquals(events[0], []) + self.assertEquals(events[1], 0) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index e45980316b6d..a37bce08c33a 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -273,7 +273,7 @@ def test_get(self): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode("ascii")) + request.write(b'{ "a": 1 }') request.finish() self.reactor.pump((0.1,)) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ed9a884d761b..d9a8b077d390 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -102,7 +102,7 @@ def do_request(): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode("ascii") + res_json = b'{ "a": 1 }' protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -339,10 +339,8 @@ def test_timeout_reading_body(self, method_name: str): # Send it the HTTP response client.dataReceived( - ( - b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n" - ) + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" ) # Push by enough to time it out diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index fefc8099c9d2..437113929a13 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -205,6 +205,41 @@ def test_https_request_via_no_proxy_star(self): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) def test_http_request_via_proxy(self): + """ + Tests that requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, + ) + def test_http_request_via_proxy_with_auth(self): + """ + Tests that authenticated requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials="bob:pinkponies") + + @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) + def test_https_request_via_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + + def _do_http_request_via_proxy( + self, + auth_credentials: Optional[str] = None, + ): + """ + Tests that requests can be made through a proxy. + """ agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -229,6 +264,23 @@ def test_http_request_via_proxy(self): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(b"bob:pinkponies") + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) @@ -241,19 +293,6 @@ def test_http_request_via_proxy(self): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): - """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials=None) - - @patch.dict( - os.environ, - {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, - ) - def test_https_request_via_proxy_with_auth(self): - """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") - def _do_https_request_via_proxy( self, auth_credentials: Optional[str] = None, diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 2c68b9a13c85..81d9e2f48474 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -100,9 +100,9 @@ def test_sending_events_into_room(self): "content": content, "sender": user_id, } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.message") self.assertEqual(event.room_id, room_id) @@ -136,9 +136,9 @@ def test_sending_events_into_room(self): "sender": user_id, "state_key": "", } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.power_levels") self.assertEqual(event.room_id, room_id) @@ -281,7 +281,7 @@ def test_send_local_online_presence_to_federation(self): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] @@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 624bd1b92722..e9fd991718d9 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -53,9 +53,9 @@ def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() - self.server = server_factory.buildProtocol( + self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( None - ) # type: ServerReplicationStreamProtocol + ) # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" @@ -195,7 +195,7 @@ def assert_request_is_get_repl_stream_updates( fetching updates for given stream. """ - path = request.path # type: bytes # type: ignore + path: bytes = request.path # type: ignore self.assertRegex( path, br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" @@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + servlets: List[Callable[[HomeServer, JsonResource], None]] = [] def setUp(self): super().setUp() @@ -448,7 +448,7 @@ def __init__(self, hs: HomeServer): super().__init__(hs) # list of received (stream_name, token, row) tuples - self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] + self.received_rdata_rows: List[Tuple[str, int, Any]] = [] async def on_rdata(self, stream_name, instance_name, token, rows): await super().on_rdata(stream_name, instance_name, token, rows) @@ -484,7 +484,7 @@ def buildProtocol(self, addr): class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" - transport = None # type: Optional[FakeTransport] + transport: Optional[FakeTransport] = None def __init__(self, server: FakeRedisPubSubServer): self._server = server @@ -550,12 +550,12 @@ def encode(self, obj): if obj is None: return "$-1\r\n" if isinstance(obj, str): - return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + return f"${len(obj)}\r\n{obj}\r\n" if isinstance(obj, int): - return ":{val}\r\n".format(val=obj) + return f":{obj}\r\n" if isinstance(obj, (list, tuple)): items = "".join(self.encode(a) for a in obj) - return "*{len}\r\n{items}".format(len=len(obj), items=items) + return f"*{len(obj)}\r\n{items}" raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f51fa0a79e9e..666008425a58 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -135,9 +135,9 @@ def test_update_function_huge_state_change(self): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) events = [ self._inject_state_event(sender=OTHER_USER) @@ -238,7 +238,7 @@ def test_update_function_huge_state_change(self): self.assertEqual(row.data.event_id, pl_event.event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for stream_name, _, row in received_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -290,11 +290,11 @@ def test_update_function_state_row_limit(self): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) - events = [] # type: List[EventBase] + events: List[EventBase] = [] for user in user_ids: events.extend( self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) @@ -355,7 +355,7 @@ def test_update_function_state_row_limit(self): self.assertEqual(row.data.event_id, pl_events[i].event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for _ in range(STATES_PER_USER + 1): stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 7f5d932f0bb2..38e292c1ab67 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -43,7 +43,7 @@ def test_receipt(self): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) @@ -75,7 +75,7 @@ def test_receipt(self): self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room2:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index ecd360c2d068..3ff5afc6e501 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ def test_typing(self): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) @@ -102,7 +102,7 @@ def test_reset(self): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 76e6644353d0..ffa425328f0d 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory] +test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): @@ -70,7 +70,7 @@ def _get_media_req( self.reactor, FakeSite(resource), "GET", - "/{}/{}".format(target, media_id), + f"/{target}/{media_id}", shorthand=False, access_token=self.access_token, await_result=False, @@ -113,7 +113,7 @@ def _get_media_req( self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5eca5c165d0a..f3615af97e86 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -211,7 +211,7 @@ def test_vector_clock_token(self): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) @@ -241,7 +241,7 @@ def test_vector_clock_token(self): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(vector_clock_token), + f"/sync?since={vector_clock_token}", access_token=access_token, ) @@ -266,7 +266,7 @@ def test_vector_clock_token(self): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 2f7090e5543b..a7c6e595b983 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -66,7 +66,7 @@ def test_delete_group(self): # Create a new group channel = self.make_request( "POST", - "/create_group".encode("ascii"), + b"/create_group", access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -129,9 +129,7 @@ def _check_group(self, group_id, expect_code): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request( - "GET", "/joined_groups".encode("ascii"), access_token=access_token - ) + channel = self.make_request("GET", b"/joined_groups", access_token=access_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ee071c2477b4..17ec8bfd3b93 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -535,7 +535,7 @@ def _is_purged(self, room_id): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") def _assert_peek(self, room_id, expect_code): """Assert that the admin user can (or cannot) peek into the room.""" @@ -599,7 +599,7 @@ def test_purge_room(self): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") class RoomTestCase(unittest.HomeserverTestCase): @@ -1280,7 +1280,7 @@ def prepare(self, reactor, clock, homeserver): self.public_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" def test_requester_is_no_admin(self): """ @@ -1420,7 +1420,7 @@ def test_join_private_room_if_not_member(self): private_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1463,7 +1463,7 @@ def test_join_private_room_if_member(self): # Join user to room. - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1493,7 +1493,7 @@ def test_join_private_room_if_owner(self): private_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1633,7 +1633,7 @@ def test_public_room(self): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1660,7 +1660,7 @@ def test_private_room(self): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1686,7 +1686,7 @@ def test_other_user(self): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1720,7 +1720,7 @@ def test_not_enough_power(self): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1753,7 +1753,6 @@ def test_not_enough_power(self): "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e1fe72fc5d4c..28dd47a28bf4 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -16,17 +16,19 @@ from unittest.mock import Mock from synapse.events import EventBase +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import Requester, StateMap +from synapse.util.frozenutils import unfreeze from tests import unittest thread_local = threading.local() -class ThirdPartyRulesTestModule: +class LegacyThirdPartyRulesTestModule: def __init__(self, config: Dict, module_api: ModuleApi): # keep a record of the "current" rules module, so that the test can patch # it if desired. @@ -46,8 +48,26 @@ def parse_config(config): return config -def current_rules_module() -> ThirdPartyRulesTestModule: - return thread_local.rules_module +class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return False + + +class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + d = event.get_dict() + content = unfreeze(event.content) + content["foo"] = "bar" + d["content"] = content + return d class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): @@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def default_config(self): - config = super().default_config() - config["third_party_event_rules"] = { - "module": __name__ + ".ThirdPartyRulesTestModule", - "config": {}, - } - return config + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + load_legacy_third_party_event_rules(hs) + + return hs def prepare(self, reactor, clock, homeserver): # Create a user and room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.tok = self.login("kermit", "monkey") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + # Some tests might prevent room creation on purpose. + try: + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + except Exception: + pass def test_third_party_rules(self): """Tests that a forbidden event is forbidden from being sent, but an allowed one @@ -79,10 +102,12 @@ def test_third_party_rules(self): # patch the rules module with a Mock which will return False for some event # types async def check(ev, state): - return ev.type != "foo.bar.forbidden" + return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) - current_rules_module().check_event_allowed = callback + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [ + callback + ] channel = self.make_request( "PUT", @@ -116,9 +141,9 @@ def test_cannot_modify_event(self): # first patch the event checker so that it will try to modify the event async def check(ev: EventBase, state): ev.content = {"x": "y"} - return True + return True, None - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -127,7 +152,19 @@ async def check(ev: EventBase, state): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"500", channel.result) + # check_event_allowed has some error handling, so it shouldn't 500 just because a + # module did something bad. + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "x") def test_modify_event(self): """The module can return a modified version of the event""" @@ -135,9 +172,9 @@ def test_modify_event(self): async def check(ev: EventBase, state): d = ev.get_dict() d["content"] = {"x": "y"} - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -168,9 +205,9 @@ async def check(ev: EventBase, state): "msgtype": "m.text", "body": d["content"]["body"].upper(), } - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # Send an event, then edit it. channel = self.make_request( @@ -222,7 +259,7 @@ async def check(ev: EventBase, state): self.assertEqual(ev["content"]["body"], "EDITED BODY") def test_send_event(self): - """Tests that the module can send an event into a room via the module api""" + """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", "body": "Hello!", @@ -233,13 +270,60 @@ def test_send_event(self): "content": content, "sender": self.user_id, } - event = self.get_success( - current_rules_module().module_api.create_and_send_event_into_room( - event_dict - ) - ) # type: EventBase + event: EventBase = self.get_success( + self.hs.get_module_api().create_and_send_event_into_room(event_dict) + ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) self.assertEquals(event.type, "m.room.message") self.assertEquals(event.content, content) + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyChangeEvents", + "config": {}, + } + } + ) + def test_legacy_check_event_allowed(self): + """Tests that the wrapper for legacy check_event_allowed callbacks works + correctly. + """ + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + self.assertIn("foo", channel.json_body["content"].keys()) + self.assertEqual(channel.json_body["content"]["foo"], "bar") + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyDenyNewRooms", + "config": {}, + } + } + ) + def test_legacy_on_create_room(self): + """Tests that the wrapper for legacy on_create_room callbacks works + correctly. + """ + self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 605b9523162e..7eba69642a6b 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -453,7 +453,7 @@ def test_get_msc2858_login_flows(self): self.assertEqual(channel.code, 200, channel.result) # stick the flows results in a dict by type - flow_results = {} # type: Dict[str, Any] + flow_results: Dict[str, Any] = {} for f in channel.json_body["flows"]: flow_type = f["type"] self.assertNotIn( @@ -501,7 +501,7 @@ def test_multi_sso_redirect(self): p.close() # there should be a link for each href - returned_idps = [] # type: List[str] + returned_idps: List[str] = [] for link in p.links: path, query = link.split("?", 1) self.assertEqual(path, "pick_idp") @@ -582,7 +582,7 @@ def test_login_via_oidc(self): # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") assert cookie_headers - cookies = {} # type: Dict[str, str] + cookies: Dict[str, str] = {} for h in cookie_headers: key, value = h.split(";")[0].split("=", maxsplit=1) cookies[key] = value @@ -874,9 +874,7 @@ def make_homeserver(self, reactor, clock): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode( - payload, secret, self.jwt_algorithm - ) # type: Union[str, bytes] + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) if isinstance(result, bytes): return result.decode("ascii") return result @@ -1084,7 +1082,7 @@ def make_homeserver(self, reactor, clock): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") if isinstance(result, bytes): return result.decode("ascii") return result @@ -1272,7 +1270,7 @@ def test_username_picker(self): self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie - cookies = {} # type: Dict[str,str] + cookies: Dict[str, str] = {} channel.extract_cookies(cookies) self.assertIn("username_mapping_session", cookies) session_id = cookies["username_mapping_session"] diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e94566ffd70b..3df070c93653 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1206,7 +1206,7 @@ def test_join_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/join", content={"reason": reason}, access_token=self.second_tok, ) @@ -1220,7 +1220,7 @@ def test_leave_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) @@ -1234,7 +1234,7 @@ def test_kick_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/kick", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) @@ -1248,7 +1248,7 @@ def test_ban_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/ban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1260,7 +1260,7 @@ def test_unban_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/unban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1272,7 +1272,7 @@ def test_invite_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/invite", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1291,7 +1291,7 @@ def test_reject_invite_reason(self): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 69798e95c3f7..fc2d35596ea3 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -19,7 +19,7 @@ import re import time import urllib.parse -from typing import Any, Dict, Mapping, MutableMapping, Optional +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from unittest.mock import patch import attr @@ -53,6 +53,9 @@ def create_room_as( tok: str = None, expect_code: int = 200, extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> str: """ Create a room. @@ -87,6 +90,7 @@ def create_room_as( "POST", path, json.dumps(content).encode("utf8"), + custom_headers=custom_headers, ) assert channel.result["code"] == b"%d" % expect_code, channel.result @@ -175,14 +179,30 @@ def change_membership( self.auth_user_id = temp_id - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): if body is None: body = "body_text_here" content = {"msgtype": "m.text", "body": body} return self.send_event( - room_id, "m.room.message", content, txn_id, tok, expect_code + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, ) def send_event( @@ -193,6 +213,9 @@ def send_event( txn_id=None, tok=None, expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -207,6 +230,7 @@ def send_event( "PUT", path, json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, ) assert ( diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 856aa8682f7e..2e2f94742ef8 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -273,7 +273,7 @@ def test_aggregation_pagination_within_group(self): prev_token = None found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index 1ec6b05e5bd2..a76a6fef1e3f 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -41,7 +41,7 @@ def prepare(self, reactor, clock, hs): self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) resp = self.helper.send(self.room_id, tok=self.admin_user_tok) self.event_id = resp["event_id"] - self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" def test_reason_str_and_score_int(self): data = {"reason": "this makes me sad", "score": -100} diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 95e707584160..2d6b49692ee7 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -310,7 +310,7 @@ def test_disposition_filenamestar_utf8escaped(self): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote("\u2603".encode("utf8")).encode("ascii") + filename = parse.quote("\u2603".encode()).encode("ascii") channel = self._req( b"inline; filename*=utf-8''" + filename + self.test_image.extension ) diff --git a/tests/server.py b/tests/server.py index f32d8dc37568..6fddd3b30558 100644 --- a/tests/server.py +++ b/tests/server.py @@ -52,7 +52,7 @@ class FakeChannel: _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) _ip = attr.ib(type=str, default="127.0.0.1") - _producer = None # type: Optional[Union[IPullProducer, IPushProducer]] + _producer: Optional[Union[IPullProducer, IPushProducer]] = None @property def json_body(self): @@ -316,8 +316,10 @@ def __init__(self): self._tcp_callbacks = {} self._udp = [] - lookups = self.lookups = {} # type: Dict[str, str] - self._thread_callbacks = deque() # type: Deque[Callable[[], None]] + self.lookups: Dict[str, str] = {} + self._thread_callbacks: Deque[Callable[[], None]] = deque() + + lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 9ca70e7367b4..d326a1d6a642 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -98,3 +98,16 @@ def test_drop(self): lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) self.assertIsNotNone(lock2) + + def test_shutdown(self): + """Test that shutting down Synapse releases the locks""" + # Acquire two locks + lock = self.get_success(self.store.try_acquire_lock("name", "key1")) + self.assertIsNotNone(lock) + lock2 = self.get_success(self.store.try_acquire_lock("name", "key2")) + self.assertIsNotNone(lock2) + + # Now call the shutdown code + self.get_success(self.store._on_shutdown()) + + self.assertEqual(self.store._live_tokens, {}) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 069db0edc43a..0da42b5ac559 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -7,9 +7,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = ( - self.hs.get_datastore().db_pool.updates - ) # type: BackgroundUpdater + self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 41bef62ca8fc..43628ce44f41 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -59,5 +59,5 @@ def test_delete_alias(self): self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (self.get_success(self.store.get_association_from_room_alias(self.alias))) + self.get_success(self.store.get_association_from_room_alias(self.alias)) ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 792b1c44c198..748607828496 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 8a446da848d7..a1ba99ff140e 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -45,11 +45,7 @@ def test_displayname(self): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_displayname(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) def test_avatar_url(self): @@ -76,9 +72,5 @@ def test_avatar_url(self): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_avatar_url(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) ) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 54c5b470c789..e5574063f17f 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -75,7 +75,7 @@ def test_purge_history_wont_delete_extrems(self): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format(token.topological + 1, token.stream + 1) + event = f"t{token.topological + 1}-{token.stream + 1}" # Purge everything before this topological token f = self.get_failure( diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 70257bf21027..31ce7f62528f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -49,7 +49,7 @@ def test_get_room(self): ) def test_get_room_unknown_room(self): - self.assertIsNone((self.get_success(self.store.get_room("!uknown:test")))) + self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) def test_get_room_with_stats(self): self.assertDictContainsSubset( diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 88888319ccb7..f73306ecc4e7 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,12 +13,13 @@ # limitations under the License. import unittest +from typing import Optional from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.types import get_domain_from_id +from synapse.events import EventBase, make_event_from_dict +from synapse.types import JsonDict, get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -432,7 +433,7 @@ def test_join_rules_msc3083_restricted(self): TEST_ROOM_ID = "!test:room" -def _create_event(user_id): +def _create_event(user_id: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -444,7 +445,9 @@ def _create_event(user_id): ) -def _member_event(user_id, membership, sender=None): +def _member_event( + user_id: str, membership: str, sender: Optional[str] = None +) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -458,11 +461,11 @@ def _member_event(user_id, membership, sender=None): ) -def _join_event(user_id): +def _join_event(user_id: str) -> EventBase: return _member_event(user_id, "join") -def _power_levels_event(sender, content): +def _power_levels_event(sender: str, content: JsonDict) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -475,7 +478,7 @@ def _power_levels_event(sender, content): ) -def _alias_event(sender, **kwargs): +def _alias_event(sender: str, **kwargs) -> EventBase: data = { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -488,7 +491,7 @@ def _alias_event(sender, **kwargs): return make_event_from_dict(data) -def _random_state_event(sender): +def _random_state_event(sender: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -501,7 +504,7 @@ def _random_state_event(sender): ) -def _join_rules_event(sender, join_rule): +def _join_rules_event(sender: str, join_rule: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -519,7 +522,7 @@ def _join_rules_event(sender, join_rule): event_count = 0 -def _get_event_id(): +def _get_event_id() -> str: global event_count c = event_count event_count += 1 diff --git a/tests/test_state.py b/tests/test_state.py index 62f70958732c..e5488df1ac71 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -168,6 +168,7 @@ def setUp(self): "get_state_handler", "get_clock", "get_state_resolution_handler", + "get_account_validity_handler", "hostname", ] ) @@ -199,7 +200,7 @@ def test_branch_no_conflict(self): self.store.register_events(graph.walk()) - context_store = {} # type: dict[str, EventContext] + context_store: dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( diff --git a/tests/test_types.py b/tests/test_types.py index d7881021d3f5..0d0c00d97a06 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -103,6 +103,4 @@ def testLeadingUnderscore(self): def testNonAscii(self): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("tĂȘst"), "t=c3=aast") - self.assertEqual( - map_username_to_mxid_localpart("tĂȘst".encode("utf-8")), "t=c3=aast" - ) + self.assertEqual(map_username_to_mxid_localpart("tĂȘst".encode()), "t=c3=aast") diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index 1fbb38f4be03..e878af5f12e7 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -23,13 +23,13 @@ def __init__(self): super().__init__() # a list of links found in the doc - self.links = [] # type: List[str] + self.links: List[str] = [] # the values of any hidden s: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] + self.hiddens: Dict[str, Optional[str]] = {} # the values of any radio buttons: map from name to list of values - self.radios = {} # type: Dict[str, List[Optional[str]]] + self.radios: Dict[str, List[Optional[str]]] = {} def handle_starttag( self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] diff --git a/tests/unittest.py b/tests/unittest.py index 74db7c08f1ee..3eec9c4d5b6f 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -140,7 +140,7 @@ def assertObjectHasAttributes(self, attrs, obj): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))("Assert error for '.{}':".format(key)) from e + raise (type(e))(f"Assert error for '.{key}':") from e def assert_dict(self, required, actual): """Does a partial assert of a dict. @@ -520,7 +520,7 @@ def get_success_or_raise(self, d, by=0.0): if not isinstance(deferred, Deferred): return d - results = [] # type: list + results: list = [] deferred.addBoth(results.append) self.pump(by=by) @@ -594,7 +594,15 @@ def register_user( user_id = channel.json_body["user_id"] return user_id - def login(self, username, password, device_id=None): + def login( + self, + username, + password, + device_id=None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): """ Log in a user, and get an access token. Requires the Login API be registered. @@ -605,7 +613,10 @@ def login(self, username, password, device_id=None): body["device_id"] = device_id channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") + "POST", + "/_matrix/client/r0/login", + json.dumps(body).encode("utf8"), + custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 0277998cbe59..39947a166b9b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -174,7 +174,7 @@ def fn(self, arg1): return self.result obj = Cls() - callbacks = set() # type: Set[str] + callbacks: Set[str] = set() # set off an asynchronous request obj.result = origin_d = defer.Deferred() diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index e712eb42ea4b..3c0ddd4f1839 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -44,7 +44,7 @@ def test_uneven_parts(self): ) def test_empty_input(self): - parts = chunk_seq([], 5) # type: Iterable[Sequence] + parts: Iterable[Sequence] = chunk_seq([], 5) self.assertEqual( list(parts), @@ -56,13 +56,13 @@ class SortTopologically(TestCase): def test_empty(self): "Test that an empty graph works correctly" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} self.assertEqual(list(sorted_topologically([], graph)), []) def test_handle_empty_graph(self): "Test that a graph where a node doesn't have an entry is treated as empty" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -70,7 +70,7 @@ def test_handle_empty_graph(self): def test_disconnected(self): "Test that a graph with no edges work" - graph = {1: [], 2: []} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: []} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -78,19 +78,19 @@ def test_disconnected(self): def test_linear(self): "Test that a simple `4 -> 3 -> 2 -> 1` graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_subset(self): "Test that only sorting a subset of the graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) def test_fork(self): "Test that a forked graph works" - graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # always get the same one. @@ -98,12 +98,12 @@ def test_fork(self): def test_duplicates(self): "Test that a graph with duplicate edges work" - graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_multiple_paths(self): "Test that a graph with multiple paths between two nodes work" - graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])