diff --git a/.ci/scripts/auditwheel_wrapper.py b/.ci/scripts/auditwheel_wrapper.py index a33b39314fb8..18cd0a7b528d 100755 --- a/.ci/scripts/auditwheel_wrapper.py +++ b/.ci/scripts/auditwheel_wrapper.py @@ -50,7 +50,16 @@ def cpython(wheel_file: str, name: str, version: Version, tag: Tag) -> str: check_is_abi3_compatible(wheel_file) - abi3_tag = Tag(tag.interpreter, "abi3", tag.platform) + # HACK: it seems that some older versions of pip will consider a wheel marked + # as macosx_11_0 as incompatible with Big Sur. I haven't done the full archaeology + # here; there are some clues in + # https://github.com/pantsbuild/pants/pull/12857 + # https://github.com/pypa/pip/issues/9138 + # https://github.com/pypa/packaging/pull/319 + # Empirically this seems to work, note that macOS 11 and 10.16 are the same, + # both versions are valid for backwards compatibility. + platform = tag.platform.replace("macosx_11_0", "macosx_10_16") + abi3_tag = Tag(tag.interpreter, "abi3", platform) dirname = os.path.dirname(wheel_file) new_wheel_file = os.path.join( diff --git a/.ci/scripts/check_lockfile.py b/.ci/scripts/check_lockfile.py new file mode 100755 index 000000000000..dfdc0105d58a --- /dev/null +++ b/.ci/scripts/check_lockfile.py @@ -0,0 +1,23 @@ +#! /usr/bin/env python +import sys + +if sys.version_info < (3, 11): + raise RuntimeError("Requires at least Python 3.11, to import tomllib") + +import tomllib + +with open("poetry.lock", "rb") as f: + lockfile = tomllib.load(f) + +try: + lock_version = lockfile["metadata"]["lock-version"] + assert lock_version == "2.0" +except Exception: + print( + """\ + Lockfile is not version 2.0. You probably need to upgrade poetry on your local box + and re-run `poetry lock --no-update`. See the Poetry cheat sheet at + https://matrix-org.github.io/synapse/develop/development/dependencies.html + """ + ) + raise diff --git a/.ci/scripts/prepare_old_deps.sh b/.ci/scripts/prepare_old_deps.sh index 7e4f060b17b8..3398193ee565 100755 --- a/.ci/scripts/prepare_old_deps.sh +++ b/.ci/scripts/prepare_old_deps.sh @@ -53,7 +53,7 @@ with open('pyproject.toml', 'w') as f: " python3 -c "$REMOVE_DEV_DEPENDENCIES" -pip install poetry==1.2.0 +pip install poetry==1.3.2 poetry lock echo "::group::Patched pyproject.toml" diff --git a/.ci/scripts/test_export_data_command.sh b/.ci/scripts/test_export_data_command.sh index 9f6c49acff73..36f836345cae 100755 --- a/.ci/scripts/test_export_data_command.sh +++ b/.ci/scripts/test_export_data_command.sh @@ -23,8 +23,9 @@ poetry run python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-dat --output-directory /tmp/export_data # Test that the output directory exists and contains the rooms directory -dir="/tmp/export_data/rooms" -if [ -d "$dir" ]; then +dir_r="/tmp/export_data/rooms" +dir_u="/tmp/export_data/user_data" +if [ -d "$dir_r" ] && [ -d "$dir_u" ]; then echo "Command successful, this test passes" else echo "No output directories found, the command fails against a sqlite database." @@ -43,8 +44,9 @@ poetry run python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-d --output-directory /tmp/export_data2 # Test that the output directory exists and contains the rooms directory -dir2="/tmp/export_data2/rooms" -if [ -d "$dir2" ]; then +dir_r2="/tmp/export_data2/rooms" +dir_u2="/tmp/export_data2/user_data" +if [ -d "$dir_r2" ] && [ -d "$dir_u2" ]; then echo "Command successful, this test passes" else echo "No output directories found, the command fails against a postgres database." diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 49427ab50d0f..4bbe5decf88e 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -48,7 +48,7 @@ jobs: type=pep440,pattern={{raw}} - name: Build and push all platforms - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v4 with: push: true labels: "gitsha1=${{ github.sha }}" diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 0b33058337b9..55b4b287f677 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -58,7 +58,7 @@ jobs: # Deploy to the target directory. - name: Deploy to gh pages - uses: peaceiris/actions-gh-pages@64b46b4226a4a12da2239ba3ea5aa73e3163c75b # v3.9.1 + uses: peaceiris/actions-gh-pages@bd8c6b06eba6b3d25d72b7a1767993c0aeee42e7 # v3.9.2 with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./book diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index 5ab9a8af3411..99fc2cee08c4 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -27,7 +27,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -37,7 +37,7 @@ jobs: - uses: matrix-org/setup-python-poetry@v1 with: python-version: "3.x" - poetry-version: "1.2.0" + poetry-version: "1.3.2" extras: "all" # Dump installed versions for debugging. - run: poetry run pip list > before.txt @@ -61,7 +61,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -134,7 +134,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: stable - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 30ac4c157169..bf57bcab6104 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -127,7 +127,7 @@ jobs: python-version: "3.x" - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.9.0 poetry==1.2.0 + run: python -m pip install cibuildwheel==2.9.0 - name: Set up QEMU to emulate aarch64 if: matrix.arch == 'aarch64' @@ -148,7 +148,7 @@ jobs: env: # Skip testing for platforms which various libraries don't have wheels # for, and so need extra build deps. - CIBW_TEST_SKIP: pp3{7,9}-* *i686* *musl* + CIBW_TEST_SKIP: pp3*-* *i686* *musl* # Fix Rust OOM errors on emulated aarch64: https://github.com/rust-lang/cargo/issues/10583 CARGO_NET_GIT_FETCH_WITH_CLI: true CIBW_ENVIRONMENT_PASS_LINUX: CARGO_NET_GIT_FETCH_WITH_CLI diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5a0c0a0d65c9..e945ffe7f3ca 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,11 +33,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: "3.x" - uses: matrix-org/setup-python-poetry@v1 with: + python-version: "3.x" + poetry-version: "1.3.2" extras: "all" - run: poetry run scripts-dev/generate_sample_config.sh --check - run: poetry run scripts-dev/config-lint.sh @@ -52,6 +51,15 @@ jobs: - run: "pip install 'click==8.1.1' 'GitPython>=3.1.20'" - run: scripts-dev/check_schema_delta.py --force-colors + check-lockfile: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - run: .ci/scripts/check_lockfile.py + lint: uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v2" with: @@ -88,6 +96,7 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} - uses: matrix-org/setup-python-poetry@v1 with: + poetry-version: "1.3.2" extras: "all" - run: poetry run scripts-dev/check_pydantic_models.py @@ -103,7 +112,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 components: clippy @@ -125,7 +134,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: nightly-2022-12-01 components: clippy @@ -145,7 +154,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 components: rustfmt @@ -163,6 +172,7 @@ jobs: - lint-pydantic - check-sampleconfig - check-schema-delta + - check-lockfile - lint-clippy - lint-rustfmt runs-on: ubuntu-latest @@ -211,7 +221,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -219,6 +229,7 @@ jobs: - uses: matrix-org/setup-python-poetry@v1 with: python-version: ${{ matrix.job.python-version }} + poetry-version: "1.3.2" extras: ${{ matrix.job.extras }} - name: Await PostgreSQL if: ${{ matrix.job.postgres-version }} @@ -255,7 +266,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -294,6 +305,7 @@ jobs: - uses: matrix-org/setup-python-poetry@v1 with: python-version: '3.7' + poetry-version: "1.3.2" extras: "all test" - run: poetry run trial -j6 tests @@ -328,6 +340,7 @@ jobs: - uses: matrix-org/setup-python-poetry@v1 with: python-version: ${{ matrix.python-version }} + poetry-version: "1.3.2" extras: ${{ matrix.extras }} - run: poetry run trial --jobs=2 tests - name: Dump logs @@ -373,7 +386,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -419,6 +432,7 @@ jobs: - run: sudo apt-get -qq install xmlsec1 postgresql-client - uses: matrix-org/setup-python-poetry@v1 with: + poetry-version: "1.3.2" extras: "postgres" - run: .ci/scripts/test_export_data_command.sh env: @@ -470,6 +484,7 @@ jobs: - uses: matrix-org/setup-python-poetry@v1 with: python-version: ${{ matrix.python-version }} + poetry-version: "1.3.2" extras: "postgres" - run: .ci/scripts/test_synapse_port_db.sh id: run_tester_script @@ -516,7 +531,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -526,8 +541,11 @@ jobs: - run: | set -o pipefail - POSTGRES=${{ (matrix.database == 'Postgres') && 1 || '' }} WORKERS=${{ (matrix.arrangement == 'workers') && 1 || '' }} COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | synapse/.ci/scripts/gotestfmt + COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | synapse/.ci/scripts/gotestfmt shell: bash + env: + POSTGRES: ${{ (matrix.database == 'Postgres') && 1 || '' }} + WORKERS: ${{ (matrix.arrangement == 'workers') && 1 || '' }} name: Run Complement Tests cargo-test: @@ -544,13 +562,36 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 - run: cargo test + # We want to ensure that the cargo benchmarks still compile, which requires a + # nightly compiler. + cargo-bench: + if: ${{ needs.changes.outputs.rust == 'true' }} + runs-on: ubuntu-latest + needs: + - linting-done + - changes + + steps: + - uses: actions/checkout@v3 + + - name: Install Rust + # There don't seem to be versioned releases of this action per se: for each rust + # version there is a branch which gets constantly rebased on top of master. + # We pin to a specific commit for paranoia's sake. + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 + with: + toolchain: nightly-2022-12-01 + - uses: Swatinem/rust-cache@v2 + + - run: cargo bench --no-run + # a job which marks all the other jobs as complete, thus allowing PRs to be merged. tests-done: if: ${{ always() }} @@ -562,6 +603,7 @@ jobs: - portdb - complement - cargo-test + - cargo-bench runs-on: ubuntu-latest steps: - uses: matrix-org/done-action@v2 @@ -573,3 +615,4 @@ jobs: skippable: | lint-newsfile cargo-test + cargo-bench diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 0a88f0cd7b6e..a59c8dac09b8 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -43,7 +43,7 @@ jobs: - run: sudo apt-get -qq install xmlsec1 - name: Install Rust - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -82,7 +82,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -148,7 +148,7 @@ jobs: run: | set -x DEBIAN_FRONTEND=noninteractive sudo apt-get install -yqq python3 pipx - pipx install poetry==1.2.0 + pipx install poetry==1.3.2 poetry remove -n twisted poetry add -n --extras tls git+https://github.com/twisted/twisted.git#trunk diff --git a/.gitignore b/.gitignore index 2b09bddf18f6..6937de88bcd1 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ book/ # Poetry will create a setup.py, which we don't want to include. /setup.py + +# Don't include users' poetry configs +/poetry.toml diff --git a/CHANGES.md b/CHANGES.md index 30f0a0674761..a62bd4eb28a3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,222 @@ +Synapse 1.77.0 (2023-02-14) +=========================== + +No significant changes since 1.77.0rc2. + + +Synapse 1.77.0rc2 (2023-02-10) +============================== + +Bugfixes +-------- + +- Fix bug where retried replication requests would return a failure. Introduced in v1.76.0. ([\#15024](https://github.com/matrix-org/synapse/issues/15024)) + + +Internal Changes +---------------- + +- Prepare for future database schema changes. ([\#15036](https://github.com/matrix-org/synapse/issues/15036)) + + +Synapse 1.77.0rc1 (2023-02-07) +============================== + +Features +-------- + +- Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions. ([\#14823](https://github.com/matrix-org/synapse/issues/14823), [\#14943](https://github.com/matrix-org/synapse/issues/14943), [\#14957](https://github.com/matrix-org/synapse/issues/14957), [\#14958](https://github.com/matrix-org/synapse/issues/14958)) +- Experimental support to suppress notifications from message edits ([MSC3958](https://github.com/matrix-org/matrix-spec-proposals/pull/3958)). ([\#14960](https://github.com/matrix-org/synapse/issues/14960), [\#15016](https://github.com/matrix-org/synapse/issues/15016)) +- Add profile information, devices and connections to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.77/usage/administration/admin_faq.html#how-can-i-export-user-data). ([\#14894](https://github.com/matrix-org/synapse/issues/14894)) +- Improve performance when joining or sending an event in large rooms. ([\#14962](https://github.com/matrix-org/synapse/issues/14962)) +- Improve performance of joining and leaving large rooms with many local users. ([\#14971](https://github.com/matrix-org/synapse/issues/14971)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.53.0 where `next_batch` tokens from `/sync` could not be used with the `/relations` endpoint. ([\#14866](https://github.com/matrix-org/synapse/issues/14866)) +- Fix a bug introduced in Synapse 1.35.0 where the module API's `send_local_online_presence_to` would fail to send presence updates over federation. ([\#14880](https://github.com/matrix-org/synapse/issues/14880)) +- Fix a bug introduced in Synapse 1.70.0 where the background updates to add non-thread unique indexes on receipts could fail when upgrading from 1.67.0 or earlier. ([\#14915](https://github.com/matrix-org/synapse/issues/14915)) +- Fix a regression introduced in Synapse 1.69.0 which can result in database corruption when database migrations are interrupted on sqlite. ([\#14926](https://github.com/matrix-org/synapse/issues/14926)) +- Fix a bug introduced in Synapse 1.68.0 where we were unable to service remote joins in rooms with `@room` notification levels set to `null` in their (malformed) power levels. ([\#14942](https://github.com/matrix-org/synapse/issues/14942)) +- Fix a bug introduced in Synapse 1.64.0 where boolean power levels were erroneously permitted in [v10 rooms](https://spec.matrix.org/v1.5/rooms/v10/). ([\#14944](https://github.com/matrix-org/synapse/issues/14944)) +- Fix a long-standing bug where sending messages on servers with presence enabled would spam "Re-starting finished log context" log lines. ([\#14947](https://github.com/matrix-org/synapse/issues/14947)) +- Fix a bug introduced in Synapse 1.68.0 where logging from the Rust module was not properly logged. ([\#14976](https://github.com/matrix-org/synapse/issues/14976)) +- Fix various long-standing bugs in Synapse's config, event and request handling where booleans were unintentionally accepted where an integer was expected. ([\#14945](https://github.com/matrix-org/synapse/issues/14945)) + + +Internal Changes +---------------- + +- Add missing type hints. ([\#14879](https://github.com/matrix-org/synapse/issues/14879), [\#14886](https://github.com/matrix-org/synapse/issues/14886), [\#14887](https://github.com/matrix-org/synapse/issues/14887), [\#14904](https://github.com/matrix-org/synapse/issues/14904), [\#14927](https://github.com/matrix-org/synapse/issues/14927), [\#14956](https://github.com/matrix-org/synapse/issues/14956), [\#14983](https://github.com/matrix-org/synapse/issues/14983), [\#14984](https://github.com/matrix-org/synapse/issues/14984), [\#14985](https://github.com/matrix-org/synapse/issues/14985), [\#14987](https://github.com/matrix-org/synapse/issues/14987), [\#14988](https://github.com/matrix-org/synapse/issues/14988), [\#14990](https://github.com/matrix-org/synapse/issues/14990), [\#14991](https://github.com/matrix-org/synapse/issues/14991), [\#14992](https://github.com/matrix-org/synapse/issues/14992), [\#15007](https://github.com/matrix-org/synapse/issues/15007)) +- Use `StrCollection` to avoid potential bugs with `Collection[str]`. ([\#14922](https://github.com/matrix-org/synapse/issues/14922)) +- Allow running the complement tests suites with the asyncio reactor enabled. ([\#14858](https://github.com/matrix-org/synapse/issues/14858)) +- Improve performance of `/sync` in a few situations. ([\#14908](https://github.com/matrix-org/synapse/issues/14908), [\#14970](https://github.com/matrix-org/synapse/issues/14970)) +- Document how to handle Dependabot pull requests. ([\#14916](https://github.com/matrix-org/synapse/issues/14916)) +- Fix typo in release script. ([\#14920](https://github.com/matrix-org/synapse/issues/14920)) +- Update build system requirements to allow building with poetry-core 1.5.0. ([\#14949](https://github.com/matrix-org/synapse/issues/14949), [\#15019](https://github.com/matrix-org/synapse/issues/15019)) +- Add an [lnav](https://lnav.org) config file for Synapse logs to `/contrib/lnav`. ([\#14953](https://github.com/matrix-org/synapse/issues/14953)) +- Faster joins: Refactor internal handling of servers in room to never store an empty list. ([\#14954](https://github.com/matrix-org/synapse/issues/14954)) +- Faster joins: tag `v2/send_join/` requests to indicate if they served a partial join response. ([\#14950](https://github.com/matrix-org/synapse/issues/14950)) +- Allow running `cargo` without the `extension-module` option. ([\#14965](https://github.com/matrix-org/synapse/issues/14965)) +- Preparatory work for adding a denormalised event stream ordering column in the future. Contributed by Nick @ Beeper (@fizzadar). ([\#14979](https://github.com/matrix-org/synapse/issues/14979), [9cd7610](https://github.com/matrix-org/synapse/commit/9cd7610f86ab5051c9365dd38d1eec405a5f8ca6), [f10caa7](https://github.com/matrix-org/synapse/commit/f10caa73eee0caa91cf373966104d1ededae2aee); see [\#15014](https://github.com/matrix-org/synapse/issues/15014)) +- Add tests for `_flatten_dict`. ([\#14981](https://github.com/matrix-org/synapse/issues/14981), [\#15002](https://github.com/matrix-org/synapse/issues/15002)) + +
Dependabot updates + +- Bump dtolnay/rust-toolchain from e645b0cf01249a964ec099494d38d2da0f0b349f to 9cd00a88a73addc8617065438eff914dd08d0955. ([\#14968](https://github.com/matrix-org/synapse/issues/14968)) +- Bump docker/build-push-action from 3 to 4. ([\#14952](https://github.com/matrix-org/synapse/issues/14952)) +- Bump ijson from 3.1.4 to 3.2.0.post0. ([\#14935](https://github.com/matrix-org/synapse/issues/14935)) +- Bump types-pyyaml from 6.0.12.2 to 6.0.12.3. ([\#14936](https://github.com/matrix-org/synapse/issues/14936)) +- Bump types-jsonschema from 4.17.0.2 to 4.17.0.3. ([\#14937](https://github.com/matrix-org/synapse/issues/14937)) +- Bump types-pillow from 9.4.0.3 to 9.4.0.5. ([\#14938](https://github.com/matrix-org/synapse/issues/14938)) +- Bump hiredis from 2.0.0 to 2.1.1. ([\#14939](https://github.com/matrix-org/synapse/issues/14939)) +- Bump hiredis from 2.1.1 to 2.2.1. ([\#14993](https://github.com/matrix-org/synapse/issues/14993)) +- Bump types-setuptools from 65.6.0.3 to 67.1.0.0. ([\#14994](https://github.com/matrix-org/synapse/issues/14994)) +- Bump prometheus-client from 0.15.0 to 0.16.0. ([\#14995](https://github.com/matrix-org/synapse/issues/14995)) +- Bump anyhow from 1.0.68 to 1.0.69. ([\#14996](https://github.com/matrix-org/synapse/issues/14996)) +- Bump serde_json from 1.0.91 to 1.0.92. ([\#14997](https://github.com/matrix-org/synapse/issues/14997)) +- Bump isort from 5.11.4 to 5.11.5. ([\#14998](https://github.com/matrix-org/synapse/issues/14998)) +- Bump phonenumbers from 8.13.4 to 8.13.5. ([\#14999](https://github.com/matrix-org/synapse/issues/14999)) +
+ +Synapse 1.76.0 (2023-01-31) +=========================== + +The 1.76 release is the first to enable faster joins ([MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706) and [MSC3902](https://github.com/matrix-org/matrix-spec-proposals/pull/3902)) by default. Admins can opt-out: see [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.76/docs/upgrade.md#faster-joins-are-enabled-by-default) for more details. + +The upgrade from 1.75 to 1.76 changes the account data replication streams in a backwards-incompatible manner. Server operators running a multi-worker deployment should consult [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.76/docs/upgrade.md#changes-to-the-account-data-replication-streams). + +Those who are `poetry install`ing from source using our lockfile should ensure their poetry version is 1.3.2 or higher; [see upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.76/docs/upgrade.md#minimum-version-of-poetry-is-now-132). + + +Notes on faster joins +--------------------- + +The faster joins project sees the most benefit when joining a room with a large number of members (joined or historical). We expect it to be particularly useful for joining large public rooms like the [Matrix HQ](https://matrix.to/#/#matrix:matrix.org) or [Synapse Admins](https://matrix.to/#/#synapse:matrix.org) rooms. + +After a faster join, Synapse considers that room "partially joined". In this state, you should be able to + +- read incoming messages; +- see incoming state changes, e.g. room topic changes; and +- send messages, if the room is unencrypted. + +Synapse has to spend more effort to complete the join in the background. Once this finishes, you will be able to + +- send messages, if the room is in encrypted; +- retrieve room history from before your join, if permitted by the room settings; and +- access the full list of room members. + + +Improved Documentation +---------------------- + +- Describe the ideas and the internal machinery behind faster joins. ([\#14677](https://github.com/matrix-org/synapse/issues/14677)) + + +Synapse 1.76.0rc2 (2023-01-27) +============================== + +Bugfixes +-------- + +- Faster joins: Fix a bug introduced in Synapse 1.69 where device list EDUs could fail to be handled after a restart when a faster join sync is in progress. ([\#14914](https://github.com/matrix-org/synapse/issues/14914)) + + +Internal Changes +---------------- + +- Faster joins: Improve performance of looking up partial-state status of rooms. ([\#14917](https://github.com/matrix-org/synapse/issues/14917)) + + +Synapse 1.76.0rc1 (2023-01-25) +============================== + +Features +-------- + +- Update the default room version to [v10](https://spec.matrix.org/v1.5/rooms/v10/) ([MSC 3904](https://github.com/matrix-org/matrix-spec-proposals/pull/3904)). Contributed by @FSG-Cat. ([\#14111](https://github.com/matrix-org/synapse/issues/14111)) +- Add a `set_displayname()` method to the module API for setting a user's display name. ([\#14629](https://github.com/matrix-org/synapse/issues/14629)) +- Add a dedicated listener configuration for `health` endpoint. ([\#14747](https://github.com/matrix-org/synapse/issues/14747)) +- Implement support for [MSC3890](https://github.com/matrix-org/matrix-spec-proposals/pull/3890): Remotely silence local notifications. ([\#14775](https://github.com/matrix-org/synapse/issues/14775)) +- Implement experimental support for [MSC3930](https://github.com/matrix-org/matrix-spec-proposals/pull/3930): Push rules for ([MSC3381](https://github.com/matrix-org/matrix-spec-proposals/pull/3381)) Polls. ([\#14787](https://github.com/matrix-org/synapse/issues/14787)) +- Per [MSC3925](https://github.com/matrix-org/matrix-spec-proposals/pull/3925), bundle the whole of the replacement with any edited events, and optionally inhibit server-side replacement. ([\#14811](https://github.com/matrix-org/synapse/issues/14811)) +- Faster joins: always serve a partial join response to servers that request it with the stable query param. ([\#14839](https://github.com/matrix-org/synapse/issues/14839)) +- Faster joins: allow non-lazy-loading ("eager") syncs to complete after a partial join by omitting partial state rooms until they become fully stated. ([\#14870](https://github.com/matrix-org/synapse/issues/14870)) +- Faster joins: request partial joins by default. Admins can opt-out of this for the time being---see the upgrade notes. ([\#14905](https://github.com/matrix-org/synapse/issues/14905)) + + +Bugfixes +-------- + +- Add index to improve performance of the `/timestamp_to_event` endpoint used for jumping to a specific date in the timeline of a room. ([\#14799](https://github.com/matrix-org/synapse/issues/14799)) +- Fix a long-standing bug where Synapse would exhaust the stack when processing many federation requests where the remote homeserver has disconencted early. ([\#14812](https://github.com/matrix-org/synapse/issues/14812), [\#14842](https://github.com/matrix-org/synapse/issues/14842)) +- Fix rare races when using workers. ([\#14820](https://github.com/matrix-org/synapse/issues/14820)) +- Fix a bug introduced in Synapse 1.64.0 when using room version 10 with frozen events enabled. ([\#14864](https://github.com/matrix-org/synapse/issues/14864)) +- Fix a long-standing bug where the `populate_room_stats` background job could fail on broken rooms. ([\#14873](https://github.com/matrix-org/synapse/issues/14873)) +- Faster joins: Fix a bug in worker deployments where the room stats and user directory would not get updated when finishing a fast join until another event is sent or received. ([\#14874](https://github.com/matrix-org/synapse/issues/14874)) +- Faster joins: Fix incompatibility with joins into restricted rooms where no local users have the ability to invite. ([\#14882](https://github.com/matrix-org/synapse/issues/14882)) +- Fix a regression introduced in Synapse 1.69.0 which can result in database corruption when database migrations are interrupted on sqlite. ([\#14910](https://github.com/matrix-org/synapse/issues/14910)) + + +Updates to the Docker image +--------------------------- + +- Bump default Python version in the Dockerfile from 3.9 to 3.11. ([\#14875](https://github.com/matrix-org/synapse/issues/14875)) + + +Improved Documentation +---------------------- + +- Include `x_forwarded` entry in the HTTP listener example configs and remove the remaining `worker_main_http_uri` entries. ([\#14667](https://github.com/matrix-org/synapse/issues/14667)) +- Remove duplicate commands from the Code Style documentation page; point to the Contributing Guide instead. ([\#14773](https://github.com/matrix-org/synapse/issues/14773)) +- Add missing documentation for `tag` to `listeners` section. ([\#14803](https://github.com/matrix-org/synapse/issues/14803)) +- Updated documentation in configuration manual for `user_directory.search_all_users`. ([\#14818](https://github.com/matrix-org/synapse/issues/14818)) +- Add `worker_manhole` to configuration manual. ([\#14824](https://github.com/matrix-org/synapse/issues/14824)) +- Fix the example config missing the `id` field in [application service documentation](https://matrix-org.github.io/synapse/latest/application_services.html). ([\#14845](https://github.com/matrix-org/synapse/issues/14845)) +- Minor corrections to the logging configuration documentation. ([\#14868](https://github.com/matrix-org/synapse/issues/14868)) +- Document the export user data command. Contributed by @thezaidbintariq. ([\#14883](https://github.com/matrix-org/synapse/issues/14883)) + + +Deprecations and Removals +------------------------- + +- Poetry 1.3.2 or higher is now required when `poetry install`ing from source. ([\#14860](https://github.com/matrix-org/synapse/issues/14860)) + + +Internal Changes +---------------- + +- Faster remote room joins (worker mode): do not populate external hosts-in-room cache when sending events as this requires blocking for full state. ([\#14749](https://github.com/matrix-org/synapse/issues/14749)) +- Enable Complement tests for Faster Remote Room Joins against worker-mode Synapse. ([\#14752](https://github.com/matrix-org/synapse/issues/14752)) +- Add some clarifying comments and refactor a portion of the `Keyring` class for readability. ([\#14804](https://github.com/matrix-org/synapse/issues/14804)) +- Add local poetry config files (`poetry.toml`) to `.gitignore`. ([\#14807](https://github.com/matrix-org/synapse/issues/14807)) +- Add missing type hints. ([\#14816](https://github.com/matrix-org/synapse/issues/14816), [\#14885](https://github.com/matrix-org/synapse/issues/14885), [\#14889](https://github.com/matrix-org/synapse/issues/14889)) +- Refactor push tests. ([\#14819](https://github.com/matrix-org/synapse/issues/14819)) +- Re-enable some linting that was disabled when we switched to ruff. ([\#14821](https://github.com/matrix-org/synapse/issues/14821)) +- Add `cargo fmt` and `cargo clippy` to the lint script. ([\#14822](https://github.com/matrix-org/synapse/issues/14822)) +- Drop unused table `presence`. ([\#14825](https://github.com/matrix-org/synapse/issues/14825)) +- Merge the two account data and the two device list replication streams. ([\#14826](https://github.com/matrix-org/synapse/issues/14826), [\#14833](https://github.com/matrix-org/synapse/issues/14833)) +- Faster joins: use stable identifiers from [MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706). ([\#14832](https://github.com/matrix-org/synapse/issues/14832), [\#14841](https://github.com/matrix-org/synapse/issues/14841)) +- Add a parameter to control whether the federation client performs a partial state join. ([\#14843](https://github.com/matrix-org/synapse/issues/14843)) +- Add check to avoid starting duplicate partial state syncs. ([\#14844](https://github.com/matrix-org/synapse/issues/14844)) +- Add an early return when handling no-op presence updates. ([\#14855](https://github.com/matrix-org/synapse/issues/14855)) +- Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token. ([\#14856](https://github.com/matrix-org/synapse/issues/14856), [\#14872](https://github.com/matrix-org/synapse/issues/14872)) +- Always notify replication when a stream advances automatically. ([\#14877](https://github.com/matrix-org/synapse/issues/14877)) +- Reduce max time we wait for stream positions. ([\#14881](https://github.com/matrix-org/synapse/issues/14881)) +- Faster joins: allow the resync process more time to fetch `/state` ids. ([\#14912](https://github.com/matrix-org/synapse/issues/14912)) +- Bump regex from 1.7.0 to 1.7.1. ([\#14848](https://github.com/matrix-org/synapse/issues/14848)) +- Bump peaceiris/actions-gh-pages from 3.9.1 to 3.9.2. ([\#14861](https://github.com/matrix-org/synapse/issues/14861)) +- Bump ruff from 0.0.215 to 0.0.224. ([\#14862](https://github.com/matrix-org/synapse/issues/14862)) +- Bump types-pillow from 9.4.0.0 to 9.4.0.3. ([\#14863](https://github.com/matrix-org/synapse/issues/14863)) +- Bump types-opentracing from 2.4.10 to 2.4.10.1. ([\#14896](https://github.com/matrix-org/synapse/issues/14896)) +- Bump ruff from 0.0.224 to 0.0.230. ([\#14897](https://github.com/matrix-org/synapse/issues/14897)) +- Bump types-requests from 2.28.11.7 to 2.28.11.8. ([\#14899](https://github.com/matrix-org/synapse/issues/14899)) +- Bump types-psycopg2 from 2.9.21.2 to 2.9.21.4. ([\#14900](https://github.com/matrix-org/synapse/issues/14900)) +- Bump types-commonmark from 0.9.2 to 0.9.2.1. ([\#14901](https://github.com/matrix-org/synapse/issues/14901)) + + Synapse 1.75.0 (2023-01-17) =========================== diff --git a/Cargo.lock b/Cargo.lock index ace6a8c50aa0..a9219eac11e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb2f989d18dd141ab8ae82f64d1a8cdd37e0840f73a406896cf5e99502fab61" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" [[package]] name = "arc-swap" @@ -294,9 +294,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a" +checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ "aho-corasick", "memchr", @@ -343,9 +343,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.91" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a" dependencies = [ "itoa", "ryu", diff --git a/contrib/lnav/README.md b/contrib/lnav/README.md new file mode 100644 index 000000000000..5230a191d28e --- /dev/null +++ b/contrib/lnav/README.md @@ -0,0 +1,47 @@ +# `lnav` config for Synapse logs + +[lnav](https://lnav.org/) is a log-viewing tool. It is particularly useful when +you need to interleave multiple log files, or for exploring a large log file +with regex filters. The downside is that it is not as ubiquitous as tools like +`less`, `grep`, etc. + +This directory contains an `lnav` [log format definition]( + https://docs.lnav.org/en/v0.10.1/formats.html#defining-a-new-format +) for Synapse logs as +emitted by Synapse with the default [logging configuration]( + https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#log_config +). It supports lnav 0.10.1 because that's what's packaged by my distribution. + +This should allow lnav: + +- to interpret timestamps, allowing log interleaving; +- to interpret log severity levels, allowing colouring by log level(!!!); +- to interpret request IDs, allowing you to skip through a specific request; and +- to highlight room, event and user IDs in logs. + +See also https://gist.github.com/benje/e2ab750b0a81d11920d83af637d289f7 for a + similar example. + +## Example + +[![asciicast](https://asciinema.org/a/556133.svg)](https://asciinema.org/a/556133) + +## Tips + +- `lnav -i /path/to/synapse/checkout/contrib/lnav/synapse-log-format.json` +- `lnav my_synapse_log_file` or `lnav synapse_log_files.*`, etc. +- `lnav --help` for CLI help. + +Within lnav itself: + +- `?` for help within lnav itself. +- `q` to quit. +- `/` to search a-la `less` and `vim`, then `n` and `N` to continue searching + down and up. +- Use `o` and `O` to skip through logs based on the request ID (`POST-1234`, or + else the value of the [`request_id_header`]( + https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html?highlight=request_id_header#listeners + ) header). This may get confused if the same request ID is repeated among + multiple files or process restarts. +- ??? +- Profit diff --git a/contrib/lnav/synapse-log-format.json b/contrib/lnav/synapse-log-format.json new file mode 100644 index 000000000000..ad7017ee5ec2 --- /dev/null +++ b/contrib/lnav/synapse-log-format.json @@ -0,0 +1,67 @@ +{ + "$schema": "https://lnav.org/schemas/format-v1.schema.json", + "synapse": { + "title": "Synapse logs", + "description": "Logs output by Synapse, a Matrix homesever, under its default logging config.", + "regex": { + "log": { + "pattern": ".*(?\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2},\\d{3}) - (?.+) - (?\\d+) - (?\\w+) - (?.+) - (?.*)" + } + }, + "json": false, + "timestamp-field": "timestamp", + "timestamp-format": [ + "%Y-%m-%d %H:%M:%S,%L" + ], + "level-field": "level", + "body-field": "body", + "opid-field": "context", + "level": { + "critical": "CRITICAL", + "error": "ERROR", + "warning": "WARNING", + "info": "INFO", + "debug": "DEBUG" + }, + "sample": [ + { + "line": "my-matrix-server-generic-worker-4 | 2023-01-27 09:47:09,818 - synapse.replication.tcp.client - 381 - ERROR - PUT-32992 - Timed out waiting for stream receipts", + "level": "error" + }, + { + "line": "my-matrix-server-federation-sender-1 | 2023-01-25 20:56:20,995 - synapse.http.matrixfederationclient - 709 - WARNING - federation_transaction_transmission_loop-3 - {PUT-O-3} [example.com] Request failed: PUT matrix://example.com/_matrix/federation/v1/send/1674680155797: HttpResponseException('403: Forbidden')", + "level": "warning" + }, + { + "line": "my-matrix-server | 2023-01-25 20:55:54,433 - synapse.storage.databases - 66 - INFO - main - [database config 'master']: Checking database server", + "level": "info" + }, + { + "line": "my-matrix-server | 2023-01-26 15:08:40,447 - synapse.access.http.8008 - 460 - INFO - PUT-74929 - 0.0.0.0 - 8008 - {@alice:example.com} Processed request: 0.011sec/0.000sec (0.000sec, 0.000sec) (0.001sec/0.008sec/3) 2B 200 \"PUT /_matrix/client/r0/user/%40alice%3Atexample.com/account_data/im.vector.setting.breadcrumbs HTTP/1.0\" \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Element/1.11.20 Chrome/108.0.5359.179 Electron/22.0.3 Safari/537.36\" [0 dbevts]", + "level": "info" + } + ], + "highlights": { + "user_id": { + "pattern": "(@|%40)[^:% ]+(:|%3A)[\\[\\]0-9a-zA-Z.\\-:]+(:\\d{1,5})?(? Tue, 14 Feb 2023 12:59:02 +0100 + +matrix-synapse-py3 (1.77.0~rc2) stable; urgency=medium + + * New Synapse release 1.77.0rc2. + + -- Synapse Packaging team Fri, 10 Feb 2023 12:44:21 +0000 + +matrix-synapse-py3 (1.77.0~rc1) stable; urgency=medium + + * New Synapse release 1.77.0rc1. + + -- Synapse Packaging team Tue, 07 Feb 2023 13:45:14 +0000 + +matrix-synapse-py3 (1.76.0) stable; urgency=medium + + * New Synapse release 1.76.0. + + -- Synapse Packaging team Tue, 31 Jan 2023 08:21:47 -0800 + +matrix-synapse-py3 (1.76.0~rc2) stable; urgency=medium + + * New Synapse release 1.76.0rc2. + + -- Synapse Packaging team Fri, 27 Jan 2023 11:17:57 +0000 + +matrix-synapse-py3 (1.76.0~rc1) stable; urgency=medium + + * Use Poetry 1.3.2 to manage the bundled virtualenv included with this package. + * New Synapse release 1.76.0rc1. + + -- Synapse Packaging team Wed, 25 Jan 2023 16:21:16 +0000 + matrix-synapse-py3 (1.75.0) stable; urgency=medium * New Synapse release 1.75.0. diff --git a/docker/Dockerfile b/docker/Dockerfile index bfa78cb4b0f2..bfa0cb752324 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -17,13 +17,8 @@ # Irritatingly, there is no blessed guide on how to distribute an application with its # poetry-managed environment in a docker image. We have opted for -# `poetry export | pip install -r /dev/stdin`, but there are known bugs in -# in `poetry export` whose fixes (scheduled for poetry 1.2) have yet to be released. -# In case we get bitten by those bugs in the future, the recommendations here might -# be useful: -# https://github.com/python-poetry/poetry/discussions/1879#discussioncomment-216865 -# https://stackoverflow.com/questions/53835198/integrating-python-poetry-with-docker?answertab=scoredesc - +# `poetry export | pip install -r /dev/stdin`, but beware: we have experienced bugs in +# in `poetry export` in the past. ARG PYTHON_VERSION=3.11 ARG BASE_IMAGE=docker.io/python:${PYTHON_VERSION}-slim-bullseye @@ -57,7 +52,7 @@ RUN curl -sSf https://sh.rustup.rs | sh -s -- -y --no-modify-path --default-tool # We install poetry in its own build stage to avoid its dependencies conflicting with # synapse's dependencies. RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --user "poetry==1.2.0" + pip install --user "poetry==1.3.2" WORKDIR /synapse diff --git a/docker/complement/conf/start_for_complement.sh b/docker/complement/conf/start_for_complement.sh index 49d79745b064..af13209c54e9 100755 --- a/docker/complement/conf/start_for_complement.sh +++ b/docker/complement/conf/start_for_complement.sh @@ -6,7 +6,7 @@ set -e echo "Complement Synapse launcher" echo " Args: $@" -echo " Env: SYNAPSE_COMPLEMENT_DATABASE=$SYNAPSE_COMPLEMENT_DATABASE SYNAPSE_COMPLEMENT_USE_WORKERS=$SYNAPSE_COMPLEMENT_USE_WORKERS" +echo " Env: SYNAPSE_COMPLEMENT_DATABASE=$SYNAPSE_COMPLEMENT_DATABASE SYNAPSE_COMPLEMENT_USE_WORKERS=$SYNAPSE_COMPLEMENT_USE_WORKERS SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR=$SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR" function log { d=$(date +"%Y-%m-%d %H:%M:%S,%3N") @@ -76,6 +76,17 @@ else fi +if [[ -n "$SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR" ]]; then + if [[ -n "$SYNAPSE_USE_EXPERIMENTAL_FORKING_LAUNCHER" ]]; then + export SYNAPSE_COMPLEMENT_FORKING_LAUNCHER_ASYNC_IO_REACTOR="1" + else + export SYNAPSE_ASYNC_IO_REACTOR="1" + fi +else + export SYNAPSE_ASYNC_IO_REACTOR="0" +fi + + # Add Complement's appservice registration directory, if there is one # (It can be absent when there are no application services in this test!) if [ -d /complement/appservice ]; then diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index cb839fed078d..63acf86a4619 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -94,16 +94,16 @@ allow_device_name_lookup_over_federation: true experimental_features: # Enable history backfilling support msc2716_enabled: true - # server-side support for partial state in /send_join responses - msc3706_enabled: true - {% if not workers_in_use %} # client-side support for partial state in /send_join responses faster_joins: true - {% endif %} - # Filtering /messages by relation type. - msc3874_enabled: true + # Enable support for polls + msc3381_polls_enabled: true + # Enable deleting device-specific notification settings stored in account data + msc3890_enabled: true # Enable removing account data support msc3391_enabled: true + # Filtering /messages by relation type. + msc3874_enabled: true server_notices: system_mxid_localpart: _server diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 8d68719958d6..ade77d49261c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -97,6 +97,7 @@ - [Log Contexts](log_contexts.md) - [Replication](replication.md) - [TCP Replication](tcp_replication.md) + - [Faster remote joins](development/synapse_architecture/faster_joins.md) - [Internal Documentation](development/internal_documentation/README.md) - [Single Sign-On]() - [SAML](development/saml.md) diff --git a/docs/application_services.md b/docs/application_services.md index e4592010a2b5..1f988185a95e 100644 --- a/docs/application_services.md +++ b/docs/application_services.md @@ -15,6 +15,7 @@ app_service_config_files: The format of the AS configuration file is as follows: ```yaml +id: url: as_token: hs_token: diff --git a/docs/code_style.md b/docs/code_style.md index 3aa7d0d741ba..026001b8a3cf 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -13,23 +13,14 @@ The necessary tools are: - [ruff](https://github.com/charliermarsh/ruff), which can spot common errors; and - [mypy](https://mypy.readthedocs.io/en/stable/), a type checker. -Install them with: - -```sh -pip install -e ".[lint,mypy]" -``` - -The easiest way to run the lints is to invoke the linter script as follows. - -```sh -scripts-dev/lint.sh -``` +See [the contributing guide](development/contributing_guide.md#run-the-linters) for instructions +on how to install the above tools and run the linters. It's worth noting that modern IDEs and text editors can run these tools automatically on save. It may be worth looking into whether this functionality is supported in your editor for a more convenient development workflow. It is not, however, recommended to run `mypy` -on save as they take a while and can be very resource intensive. +on save as it takes a while and can be very resource intensive. ## General rules diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index 4c1067671482..36bc88468468 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -67,7 +67,7 @@ pipx install poetry but see poetry's [installation instructions](https://python-poetry.org/docs/#installation) for other installation methods. -Synapse requires Poetry version 1.2.0 or later. +Developing Synapse requires Poetry version 1.3.2 or later. Next, open a terminal and install dependencies as follows: @@ -332,6 +332,7 @@ The above will run a monolithic (single-process) Synapse with SQLite as the data [here](https://github.com/matrix-org/synapse/blob/develop/docker/configure_workers_and_start.py#L54). A safe example would be `WORKER_TYPES="federation_inbound, federation_sender, synchrotron"`. See the [worker documentation](../workers.md) for additional information on workers. +- Passing `ASYNCIO_REACTOR=1` as an environment variable to use the Twisted asyncio reactor instead of the default one. To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`, e.g: ```sh diff --git a/docs/development/dependencies.md b/docs/development/dependencies.md index 8474525480d6..c4449c51f780 100644 --- a/docs/development/dependencies.md +++ b/docs/development/dependencies.md @@ -2,6 +2,13 @@ This is a quick cheat sheet for developers on how to use [`poetry`](https://python-poetry.org/). +# Installing + +See the [contributing guide](contributing_guide.md#4-install-the-dependencies). + +Developers should use Poetry 1.3.2 or higher. If you encounter problems related +to poetry, please [double-check your poetry version](#check-the-version-of-poetry-with-poetry---version). + # Background Synapse uses a variety of third-party Python packages to function as a homeserver. @@ -123,7 +130,7 @@ context of poetry's venv, without having to run `poetry shell` beforehand. ## ...reset my venv to the locked environment? ```shell -poetry install --extras all --remove-untracked +poetry install --all-extras --sync ``` ## ...delete everything and start over from scratch? @@ -183,7 +190,6 @@ Either: - manually update `pyproject.toml`; then `poetry lock --no-update`; or else - `poetry add packagename`. See `poetry add --help`; note the `--dev`, `--extras` and `--optional` flags in particular. - - **NB**: this specifies the new package with a version given by a "caret bound". This won't get forced to its lowest version in the old deps CI job: see [this TODO](https://github.com/matrix-org/synapse/blob/4e1374373857f2f7a911a31c50476342d9070681/.ci/scripts/test_old_deps.sh#L35-L39). Include the updated `pyproject.toml` and `poetry.lock` files in your commit. @@ -196,7 +202,7 @@ poetry remove packagename ``` ought to do the trick. Alternatively, manually update `pyproject.toml` and -`poetry lock --no-update`. Include the updated `pyproject.toml` and poetry.lock` +`poetry lock --no-update`. Include the updated `pyproject.toml` and `poetry.lock` files in your commit. ## ...update the version range for an existing dependency? @@ -240,9 +246,6 @@ poetry export --extras all Be wary of bugs in `poetry export` and `pip install -r requirements.txt`. -Note: `poetry export` will be made a plugin in Poetry 1.2. Additional config may -be required. - ## ...build a test wheel? I usually use @@ -255,12 +258,26 @@ because [`build`](https://github.com/pypa/build) is a standardish tool which doesn't require poetry. (It's what we use in CI too). However, you could try `poetry build` too. +## ...handle a Dependabot pull request? + +Synapse uses Dependabot to keep the `poetry.lock` file up-to-date. When it +creates a pull request a GitHub Action will run to automatically create a changelog +file. Ensure that: + +* the lockfile changes look reasonable; +* the upstream changelog file (linked in the description) doesn't include any + breaking changes; +* continuous integration passes (due to permissions, the GitHub Actions run on + the changelog commit will fail, look at the initial commit of the pull request); + +In particular, any updates to the type hints (usually packages which start with `types-`) +should be safe to merge if linting passes. # Troubleshooting ## Check the version of poetry with `poetry --version`. -The minimum version of poetry supported by Synapse is 1.2. +The minimum version of poetry supported by Synapse is 1.3.2. It can also be useful to check the version of `poetry-core` in use. If you've installed `poetry` with `pipx`, try `pipx runpip poetry list | grep diff --git a/docs/development/synapse_architecture/faster_joins.md b/docs/development/synapse_architecture/faster_joins.md new file mode 100644 index 000000000000..c32d713b8af7 --- /dev/null +++ b/docs/development/synapse_architecture/faster_joins.md @@ -0,0 +1,375 @@ +# How do faster joins work? + +This is a work-in-progress set of notes with two goals: +- act as a reference, explaining how Synapse implements faster joins; and +- record the rationale behind our choices. + +See also [MSC3902](https://github.com/matrix-org/matrix-spec-proposals/pull/3902). + +The key idea is described by [MSC706](https://github.com/matrix-org/matrix-spec-proposals/pull/3902). This allows servers to +request a lightweight response to the federation `/send_join` endpoint. +This is called a **faster join**, also known as a **partial join**. In these +notes we'll usually use the word "partial" as it matches the database schema. + +## Overview: processing events in a partially-joined room + +The response to a partial join consists of +- the requested join event `J`, +- a list of the servers in the room (according to the state before `J`), +- a subset of the state of the room before `J`, +- the full auth chain of that state subset. + +Synapse marks the room as partially joined by adding a row to the database table +`partial_state_rooms`. It also marks the join event `J` as "partially stated", +meaning that we have neither received nor computed the full state before/after +`J`. This is done by adding a row to `partial_state_events`. + +
DB schema + +``` +matrix=> \d partial_state_events +Table "matrix.partial_state_events" + Column │ Type │ Collation │ Nullable │ Default +══════════╪══════╪═══════════╪══════════╪═════════ + room_id │ text │ │ not null │ + event_id │ text │ │ not null │ + +matrix=> \d partial_state_rooms + Table "matrix.partial_state_rooms" + Column │ Type │ Collation │ Nullable │ Default +════════════════════════╪════════╪═══════════╪══════════╪═════════ + room_id │ text │ │ not null │ + device_lists_stream_id │ bigint │ │ not null │ 0 + join_event_id │ text │ │ │ + joined_via │ text │ │ │ + +matrix=> \d partial_state_rooms_servers + Table "matrix.partial_state_rooms_servers" + Column │ Type │ Collation │ Nullable │ Default +═════════════╪══════╪═══════════╪══════════╪═════════ + room_id │ text │ │ not null │ + server_name │ text │ │ not null │ +``` + +Indices, foreign-keys and check constraints are omitted for brevity. +
+ +While partially joined to a room, Synapse receives events `E` from remote +homeservers as normal, and can create events at the request of its local users. +However, we run into trouble when we enforce the [checks on an event]. + +> 1. Is a valid event, otherwise it is dropped. For an event to be valid, it + must contain a room_id, and it must comply with the event format of that +> room version. +> 2. Passes signature checks, otherwise it is dropped. +> 3. Passes hash checks, otherwise it is redacted before being processed further. +> 4. Passes authorization rules based on the event’s auth events, otherwise it +> is rejected. +> 5. **Passes authorization rules based on the state before the event, otherwise +> it is rejected.** +> 6. **Passes authorization rules based on the current state of the room, +> otherwise it is “soft failed”.** + +[checks on an event]: https://spec.matrix.org/v1.5/server-server-api/#checks-performed-on-receipt-of-a-pdu + +We can enforce checks 1--4 without any problems. +But we cannot enforce checks 5 or 6 with complete certainty, since Synapse does +not know the full state before `E`, nor that of the room. + +### Partial state + +Instead, we make a best-effort approximation. +While the room is considered partially joined, Synapse tracks the "partial +state" before events. +This works in a similar way as regular state: + +- The partial state before `J` is that given to us by the partial join response. +- The partial state before an event `E` is the resolution of the partial states + after each of `E`'s `prev_event`s. +- If `E` is rejected or a message event, the partial state after `E` is the + partial state before `E`. +- Otherwise, the partial state after `E` is the partial state before `E`, plus + `E` itself. + +More concisely, partial state propagates just like full state; the only +difference is that we "seed" it with an incomplete initial state. +Synapse records that we have only calculated partial state for this event with +a row in `partial_state_events`. + +While the room remains partially stated, check 5 on incoming events to that +room becomes: + +> 5. Passes authorization rules based on **the resolution between the partial +> state before `E` and `E`'s auth events.** If the event fails to pass +> authorization rules, it is rejected. + +Additionally, check 6 is deleted: no soft-failures are enforced. + +While partially joined, the current partial state of the room is defined as the +resolution across the partial states after all forward extremities in the room. + +_Remark._ Events with partial state are _not_ considered +[outliers](../room-dag-concepts.md#outliers). + +### Approximation error + +Using partial state means the auth checks can fail in a few different ways[^2]. + +[^2]: Is this exhaustive? + +- We may erroneously accept an incoming event in check 5 based on partial state + when it would have been rejected based on full state, or vice versa. +- This means that an event could erroneously be added to the current partial + state of the room when it would not be present in the full state of the room, + or vice versa. +- Additionally, we may have skipped soft-failing an event that would have been + soft-failed based on full state. + +(Note that the discrepancies described in the last two bullets are user-visible.) + +This means that we have to be very careful when we want to lookup pieces of room +state in a partially-joined room. Our approximation of the state may be +incorrect or missing. But we can make some educated guesses. If + +- our partial state is likely to be correct, or +- the consequences of our partial state being incorrect are minor, + +then we proceed as normal, and let the resync process fix up any mistakes (see +below). + +When is our partial state likely to be correct? + +- It's more accurate the closer we are to the partial join event. (So we should + ideally complete the resync as soon as possible.) +- Non-member events: we will have received them as part of the partial join + response, if they were part of the room state at that point. We may + incorrectly accept or reject updates to that state (at first because we lack + remote membership information; later because of compounding errors), so these + can become incorrect over time. +- Local members' memberships: we are the only ones who can create join and + knock events for our users. We can't be completely confident in the + correctness of bans, invites and kicks from other homeservers, but the resync + process should correct any mistakes. +- Remote members' memberships: we did not receive these in the /send_join + response, so we have essentially no idea if these are correct or not. + +In short, we deem it acceptable to trust the partial state for non-membership +and local membership events. For remote membership events, we wait for the +resync to complete, at which point we have the full state of the room and can +proceed as normal. + +### Fixing the approximation with a resync + +The partial-state approximation is only a temporary affair. In the background, +synapse beings a "resync" process. This is a continuous loop, starting at the +partial join event and proceeding downwards through the event graph. For each +`E` seen in the room since partial join, Synapse will fetch + +- the event ids in the state of the room before `E`, via + [`/state_ids`](https://spec.matrix.org/v1.5/server-server-api/#get_matrixfederationv1state_idsroomid); +- the event ids in the full auth chain of `E`, included in the `/state_ids` + response; and +- any events from the previous two bullets that Synapse hasn't persisted, via + [`/state](https://spec.matrix.org/v1.5/server-server-api/#get_matrixfederationv1stateroomid). + +This means Synapse has (or can compute) the full state before `E`, which allows +Synapse to properly authorise or reject `E`. At this point ,the event +is considered to have "full state" rather than "partial state". We record this +by removing `E` from the `partial_state_events` table. + +\[**TODO:** Does Synapse persist a new state group for the full state +before `E`, or do we alter the (partial-)state group in-place? Are state groups +ever marked as partially-stated? \] + +This scheme means it is possible for us to have accepted and sent an event to +clients, only to reject it during the resync. From a client's perspective, the +effect is similar to a retroactive +state change due to state resolution---i.e. a "state reset".[^3] + +[^3]: Clients should refresh caches to detect such a change. Rumour has it that +sliding sync will fix this. + +When all events since the join `J` have been fully-stated, the room resync +process is complete. We record this by removing the room from +`partial_state_rooms`. + +## Faster joins on workers + +For the time being, the resync process happens on the master worker. +A new replication stream `un_partial_stated_room` is added. Whenever a resync +completes and a partial-state room becomes fully stated, a new message is sent +into that stream containing the room ID. + +## Notes on specific cases + +> **NB.** The notes below are rough. Some of them are hidden under `
` +disclosures because they have yet to be implemented in mainline Synapse. + +### Creating events during a partial join + +When sending out messages during a partial join, we assume our partial state is +accurate and proceed as normal. For this to have any hope of succeeding at all, +our partial state must contain an entry for each of the (type, state key) pairs +[specified by the auth rules](https://spec.matrix.org/v1.3/rooms/v10/#authorization-rules): + +- `m.room.create` +- `m.room.join_rules` +- `m.room.power_levels` +- `m.room.third_party_invite` +- `m.room.member` + +The first four of these should be present in the state before `J` that is given +to us in the partial join response; only membership events are omitted. In order +for us to consider the user joined, we must have their membership event. That +means the only possible omission is the target's membership in an invite, kick +or ban. + +The worst possibility is that we locally invite someone who is banned according to +the full state, because we lack their ban in our current partial state. The rest +of the federation---at least, those who are fully joined---should correctly +enforce the [membership transition constraints]( + https://spec.matrix.org/v1.3/client-server-api/#room-membership +). So any the erroneous invite should be ignored by fully-joined +homeservers and resolved by the resync for partially-joined homeservers. + + + +In more generality, there are two problems we're worrying about here: + +- We might create an event that is valid under our partial state, only to later + find out that is actually invalid according to the full state. +- Or: we might refuse to create an event that is invalid under our partial + state, even though it would be perfectly valid under the full state. + +However we expect such problems to be unlikely in practise, because + +- We trust that the room has sensible power levels, e.g. that bad actors with + high power levels are demoted before their ban. +- We trust that the resident server provides us up-to-date power levels, join + rules, etc. +- State changes in rooms are relatively infrequent, and the resync period is + relatively quick. + +#### Sending out the event over federation + +**TODO:** needs prose fleshing out. + +Normally: send out in a fed txn to all HSes in the room. +We only know that some HSes were in the room at some point. Wat do. +Send it out to the list of servers from the first join. +**TODO** what do we do here if we have full state? +If the prev event was created by us, we can risk sending it to the wrong HS. (Motivation: privacy concern of the content. Not such a big deal for a public room or an encrypted room. But non-encrypted invite-only...) +But don't want to send out sensitive data in other HS's events in this way. + +Suppose we discover after resync that we shouldn't have sent out one our events (not a prev_event) to a target HS. Not much we can do. +What about if we didn't send them an event but shouldn't've? +E.g. what if someone joined from a new HS shortly after you did? We wouldn't talk to them. +Could imagine sending out the "Missed" events after the resync but... painful to work out what they shuld have seen if they joined/left. +Instead, just send them the latest event (if they're still in the room after resync) and let them backfill.(?) +- Don't do this currently. +- If anyone who has received our messages sends a message to a HS we missed, they can backfill our messages +- Gap: rooms which are infrequently used and take a long time to resync. + +### Joining after a partial join + +**NB.** Not yet implemented. + +
+ +**TODO:** needs prose fleshing out. Liase with Matthieu. Explain why /send_join +(Rich was surprised we didn't just create it locally. Answer: to try and avoid +a join which then gets rejected after resync.) + +We don't know for sure that any join we create would be accepted. +E.g. the joined user might have been banned; the join rules might have changed in a way that we didn't realise... some way in which the partial state was mistaken. +Instead, do another partial make-join/send-join handshake to confirm that the join works. +- Probably going to get a bunch of duplicate state events and auth events.... but the point of partial joins is that these should be small. Many are already persisted = good. +- What if the second send_join response includes a different list of reisdent HSes? Could ignore it. + - Could even have a special flag that says "just make me a join", i.e. don't bother giving me state or servers in room. Deffo want the auth chain tho. +- SQ: wrt device lists it's a lot safer to ignore it!!!!! +- What if the state at the second join is inconsistent with what we have? Ignore it? + +
+ +### Leaving (and kicks and bans) after a partial join + +**NB.** Not yet implemented. + +
+ +When you're fully joined to a room, to have `U` leave a room their homeserver +needs to + +- create a new leave event for `U` which will be accepted by other homeservers, + and +- send that event `U` out to the homeservers in the federation. + +When is a leave event accepted? See +[v10 auth rules](https://spec.matrix.org/v1.5/rooms/v10/#authorization-rules): + +> 4. If type is m.room.member: [...] + > + > 5. If membership is leave: + > + > 1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock. +> 2. [...] + +I think this means that (well-formed!) self-leaves are governed entirely by +4.5.1. This means that if we correctly calculate state which says that `U` is +invited, joined or knocked and include it in the leave's auth events, our event +is accepted by checks 4 and 5 on incoming events. + +> 4. Passes authorization rules based on the event’s auth events, otherwise + > it is rejected. +> 5. Passes authorization rules based on the state before the event, otherwise + > it is rejected. + +The only way to fail check 6 is if the receiving server's current state of the +room says that `U` is banned, has left, or has no membership event. But this is +fine: the receiving server already thinks that `U` isn't in the room. + +> 6. Passes authorization rules based on the current state of the room, + > otherwise it is “soft failed”. + +For the second point (publishing the leave event), the best thing we can do is +to is publish to all HSes we know to be currently in the room. If they miss that +event, they might send us traffic in the room that we don't care about. This is +a problem with leaving after a "full" join; we don't seek to fix this with +partial joins. + +(With that said: there's nothing machine-readable in the /send response. I don't +think we can deduce "destination has left the room" from a failure to /send an +event into that room?) + +#### Can we still do this during a partial join? + +We can create leave events and can choose what gets included in our auth events, +so we can be sure that we pass check 4 on incoming events. For check 5, we might +have an incorrect view of the state before an event. +The only way we might erroneously think a leave is valid is if + +- the partial state before the leave has `U` joined, invited or knocked, but +- the full state before the leave has `U` banned, left or not present, + +in which case the leave doesn't make anything worse: other HSes already consider +us as not in the room, and will continue to do so after seeing the leave. + +The remaining obstacle is then: can we safely broadcast the leave event? We may +miss servers or incorrectly think that a server is in the room. Or the +destination server may be offline and miss the transaction containing our leave +event.This should self-heal when they see an event whose `prev_events` descends +from our leave. + +Another option we considered was to use federation `/send_leave` to ask a +fully-joined server to send out the event on our behalf. But that introduces +complexity without much benefit. Besides, as Rich put it, + +> sending out leaves is pretty best-effort currently + +so this is probably good enough as-is. + +#### Cleanup after the last leave + +**TODO**: what cleanup is necessary? Is it all just nice-to-have to save unused +work? +
diff --git a/docs/systemd-with-workers/workers/event_persister.yaml b/docs/systemd-with-workers/workers/event_persister.yaml index 9bc6997bad99..c11d5897b18e 100644 --- a/docs/systemd-with-workers/workers/event_persister.yaml +++ b/docs/systemd-with-workers/workers/event_persister.yaml @@ -17,6 +17,7 @@ worker_listeners: # #- type: http # port: 8035 + # x_forwarded: true # resources: # - names: [client] diff --git a/docs/systemd-with-workers/workers/generic_worker.yaml b/docs/systemd-with-workers/workers/generic_worker.yaml index 6e7b60886e72..a858f99ed1d9 100644 --- a/docs/systemd-with-workers/workers/generic_worker.yaml +++ b/docs/systemd-with-workers/workers/generic_worker.yaml @@ -5,11 +5,10 @@ worker_name: generic_worker1 worker_replication_host: 127.0.0.1 worker_replication_http_port: 9093 -worker_main_http_uri: http://localhost:8008/ - worker_listeners: - type: http port: 8083 + x_forwarded: true resources: - names: [client, federation] diff --git a/docs/systemd-with-workers/workers/media_worker.yaml b/docs/systemd-with-workers/workers/media_worker.yaml index eb34d1249231..8ad046f11a5b 100644 --- a/docs/systemd-with-workers/workers/media_worker.yaml +++ b/docs/systemd-with-workers/workers/media_worker.yaml @@ -8,6 +8,7 @@ worker_replication_http_port: 9093 worker_listeners: - type: http port: 8085 + x_forwarded: true resources: - names: [media] diff --git a/docs/upgrade.md b/docs/upgrade.md index c4bc5889a95f..bc143444bed6 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,39 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.76.0 + +## Faster joins are enabled by default + +When joining a room for the first time, Synapse 1.76.0 will request a partial join from the other server by default. Previously, server admins had to opt-in to this using an experimental config flag. + +Server admins can opt out of this feature for the time being by setting + +```yaml +experimental: + faster_joins: false +``` + +in their server config. + +## Changes to the account data replication streams + +Synapse has changed the format of the account data and devices replication +streams (between workers). This is a forwards- and backwards-incompatible +change: v1.75 workers cannot process account data replicated by v1.76 workers, +and vice versa. + +Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data +and device replication will resume as normal. + +## Minimum version of Poetry is now 1.3.2 + +The minimum supported version of Poetry is now 1.3.2 (previously 1.2.0, [since +Synapse 1.67](#upgrading-to-v1670)). If you have used `poetry install` to +install Synapse from a source checkout, you should upgrade poetry: see its +[installation instructions](https://python-poetry.org/docs/#installation). +For all other installation methods, no acction is required. + # Upgrading to v1.74.0 ## Unicode support in user search diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md index a6dc6197c90e..7a2774119964 100644 --- a/docs/usage/administration/admin_faq.md +++ b/docs/usage/administration/admin_faq.md @@ -2,13 +2,19 @@ How do I become a server admin? --- -If your server already has an admin account you should use the [User Admin API](../../admin_api/user_admin_api.md#change-whether-a-user-is-a-server-administrator-or-not) to promote other accounts to become admins. +If your server already has an admin account you should use the +[User Admin API](../../admin_api/user_admin_api.md#change-whether-a-user-is-a-server-administrator-or-not) +to promote other accounts to become admins. -If you don't have any admin accounts yet you won't be able to use the admin API, so you'll have to edit the database manually. Manually editing the database is generally not recommended so once you have an admin account: use the admin APIs to make further changes. +If you don't have any admin accounts yet you won't be able to use the admin API, +so you'll have to edit the database manually. Manually editing the database is +generally not recommended so once you have an admin account: use the admin APIs +to make further changes. ```sql UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; ``` + What servers are my server talking to? --- Run this sql query on your db: @@ -32,6 +38,44 @@ What users are registered on my server? SELECT NAME from users; ``` +How can I export user data? +--- +Synapse includes a Python command to export data for a specific user. It takes the homeserver +configuration file and the full Matrix ID of the user to export: + +```console +python -m synapse.app.admin_cmd -c export-data --output-directory +``` + +If you uses [Poetry](../../development/dependencies.md#managing-dependencies-with-poetry) +to run Synapse: + +```console +poetry run python -m synapse.app.admin_cmd -c export-data --output-directory +``` + +The directory to store the export data in can be customised with the +`--output-directory` parameter; ensure that the provided directory is +empty. If this parameter is not provided, Synapse defaults to creating +a temporary directory (which starts with "synapse-exfiltrate") in `/tmp`, +`/var/tmp`, or `/usr/tmp`, in that order. + +The exported data has the following layout: + +``` +output-directory +├───rooms +│ └─── +│ ├───events +│ ├───state +│ ├───invite_state +│ └───knock_state +└───user_data + ├───connections + ├───devices + └───profile +``` + Manually resetting passwords --- Users can reset their password through their client. Alternatively, a server admin @@ -42,21 +86,29 @@ I have a problem with my server. Can I just delete my database and start again? --- Deleting your database is unlikely to make anything better. -It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated network: lots of other servers have information about your server. +It's easy to make the mistake of thinking that you can start again from a clean +slate by dropping your database, but things don't work like that in a federated +network: lots of other servers have information about your server. -For example: other servers might think that you are in a room, your server will think that you are not, and you'll probably be unable to interact with that room in a sensible way ever again. +For example: other servers might think that you are in a room, your server will +think that you are not, and you'll probably be unable to interact with that room +in a sensible way ever again. -In general, there are better solutions to any problem than dropping the database. Come and seek help in https://matrix.to/#/#synapse:matrix.org. +In general, there are better solutions to any problem than dropping the database. +Come and seek help in https://matrix.to/#/#synapse:matrix.org. There are two exceptions when it might be sensible to delete your database and start again: -* You have *never* joined any rooms which are federated with other servers. For instance, a local deployment which the outside world can't talk to. -* You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database. +* You have *never* joined any rooms which are federated with other servers. For +instance, a local deployment which the outside world can't talk to. +* You are changing the `server_name` in the homeserver configuration. In effect +this makes your server a completely new one from the point of view of the network, +so in this case it makes sense to start with a clean database. (In both cases you probably also want to clear out the media_store.) I've stuffed up access to my room, how can I delete it to free up the alias? --- Using the following curl command: -``` +```console curl -H 'Authorization: Bearer ' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/ ``` `` - can be obtained in riot by looking in the riot settings, down the bottom is: @@ -67,19 +119,25 @@ Access Token:\ How can I find the lines corresponding to a given HTTP request in my homeserver log? --- -Synapse tags each log line according to the HTTP request it is processing. When it finishes processing each request, it logs a line containing the words `Processed request: `. For example: +Synapse tags each log line according to the HTTP request it is processing. When +it finishes processing each request, it logs a line containing the words +`Processed request: `. For example: ``` 2019-02-14 22:35:08,196 - synapse.access.http.8008 - 302 - INFO - GET-37 - ::1 - 8008 - {@richvdh:localhost} Processed request: 0.173sec/0.001sec (0.002sec, 0.000sec) (0.027sec/0.026sec/2) 687B 200 "GET /_matrix/client/r0/sync HTTP/1.1" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Safari/537.36" [0 dbevts]" ``` -Here we can see that the request has been tagged with `GET-37`. (The tag depends on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, `OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do: +Here we can see that the request has been tagged with `GET-37`. (The tag depends +on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, +`OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do: -``` +```console grep 'GET-37' homeserver.log ``` -If you want to paste that output into a github issue or matrix room, please remember to surround it with triple-backticks (```) to make it legible (see [quoting code](https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code)). +If you want to paste that output into a github issue or matrix room, please +remember to surround it with triple-backticks (```) to make it legible +(see [quoting code](https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code)). What do all those fields in the 'Processed' line mean? @@ -119,7 +177,7 @@ This is normally caused by a misconfiguration in your reverse-proxy. See [the re Help!! Synapse is slow and eats all my RAM/CPU! ------------------------------------------------ +--- First, ensure you are running the latest version of Synapse, using Python 3 with a [PostgreSQL database](../../postgres.md). @@ -161,7 +219,7 @@ in the Synapse config file: [see here](../configuration/config_documentation.md# Running out of File Handles ---------------------------- +--- If Synapse runs out of file handles, it typically fails badly - live-locking at 100% CPU, and/or failing to accept new TCP connections (blocking the diff --git a/docs/usage/administration/request_log.md b/docs/usage/administration/request_log.md index 7dd9969d8668..292e3449f195 100644 --- a/docs/usage/administration/request_log.md +++ b/docs/usage/administration/request_log.md @@ -10,10 +10,10 @@ See the following for how to decode the dense data available from the default lo ``` -| Part | Explanation | +| Part | Explanation | | ----- | ------------ | | AAAA | Timestamp request was logged (not received) | -| BBBB | Logger name (`synapse.access.(http\|https).`, where 'tag' is defined in the `listeners` config section, normally the port) | +| BBBB | Logger name (`synapse.access.(http\|https).`, where 'tag' is defined in the [`listeners`](../configuration/config_documentation.md#listeners) config section, normally the port) | | CCCC | Line number in code | | DDDD | Log Level | | EEEE | Request Identifier (This identifier is shared by related log lines)| diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 93d6c7fb02dd..2883f76a26cd 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -295,7 +295,9 @@ Known room versions are listed [here](https://spec.matrix.org/latest/rooms/#comp For example, for room version 1, `default_room_version` should be set to "1". -Currently defaults to "9". +Currently defaults to ["10"](https://spec.matrix.org/v1.5/rooms/v10/). + +_Changed in Synapse 1.76:_ the default version room version was increased from [9](https://spec.matrix.org/v1.5/rooms/v9/) to [10](https://spec.matrix.org/v1.5/rooms/v10/). Example configuration: ```yaml @@ -422,6 +424,10 @@ Sub-options for each listener include: * `port`: the TCP port to bind to. +* `tag`: An alias for the port in the logger name. If set the tag is logged instead +of the port. Default to `None`, is optional and only valid for listener with `type: http`. +See the docs [request log format](../administration/request_log.md). + * `bind_addresses`: a list of local addresses to listen on. The default is 'all local interfaces'. @@ -476,6 +482,12 @@ Valid resource names are: * `static`: static resources under synapse/static (/_matrix/static). (Mostly useful for 'fallback authentication'.) +* `health`: the [health check endpoint](../../reverse_proxy.md#health-check-endpoint). This endpoint + is by default active for all other resources and does not have to be activated separately. + This is only useful if you want to use the health endpoint explicitly on a dedicated port or + for [workers](../../workers.md) and containers without listener e.g. + [application services](../../workers.md#notifying-application-services). + Example configuration #1: ```yaml listeners: @@ -3462,8 +3474,8 @@ This setting defines options related to the user directory. This option has the following sub-options: * `enabled`: Defines whether users can search the user directory. If false then empty responses are returned to all queries. Defaults to true. -* `search_all_users`: Defines whether to search all users visible to your HS when searching - the user directory. If false, search results will only contain users +* `search_all_users`: Defines whether to search all users visible to your HS at the time the search is performed. If set to true, will return all users who share a room with the user from the homeserver. + If false, search results will only contain users visible in public rooms and users sharing a room with the requester. Defaults to false. @@ -4019,6 +4031,27 @@ worker_listeners: resources: - names: [client, federation] ``` +--- +### `worker_manhole` + +A worker may have a listener for [`manhole`](../../manhole.md). +It allows server administrators to access a Python shell on the worker. + +Example configuration: +```yaml +worker_manhole: 9000 +``` + +This is a short form for: +```yaml +worker_listeners: + - port: 9000 + bind_addresses: ['127.0.0.1'] + type: manhole +``` + +It needs also an additional [`manhole_settings`](#manhole_settings) configuration. + --- ### `worker_daemonize` diff --git a/docs/usage/configuration/logging_sample_config.md b/docs/usage/configuration/logging_sample_config.md index 499ab7cfe500..895674199738 100644 --- a/docs/usage/configuration/logging_sample_config.md +++ b/docs/usage/configuration/logging_sample_config.md @@ -1,9 +1,11 @@ # Logging Sample Configuration File Below is a sample logging configuration file. This file can be tweaked to control how your -homeserver will output logs. A restart of the server is generally required to apply any -changes made to this file. The value of the `log_config` option in your homeserver -config should be the path to this file. +homeserver will output logs. The value of the `log_config` option in your homeserver config +should be the path to this file. + +To apply changes made to this file, send Synapse a SIGHUP signal (or, if using `systemd`, run +`systemctl reload` on the Synapse service). Note that a default logging configuration (shown below) is created automatically alongside the homeserver config when following the [installation instructions](../../setup/installation.md). diff --git a/mypy.ini b/mypy.ini index 013fbbdfc02a..0efafb26b6a6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,31 +32,9 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - |tests/api/test_auth.py - |tests/api/test_ratelimiting.py - |tests/app/test_openid_listener.py - |tests/appservice/test_scheduler.py - |tests/events/test_presence_router.py - |tests/events/test_utils.py - |tests/federation/test_federation_catch_up.py - |tests/federation/test_federation_sender.py - |tests/federation/transport/test_knocking.py - |tests/handlers/test_typing.py - |tests/http/federation/test_matrix_federation_agent.py - |tests/http/federation/test_srv_resolver.py - |tests/http/test_proxyagent.py - |tests/logging/__init__.py - |tests/logging/test_terse_json.py |tests/module_api/test_api.py - |tests/push/test_email.py - |tests/push/test_presentable_names.py - |tests/push/test_push_rule_evaluator.py - |tests/rest/client/test_transactions.py |tests/rest/media/v1/test_media_storage.py |tests/server.py - |tests/server_notices/test_resource_limits_server_notices.py - |tests/test_state.py - |tests/test_terms_auth.py )$ [mypy-synapse.federation.transport.client] @@ -86,22 +64,43 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.api.*] +disallow_untyped_defs = True + +[mypy-tests.app.*] +disallow_untyped_defs = True + +[mypy-tests.appservice.*] +disallow_untyped_defs = True + [mypy-tests.config.*] disallow_untyped_defs = True [mypy-tests.crypto.*] disallow_untyped_defs = True -[mypy-tests.federation.transport.test_client] +[mypy-tests.events.*] +disallow_untyped_defs = True + +[mypy-tests.federation.*] disallow_untyped_defs = True [mypy-tests.handlers.*] disallow_untyped_defs = True +[mypy-tests.http.*] +disallow_untyped_defs = True + +[mypy-tests.logging.*] +disallow_untyped_defs = True + [mypy-tests.metrics.*] disallow_untyped_defs = True -[mypy-tests.push.test_bulk_push_rule_evaluator] +[mypy-tests.push.*] +disallow_untyped_defs = True + +[mypy-tests.replication.*] disallow_untyped_defs = True [mypy-tests.rest.*] @@ -116,6 +115,12 @@ disallow_untyped_defs = True [mypy-tests.test_server] disallow_untyped_defs = True +[mypy-tests.test_state] +disallow_untyped_defs = True + +[mypy-tests.test_terms_auth] +disallow_untyped_defs = True + [mypy-tests.types.*] disallow_untyped_defs = True @@ -142,9 +147,6 @@ disallow_untyped_defs = True [mypy-authlib.*] ignore_missing_imports = True -[mypy-canonicaljson] -ignore_missing_imports = True - [mypy-ijson.*] ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 108037a33db5..88884dce3d5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,11 +48,6 @@ line-length = 88 # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) # -# See https://github.com/charliermarsh/ruff/#pyflakes -# F401: unused import -# F811: Redefinition of unused -# F821: Undefined name -# # flake8-bugbear compatible checks. Its error codes are described at # https://github.com/charliermarsh/ruff/#flake8-bugbear # B019: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks @@ -64,9 +59,6 @@ ignore = [ "B024", "E501", "E731", - "F401", - "F811", - "F821", ] select = [ # pycodestyle checks. @@ -97,7 +89,7 @@ manifest-path = "rust/Cargo.toml" [tool.poetry] name = "matrix-synapse" -version = "1.75.0" +version = "1.77.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "Apache-2.0" @@ -135,7 +127,9 @@ exclude = [ { path = "synapse/*.so", format = "sdist"} ] -build = "build_rust.py" +[tool.poetry.build] +script = "build_rust.py" +generate-setup-file = true [tool.poetry.scripts] synapse_homeserver = "synapse.app.homeserver:main" @@ -318,7 +312,7 @@ all = [ # We pin black so that our tests don't start failing on new releases. isort = ">=5.10.1" black = ">=22.3.0" -ruff = "0.0.215" +ruff = "0.0.230" # Typechecking mypy = "*" @@ -359,7 +353,7 @@ towncrier = ">=18.6.0rc1" # system changes. # We are happy to raise these upper bounds upon request, # provided we check that it's safe to do so (i.e. that CI passes). -requires = ["poetry-core>=1.0.0,<=1.3.2", "setuptools_rust>=1.3,<=1.5.2"] +requires = ["poetry-core>=1.0.0,<=1.5.0", "setuptools_rust>=1.3,<=1.5.2"] build-backend = "poetry.core.masonry.api" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index cffaa5b51b94..09e2bba5e5c2 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -23,13 +23,17 @@ name = "synapse.synapse_rust" anyhow = "1.0.63" lazy_static = "1.4.0" log = "0.4.17" -pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] } +pyo3 = { version = "0.17.1", features = ["macros", "anyhow", "abi3", "abi3-py37"] } pyo3-log = "0.7.0" pythonize = "0.17.0" regex = "1.6.0" serde = { version = "1.0.144", features = ["derive"] } serde_json = "1.0.85" +[features] +extension-module = ["pyo3/extension-module"] +default = ["extension-module"] + [build-dependencies] blake2 = "0.10.4" hex = "0.4.3" diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 442a79348fcf..35f7a50bcea0 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -13,6 +13,7 @@ // limitations under the License. #![feature(test)] +use std::collections::BTreeSet; use synapse::push::{ evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, }; @@ -32,6 +33,9 @@ fn bench_match_exact(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -68,6 +72,9 @@ fn bench_match_word(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -104,6 +111,9 @@ fn bench_match_word_miss(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -140,6 +150,9 @@ fn bench_eval_message(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -150,8 +163,15 @@ fn bench_eval_message(b: &mut Bencher) { ) .unwrap(); - let rules = - FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false); + let rules = FilteredPushRules::py_new( + PushRules::new(Vec::new()), + Default::default(), + false, + false, + false, + false, + false, + ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c7b60e58a73b..ce67f5861183 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,7 +1,13 @@ +use lazy_static::lazy_static; use pyo3::prelude::*; +use pyo3_log::ResetHandle; pub mod push; +lazy_static! { + static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init(); +} + /// Returns the hash of all the rust source files at the time it was compiled. /// /// Used by python to detect if the rust library is outdated. @@ -17,13 +23,20 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { Ok((a + b).to_string()) } +/// Reset the cached logging configuration of pyo3-log to pick up any changes +/// in the Python logging configuration. +/// +#[pyfunction] +fn reset_logging_config() { + LOGGING_HANDLE.reset(); +} + /// The entry point for defining the Python module. #[pymodule] fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { - pyo3_log::init(); - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; + m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?; push::register_module(py, m)?; diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 9cbb2828990f..c2d15a4285c8 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -95,20 +95,20 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, - // We don't want to notify on edits in Beeper land. Not only can this be confusing - // in real time (2 notifications, one message) but it's also especially confusing - // when a bridge needs to edit a previously backfilled message. + // We don't want to notify on edits. Not only can this be confusing in real + // time (2 notifications, one message) but it's especially confusing + // if a bridge needs to edit a previously backfilled message. PushRule { rule_id: Cow::Borrowed("global/override/.com.beeper.suppress_edits"), priority_class: 5, - conditions: Cow::Borrowed(&[ - Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { key: Cow::Borrowed("content.m.relates_to.rel_type"), pattern: Some(Cow::Borrowed("m.replace")), pattern_type: None, - })), - ]), - actions: Cow::Borrowed(&[Action::DontNotify]), + }, + ))]), + actions: Cow::Borrowed(&[]), default: true, default_enabled: true, }, @@ -194,6 +194,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_user_mention"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsUserMention)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, @@ -202,6 +210,19 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mention"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::IsRoomMention), + Condition::Known(KnownCondition::SenderNotificationPermission { + key: Cow::Borrowed("room"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"), priority_class: 5, @@ -282,6 +303,20 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.org.matrix.msc3930.rule.poll_response"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.response")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, ]; pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { @@ -713,6 +748,68 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_start_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.start")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_start"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.start")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_end_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.end")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_end"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.end")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify]), + default: true, + default_enabled: true, + }, ]; lazy_static! { diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 305df8e8c3ef..78560e34a1f2 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -68,6 +68,13 @@ pub struct PushRuleEvaluator { /// The "content.body", if any. body: String, + /// True if the event has a mentions property and MSC3952 support is enabled. + has_mentions: bool, + /// The user mentions that were part of the message. + user_mentions: BTreeSet, + /// True if the message is a room message. + room_mention: bool, + /// The number of users in the room. room_member_count: u64, @@ -100,6 +107,9 @@ impl PushRuleEvaluator { #[new] pub fn py_new( flattened_keys: BTreeMap, + has_mentions: bool, + user_mentions: BTreeSet, + room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, @@ -116,6 +126,9 @@ impl PushRuleEvaluator { Ok(PushRuleEvaluator { flattened_keys, body, + has_mentions, + user_mentions, + room_mention, room_member_count, notification_power_levels, sender_power_level, @@ -146,6 +159,19 @@ impl PushRuleEvaluator { } let rule_id = &push_rule.rule_id().to_string(); + + // For backwards-compatibility the legacy mention rules are disabled + // if the event contains the 'm.mentions' property (and if the + // experimental feature is enabled, both of these are represented + // by the has_mentions flag). + if self.has_mentions + && (rule_id == "global/override/.m.rule.contains_display_name" + || rule_id == "global/content/.m.rule.contains_user_name" + || rule_id == "global/override/.m.rule.roomnotif") + { + continue; + } + let extev_flag = &RoomVersionFeatures::ExtensibleEvents.as_str().to_string(); let supports_extensible_events = self.room_version_feature_flags.contains(extev_flag); let safe_from_rver_condition = SAFE_EXTENSIBLE_EVENTS_RULE_IDS.contains(rule_id); @@ -232,6 +258,14 @@ impl PushRuleEvaluator { KnownCondition::InverseRelatedEventMatch(event_match) => { !self.match_related_event_match(event_match, user_id)? } + KnownCondition::IsUserMention => { + if let Some(uid) = user_id { + self.user_mentions.contains(uid) + } else { + false + } + } + KnownCondition::IsRoomMention => self.room_mention, KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -427,6 +461,9 @@ fn push_rule_evaluator() { flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); let evaluator = PushRuleEvaluator::py_new( flattened_keys, + false, + BTreeSet::new(), + false, 10, Some(0), BTreeMap::new(), @@ -452,6 +489,9 @@ fn test_requires_room_version_supports_condition() { let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, + false, + BTreeSet::new(), + false, 10, Some(0), BTreeMap::new(), @@ -486,7 +526,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, true), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 3dd32bd60709..cea46e9afb1a 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -271,6 +271,10 @@ pub enum KnownCondition { RelatedEventMatch(RelatedEventMatchCondition), #[serde(rename = "im.nheko.msc3664.inverse_related_event_match")] InverseRelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3952.is_user_mention")] + IsUserMention, + #[serde(rename = "org.matrix.msc3952.is_room_mention")] + IsRoomMention, ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -413,8 +417,11 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3664_enabled: bool, msc1767_enabled: bool, + msc3381_polls_enabled: bool, + msc3664_enabled: bool, + msc3952_intentional_mentions: bool, + msc3958_suppress_edits_enabled: bool, } #[pymethods] @@ -423,14 +430,20 @@ impl FilteredPushRules { pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, - msc3664_enabled: bool, msc1767_enabled: bool, + msc3381_polls_enabled: bool, + msc3664_enabled: bool, + msc3952_intentional_mentions: bool, + msc3958_suppress_edits_enabled: bool, ) -> Self { Self { push_rules, enabled_map, - msc3664_enabled, msc1767_enabled, + msc3381_polls_enabled, + msc3664_enabled, + msc3952_intentional_mentions, + msc3958_suppress_edits_enabled, } } @@ -449,13 +462,28 @@ impl FilteredPushRules { .iter() .filter(|rule| { // Ignore disabled experimental push rules + + if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + return false; + } + if !self.msc3664_enabled && rule.rule_id == "global/override/.im.nheko.msc3664.reply" { return false; } - if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + if !self.msc3381_polls_enabled && rule.rule_id.contains("org.matrix.msc3930") { + return false; + } + + if !self.msc3952_intentional_mentions && rule.rule_id.contains("org.matrix.msc3952") + { + return false; + } + if !self.msc3958_suppress_edits_enabled + && rule.rule_id == "global/override/.com.beeper.suppress_edits" + { return false; } @@ -516,6 +544,28 @@ fn test_deserialize_unstable_msc3931_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3952_user_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsUserMention) + )); +} + +#[test] +fn test_deserialize_unstable_msc3952_room_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsRoomMention) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 51d1bac6183c..66aaa3d8480d 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -190,7 +190,7 @@ fi extra_test_args=() -test_tags="synapse_blacklist,msc3787,msc3874,msc3391" +test_tags="synapse_blacklist,msc3787,msc3874,msc3890,msc3391,msc3930,faster_joins" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) @@ -223,12 +223,14 @@ else export PASS_SYNAPSE_COMPLEMENT_DATABASE=sqlite fi - # We only test faster room joins on monoliths, because they are purposefully - # being developed without worker support to start with. - # - # The tests for importing historical messages (MSC2716) also only pass with monoliths, - # currently. - test_tags="$test_tags,faster_joins,msc2716" + # The tests for importing historical messages (MSC2716) + # only pass with monoliths, currently. + test_tags="$test_tags,msc2716" +fi + +if [[ -n "$ASYNCIO_REACTOR" ]]; then + # Enable the Twisted asyncio reactor + export PASS_SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR=true fi diff --git a/scripts-dev/database-save.sh b/scripts-dev/database-save.sh index 040c8a494319..91674027ae29 100755 --- a/scripts-dev/database-save.sh +++ b/scripts-dev/database-save.sh @@ -11,6 +11,5 @@ sqlite3 "$1" <<'EOF' >table-save.sql .dump users .dump access_tokens -.dump presence .dump profiles EOF diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 2bf58ac5d4a0..392c509a8ac1 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -101,10 +101,43 @@ echo # Print out the commands being run set -x +# Ensure the sort order of imports. isort "${files[@]}" + +# Ensure Python code conforms to an opinionated style. python3 -m black "${files[@]}" + +# Ensure the sample configuration file conforms to style checks. ./scripts-dev/config-lint.sh + +# Catch any common programming mistakes in Python code. # --quiet suppresses the update check. ruff --quiet "${files[@]}" + +# Catch any common programming mistakes in Rust code. +# +# --bins, --examples, --lib, --tests combined explicitly disable checking +# the benchmarks, which can fail due to `#![feature]` macros not being +# allowed on the stable rust toolchain (rustc error E0554). +# +# --allow-staged and --allow-dirty suppress clippy raising errors +# for uncommitted files. Only needed when using --fix. +# +# -D warnings disables the "warnings" lint. +# +# Using --fix has a tendency to cause subsequent runs of clippy to recompile +# rust code, which can slow down this script. Thus we run clippy without --fix +# first which is quick, and then re-run it with --fix if an error was found. +if ! cargo-clippy --bins --examples --lib --tests -- -D warnings > /dev/null 2>&1; then + cargo-clippy \ + --bins --examples --lib --tests --allow-staged --allow-dirty --fix -- -D warnings +fi + +# Ensure the formatting of Rust code. +cargo-fmt + +# Ensure all Pydantic models use strict types. ./scripts-dev/check_pydantic_models.py lint + +# Ensure type hints are correct. mypy diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 6974fd789575..008a5bd965d4 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -438,7 +438,7 @@ def _upload(gh_token: Optional[str]) -> None: repo = get_repo_and_check_clean_checkout() tag = repo.tag(f"refs/tags/{tag_name}") if repo.head.commit != tag.commit: - click.echo("Tag {tag_name} (tag.commit) is not currently checked out!") + click.echo(f"Tag {tag_name} ({tag.commit}) is not currently checked out!") click.get_current_context().abort() # Query all the assets corresponding to this release. diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index cd4c969849b1..1fe1a136f190 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -7,7 +7,6 @@ from __future__ import annotations from typing import ( Any, Callable, - Generic, Iterable, Iterator, List, diff --git a/stubs/sortedcontainers/sortedset.pyi b/stubs/sortedcontainers/sortedset.pyi index d761c438f792..6db11eacbed9 100644 --- a/stubs/sortedcontainers/sortedset.pyi +++ b/stubs/sortedcontainers/sortedset.pyi @@ -5,10 +5,8 @@ from __future__ import annotations from typing import ( - AbstractSet, Any, Callable, - Generic, Hashable, Iterable, Iterator, diff --git a/stubs/synapse/synapse_rust/__init__.pyi b/stubs/synapse/synapse_rust/__init__.pyi index 8658d3138f89..d25c60910662 100644 --- a/stubs/synapse/synapse_rust/__init__.pyi +++ b/stubs/synapse/synapse_rust/__init__.pyi @@ -1,2 +1,3 @@ def sum_as_string(a: int, b: int) -> str: ... def get_rust_file_digest() -> str: ... +def reset_logging_config() -> None: ... diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index dab5d4aff7ce..754acab2f978 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -1,4 +1,18 @@ -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -29,8 +43,11 @@ class FilteredPushRules: self, push_rules: PushRules, enabled_map: Dict[str, bool], - msc3664_enabled: bool, msc1767_enabled: bool, + msc3381_polls_enabled: bool, + msc3664_enabled: bool, + msc3952_intentional_mentions: bool, + msc3958_suppress_edits_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -40,6 +57,9 @@ class PushRuleEvaluator: def __init__( self, flattened_keys: Mapping[str, str], + has_mentions: bool, + user_mentions: Set[str], + room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], @@ -54,3 +74,6 @@ class PushRuleEvaluator: user_id: Optional[str], display_name: Optional[str], ) -> Collection[Union[Mapping, str]]: ... + def matches( + self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str] + ) -> bool: ... diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index c463b60b2620..5e137dbbf711 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -51,6 +51,7 @@ make_deferred_yieldable, run_in_background, ) +from synapse.notifier import ReplicationNotifier from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -260,6 +261,9 @@ def get_instance_name(self) -> str: def should_send_federation(self) -> bool: return False + def get_replication_notifier(self) -> ReplicationNotifier: + return ReplicationNotifier() + class Porter: def __init__( diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 82150e40cdf9..5a943b02f34a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +import enum + from typing_extensions import Final # the max size of a (canonical-json-encoded) event @@ -231,6 +233,9 @@ class EventContentFields: # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" + # Use for mentioning users. + MSC3952_MENTIONS: Final = "org.matrix.msc3952.mentions" + # an unspecced field added to to-device messages to identify them uniquely-ish TO_DEVICE_MSGID: Final = "org.matrix.msgid" @@ -249,6 +254,7 @@ class RoomEncryptionAlgorithms: class AccountDataTypes: DIRECT: Final = "m.direct" IGNORED_USER_LIST: Final = "m.ignored_user_list" + TAG: Final = "m.tag" class HistoryVisibility: @@ -290,3 +296,8 @@ class ApprovalNoticeMedium: NONE = "org.matrix.msc3866.none" EMAIL = "org.matrix.msc3866.email" + + +class Direction(enum.Enum): + BACKWARDS = "b" + FORWARDS = "f" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4cf8f0cc8ef9..83c42fc25a22 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -252,9 +252,9 @@ def unread_thread_notifications(self) -> bool: return self._room_timeline_filter.unread_thread_notifications async def filter_presence( - self, events: Iterable[UserPresenceState] + self, presence_states: Iterable[UserPresenceState] ) -> List[UserPresenceState]: - return await self._presence_filter.filter(events) + return await self._presence_filter.filter(presence_states) async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: return await self._account_data.filter(events) @@ -283,6 +283,9 @@ async def filter_room_account_data( await self._room_filter.filter(events) ) + def blocks_all_rooms(self) -> bool: + return self._room_filter.filters_all_rooms() + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 165d1c5db06b..fe7afb94755e 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -35,6 +35,7 @@ ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) +from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.event_federation import EventFederationWorkerStore @@ -43,6 +44,7 @@ ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore @@ -54,7 +56,7 @@ from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import StateMap +from synapse.types import JsonDict, StateMap from synapse.util import SYNAPSE_VERSION from synapse.util.logcontext import LoggingContext @@ -63,6 +65,7 @@ class AdminCmdSlavedStore( FilteringWorkerStore, + ClientIpWorkerStore, DeviceWorkerStore, TagsWorkerStore, DeviceInboxWorkerStore, @@ -82,6 +85,7 @@ class AdminCmdSlavedStore( EventsWorkerStore, RegistrationWorkerStore, RoomWorkerStore, + ProfileWorkerStore, ): def __init__( self, @@ -192,6 +196,32 @@ def write_knock( for event in state.values(): print(json.dumps(event), file=f) + def write_profile(self, profile: JsonDict) -> None: + user_directory = os.path.join(self.base_directory, "user_data") + os.makedirs(user_directory, exist_ok=True) + profile_file = os.path.join(user_directory, "profile") + + with open(profile_file, "a") as f: + print(json.dumps(profile), file=f) + + def write_devices(self, devices: List[JsonDict]) -> None: + user_directory = os.path.join(self.base_directory, "user_data") + os.makedirs(user_directory, exist_ok=True) + device_file = os.path.join(user_directory, "devices") + + for device in devices: + with open(device_file, "a") as f: + print(json.dumps(device), file=f) + + def write_connections(self, connections: List[JsonDict]) -> None: + user_directory = os.path.join(self.base_directory, "user_data") + os.makedirs(user_directory, exist_ok=True) + connection_file = os.path.join(user_directory, "connections") + + for connection in connections: + with open(connection_file, "a") as f: + print(json.dumps(connection), file=f) + def finished(self) -> str: return self.base_directory diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 8c0f4a57e70a..920538f44df2 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -110,6 +110,8 @@ def _worker_entrypoint( and then kick off the worker's main() function. """ + from synapse.util.stringutils import strtobool + sys.argv = args # reset the custom signal handlers that we installed, so that the children start @@ -117,9 +119,24 @@ def _worker_entrypoint( for sig, handler in _original_signal_handlers.items(): signal.signal(sig, handler) - from twisted.internet.epollreactor import EPollReactor + # Install the asyncio reactor if the + # SYNAPSE_COMPLEMENT_FORKING_LAUNCHER_ASYNC_IO_REACTOR is set to 1. The + # SYNAPSE_ASYNC_IO_REACTOR variable would be used, but then causes + # synapse/__init__.py to also try to install an asyncio reactor. + if strtobool( + os.environ.get("SYNAPSE_COMPLEMENT_FORKING_LAUNCHER_ASYNC_IO_REACTOR", "0") + ): + import asyncio + + from twisted.internet.asyncioreactor import AsyncioSelectorReactor + + reactor = AsyncioSelectorReactor(asyncio.get_event_loop()) + proxy_reactor._install_real_reactor(reactor) + else: + from twisted.internet.epollreactor import EPollReactor + + proxy_reactor._install_real_reactor(EPollReactor()) - proxy_reactor._install_real_reactor(EPollReactor()) func() diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 50dc86841921..a9462c9c5262 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -201,6 +201,9 @@ def _listen_http(self, listener_config: ListenerConfig) -> None: "A 'media' listener is configured but the media" " repository is disabled. Ignoring." ) + elif name == "health": + # Skip loading, health resource is always included + continue if name == "openid" and "federation" not in res.names: # Only load the openid resource separately if federation resource @@ -281,13 +284,6 @@ def start(config_options: List[str]) -> None: "synapse.app.user_dir", ) - if config.experimental.faster_joins_enabled: - raise ConfigError( - "You have enabled the experimental `faster_joins` config option, but it is " - "not compatible with worker deployments yet. Please disable `faster_joins` " - "or run Synapse as a single process deployment instead." - ) - synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b9be558c7ea0..6176a70eb2a0 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -96,6 +96,9 @@ def _listener_http( # Skip loading openid resource if federation is defined # since federation resource will include openid continue + if name == "health": + # Skip loading, health resource is always included + continue resources.update(self._configure_named_resource(name, res.compress)) additional_resources = listener_config.http_options.additional_resources diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 288aa5814cb3..c99c0ec899ba 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -16,7 +16,7 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern, Sequence import attr from netaddr import IPSet @@ -382,7 +382,7 @@ def __init__( self, service: ApplicationService, id: int, - events: List[EventBase], + events: Sequence[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], one_time_keys_count: TransactionOneTimeKeysCount, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index edafd433cda3..1a6f69e7d3f3 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,7 +14,17 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, +) from prometheus_client import Counter from typing_extensions import TypeGuard @@ -259,7 +269,7 @@ async def _get() -> Optional[JsonDict]: async def push_bulk( self, service: "ApplicationService", - events: List[EventBase], + events: Sequence[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], one_time_keys_count: TransactionOneTimeKeysCount, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index e11a312e6bc8..2c9e587e3a82 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -57,6 +57,7 @@ Iterable, List, Optional, + Sequence, Set, Tuple, ) @@ -364,7 +365,7 @@ def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi async def send( self, service: ApplicationService, - events: List[EventBase], + events: Sequence[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1f6362aedd56..2ce60610ca6f 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -174,15 +174,29 @@ def __init__(self, root_config: "RootConfig" = None): @staticmethod def parse_size(value: Union[str, int]) -> int: - if isinstance(value, int): + """Interpret `value` as a number of bytes. + + If an integer is provided it is treated as bytes and is unchanged. + + String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and + mebibytes respectively. No suffix is understood as a plain byte count. + + Raises: + TypeError, if given something other than an integer or a string + ValueError: if given a string not of the form described above. + """ + if type(value) is int: return value - sizes = {"K": 1024, "M": 1024 * 1024} - size = 1 - suffix = value[-1] - if suffix in sizes: - value = value[:-1] - size = sizes[suffix] - return int(value) * size + elif type(value) is str: + sizes = {"K": 1024, "M": 1024 * 1024} + size = 1 + suffix = value[-1] + if suffix in sizes: + value = value[:-1] + size = sizes[suffix] + return int(value) * size + else: + raise TypeError(f"Bad byte size {value!r}") @staticmethod def parse_duration(value: Union[str, int]) -> int: @@ -198,22 +212,36 @@ def parse_duration(value: Union[str, int]) -> int: Returns: The number of milliseconds in the duration. + + Raises: + TypeError, if given something other than an integer or a string + ValueError: if given a string not of the form described above. """ - if isinstance(value, int): + if type(value) is int: return value - second = 1000 - minute = 60 * second - hour = 60 * minute - day = 24 * hour - week = 7 * day - year = 365 * day - sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} - size = 1 - suffix = value[-1] - if suffix in sizes: - value = value[:-1] - size = sizes[suffix] - return int(value) * size + elif type(value) is str: + second = 1000 + minute = 60 * second + hour = 60 * minute + day = 24 * hour + week = 7 * day + year = 365 * day + sizes = { + "s": second, + "m": minute, + "h": hour, + "d": day, + "w": week, + "y": year, + } + size = 1 + suffix = value[-1] + if suffix in sizes: + value = value[:-1] + size = sizes[suffix] + return int(value) * size + else: + raise TypeError(f"Bad duration {value!r}") @staticmethod def abspath(file_path: str) -> str: diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index bd265de53613..b5cec132b4c1 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,5 +1,3 @@ -from __future__ import annotations - import argparse from typing import ( Any, @@ -20,7 +18,7 @@ from typing import ( import jinja2 -from synapse.config import ( +from synapse.config import ( # noqa: F401 account_validity, api, appservice, @@ -169,7 +167,7 @@ class RootConfig: self, section_name: Literal["caches"] ) -> cache.CacheConfig: ... @overload - def reload_config_section(self, section_name: str) -> Config: ... + def reload_config_section(self, section_name: str) -> "Config": ... class Config: root: RootConfig @@ -202,9 +200,9 @@ def find_config_files(search_paths: List[str]) -> List[str]: ... class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... - def should_handle(self, instance_name: str, key: str) -> bool: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... # noqa: F811 class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): - def get_instance(self, key: str) -> str: ... + def get_instance(self, key: str) -> str: ... # noqa: F811 def read_file(file_path: Any, config_path: Iterable[str]) -> str: ... diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 015b2a138e85..05f69cb1bacc 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -126,7 +126,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: cache_config = config.get("caches") or {} self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE) - if not isinstance(self.global_factor, (int, float)): + if type(self.global_factor) not in (int, float): raise ConfigError("caches.global_factor must be a number.") # Load cache factors from the config @@ -151,7 +151,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: ) for cache, factor in individual_factors.items(): - if not isinstance(factor, (int, float)): + if type(factor) not in (int, float): raise ConfigError( "caches.per_cache_factors.%s must be a number" % (cache,) ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 59a79dc514c5..12d6f51d12b8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -17,6 +17,7 @@ import attr from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions +from synapse.config import ConfigError from synapse.config._base import Config from synapse.types import JsonDict @@ -78,12 +79,16 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: ) # MSC3706 (server-side support for partial state in /send_join responses) + # Synapse will always serve partial state responses to requests using the stable + # query parameter `omit_members`. If this flag is set, Synapse will also serve + # partial state responses to requests using the unstable query parameter + # `org.matrix.msc3706.partial_state`. self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) # experimental support for faster joins over federation # (MSC2775, MSC3706, MSC3895) - # requires a target server with msc3706_enabled enabled. - self.faster_joins_enabled: bool = experimental.get("faster_joins", False) + # requires a target server that can provide a partial join response (MSC3706) + self.faster_joins_enabled: bool = experimental.get("faster_joins", True) # MSC3720 (Account status endpoint) self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) @@ -97,6 +102,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) + # MSC3391: Removing account data. + self.msc3391_enabled = experimental.get("msc3391_enabled", False) + # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) @@ -131,6 +139,24 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "msc3886_endpoint", None ) + # MSC3890: Remotely silence local notifications + # Note: This option requires "experimental_features.msc3391_enabled" to be + # set to "true", in order to communicate account data deletions to clients. + self.msc3890_enabled: bool = experimental.get("msc3890_enabled", False) + if self.msc3890_enabled and not self.msc3391_enabled: + raise ConfigError( + "Option 'experimental_features.msc3391' must be set to 'true' to " + "enable 'experimental_features.msc3890'. MSC3391 functionality is " + "required to communicate account data deletions to clients." + ) + + # MSC3381: Polls. + # In practice, supporting polls in Synapse only requires an implementation of + # MSC3930: Push rules for MSC3391 polls; which is what this option enables. + self.msc3381_polls_enabled: bool = experimental.get( + "msc3381_polls_enabled", False + ) + # MSC3912: Relation-based redactions. self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) @@ -148,3 +174,16 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "beeper_user_notification_counts_enabled", False, ) + + # MSC3925: do not replace events with their edits + self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + + # MSC3952: Intentional mentions + self.msc3952_intentional_mentions = experimental.get( + "msc3952_intentional_mentions", False + ) + + # MSC3959: Do not generate notifications for edits. + self.msc3958_supress_edit_notifs = experimental.get( + "msc3958_supress_edit_notifs", False + ) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 5468b963a2c1..56db875b25b1 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -34,6 +34,7 @@ from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter +from synapse.synapse_rust import reset_logging_config from synapse.types import JsonDict from ..util import SYNAPSE_VERSION @@ -200,24 +201,6 @@ def _setup_stdlib_logging( """ Set up Python standard library logging. """ - if log_config_path is None: - log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" - " - %(message)s" - ) - - logger = logging.getLogger("") - logger.setLevel(logging.INFO) - logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO) - - formatter = logging.Formatter(log_format) - - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - else: - # Load the logging configuration. - _load_logging_config(log_config_path) # We add a log record factory that runs all messages through the # LoggingContextFilter so that we get the context *at the time we log* @@ -237,6 +220,26 @@ def factory(*args: Any, **kwargs: Any) -> logging.LogRecord: logging.setLogRecordFactory(factory) + # Configure the logger with the initial configuration. + if log_config_path is None: + log_format = ( + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" + " - %(message)s" + ) + + logger = logging.getLogger("") + logger.setLevel(logging.INFO) + logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO) + + formatter = logging.Formatter(log_format) + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + else: + # Load the logging configuration. + _load_logging_config(log_config_path) + # Route Twisted's native logging through to the standard library logging # system. observer = STDLibLogObserver() @@ -294,6 +297,9 @@ def _load_logging_config(log_config_path: str) -> None: logging.config.dictConfig(log_config) + # Blow away the pyo3-log cache so that it reloads the configuration. + reset_logging_config() + def _reload_logging_config(log_config_path: Optional[str]) -> None: """ diff --git a/synapse/config/server.py b/synapse/config/server.py index ec46ca63adf5..ecdaa2d9dd24 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -151,7 +151,7 @@ def generate_ip_set( "fec0::/10", ] -DEFAULT_ROOM_VERSION = "9" +DEFAULT_ROOM_VERSION = "10" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " @@ -904,7 +904,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type")) port = listener.get("port") - if not isinstance(port, int): + if type(port) is not int: raise ConfigError("Listener configuration is lacking a valid 'port' option") tls = listener.get("tls", False) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 69310d90351c..86cd4af9bd5a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -154,17 +154,21 @@ def __init__( if key_fetchers is None: key_fetchers = ( + # Fetch keys from the database. StoreKeyFetcher(hs), + # Fetch keys from a configured Perspectives server. PerspectivesKeyFetcher(hs), + # Fetch keys from the origin server directly. ServerKeyFetcher(hs), ) self._key_fetchers = key_fetchers - self._server_queue: BatchingQueue[ + self._fetch_keys_queue: BatchingQueue[ _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] ] = BatchingQueue( "keyring_server", clock=hs.get_clock(), + # The method called to fetch each key process_batch_callback=self._inner_fetch_key_requests, ) @@ -287,7 +291,7 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None: minimum_valid_until_ts=verify_request.minimum_valid_until_ts, key_ids=list(key_ids_to_find), ) - found_keys_by_server = await self._server_queue.add_to_queue( + found_keys_by_server = await self._fetch_keys_queue.add_to_queue( key_request, key=verify_request.server_name ) @@ -352,7 +356,17 @@ async def _process_json( async def _inner_fetch_key_requests( self, requests: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: - """Processing function for the queue of `_FetchKeyRequest`.""" + """Processing function for the queue of `_FetchKeyRequest`. + + Takes a list of key fetch requests, de-duplicates them and then carries out + each request by invoking self._inner_fetch_key_request. + + Args: + requests: A list of requests for homeserver verify keys. + + Returns: + {server name: {key id: fetch key result}} + """ logger.debug("Starting fetch for %s", requests) @@ -397,8 +411,23 @@ async def _inner_fetch_key_requests( async def _inner_fetch_key_request( self, verify_request: _FetchKeyRequest ) -> Dict[str, FetchKeyResult]: - """Attempt to fetch the given key by calling each key fetcher one by - one. + """Attempt to fetch the given key by calling each key fetcher one by one. + + If a key is found, check whether its `valid_until_ts` attribute satisfies the + `minimum_valid_until_ts` attribute of the `verify_request`. If it does, we + refrain from asking subsequent fetchers for that key. + + Even if the above check fails, we still return the found key - the caller may + still find the invalid key result useful. In this case, we continue to ask + subsequent fetchers for the invalid key, in case they return a valid result + for it. This can happen when fetching a stale key result from the database, + before querying the origin server for an up-to-date result. + + Args: + verify_request: The request for a verify key. Can include multiple key IDs. + + Returns: + A map of {key_id: the key fetch result}. """ logger.debug("Starting fetch for %s", verify_request) @@ -420,26 +449,22 @@ async def _inner_fetch_key_request( if not key: continue - # If we already have a result for the given key ID we keep the + # If we already have a result for the given key ID, we keep the # one with the highest `valid_until_ts`. existing_key = found_keys.get(key_id) - if existing_key: - if key.valid_until_ts <= existing_key.valid_until_ts: - continue + if existing_key and existing_key.valid_until_ts > key.valid_until_ts: + continue + + # Check if this key's expiry timestamp is valid for the verify request. + if key.valid_until_ts >= verify_request.minimum_valid_until_ts: + # Stop looking for this key from subsequent fetchers. + missing_key_ids.discard(key_id) - # We always store the returned key even if it doesn't the + # We always store the returned key even if it doesn't meet the # `minimum_valid_until_ts` requirement, as some verification # requests may still be able to be satisfied by it. - # - # We still keep looking for the key from other fetchers in that - # case though. found_keys[key_id] = key - if key.valid_until_ts < verify_request.minimum_valid_until_ts: - continue - - missing_key_ids.discard(key_id) - return found_keys diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c4a7b16413c5..e0be9f88cc9d 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -875,11 +875,11 @@ def _check_power_levels( "kick", "invite", }: - if not isinstance(v, int): + if type(v) is not int: raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - isinstance(v, int) for v in v.values() + type(v) is int for v in v.values() ): raise SynapseError( 400, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 3a92f6284835..caf6ba0605af 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -408,6 +408,14 @@ class EventClientSerializer: clients. """ + def __init__(self, inhibit_replacement_via_edits: bool = False): + """ + Args: + inhibit_replacement_via_edits: If this is set to True, then events are + never replaced by their edits. + """ + self._inhibit_replacement_via_edits = inhibit_replacement_via_edits + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -427,6 +435,8 @@ def serialize_event( into the event. apply_edits: Whether the content of the event should be modified to reflect any replacement in `bundle_aggregations[].replace`. + See also the `inhibit_replacement_via_edits` constructor arg: if that is + set to True, then this argument is ignored. Returns: The serialized event """ @@ -500,7 +510,8 @@ def _inject_bundled_aggregations( again for additional events in a recursive manner. serialized_event: The serialized event which may be modified. apply_edits: Whether the content of the event should be modified to reflect - any replacement in `aggregations.replace`. + any replacement in `aggregations.replace` (subject to the + `inhibit_replacement_via_edits` constructor arg). """ # We have already checked that aggregations exist for this event. @@ -523,15 +534,21 @@ def _inject_bundled_aggregations( if event_aggregations.replace: # If there is an edit, optionally apply it to the event. edit = event_aggregations.replace - if apply_edits: + if apply_edits and not self._inhibit_replacement_via_edits: self._apply_edit(event, serialized_event, edit) # Include information about it in the relations dict. - serialized_aggregations[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - "origin_server_ts": edit.origin_server_ts, - "sender": edit.sender, - } + # + # Matrix spec v1.5 (https://spec.matrix.org/v1.5/client-server-api/#server-side-aggregation-of-mreplace-relationships) + # said that we should only include the `event_id`, `origin_server_ts` and + # `sender` of the edit; however MSC3925 proposes extending it to the whole + # of the edit, which is what we do here. + serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event( + edit, + time_now, + config=config, + apply_edits=False, + ) # Include any threaded replies to this event. if event_aggregations.thread: @@ -593,10 +610,11 @@ def serialize_events( _PowerLevel = Union[str, int] +PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] def copy_and_fixup_power_levels_contents( - old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] + old_power_levels: PowerLevelsContent, ) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way. @@ -635,10 +653,10 @@ def _copy_power_level_value_as_integer( ) -> None: """Set `power_levels[key]` to the integer represented by `old_value`. - :raises TypeError: if `old_value` is not an integer, nor a base-10 string + :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if isinstance(old_value, int): + if type(old_value) is int: power_levels[key] = old_value return @@ -666,7 +684,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if isinstance(value, int): + if type(value) is int: if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 01a11926f6e8..08d601b8cf99 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -144,7 +144,7 @@ def _validate_retention(self, event: EventBase) -> None: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if not isinstance(min_lifetime, int): + if type(min_lifetime) is not int: raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -152,7 +152,7 @@ def _validate_retention(self, event: EventBase) -> None: ) if max_lifetime is not None: - if not isinstance(max_lifetime, int): + if type(max_lifetime) is not int: raise SynapseError( code=400, msg="'max_lifetime' must be an integer", diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 6bd4742140c4..29fae716f589 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if not isinstance(depth, int): + if type(depth) is not int: raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 137cfb3346d2..0ac85a3be717 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,6 +19,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Awaitable, Callable, Collection, @@ -37,7 +38,7 @@ import attr from prometheus_client import Counter -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership from synapse.api.errors import ( CodeMessageException, Codes, @@ -110,8 +111,9 @@ class SendJoinResult: # True if 'state' elides non-critical membership events partial_state: bool - # if 'partial_state' is set, a list of the servers in the room (otherwise empty) - servers_in_room: List[str] + # If 'partial_state' is set, a set of the servers in the room (otherwise empty). + # Always contains the server we joined off. + servers_in_room: AbstractSet[str] class FederationClient(FederationBase): @@ -1014,7 +1016,11 @@ async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]: ) async def send_join( - self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion + self, + destinations: Iterable[str], + pdu: EventBase, + room_version: RoomVersion, + partial_state: bool = True, ) -> SendJoinResult: """Sends a join event to one of a list of homeservers. @@ -1027,6 +1033,10 @@ async def send_join( pdu: event to be sent room_version: the version of the room (according to the server that did the make_join) + partial_state: whether to ask the remote server to omit membership state + events from the response. If the remote server complies, + `partial_state` in the send join result will be set. Defaults to + `True`. Returns: The result of the send join request. @@ -1037,7 +1047,9 @@ async def send_join( """ async def send_request(destination: str) -> SendJoinResult: - response = await self._do_send_join(room_version, destination, pdu) + response = await self._do_send_join( + room_version, destination, pdu, omit_members=partial_state + ) # If an event was returned (and expected to be returned): # @@ -1142,18 +1154,32 @@ async def _execute(pdu: EventBase) -> None: % (auth_chain_create_events,) ) - if response.partial_state and not response.servers_in_room: - raise InvalidResponseError( - "partial_state was set, but no servers were listed in the room" - ) + servers_in_room = None + if response.servers_in_room is not None: + servers_in_room = set(response.servers_in_room) + + if response.members_omitted: + if not servers_in_room: + raise InvalidResponseError( + "members_omitted was set, but no servers were listed in the room" + ) + + if not partial_state: + raise InvalidResponseError( + "members_omitted was set, but we asked for full state" + ) + + # `servers_in_room` is supposed to be a complete list. + # Fix things up in case the remote homeserver is badly behaved. + servers_in_room.add(destination) return SendJoinResult( event=event, state=signed_state, auth_chain=signed_auth, origin=destination, - partial_state=response.partial_state, - servers_in_room=response.servers_in_room or [], + partial_state=response.members_omitted, + servers_in_room=servers_in_room or frozenset(), ) # MSC3083 defines additional error codes for room joins. @@ -1177,7 +1203,11 @@ async def _execute(pdu: EventBase) -> None: ) async def _do_send_join( - self, room_version: RoomVersion, destination: str, pdu: EventBase + self, + room_version: RoomVersion, + destination: str, + pdu: EventBase, + omit_members: bool, ) -> SendJoinResponse: time_now = self._clock.time_msec() @@ -1188,6 +1218,7 @@ async def _do_send_join( room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), + omit_members=omit_members, ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, @@ -1660,7 +1691,12 @@ async def send_request( return result async def timestamp_to_event( - self, *, destinations: List[str], room_id: str, timestamp: int, direction: str + self, + *, + destinations: List[str], + room_id: str, + timestamp: int, + direction: Direction, ) -> Optional["TimestampToEventResponse"]: """ Calls each remote federating server from `destinations` asking for their closest @@ -1673,7 +1709,7 @@ async def timestamp_to_event( room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -1718,7 +1754,7 @@ async def _timestamp_to_event_from_destination( return None async def _timestamp_to_event_from_destination( - self, destination: str, room_id: str, timestamp: int, direction: str + self, destination: str, room_id: str, timestamp: int, direction: Direction ) -> "TimestampToEventResponse": """ Calls a remote federating server at `destination` asking for their @@ -1731,7 +1767,7 @@ async def _timestamp_to_event_from_destination( room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -1844,7 +1880,7 @@ def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": ) origin_server_ts = d.get("origin_server_ts") - if not isinstance(origin_server_ts, int): + if type(origin_server_ts) is not int: raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index bb20af6e91ed..8d36172484d6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,13 @@ from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + Direction, + EduTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,7 +68,9 @@ run_in_background, ) from synapse.logging.opentracing import ( + SynapseTags, log_kv, + set_tag, start_active_span_from_edu, tag_args, trace, @@ -216,7 +224,7 @@ async def on_backfill_request( return 200, res async def on_timestamp_to_event_request( - self, origin: str, room_id: str, timestamp: int, direction: str + self, origin: str, room_id: str, timestamp: int, direction: Direction ) -> Tuple[int, Dict[str, Any]]: """When we receive a federated `/timestamp_to_event` request, handle all of the logic for validating and fetching the event. @@ -226,7 +234,7 @@ async def on_timestamp_to_event_request( room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -678,6 +686,10 @@ async def on_send_join_request( room_id: str, caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: + set_tag( + SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE, + caller_supports_partial_state, + ) await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] requester=None, key=room_id, @@ -725,10 +737,12 @@ async def on_send_join_request( "state": [p.get_pdu_json(time_now) for p in state_events], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events], "org.matrix.msc3706.partial_state": caller_supports_partial_state, + "members_omitted": caller_supports_partial_state, } if servers_in_room is not None: resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room) + resp["servers_in_room"] = list(servers_in_room) return resp @@ -1500,7 +1514,7 @@ def _get_event_ids_for_partial_state_join( prev_state_ids: StateMap[str], summary: Dict[str, MemberSummary], ) -> Collection[str]: - """Calculate state to be retuned in a partial_state send_join + """Calculate state to be returned in a partial_state send_join Args: join_event: the join event being send_joined diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index b6c69ee8e769..902214f6513e 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -454,7 +454,7 @@ async def handle_event(event: EventBase) -> None: ) ) - if len(partial_state_destinations) > 0: + if partial_state_destinations is not None: destinations = partial_state_destinations if destinations is None: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 77f1f39cacb1..c05d598b70cf 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -32,7 +32,7 @@ import attr import ijson -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.room_versions import RoomVersion from synapse.api.urls import ( @@ -102,6 +102,10 @@ async def get_room_state( destination, path=path, args={"event_id": event_id}, + # This can take a looooooong time for large rooms. Give this a generous + # timeout of 10 minutes to avoid the partial state resync timing out early + # and trying a bunch of servers who haven't seen our join yet. + timeout=600_000, parser=_StateParser(room_version), ) @@ -165,7 +169,7 @@ async def backfill( ) async def timestamp_to_event( - self, destination: str, room_id: str, timestamp: int, direction: str + self, destination: str, room_id: str, timestamp: int, direction: Direction ) -> Union[JsonDict, List]: """ Calls a remote federating server at `destination` asking for their @@ -176,7 +180,7 @@ async def timestamp_to_event( room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -190,7 +194,7 @@ async def timestamp_to_event( room_id, ) - args = {"ts": [str(timestamp)], "dir": [direction]} + args = {"ts": [str(timestamp)], "dir": [direction.value]} remote_response = await self.client.get_json( destination, path=path, args=args, try_trailing_slash_on_400=True @@ -351,12 +355,16 @@ async def send_join_v2( room_id: str, event_id: str, content: JsonDict, + omit_members: bool, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) query_params: Dict[str, str] = {} if self._faster_joins_enabled: # lazy-load state on join - query_params["org.matrix.msc3706.partial_state"] = "true" + query_params["org.matrix.msc3706.partial_state"] = ( + "true" if omit_members else "false" + ) + query_params["omit_members"] = "true" if omit_members else "false" return await self.client.put_json( destination=destination, @@ -794,7 +802,7 @@ class SendJoinResponse: event: Optional[EventBase] = None # The room state is incomplete - partial_state: bool = False + members_omitted: bool = False # List of servers in the room servers_in_room: Optional[List[str]] = None @@ -834,16 +842,18 @@ def _event_list_parser( @ijson.coroutine -def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]: +def _members_omitted_parser(response: SendJoinResponse) -> Generator[None, Any, None]: """Helper function for use with `ijson.items_coro` - Parses the partial_state field in send_join responses + Parses the members_omitted field in send_join responses """ while True: val = yield if not isinstance(val, bool): - raise TypeError("partial_state must be a boolean") - response.partial_state = val + raise TypeError( + "members_omitted (formerly org.matrix.msc370c.partial_state) must be a boolean" + ) + response.members_omitted = val @ijson.coroutine @@ -904,11 +914,19 @@ def __init__(self, room_version: RoomVersion, v1_api: bool): if not v1_api: self._coros.append( ijson.items_coro( - _partial_state_parser(self._response), + _members_omitted_parser(self._response), "org.matrix.msc3706.partial_state", use_float="True", ) ) + # The stable field name comes last, so it "wins" if the fields disagree + self._coros.append( + ijson.items_coro( + _members_omitted_parser(self._response), + "members_omitted", + use_float="True", + ) + ) self._coros.append( ijson.items_coro( @@ -918,6 +936,15 @@ def __init__(self, room_version: RoomVersion, v1_api: bool): ) ) + # Again, stable field name comes last + self._coros.append( + ijson.items_coro( + _servers_in_room_parser(self._response), + "servers_in_room", + use_float="True", + ) + ) + def write(self, data: bytes) -> int: for c in self._coros: c.send(data) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 53e77b4bb62b..f7ca87adc491 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -26,7 +26,7 @@ from typing_extensions import Literal -from synapse.api.constants import EduTypes +from synapse.api.constants import Direction, EduTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX @@ -234,9 +234,10 @@ async def on_GET( room_id: str, ) -> Tuple[int, JsonDict]: timestamp = parse_integer_from_args(query, "ts", required=True) - direction = parse_string_from_args( - query, "dir", default="f", allowed_values=["f", "b"], required=True + direction_str = parse_string_from_args( + query, "dir", allowed_values=["f", "b"], required=True ) + direction = Direction(direction_str) return await self.handler.on_timestamp_to_event_request( origin, room_id, timestamp, direction @@ -422,7 +423,7 @@ def __init__( server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._msc3706_enabled = hs.config.experimental.msc3706_enabled + self._read_msc3706_query_param = hs.config.experimental.msc3706_enabled async def on_PUT( self, @@ -436,10 +437,16 @@ async def on_PUT( # match those given in content partial_state = False - if self._msc3706_enabled: + # The stable query parameter wins, if it disagrees with the unstable + # parameter for some reason. + stable_param = parse_boolean_from_args(query, "omit_members", default=None) + if stable_param is not None: + partial_state = stable_param + elif self._read_msc3706_query_param: partial_state = parse_boolean_from_args( query, "org.matrix.msc3706.partial_state", default=False ) + result = await self.handler.on_send_join_request( origin, content, room_id, caller_supports_partial_state=partial_state ) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index aba7315cf730..67e789eef70e 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,8 +14,9 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( ReplicationAddRoomAccountDataRestServlet, ReplicationAddTagRestServlet, @@ -25,7 +26,7 @@ ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -313,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - def get_current_key(self, direction: str = "f") -> int: + def get_current_key(self) -> int: return self.store.get_max_account_data_stream_id() async def get_new_events( @@ -321,7 +322,7 @@ async def get_new_events( user: UserID, from_key: int, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: @@ -335,7 +336,11 @@ async def get_new_events( for room_id, room_tags in tags.items(): results.append( - {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + { + "type": AccountDataTypes.TAG, + "content": {"tags": room_tags}, + "room_id": room_id, + } ) ( diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5bf8e863875b..b03c214b145a 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client @@ -30,6 +30,7 @@ class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._device_handler = hs.get_device_handler() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state self._msc3866_enabled = hs.config.experimental.msc3866.enabled @@ -197,7 +198,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> # efficient method perhaps but it does guarantee we get everything. while True: events, _ = await self.store.paginate_room_events( - room_id, from_key, to_key, limit=100, direction="f" + room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: break @@ -247,6 +248,21 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> ) writer.write_state(room_id, event_id, state) + # Get the user profile + profile = await self.get_user(UserID.from_string(user_id)) + if profile is not None: + writer.write_profile(profile) + + # Get all devices the user has + devices = await self._device_handler.get_devices_by_user(user_id) + writer.write_devices(devices) + + # Get all connections the user has + connections = await self.get_whois(UserID.from_string(user_id)) + writer.write_connections( + connections["devices"][""]["sessions"][0]["connections"] + ) + return writer.finished() @@ -297,6 +313,33 @@ def write_knock( """ raise NotImplementedError() + @abc.abstractmethod + def write_profile(self, profile: JsonDict) -> None: + """Write the profile of a user. + + Args: + profile: The user profile. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_devices(self, devices: List[JsonDict]) -> None: + """Write the devices of a user. + + Args: + devices: The list of devices. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_connections(self, connections: List[JsonDict]) -> None: + """Write the connections of a user. + + Args: + connections: The list of connections / sessions. + """ + raise NotImplementedError() + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 89864e111941..6f7963df43ae 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -18,7 +18,6 @@ from typing import ( TYPE_CHECKING, Any, - Collection, Dict, Iterable, List, @@ -45,6 +44,7 @@ ) from synapse.types import ( JsonDict, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -146,7 +146,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict: @cancellable async def get_device_changes_in_shared_rooms( - self, user_id: str, room_ids: Collection[str], from_token: StreamToken + self, user_id: str, room_ids: StrCollection, from_token: StreamToken ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. @@ -346,6 +346,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_sender = hs.get_federation_sender() + self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() self.device_list_updater = DeviceListUpdater(hs, self) @@ -502,7 +503,7 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: else: raise - # Delete access tokens and e2e keys for each device. Not optimised as it is not + # Delete data specific to each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: await self._auth_handler.delete_access_tokens_for_user( @@ -512,6 +513,14 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: user_id=user_id, device_id=device_id ) + if self.hs.config.experimental.msc3890_enabled: + # Remove any local notification settings for this device in accordance + # with MSC3890. + await self._account_data_handler.remove_account_data_for_user( + user_id, + f"org.matrix.msc3890.local_notification_settings.{device_id}", + ) + await self.notify_device_update(user_id, device_ids) async def update_device(self, user_id: str, device_id: str, content: dict) -> None: @@ -542,7 +551,7 @@ async def update_device(self, user_id: str, device_id: str, content: dict) -> No @trace @measure_func("notify_device_update") async def notify_device_update( - self, user_id: str, device_ids: Collection[str] + self, user_id: str, device_ids: StrCollection ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. @@ -850,6 +859,7 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None: known_hosts_at_join = await self.store.get_partial_state_servers_at_join( room_id ) + assert known_hosts_at_join is not None potentially_changed_hosts.difference_update(known_hosts_at_join) potentially_changed_hosts.discard(self.server_name) @@ -965,6 +975,7 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler + self._notifier = hs.get_notifier() self._remote_edu_linearizer = Linearizer(name="remote_device_list") @@ -1045,6 +1056,7 @@ async def incoming_device_list_update( user_id, device_id, ) + self._notifier.notify_replication() room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index f91dbbecb79c..a23a8ce2a167 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,7 @@ ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap, get_domain_from_id +from synapse.types import StateMap, StrCollection, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,7 +290,7 @@ async def has_restricted_join_rules( async def get_rooms_that_allow_join( self, state_ids: StateMap[str] - ) -> Collection[str]: + ) -> StrCollection: """ Generate a list of rooms in which membership allows access to a room. @@ -331,7 +331,7 @@ async def get_rooms_that_allow_join( return result - async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool: + async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool: """ Check whether a user is a member of any of the provided rooms. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eca75f1108d1..7f64130e0aa1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -22,11 +22,12 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, - Collection, + AbstractSet, Dict, Iterable, List, Optional, + Set, Tuple, Union, ) @@ -47,7 +48,6 @@ FederationError, FederationPullAttemptBackoffError, HttpResponseException, - LimitExceededError, NotFoundError, RequestSendFailed, SynapseError, @@ -70,7 +70,7 @@ ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -171,12 +171,29 @@ def __init__(self, hs: "HomeServer"): self.third_party_event_rules = hs.get_third_party_event_rules() + # Tracks running partial state syncs by room ID. + # Partial state syncs currently only run on the main process, so it's okay to + # track them in-memory for now. + self._active_partial_state_syncs: Set[str] = set() + # Tracks partial state syncs we may want to restart. + # A dictionary mapping room IDs to (initial destination, other destinations) + # tuples. + self._partial_state_syncs_maybe_needing_restart: Dict[ + str, Tuple[Optional[str], AbstractSet[str]] + ] = {} + # A lock guarding the partial state flag for rooms. + # When the lock is held for a given room, no other concurrent code may + # partial state or un-partial state the room. + self._is_partial_state_room_linearizer = Linearizer( + name="_is_partial_state_room_linearizer" + ) + # if this is the main process, fire off a background process to resume # any partial-state-resync operations which were in flight when we # were shut down. if not hs.config.worker.worker_app: run_as_background_process( - "resume_sync_partial_state_room", self._resume_sync_partial_state_room + "resume_sync_partial_state_room", self._resume_partial_state_room_sync ) @trace @@ -420,7 +437,7 @@ async def _maybe_backfill_inner( ) ) - async def try_backfill(domains: Collection[str]) -> bool: + async def try_backfill(domains: StrCollection) -> bool: # TODO: Should we try multiple of these at a time? # Number of contacted remote homeservers that have denied our backfill @@ -587,7 +604,23 @@ async def do_invite_join( self._federation_event_handler.room_queues[room_id] = [] - await self._clean_room_for_join(room_id) + is_host_joined = await self.store.is_host_joined(room_id, self.server_name) + + if not is_host_joined: + # We may have old forward extremities lying around if the homeserver left + # the room completely in the past. Clear them out. + # + # Note that this check-then-clear is subject to races where + # * the homeserver is in the room and stops being in the room just after + # the check. We won't reset the forward extremities, but that's okay, + # since they will be almost up to date. + # * the homeserver is not in the room and starts being in the room just + # after the check. This can't happen, since `RoomMemberHandler` has a + # linearizer lock which prevents concurrent remote joins into the same + # room. + # In short, the races either have an acceptable outcome or should be + # impossible. + await self._clean_room_for_join(room_id) try: # Try the host we successfully got a response to /make_join/ @@ -599,93 +632,115 @@ async def do_invite_join( except ValueError: pass - ret = await self.federation_client.send_join( - host_list, event, room_version_obj - ) - - event = ret.event - origin = ret.origin - state = ret.state - auth_chain = ret.auth_chain - auth_chain.sort(key=lambda e: e.depth) - - logger.debug("do_invite_join auth_chain: %s", auth_chain) - logger.debug("do_invite_join state: %s", state) - - logger.debug("do_invite_join event: %s", event) + async with self._is_partial_state_room_linearizer.queue(room_id): + already_partial_state_room = await self.store.is_partial_state_room( + room_id + ) - # if this is the first time we've joined this room, it's time to add - # a row to `rooms` with the correct room version. If there's already a - # row there, we should override it, since it may have been populated - # based on an invite request which lied about the room version. - # - # federation_client.send_join has already checked that the room - # version in the received create event is the same as room_version_obj, - # so we can rely on it now. - # - await self.store.upsert_room_on_join( - room_id=room_id, - room_version=room_version_obj, - state_events=state, - ) + ret = await self.federation_client.send_join( + host_list, + event, + room_version_obj, + # Perform a full join when we are already in the room and it is a + # full state room, since we are not allowed to persist a partial + # state join event in a full state room. In the future, we could + # optimize this by always performing a partial state join and + # computing the state ourselves or retrieving it from the remote + # homeserver if necessary. + # + # There's a race where we leave the room, then perform a full join + # anyway. This should end up being fast anyway, since we would + # already have the full room state and auth chain persisted. + partial_state=not is_host_joined or already_partial_state_room, + ) - if ret.partial_state: - # Mark the room as having partial state. - # The background process is responsible for unmarking this flag, - # even if the join fails. - await self.store.store_partial_state_room( + event = ret.event + origin = ret.origin + state = ret.state + auth_chain = ret.auth_chain + auth_chain.sort(key=lambda e: e.depth) + + logger.debug("do_invite_join auth_chain: %s", auth_chain) + logger.debug("do_invite_join state: %s", state) + + logger.debug("do_invite_join event: %s", event) + + # if this is the first time we've joined this room, it's time to add + # a row to `rooms` with the correct room version. If there's already a + # row there, we should override it, since it may have been populated + # based on an invite request which lied about the room version. + # + # federation_client.send_join has already checked that the room + # version in the received create event is the same as room_version_obj, + # so we can rely on it now. + # + await self.store.upsert_room_on_join( room_id=room_id, - servers=ret.servers_in_room, - device_lists_stream_id=self.store.get_device_stream_token(), - joined_via=origin, + room_version=room_version_obj, + state_events=state, ) - try: - max_stream_id = ( - await self._federation_event_handler.process_remote_join( - origin, - room_id, - auth_chain, - state, - event, - room_version_obj, - partial_state=ret.partial_state, + if ret.partial_state and not already_partial_state_room: + # Mark the room as having partial state. + # The background process is responsible for unmarking this flag, + # even if the join fails. + # TODO(faster_joins): + # We may want to reset the partial state info if it's from an + # old, failed partial state join. + # https://github.com/matrix-org/synapse/issues/13000 + await self.store.store_partial_state_room( + room_id=room_id, + servers=ret.servers_in_room, + device_lists_stream_id=self.store.get_device_stream_token(), + joined_via=origin, ) - ) - except PartialStateConflictError as e: - # The homeserver was already in the room and it is no longer partial - # stated. We ought to be doing a local join instead. Turn the error into - # a 429, as a hint to the client to try again. - # TODO(faster_joins): `_should_perform_remote_join` suggests that we may - # do a remote join for restricted rooms even if we have full state. - logger.error( - "Room %s was un-partial stated while processing remote join.", - room_id, - ) - raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0) - else: - # Record the join event id for future use (when we finish the full - # join). We have to do this after persisting the event to keep foreign - # key constraints intact. - if ret.partial_state: - await self.store.write_partial_state_rooms_join_event_id( - room_id, event.event_id + + try: + max_stream_id = ( + await self._federation_event_handler.process_remote_join( + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, + ) ) - finally: - # Always kick off the background process that asynchronously fetches - # state for the room. - # If the join failed, the background process is responsible for - # cleaning up — including unmarking the room as a partial state room. - if ret.partial_state: - # Kick off the process of asynchronously fetching the state for this - # room. - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, - initial_destination=origin, - other_destinations=ret.servers_in_room, - room_id=room_id, + except PartialStateConflictError: + # This should be impossible, since we hold the lock on the room's + # partial statedness. + logger.error( + "Room %s was un-partial stated while processing remote join.", + room_id, ) + raise + else: + # Record the join event id for future use (when we finish the full + # join). We have to do this after persisting the event to keep + # foreign key constraints intact. + if ret.partial_state and not already_partial_state_room: + # TODO(faster_joins): + # We may want to reset the partial state info if it's from + # an old, failed partial state join. + # https://github.com/matrix-org/synapse/issues/13000 + await self.store.write_partial_state_rooms_join_event_id( + room_id, event.event_id + ) + finally: + # Always kick off the background process that asynchronously fetches + # state for the room. + # If the join failed, the background process is responsible for + # cleaning up — including unmarking the room as a partial state + # room. + if ret.partial_state: + # Kick off the process of asynchronously fetching the state for + # this room. + self._start_partial_state_room_sync( + initial_destination=origin, + other_destinations=ret.servers_in_room, + room_id=room_id, + ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. @@ -1660,24 +1715,104 @@ async def get_room_complexity( # well. return None - async def _resume_sync_partial_state_room(self) -> None: + async def _resume_partial_state_room_sync(self) -> None: """Resumes resyncing of all partial-state rooms after a restart.""" assert not self.config.worker.worker_app partial_state_rooms = await self.store.get_partial_state_room_resync_info() for room_id, resync_info in partial_state_rooms.items(): - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, + self._start_partial_state_room_sync( initial_destination=resync_info.joined_via, other_destinations=resync_info.servers_in_room, room_id=room_id, ) + def _start_partial_state_room_sync( + self, + initial_destination: Optional[str], + other_destinations: AbstractSet[str], + room_id: str, + ) -> None: + """Starts the background process to resync the state of a partial state room, + if it is not already running. + + Args: + initial_destination: the initial homeserver to pull the state from + other_destinations: other homeservers to try to pull the state from, if + `initial_destination` is unavailable + room_id: room to be resynced + """ + + async def _sync_partial_state_room_wrapper() -> None: + if room_id in self._active_partial_state_syncs: + # Another local user has joined the room while there is already a + # partial state sync running. This implies that there is a new join + # event to un-partial state. We might find ourselves in one of a few + # scenarios: + # 1. There is an existing partial state sync. The partial state sync + # un-partial states the new join event before completing and all is + # well. + # 2. Before the latest join, the homeserver was no longer in the room + # and there is an existing partial state sync from our previous + # membership of the room. The partial state sync may have: + # a) succeeded, but not yet terminated. The room will not be + # un-partial stated again unless we restart the partial state + # sync. + # b) failed, because we were no longer in the room and remote + # homeservers were refusing our requests, but not yet + # terminated. After the latest join, remote homeservers may + # start answering our requests again, so we should restart the + # partial state sync. + # In the cases where we would want to restart the partial state sync, + # the room would have the partial state flag when the partial state sync + # terminates. + self._partial_state_syncs_maybe_needing_restart[room_id] = ( + initial_destination, + other_destinations, + ) + return + + self._active_partial_state_syncs.add(room_id) + + try: + await self._sync_partial_state_room( + initial_destination=initial_destination, + other_destinations=other_destinations, + room_id=room_id, + ) + finally: + # Read the room's partial state flag while we still hold the claim to + # being the active partial state sync (so that another partial state + # sync can't come along and mess with it under us). + # Normally, the partial state flag will be gone. If it isn't, then we + # may find ourselves in scenario 2a or 2b as described in the comment + # above, where we want to restart the partial state sync. + is_still_partial_state_room = await self.store.is_partial_state_room( + room_id + ) + self._active_partial_state_syncs.remove(room_id) + + if room_id in self._partial_state_syncs_maybe_needing_restart: + ( + restart_initial_destination, + restart_other_destinations, + ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id) + + if is_still_partial_state_room: + self._start_partial_state_room_sync( + initial_destination=restart_initial_destination, + other_destinations=restart_other_destinations, + room_id=room_id, + ) + + run_as_background_process( + desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper + ) + async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: AbstractSet[str], room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1688,6 +1823,12 @@ async def _sync_partial_state_room( `initial_destination` is unavailable room_id: room to be resynced """ + # Assume that we run on the main process for now. + # TODO(faster_joins,multiple workers) + # When moving the sync to workers, we need to ensure that + # * `_start_partial_state_room_sync` still prevents duplicate resyncs + # * `_is_partial_state_room_linearizer` correctly guards partial state flags + # for rooms between the workers doing remote joins and resync. assert not self.config.worker.worker_app # TODO(faster_joins): do we need to lock to avoid races? What happens if other @@ -1725,20 +1866,19 @@ async def _sync_partial_state_room( logger.info("Handling any pending device list updates") await self._device_handler.handle_room_un_partial_stated(room_id) - logger.info("Clearing partial-state flag for %s", room_id) - success = await self.store.clear_partial_state_room(room_id) - if success: + async with self._is_partial_state_room_linearizer.queue(room_id): + logger.info("Clearing partial-state flag for %s", room_id) + new_stream_id = await self.store.clear_partial_state_room(room_id) + + if new_stream_id is not None: logger.info("State resync complete for %s", room_id) self._storage_controllers.state.notify_room_un_partial_stated( room_id ) - # Poke the notifier so that other workers see the write to - # the un-partial-stated rooms stream. - self._notifier.notify_replication() - # TODO(faster_joins) update room stats and user directory? - # https://github.com/matrix-org/synapse/issues/12814 - # https://github.com/matrix-org/synapse/issues/12815 + await self._notifier.on_un_partial_stated_room( + room_id, new_stream_id + ) return # we raced against more events arriving with partial state. Go round @@ -1809,9 +1949,9 @@ async def _sync_partial_state_room( def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: AbstractSet[str], room_id: str, -) -> Collection[str]: +) -> StrCollection: """Work out the order in which we should ask servers to resync events. If an `initial_destination` is given, it takes top priority. Otherwise diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6df000faafed..e037acbca2bf 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -80,6 +80,7 @@ PersistedEventPosition, RoomStreamToken, StateMap, + StrCollection, UserID, get_domain_from_id, ) @@ -615,7 +616,7 @@ async def update_state_for_partial_state_event( @trace async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Collection[str] + self, dest: str, room_id: str, limit: int, extremities: StrCollection ) -> None: """Trigger a backfill request to `dest` for the given `room_id` @@ -1565,7 +1566,7 @@ async def backfill_event_id( @trace @tag_args async def _get_events_and_persist( - self, destination: str, room_id: str, event_ids: Collection[str] + self, destination: str, room_id: str, event_ids: StrCollection ) -> None: """Fetch the given events from a server, and persist them as outliers. @@ -2259,6 +2260,10 @@ async def persist_events_and_notify( event_and_contexts, backfilled=backfilled ) + # After persistence we always need to notify replication there may + # be new data. + self._notifier.notify_replication() + if self._ephemeral_messages_enabled: for event in events: # If there's an expiry timestamp on the event, schedule its expiry. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9c335e6863f4..191529bd8e27 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + Direction, + EduTypes, + EventTypes, + Membership, +) from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -57,7 +63,13 @@ def __init__(self, hs: "HomeServer"): self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool + str, + Optional[StreamToken], + Optional[StreamToken], + Direction, + int, + bool, + bool, ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -239,7 +251,7 @@ async def handle_room(event: RoomsForUser) -> None: tags = tags_by_room.get(event.room_id) if tags: account_data_events.append( - {"type": "m.tag", "content": {"tags": tags}} + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} ) account_data = account_data_by_room.get(event.room_id, {}) @@ -326,7 +338,9 @@ async def room_initial_sync( account_data_events = [] tags = await self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) + account_data_events.append( + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} + ) account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f275c527231e..0b2d23706ae6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -377,7 +377,7 @@ def maybe_schedule_expiry(self, event: EventBase) -> None: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if not isinstance(expiry_ts, int) or event.is_state(): + if type(expiry_ts) is not int or event.is_state(): return # _schedule_expiry_for_event won't actually schedule anything if there's already @@ -1540,17 +1540,28 @@ async def cache_joined_hosts_for_events( external federation senders don't have to recalculate it themselves. """ - for event, _ in events_and_context: - if not self._external_cache.is_enabled(): - return + if not self._external_cache.is_enabled(): + return + + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None + for event, event_context in events_and_context: # Beeper hack: we don't need joined hosts for bridged events because # we don't federate them. if event.sender.startswith("@_"): return - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + if event_context.partial_state: + # To populate the cache for a partial-state event, we either have to + # block until full state, which the code below does, or change the + # meaning of cache values to be the list of hosts to which we plan to + # send events and calculate that instead. + # + # The federation senders don't use the external cache when sending + # events in partial-state rooms anyway, so let's not bother populating + # the cache. + continue # We actually store two mappings, event ID -> prev state group, # state group -> joined hosts, which is much more space efficient @@ -1948,7 +1959,9 @@ async def persist_and_notify_client_events( if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. - run_in_background(self._bump_active_time, requester.user) + run_as_background_process( + "bump_presence_active_time", self._bump_active_time, requester.user + ) async def _notify() -> None: try: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 89a947762101..0b5fb889f92e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,14 +14,14 @@ # limitations under the License. import logging import time -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import attr from prometheus_client import Histogram from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig @@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string @@ -395,7 +395,7 @@ def get_delete_status(self, delete_id: str) -> Optional[DeleteStatus]: """ return self._delete_by_id.get(delete_id) - def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]: """Get all active delete ids by room Args: @@ -460,7 +460,7 @@ async def get_messages( if pagin_config.from_token: from_token = pagin_config.from_token - elif pagin_config.direction == "f": + elif pagin_config.direction == Direction.FORWARDS: from_token = ( await self.hs.get_event_sources().get_start_token_for_pagination( room_id @@ -488,7 +488,7 @@ async def get_messages( room_id, requester, allow_departed_users=True ) - if pagin_config.direction == "b": + if pagin_config.direction == Direction.BACKWARDS: # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2af90b25a39c..87af31aa2706 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -64,7 +64,13 @@ from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + StrCollection, + StreamKeyType, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -320,7 +326,7 @@ async def maybe_send_presence_to_interested_destinations( for destination, host_states in hosts_to_states.items(): self._federation.send_presence_to_destinations(host_states, [destination]) - async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None: + async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ Adds to the list of users who should receive a full snapshot of presence upon their next sync. Note that this only works for local users. @@ -1601,7 +1607,7 @@ async def get_new_events( # Having a default limit doesn't match the EventSource API, but some # callers do not provide it. It is unused in this class. limit: int = 0, - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, @@ -1688,7 +1694,7 @@ async def get_new_events( # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users: Collection[str] + interested_and_updated_users: StrCollection if from_key is not None: # First get all users that have had a presence update @@ -2120,7 +2126,7 @@ def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler): # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] + self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = [] self._next_id = 1 @@ -2142,7 +2148,7 @@ def _clear_queue(self) -> None: self._queue = self._queue[index:] def send_presence_to_destinations( - self, states: Collection[UserPresenceState], destinations: Collection[str] + self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. @@ -2155,6 +2161,11 @@ def send_presence_to_destinations( # This should only be called on a presence writer. assert self._presence_writer + if not states or not destinations: + # Ignore calls which either don't have any new states or don't need + # to be sent anywhere. + return + if self._federation: self._federation.send_presence_to_destinations( states=states, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 57c7ca2f1bc2..c9be10a574f5 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -316,5 +316,5 @@ async def get_new_events_as( return events, to_key - def get_current_key(self, direction: str = "f") -> int: + def get_current_key(self) -> int: return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index e96f9999a8d6..0fb15391e07e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ import attr -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -413,7 +413,11 @@ async def _get_threads_for_events( # Attempt to find another event to use as the latest event. potential_events, _ = await self._main_store.get_relations_for_event( - event_id, event, room_id, RelationTypes.THREAD, direction="f" + event_id, + event, + room_id, + RelationTypes.THREAD, + direction=Direction.FORWARDS, ) # Filter out ignored users. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index fdd59b76147e..7b7f3b088f5d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,16 +21,7 @@ import time from collections import OrderedDict from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Collection, - Dict, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple import attr from prometheus_client import Histogram @@ -38,6 +29,7 @@ import synapse.events.snapshot from synapse.api.constants import ( + Direction, EventContentFields, EventTypes, GuestAccess, @@ -74,6 +66,7 @@ RoomID, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1514,7 +1507,7 @@ async def get_event_for_timestamp( requester: Requester, room_id: str, timestamp: int, - direction: str, + direction: Direction, ) -> Tuple[str, int]: """Find the closest event to the given timestamp in the given direction. If we can't find an event locally or the event we have locally is next to a gap, @@ -1525,7 +1518,7 @@ async def get_event_for_timestamp( room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -1560,13 +1553,13 @@ async def get_event_for_timestamp( local_event_id, allow_none=False, allow_rejected=False ) - if direction == "f": + if direction == Direction.FORWARDS: # We only need to check for a backward gap if we're looking forwards # to ensure there is nothing in between. is_event_next_to_backward_gap = ( await self.store.is_event_next_to_backward_gap(local_event) ) - elif direction == "b": + elif direction == Direction.BACKWARDS: # We only need to check for a forward gap if we're looking backwards # to ensure there is nothing in between is_event_next_to_forward_gap = ( @@ -1663,7 +1656,7 @@ async def get_new_events( user: UserID, from_key: RoomStreamToken, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index c6b869c6f44e..4472019fbcd2 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,7 +36,7 @@ ) from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -870,7 +870,7 @@ class _RoomQueueEntry: # The room ID of this entry. room_id: str # The server to query if the room is not known locally. - via: Sequence[str] + via: StrCollection # The minimum number of hops necessary to get to this room (compared to the # originally requested room). depth: int = 0 diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 40f4635c4e29..9bbf83047dd6 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,7 +14,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from unpaddedbase64 import decode_base64, encode_base64 @@ -23,7 +23,7 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -418,7 +418,7 @@ async def _search( async def _search_by_rank( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, @@ -491,7 +491,7 @@ async def _search_by_rank( async def _search_by_recent( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 44e70fc4b874..4a27c0f05142 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -20,7 +20,6 @@ Any, Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -47,6 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.types import ( JsonDict, + StrCollection, UserID, contains_invalid_mxid_characters, create_requester, @@ -141,7 +141,8 @@ class UserAttributes: confirm_localpart: bool = False display_name: Optional[str] = None picture: Optional[str] = None - emails: Collection[str] = attr.Factory(list) + # mypy thinks these are incompatible for some reason. + emails: StrCollection = attr.Factory(list) # type: ignore[assignment] @attr.s(slots=True, auto_attribs=True) @@ -159,7 +160,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name: Optional[str] - emails: Collection[str] + emails: StrCollection # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -174,7 +175,7 @@ class UsernameMappingSession: # choices made by the user chosen_localpart: Optional[str] = None use_display_name: bool = True - emails_to_use: Collection[str] = () + emails_to_use: StrCollection = () terms_accepted_version: Optional[str] = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9cbe9c13ba9d..a8fa06538aed 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,7 +17,6 @@ TYPE_CHECKING, AbstractSet, Any, - Collection, Dict, FrozenSet, List, @@ -31,7 +30,12 @@ import attr from prometheus_client import Counter -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -57,6 +61,7 @@ Requester, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -285,7 +290,7 @@ def __init__(self, hs: "HomeServer"): expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync async def wait_for_sync_for_user( self, @@ -1191,7 +1196,7 @@ async def compute_state_delta( async def _find_missing_partial_state_memberships( self, room_id: str, - members_to_fetch: Collection[str], + members_to_fetch: StrCollection, events_with_membership_auth: Mapping[str, EventBase], found_state_ids: StateMap[str], ) -> StateMap[str]: @@ -1379,7 +1384,10 @@ async def generate_sync_result( membership_change_events = [] if since_token: membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude + user_id, + since_token.room_key, + now_token.room_key, + self.rooms_to_exclude_globally, ) mem_last_change_by_room_id: Dict[str, EventBase] = {} @@ -1417,12 +1425,49 @@ async def generate_sync_result( if user_id in user_ids_in_room: mutable_joined_room_ids.add(room_id) + # Tweak the set of rooms to return to the client for eager (non-lazy) syncs. + mutable_rooms_to_exclude = set(self.rooms_to_exclude_globally) + if not sync_config.filter_collection.lazy_load_members(): + # Non-lazy syncs should never include partially stated rooms. + # Exclude all partially stated rooms from this sync. + results = await self.store.is_partial_state_room_batched( + mutable_joined_room_ids + ) + mutable_rooms_to_exclude.update( + room_id + for room_id, is_partial_state in results.items() + if is_partial_state + ) + + # Incremental eager syncs should additionally include rooms that + # - we are joined to + # - are full-stated + # - became fully-stated at some point during the sync period + # (These rooms will have been omitted during a previous eager sync.) + forced_newly_joined_room_ids: Set[str] = set() + if since_token and not sync_config.filter_collection.lazy_load_members(): + un_partial_stated_rooms = ( + await self.store.get_un_partial_stated_rooms_between( + since_token.un_partial_stated_rooms_key, + now_token.un_partial_stated_rooms_key, + mutable_joined_room_ids, + ) + ) + results = await self.store.is_partial_state_room_batched( + un_partial_stated_rooms + ) + forced_newly_joined_room_ids.update( + room_id + for room_id, is_partial_state in results.items() + if not is_partial_state + ) + # Now we have our list of joined room IDs, exclude as configured and freeze joined_room_ids = frozenset( ( room_id for room_id in mutable_joined_room_ids - if room_id not in self.rooms_to_exclude + if room_id not in mutable_rooms_to_exclude ) ) @@ -1452,6 +1497,8 @@ async def generate_sync_result( since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, + excluded_room_ids=frozenset(mutable_rooms_to_exclude), + forced_newly_joined_room_ids=frozenset(forced_newly_joined_room_ids), membership_change_events=membership_change_events, ) @@ -1461,39 +1508,67 @@ async def generate_sync_result( sync_result_builder ) - logger.debug("Fetching room data") - - ( - newly_joined_rooms, - newly_joined_or_invited_or_knocked_users, - newly_left_rooms, - newly_left_users, - ) = await self._generate_sync_entry_for_rooms( - sync_result_builder, account_data_by_room + # Presence data is included if the server has it enabled and not filtered out. + include_presence_data = bool( + self.hs_config.server.use_presence + and not sync_config.filter_collection.blocks_all_presence() ) + # Device list updates are sent if a since token is provided. + include_device_list_updates = bool(since_token and since_token.device_list_key) - block_all_presence_data = ( - since_token is None and sync_config.filter_collection.blocks_all_presence() - ) - if self.hs_config.server.use_presence and not block_all_presence_data: - logger.debug("Fetching presence data") - await self._generate_sync_entry_for_presence( - sync_result_builder, + # If we do not care about the rooms or things which depend on the room + # data (namely presence and device list updates), then we can skip + # this process completely. + device_lists = DeviceListUpdates() + if ( + not sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + or include_presence_data + or include_device_list_updates + ): + logger.debug("Fetching room data") + + # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which + # is used in calculate_user_changes below. + ( newly_joined_rooms, - newly_joined_or_invited_or_knocked_users, + newly_left_rooms, + ) = await self._generate_sync_entry_for_rooms( + sync_result_builder, account_data_by_room ) + # Work out which users have joined or left rooms we're in. We use this + # to build the presence and device_list parts of the sync response in + # `_generate_sync_entry_for_presence` and + # `_generate_sync_entry_for_device_list` respectively. + if include_presence_data or include_device_list_updates: + # This uses the sync_result_builder.joined which is set in + # `_generate_sync_entry_for_rooms`, if that didn't find any joined + # rooms for some reason it is a no-op. + ( + newly_joined_or_invited_or_knocked_users, + newly_left_users, + ) = sync_result_builder.calculate_user_changes() + + if include_presence_data: + logger.debug("Fetching presence data") + await self._generate_sync_entry_for_presence( + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, + ) + + if include_device_list_updates: + device_lists = await self._generate_sync_entry_for_device_list( + sync_result_builder, + newly_joined_rooms=newly_joined_rooms, + newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, + newly_left_rooms=newly_left_rooms, + newly_left_users=newly_left_users, + ) + logger.debug("Fetching to-device data") await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = await self._generate_sync_entry_for_device_list( - sync_result_builder, - newly_joined_rooms=newly_joined_rooms, - newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, - newly_left_rooms=newly_left_rooms, - newly_left_users=newly_left_users, - ) - logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_keys_count: JsonDict = {} @@ -1562,6 +1637,7 @@ async def _generate_sync_entry_for_device_list( user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token + assert since_token is not None # Take a copy since these fields will be mutated later. newly_joined_or_invited_or_knocked_users = set( @@ -1569,92 +1645,85 @@ async def _generate_sync_entry_for_device_list( ) newly_left_users = set(newly_left_users) - if since_token and since_token.device_list_key: - # We want to figure out what user IDs the client should refetch - # device keys for, and which users we aren't going to track changes - # for anymore. - # - # For the first step we check: - # a. if any users we share a room with have updated their devices, - # and - # b. we also check if we've joined any new rooms, or if a user has - # joined a room we're in. - # - # For the second step we just find any users we no longer share a - # room with by looking at all users that have left a room plus users - # that were in a room we've left. - - users_that_have_changed = set() + # We want to figure out what user IDs the client should refetch + # device keys for, and which users we aren't going to track changes + # for anymore. + # + # For the first step we check: + # a. if any users we share a room with have updated their devices, + # and + # b. we also check if we've joined any new rooms, or if a user has + # joined a room we're in. + # + # For the second step we just find any users we no longer share a + # room with by looking at all users that have left a room plus users + # that were in a room we've left. - joined_rooms = sync_result_builder.joined_room_ids + users_that_have_changed = set() - # Step 1a, check for changes in devices of users we share a room - # with - # - # We do this in two different ways depending on what we have cached. - # If we already have a list of all the user that have changed since - # the last sync then it's likely more efficient to compare the rooms - # they're in with the rooms the syncing user is in. - # - # If we don't have that info cached then we get all the users that - # share a room with our user and check if those users have changed. - cache_result = self.store.get_cached_device_list_changes( - since_token.device_list_key - ) - if cache_result.hit: - changed_users = cache_result.entities - - result = await self.store.get_rooms_for_users(changed_users) - - for changed_user_id, entries in result.items(): - # Check if the changed user shares any rooms with the user, - # or if the changed user is the syncing user (as we always - # want to include device list updates of their own devices). - if user_id == changed_user_id or any( - rid in joined_rooms for rid in entries - ): - users_that_have_changed.add(changed_user_id) - else: - users_that_have_changed = ( - await self._device_handler.get_device_changes_in_shared_rooms( - user_id, - sync_result_builder.joined_room_ids, - from_token=since_token, - ) - ) + joined_rooms = sync_result_builder.joined_room_ids - # Step 1b, check for newly joined rooms - for room_id in newly_joined_rooms: - joined_users = await self.store.get_users_in_room(room_id) - newly_joined_or_invited_or_knocked_users.update(joined_users) + # Step 1a, check for changes in devices of users we share a room + # with + # + # We do this in two different ways depending on what we have cached. + # If we already have a list of all the user that have changed since + # the last sync then it's likely more efficient to compare the rooms + # they're in with the rooms the syncing user is in. + # + # If we don't have that info cached then we get all the users that + # share a room with our user and check if those users have changed. + cache_result = self.store.get_cached_device_list_changes( + since_token.device_list_key + ) + if cache_result.hit: + changed_users = cache_result.entities - # TODO: Check that these users are actually new, i.e. either they - # weren't in the previous sync *or* they left and rejoined. - users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) + result = await self.store.get_rooms_for_users(changed_users) - user_signatures_changed = ( - await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key + for changed_user_id, entries in result.items(): + # Check if the changed user shares any rooms with the user, + # or if the changed user is the syncing user (as we always + # want to include device list updates of their own devices). + if user_id == changed_user_id or any( + rid in joined_rooms for rid in entries + ): + users_that_have_changed.add(changed_user_id) + else: + users_that_have_changed = ( + await self._device_handler.get_device_changes_in_shared_rooms( + user_id, + sync_result_builder.joined_room_ids, + from_token=since_token, ) ) - users_that_have_changed.update(user_signatures_changed) - # Now find users that we no longer track - for room_id in newly_left_rooms: - left_users = await self.store.get_users_in_room(room_id) - newly_left_users.update(left_users) + # Step 1b, check for newly joined rooms + for room_id in newly_joined_rooms: + joined_users = await self.store.get_users_in_room(room_id) + newly_joined_or_invited_or_knocked_users.update(joined_users) + + # TODO: Check that these users are actually new, i.e. either they + # weren't in the previous sync *or* they left and rejoined. + users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) - # Remove any users that we still share a room with. - left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) - for user_id, entries in left_users_rooms.items(): - if any(rid in joined_rooms for rid in entries): - newly_left_users.discard(user_id) + user_signatures_changed = await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) + users_that_have_changed.update(user_signatures_changed) - return DeviceListUpdates( - changed=users_that_have_changed, left=newly_left_users - ) - else: - return DeviceListUpdates() + # Now find users that we no longer track + for room_id in newly_left_rooms: + left_users = await self.store.get_users_in_room(room_id) + newly_left_users.update(left_users) + + # Remove any users that we still share a room with. + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) + for user_id, entries in left_users_rooms.items(): + if any(rid in joined_rooms for rid in entries): + newly_left_users.discard(user_id) + + return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users) @trace async def _generate_sync_entry_for_to_device( @@ -1731,6 +1800,7 @@ async def _generate_sync_entry_for_account_data( since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: + # TODO Do not fetch room account data if it will be unused. ( global_account_data, account_data_by_room, @@ -1747,6 +1817,7 @@ async def _generate_sync_entry_for_account_data( sync_config.user ) else: + # TODO Do not fetch room account data if it will be unused. ( global_account_data, account_data_by_room, @@ -1829,7 +1900,7 @@ async def _generate_sync_entry_for_rooms( self, sync_result_builder: "SyncResultBuilder", account_data_by_room: Dict[str, Dict[str, JsonDict]], - ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]: + ) -> Tuple[AbstractSet[str], AbstractSet[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1842,26 +1913,21 @@ async def _generate_sync_entry_for_rooms( account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple describing rooms the user has joined or left, and users who've - joined or left rooms any rooms the user is in. This gets used later in - `_generate_sync_entry_for_device_list`. + Returns a 2-tuple describing rooms the user has joined or left. Its entries are: - newly_joined_rooms - - newly_joined_or_invited_or_knocked_users - newly_left_rooms - - newly_left_users """ since_token = sync_result_builder.since_token + user_id = sync_result_builder.sync_config.user.to_string() # 1. Start by fetching all ephemeral events in rooms we've joined (if required). - user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - since_token is None - and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) - if block_all_room_ephemeral: ephemeral_by_room: Dict[str, List[JsonDict]] = {} else: @@ -1884,19 +1950,21 @@ async def _generate_sync_entry_for_rooms( ) if not tags_by_room: logger.debug("no-oping sync") - return set(), set(), set(), set() + return set(), set() # 3. Work out which rooms need reporting in the sync response. ignored_users = await self.store.ignored_users(user_id) if since_token: - room_changes = await self._get_rooms_changed( + room_changes = await self._get_room_changes_for_incremental_sync( sync_result_builder, ignored_users ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + room_changes = await self._get_room_changes_for_initial_sync( + sync_result_builder, ignored_users + ) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1911,6 +1979,7 @@ async def _generate_sync_entry_for_rooms( # joined or archived). async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) + # Note that this mutates sync_result_builder.{joined,archived}. await self._generate_room_entry( sync_result_builder, room_entry, @@ -1927,20 +1996,7 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) - # 5. Work out which users have joined or left rooms we're in. We use this - # to build the device_list part of the sync response in - # `_generate_sync_entry_for_device_list`. - ( - newly_joined_or_invited_or_knocked_users, - newly_left_users, - ) = sync_result_builder.calculate_user_changes() - - return ( - set(newly_joined_rooms), - newly_joined_or_invited_or_knocked_users, - set(newly_left_rooms), - newly_left_users, - ) + return set(newly_joined_rooms), set(newly_left_rooms) async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" @@ -1955,7 +2011,7 @@ async def _have_rooms_changed( assert since_token - if membership_change_events: + if membership_change_events or sync_result_builder.forced_newly_joined_room_ids: return True stream_id = since_token.room_key.stream @@ -1964,7 +2020,7 @@ async def _have_rooms_changed( return True return False - async def _get_rooms_changed( + async def _get_room_changes_for_incremental_sync( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str], @@ -2002,7 +2058,9 @@ async def _get_rooms_changed( for event in membership_change_events: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - newly_joined_rooms: List[str] = [] + newly_joined_rooms: List[str] = list( + sync_result_builder.forced_newly_joined_room_ids + ) newly_left_rooms: List[str] = [] room_entries: List[RoomSyncResultBuilder] = [] invited: List[InvitedSyncResult] = [] @@ -2214,7 +2272,7 @@ async def _get_rooms_changed( newly_left_rooms, ) - async def _get_all_rooms( + async def _get_room_changes_for_initial_sync( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str], @@ -2239,7 +2297,7 @@ async def _get_all_rooms( room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, - excluded_rooms=self.rooms_to_exclude, + excluded_rooms=sync_result_builder.excluded_room_ids, ) room_entries = [] @@ -2399,7 +2457,9 @@ async def _generate_room_entry( account_data_events = [] if tags is not None: - account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) + account_data_events.append( + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} + ) for account_data_type, content in account_data.items(): account_data_events.append( @@ -2641,6 +2701,13 @@ class SyncResultBuilder: since_token: The token supplied by user, or None. now_token: The token to sync up to. joined_room_ids: List of rooms the user is joined to + excluded_room_ids: Set of room ids we should omit from the /sync response. + forced_newly_joined_room_ids: + Rooms that should be presented in the /sync response as if they were + newly joined during the sync period, even if that's not the case. + (This is useful if the room was previously excluded from a /sync response, + and now the client should be made aware of it.) + Only used by incremental syncs. # The following mirror the fields in a sync response presence @@ -2657,6 +2724,8 @@ class SyncResultBuilder: since_token: Optional[StreamToken] now_token: StreamToken joined_room_ids: FrozenSet[str] + excluded_room_ids: FrozenSet[str] + forced_newly_joined_room_ids: FrozenSet[str] membership_change_events: List[EventBase] presence: List[UserPresenceState] = attr.Factory(list) diff --git a/synapse/http/client.py b/synapse/http/client.py index f1b73664ac6f..49160a8410f3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -44,6 +44,7 @@ IAddress, IDelayedCall, IHostResolution, + IReactorCore, IReactorPluggableNameResolver, IReactorTime, IResolutionReceiver, @@ -226,7 +227,9 @@ def resolutionComplete() -> None: return recv -@implementer(ISynapseReactor) +# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer +# of IReactorCore seems to keep mypy-zope happier. +@implementer(IReactorCore, ISynapseReactor) class BlacklistingReactorWrapper: """ A Reactor wrapper which will prevent DNS resolution to blacklisted IP diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 18899bc6d18d..94ef737b9ee0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -38,7 +38,6 @@ from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials -from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -84,7 +83,7 @@ class ProxyAgent(_AgentBase): def __init__( self, reactor: IReactorCore, - proxy_reactor: Optional[ISynapseReactor] = None, + proxy_reactor: Optional[IReactorCore] = None, contextFactory: Optional[IPolicyForHTTPS] = None, connectTimeout: Optional[float] = None, bindAddress: Optional[bytes] = None, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index dead02cd5c4f..0070bd2940ea 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -13,6 +13,7 @@ # limitations under the License. """ This module contains base REST classes for constructing REST servlets. """ +import enum import logging from http import HTTPStatus from typing import ( @@ -362,6 +363,7 @@ def parse_string( request: Request, name: str, *, + default: Optional[str] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -413,6 +415,74 @@ def parse_string( ) +EnumT = TypeVar("EnumT", bound=enum.Enum) + + +@overload +def parse_enum( + request: Request, + name: str, + E: Type[EnumT], + default: EnumT, +) -> EnumT: + ... + + +@overload +def parse_enum( + request: Request, + name: str, + E: Type[EnumT], + *, + required: Literal[True], +) -> EnumT: + ... + + +def parse_enum( + request: Request, + name: str, + E: Type[EnumT], + default: Optional[EnumT] = None, + required: bool = False, +) -> Optional[EnumT]: + """ + Parse an enum parameter from the request query string. + + Note that the enum *must only have string values*. + + Args: + request: the twisted HTTP request. + name: the name of the query parameter. + E: the enum which represents valid values + default: enum value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + An enum value. + + Raises: + SynapseError if the parameter is absent and required, or if the + parameter is present, must be one of a list of allowed values and + is not one of those allowed values. + """ + # Assert the enum values are strings. + assert all( + isinstance(e.value, str) for e in E + ), "parse_enum only works with string values" + str_value = parse_string( + request, + name, + default=default.value if default is not None else None, + required=required, + allowed_values=[e.value for e in E], + ) + if str_value is None: + return None + return E(str_value) + + def _parse_string_value( value: bytes, allowed_values: Optional[Iterable[str]], diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index a705af83565d..8ef9a0dda8e5 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -322,6 +322,11 @@ class SynapseTags: # The name of the external cache CACHE_NAME = "cache.name" + # Boolean. Present on /v2/send_join requests, omitted from all others. + # True iff partial state was requested and we provided (or intended to provide) + # partial state in the response. + SEND_JOIN_RESPONSE_IS_PARTIAL_STATE = "send_join.partial_state_response" + # Used to tag function arguments # # Tag a named arg. The name of the argument should be appended to this prefix. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6f4a934b0509..d22dd19d388a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1158,7 +1158,7 @@ async def send_local_online_presence_to(self, users: Iterable[str]) -> None: # Send to remote destinations. destination = UserID.from_string(user).domain presence_handler.get_federation_queue().send_presence_to_destinations( - presence_events, destination + presence_events, [destination] ) def looping_background_call( @@ -1585,6 +1585,33 @@ async def create_room( return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None) + async def set_displayname( + self, + user_id: UserID, + new_displayname: str, + deactivation: bool = False, + ) -> None: + """Sets a user's display name. + + Added in Synapse v1.76.0. + + Args: + user_id: + The user whose display name is to be changed. + new_displayname: + The new display name to give the user. + deactivation: + Whether this change was made while deactivating the user. + """ + requester = create_requester(user_id) + await self._hs.get_profile_handler().set_displayname( + target_user=user_id, + requester=requester, + new_displayname=new_displayname, + by_admin=True, + deactivation=deactivation, + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/notifier.py b/synapse/notifier.py index 26b97cf766c3..a8832a3f8e80 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -46,6 +46,7 @@ JsonDict, PersistedEventPosition, RoomStreamToken, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -226,8 +227,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] - # Called when there are new things to stream over replication - self.replication_callbacks: List[Callable[[], None]] = [] + self._replication_notifier = hs.get_replication_notifier() self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = [] self._federation_client = hs.get_federation_http_client() @@ -279,7 +279,7 @@ def add_replication_callback(self, cb: Callable[[], None]) -> None: it needs to do any asynchronous work, a background thread should be started and wrapped with run_as_background_process. """ - self.replication_callbacks.append(cb) + self._replication_notifier.add_replication_callback(cb) def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None: """Add a callback that will be called when a user joins a room. @@ -315,6 +315,32 @@ async def on_new_room_events( event_entries.append((entry, event.event_id)) await self.notify_new_room_events(event_entries, max_room_stream_token) + async def on_un_partial_stated_room( + self, + room_id: str, + new_token: int, + ) -> None: + """Used by the resync background processes to wake up all listeners + of this room when it is un-partial-stated. + + It will also notify replication listeners of the change in stream. + """ + + # Wake up all related user stream notifiers + user_streams = self.room_to_user_streams.get(room_id, set()) + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify( + StreamKeyType.UN_PARTIAL_STATED_ROOMS, new_token, time_now_ms + ) + except Exception: + logger.exception("Failed to notify listener") + + # Poke the replication so that other workers also see the write to + # the un-partial-stated rooms stream. + self.notify_replication() + async def notify_new_room_events( self, event_entries: List[Tuple[_PendingRoomEventEntry, str]], @@ -691,7 +717,7 @@ async def check_for_updates( async def _get_room_ids( self, user: UserID, explicit_room_id: Optional[str] - ) -> Tuple[Collection[str], bool]: + ) -> Tuple[StrCollection, bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: @@ -741,8 +767,7 @@ def _user_joined_room(self, user_id: str, room_id: str) -> None: def notify_replication(self) -> None: """Notify the any replication listeners that there's a new event""" - for cb in self.replication_callbacks: - cb() + self._replication_notifier.notify_replication() def notify_user_joined_room(self, event_id: str, room_id: str) -> None: for cb in self._new_join_in_room_callbacks: @@ -759,3 +784,26 @@ def notify_remote_server_up(self, server: str) -> None: # Tell the federation client about the fact the server is back up, so # that any in flight requests can be immediately retried. self._federation_client.wake_destination(server) + + +@attr.s(auto_attribs=True) +class ReplicationNotifier: + """Tracks callbacks for things that need to know about stream changes. + + This is separate from the notifier to avoid circular dependencies. + """ + + _replication_callbacks: List[Callable[[], None]] = attr.Factory(list) + + def add_replication_callback(self, cb: Callable[[], None]) -> None: + """Add a callback that will be called when some new data is available. + Callback is not given any arguments. It should *not* return a Deferred - if + it needs to do any asynchronous work, a background thread should be started and + wrapped with run_as_background_process. + """ + self._replication_callbacks.append(cb) + + def notify_replication(self) -> None: + """Notify the any replication listeners that there's a new event""" + for cb in self._replication_callbacks: + cb() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8ba276d61965..92907afbc6b1 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,6 +22,7 @@ List, Mapping, Optional, + Set, Tuple, Union, ) @@ -35,7 +36,7 @@ Membership, RelationTypes, ) -from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion +from synapse.api.room_versions import PushRuleRoomFlag from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -68,6 +69,9 @@ } +SENTINEL = object() + + def _should_count_as_unread( event: EventBase, context: EventContext, @@ -116,6 +120,9 @@ def __init__(self, hs: "HomeServer"): self.should_calculate_push_rules = self.hs.config.push.enable_push self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self._intentional_mentions_enabled = ( + self.hs.config.experimental.msc3952_intentional_mentions + ) self.room_push_rule_cache_metrics = register_cache( "cache", @@ -136,15 +143,34 @@ async def _get_rules_for_event( Returns: Mapping of user ID to their push rules. """ - # We get the users who may need to be notified by first fetching the - # local users currently in the room, finding those that have push rules, - # and *then* checking which users are actually allowed to see the event. - # - # The alternative is to first fetch all users that were joined at the - # event, but that requires fetching the full state at the event, which - # may be expensive for large rooms with few local users. + # If this is a membership event, only calculate push rules for the target. + # While it's possible for users to configure push rules to respond to such an + # event, in practise nobody does this. At the cost of violating the spec a + # little, we can skip fetching a huge number of push rules in large rooms. + # This helps make joins and leaves faster. + if event.type == EventTypes.Member: + local_users = [] + # We never notify a user about their own actions. This is enforced in + # `_action_for_event_by_user` in the loop over `rules_by_user`, but we + # do the same check here to avoid unnecessary DB queries. + if event.sender != event.state_key and self.hs.is_mine_id(event.state_key): + # Check the target is in the room, to avoid notifying them of + # e.g. a pre-emptive ban. + target_already_in_room = await self.store.check_local_user_in_room( + event.state_key, event.room_id + ) + if target_already_in_room: + local_users = [event.state_key] + else: + # We get the users who may need to be notified by first fetching the + # local users currently in the room, finding those that have push rules, + # and *then* checking which users are actually allowed to see the event. + # + # The alternative is to first fetch all users that were joined at the + # event, but that requires fetching the full state at the event, which + # may be expensive for large rooms with few local users. - local_users = await self.store.get_local_users_in_room(event.room_id) + local_users = await self.store.get_local_users_in_room(event.room_id) # Filter out appservice users. local_users = [ @@ -161,6 +187,9 @@ async def _get_rules_for_event( local_users = list(local_users) local_users.append(invited) + if not local_users: + return {} + rules_by_user = await self.store.bulk_get_push_rules(local_users) logger.debug("Users in room: %s", local_users) @@ -338,14 +367,44 @@ async def _action_for_event_by_user( related_events = await self._related_events(event) # It's possible that old room versions have non-integer power levels (floats or - # strings). Workaround this by explicitly converting to int. + # strings; even the occasional `null`). For old rooms, we interpret these as if + # they were integers. Do this here for the `@room` power level threshold. + # Note that this is done automatically for the sender's power level by + # _get_power_levels_and_sender_level in its call to get_user_power_level + # (even for room V10.) notification_levels = power_levels.get("notifications", {}) if not event.room_version.msc3667_int_only_power_levels: - for user_id, level in notification_levels.items(): - notification_levels[user_id] = int(level) + keys = list(notification_levels.keys()) + for key in keys: + level = notification_levels.get(key, SENTINEL) + if level is not SENTINEL and type(level) is not int: + try: + notification_levels[key] = int(level) + except (TypeError, ValueError): + del notification_levels[key] + + # Pull out any user and room mentions. + mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) + has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict) + user_mentions: Set[str] = set() + room_mention = False + if has_mentions: + # mypy seems to have lost the type even though it must be a dict here. + assert isinstance(mentions, dict) + # Remove out any non-string items and convert to a set. + user_mentions_raw = mentions.get("user_ids") + if isinstance(user_mentions_raw, list): + user_mentions = set( + filter(lambda item: isinstance(item, str), user_mentions_raw) + ) + # Room mention is only true if the value is exactly true. + room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( - _flatten_dict(event, room_version=event.room_version), + _flatten_dict(event), + has_mentions, + user_mentions, + room_mention, non_bot_room_member_count, sender_power_level, notification_levels, @@ -440,10 +499,31 @@ async def _action_for_event_by_user( def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], - room_version: Optional[RoomVersion] = None, prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: + """ + Given a JSON dictionary (or event) which might contain sub dictionaries, + flatten it into a single layer dictionary by combining the keys & sub-keys. + + Any (non-dictionary), non-string value is dropped. + + Transforms: + + {"foo": {"bar": "test"}} + + To: + + {"foo.bar": "test"} + + Args: + d: The event or content to continue flattening. + prefix: The key prefix (from outer dictionaries). + result: The result to mutate. + + Returns: + The resulting dictionary. + """ if prefix is None: prefix = [] if result is None: @@ -459,14 +539,13 @@ def _flatten_dict( # `room_version` should only ever be set when looking at the top level of an event if ( - room_version is not None - and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features - and isinstance(d, EventBase) + isinstance(d, EventBase) + and PushRuleRoomFlag.EXTENSIBLE_EVENTS in d.room_version.msc3931_push_features ): # Room supports extensible events: replace `content.body` with the plain text # representation from `m.markup`, as per MSC1767. markup = d.get("content").get("m.markup") - if room_version.identifier.startswith("org.matrix.msc1767."): + if d.room_version.identifier.startswith("org.matrix.msc1767."): markup = d.get("content").get("org.matrix.msc1767.markup") if markup is not None and isinstance(markup, list): text = "" diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 3f4d3fc51ae3..c20d9c7e9da7 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,7 +17,7 @@ import re import urllib.parse from inspect import signature -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -27,6 +27,7 @@ from synapse.api.errors import HttpResponseException, SynapseError from synapse.http import RequestTimedOutError from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.opentracing import trace_with_opname @@ -53,6 +54,9 @@ ) +_STREAM_POSITION_KEY = "_INT_STREAM_POS" + + class ReplicationEndpoint(metaclass=abc.ABCMeta): """Helper base class for defining new replication HTTP endpoints. @@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): a connection error is received. RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when receiving connection errors, each will backoff exponentially longer. + WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to + catch up before processing the request and/or response. Defaults to + True. """ NAME: str = abc.abstractproperty() # type: ignore @@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): RETRY_ON_CONNECT_ERROR = True RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) + WAIT_FOR_STREAMS: ClassVar[bool] = True + def __init__(self, hs: "HomeServer"): if self.CACHE: self.response_cache: ResponseCache[str] = ResponseCache( @@ -126,6 +135,10 @@ def __init__(self, hs: "HomeServer"): if hs.config.worker.worker_replication_secret: self._replication_secret = hs.config.worker.worker_replication_secret + self._streams = hs.get_replication_command_handler().get_streams_to_replicate() + self._replication = hs.get_replication_data_handler() + self._instance_name = hs.get_instance_name() + def _check_auth(self, request: Request) -> None: # Get the authorization header. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") @@ -160,7 +173,7 @@ async def _serialize_payload(**kwargs) -> JsonDict: @abc.abstractmethod async def _handle_request( - self, request: Request, **kwargs: Any + self, request: Request, content: JsonDict, **kwargs: Any ) -> Tuple[int, JsonDict]: """Handle incoming request. @@ -201,6 +214,10 @@ def make_client(cls, hs: "HomeServer") -> Callable: @trace_with_opname("outgoing_replication_request") async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: + # We have to pull these out here to avoid circular dependencies... + streams = hs.get_replication_command_handler().get_streams_to_replicate() + replication = hs.get_replication_data_handler() + with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") @@ -219,6 +236,24 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: data = await cls._serialize_payload(**kwargs) + if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS: + # Include the current stream positions that we write to. We + # don't do this for GETs as they don't have a body, and we + # generally assume that a GET won't rely on data we have + # written. + if _STREAM_POSITION_KEY in data: + raise Exception( + "data to send contains %r key", _STREAM_POSITION_KEY + ) + + data[_STREAM_POSITION_KEY] = { + "streams": { + stream.NAME: stream.current_token(local_instance_name) + for stream in streams + }, + "instance_name": local_instance_name, + } + url_args = [ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] @@ -308,6 +343,17 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() + + # Wait on any streams that the remote may have written to. + for stream_name, position in result.get( + _STREAM_POSITION_KEY, {} + ).items(): + await replication.wait_for_stream_position( + instance_name=instance_name, + stream_name=stream_name, + position=position, + ) + return result return send_request @@ -353,6 +399,22 @@ async def _check_auth_and_handle( if self._replication_secret: self._check_auth(request) + if self.METHOD == "GET": + # GET APIs always have an empty body. + content = {} + else: + content = parse_json_object_from_request(request) + + # Wait on any streams that the remote may have written to. + for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[ + "streams" + ].items(): + await self._replication.wait_for_stream_position( + instance_name=content[_STREAM_POSITION_KEY]["instance_name"], + stream_name=stream_name, + position=position, + ) + if self.CACHE: txn_id = kwargs.pop("txn_id") @@ -361,13 +423,30 @@ async def _check_auth_and_handle( # correctly yet. In particular, there may be issues to do with logging # context lifetimes. - return await self.response_cache.wrap( - txn_id, self._handle_request, request, **kwargs + code, response = await self.response_cache.wrap( + txn_id, self._handle_request, request, content, **kwargs ) + # Take a copy so we don't mutate things in the cache. + response = dict(response) + else: + # The `@cancellable` decorator may be applied to `_handle_request`. But we + # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, + # so we have to set up the cancellable flag ourselves. + request.is_render_cancellable = is_function_cancellable( + self._handle_request + ) + + code, response = await self._handle_request(request, content, **kwargs) + + # Return streams we may have written to in the course of processing this + # request. + if _STREAM_POSITION_KEY in response: + raise Exception("data to send contains %r key", _STREAM_POSITION_KEY) - # The `@cancellable` decorator may be applied to `_handle_request`. But we - # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, - # so we have to set up the cancellable flag ourselves. - request.is_render_cancellable = is_function_cancellable(self._handle_request) + if self.WAIT_FOR_STREAMS: + response[_STREAM_POSITION_KEY] = { + stream.NAME: stream.current_token(self._instance_name) + for stream in self._streams + } - return await self._handle_request(request, **kwargs) + return code, response diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 0edc95977b3a..2374f810c94f 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -61,10 +60,8 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, account_data_type: str + self, request: Request, content: JsonDict, user_id: str, account_data_type: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_account_data_for_user( user_id, account_data_type, content["content"] ) @@ -101,7 +98,7 @@ async def _serialize_payload( # type: ignore[override] return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, account_data_type: str + self, request: Request, content: JsonDict, user_id: str, account_data_type: str ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_account_data_for_user( user_id, account_data_type @@ -143,10 +140,13 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, account_data_type: str + self, + request: Request, + content: JsonDict, + user_id: str, + room_id: str, + account_data_type: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, content["content"] ) @@ -183,7 +183,12 @@ async def _serialize_payload( # type: ignore[override] return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, account_data_type: str + self, + request: Request, + content: JsonDict, + user_id: str, + room_id: str, + account_data_type: str, ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_account_data_for_room( user_id, room_id, account_data_type @@ -225,10 +230,8 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, tag: str + self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_tag_to_room( user_id, room_id, tag, content["content"] ) @@ -266,7 +269,7 @@ async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, tag: str + self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_tag_from_room( user_id, diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index ea5c08e6cfdf..ecea6fc915c7 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -78,7 +77,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) @@ -138,9 +137,8 @@ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[o return {"user_ids": user_ids} async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: - content = parse_json_object_from_request(request) user_ids: List[str] = content["user_ids"] logger.info("Resync for %r", user_ids) @@ -205,10 +203,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - user_id = content["user_id"] device_id = content["device_id"] keys = content["keys"] diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d3abafed2871..53ad32703029 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -21,7 +21,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict from synapse.util.metrics import Measure @@ -114,10 +113,8 @@ async def _serialize_payload( # type: ignore[override] return payload - async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override] + async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override] with Measure(self.clock, "repl_fed_send_events_parse"): - content = parse_json_object_from_request(request) - room_id = content["room_id"] backfilled = content["backfilled"] @@ -181,13 +178,10 @@ async def _serialize_payload( # type: ignore[override] return {"origin": origin, "content": content} async def _handle_request( # type: ignore[override] - self, request: Request, edu_type: str + self, request: Request, content: JsonDict, edu_type: str ) -> Tuple[int, JsonDict]: - with Measure(self.clock, "repl_fed_send_edu_parse"): - content = parse_json_object_from_request(request) - - origin = content["origin"] - edu_content = content["content"] + origin = content["origin"] + edu_content = content["content"] logger.info("Got %r edu from %s", edu_type, origin) @@ -231,13 +225,10 @@ async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # ty return {"args": args} async def _handle_request( # type: ignore[override] - self, request: Request, query_type: str + self, request: Request, content: JsonDict, query_type: str ) -> Tuple[int, JsonDict]: - with Measure(self.clock, "repl_fed_query_parse"): - content = parse_json_object_from_request(request) - - args = content["args"] - args["origin"] = content["origin"] + args = content["args"] + args["origin"] = content["origin"] logger.info("Got %r query from %s", query_type, args["origin"]) @@ -274,7 +265,7 @@ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: await self.store.clean_room_for_join(room_id) @@ -307,9 +298,8 @@ async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDic return {"room_version": room_version.identifier} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index c68e18da129b..6ad6cb1bfe4e 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -73,10 +72,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 663bff573848..9fa1060d48f6 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -17,7 +17,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID @@ -79,10 +78,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: SynapseRequest, room_id: str, user_id: str + self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] @@ -147,11 +144,10 @@ async def _serialize_payload( # type: ignore[override] async def _handle_request( # type: ignore[override] self, request: SynapseRequest, + content: JsonDict, room_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] @@ -217,10 +213,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: SynapseRequest, invite_event_id: str + self, request: SynapseRequest, content: JsonDict, invite_event_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - txn_id = content["txn_id"] event_content = content["content"] @@ -285,10 +279,9 @@ async def _serialize_payload( # type: ignore[override] async def _handle_request( # type: ignore[override] self, request: SynapseRequest, + content: JsonDict, knock_event_id: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - txn_id = content["txn_id"] event_content = content["content"] @@ -347,7 +340,12 @@ async def _serialize_payload( # type: ignore[override] return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str, user_id: str, change: str + self, + request: Request, + content: JsonDict, + room_id: str, + user_id: str, + change: str, ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index 4a5b08f56f73..db16aac9c206 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, UserID @@ -56,7 +55,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( UserID.from_string(user_id) @@ -107,10 +106,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - await self._presence_handler.set_state( UserID.from_string(user_id), content["state"], diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index af5c2f66a735..297e8ad564bd 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -61,10 +60,8 @@ async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDic return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - app_id = content["app_id"] pushkey = content["pushkey"] diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 976c2833603d..265e601b96a9 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -18,7 +18,6 @@ from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -96,10 +95,8 @@ async def _serialize_payload( # type: ignore[override] } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - await self.registration_handler.check_registration_ratelimit(content["address"]) # Always default admin users to approved (since it means they were created by @@ -150,10 +147,8 @@ async def _serialize_payload( # type: ignore[override] return {"auth_result": auth_result, "access_token": access_token} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - auth_result = content["auth_result"] access_token = content["access_token"] diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index a3b5d7cc1e55..bc8622333b91 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -21,7 +21,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure @@ -117,11 +116,9 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request, event_id: str + self, request: Request, content: JsonDict, event_id: str ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_event_parse"): - content = parse_json_object_from_request(request) - event_dict = content["event"] room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]] internal_metadata = content["internal_metadata"] diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py index 9121906fd16a..7fc13ef402be 100644 --- a/synapse/replication/http/send_events.py +++ b/synapse/replication/http/send_events.py @@ -21,7 +21,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure @@ -116,10 +115,9 @@ async def _serialize_payload( # type: ignore[override] return payload async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, payload: JsonDict ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_events_parse"): - payload = parse_json_object_from_request(request) events_and_context = [] events = payload["events"] diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py index 838b7584e56f..0c524e7de3fd 100644 --- a/synapse/replication/http/state.py +++ b/synapse/replication/http/state.py @@ -57,7 +57,7 @@ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: writer_instance = self._events_shard_config.get_instance(room_id) if writer_instance != self._instance_name: diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index c06522536254..3c7b5b18eab8 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): PATH_ARGS = ("stream_name",) METHOD = "GET" + # We don't want to wait for replication streams to catch up, as this gets + # called in the process of catching replication streams up. + WAIT_FOR_STREAMS = False + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -67,7 +71,7 @@ async def _serialize_payload( # type: ignore[override] return {"from_token": from_token, "upto_token": upto_token} async def _handle_request( # type: ignore[override] - self, request: Request, stream_name: str + self, request: Request, content: JsonDict, stream_name: str ) -> Tuple[int, JsonDict]: stream = self.streams.get(stream_name) if stream is None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b5e40da5337e..cc0528bd8e5f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory @@ -33,7 +34,6 @@ PushersStream, PushRulesStream, ReceiptsStream, - TagAccountDataStream, ToDeviceStream, TypingStream, UnPartialStatedEventStream, @@ -59,7 +59,7 @@ logger = logging.getLogger(__name__) # How long we allow callers to wait for replication updates before timing out. -_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 5 class DirectTcpReplicationClientFactory(ReconnectingClientFactory): @@ -133,9 +133,9 @@ def __init__(self, hs: "HomeServer"): if hs.should_send_federation(): self.send_handler = FederationSenderHandler(hs) - # Map from stream to list of deferreds waiting for the stream to + # Map from stream and instance to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} + self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -168,7 +168,7 @@ async def on_rdata( self.notifier.on_new_event( StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] ) - elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): + elif stream_name in AccountDataStream.NAME: self.notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) @@ -188,7 +188,7 @@ async def on_rdata( elif stream_name == DeviceListsStream.NAME: all_room_ids: Set[str] = set() for row in rows: - if row.entity.startswith("@"): + if row.entity.startswith("@") and not row.is_signature: room_ids = await self.store.get_rooms_for_user(row.entity) all_room_ids.update(room_ids) self.notifier.on_new_event( @@ -207,6 +207,12 @@ async def on_rdata( # we don't need to optimise this for multiple rows. for row in rows: if row.type != EventsStreamEventRow.TypeId: + # The row's data is an `EventsStreamCurrentStateRow`. + # When we recompute the current state of a room based on forward + # extremities (see `update_current_state`), no new events are + # persisted, so we must poke the replication callbacks ourselves. + # This functionality is used when finishing up a partial state join. + self.notifier.notify_replication() continue assert isinstance(row, EventsStreamRow) assert isinstance(row.data, EventsStreamEventRow) @@ -254,6 +260,7 @@ async def on_rdata( self._state_storage_controller.notify_room_un_partial_stated( row.room_id ) + await self.notifier.on_un_partial_stated_room(row.room_id, token) elif stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -270,7 +277,7 @@ async def on_rdata( # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position. - waiting_list = self._streams_to_waiters.get(stream_name, []) + waiting_list = self._streams_to_waiters.get((stream_name, instance_name), []) # Index of first item with a position after the current token, i.e we # have called all deferreds before this index. If not overwritten by @@ -279,14 +286,13 @@ async def on_rdata( # `len(list)` works for both cases. index_of_first_deferred_not_called = len(waiting_list) + # We don't fire the deferreds until after we finish iterating over the + # list, to avoid the list changing when we fire the deferreds. + deferreds_to_callback = [] + for idx, (position, deferred) in enumerate(waiting_list): if position <= token: - try: - with PreserveLoggingContext(): - deferred.callback(None) - except Exception: - # The deferred has been cancelled or timed out. - pass + deferreds_to_callback.append(deferred) else: # The list is sorted by position so we don't need to continue # checking any further entries in the list. @@ -297,6 +303,14 @@ async def on_rdata( # loop. (This maintains the order so no need to resort) waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + for deferred in deferreds_to_callback: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + async def on_position( self, stream_name: str, instance_name: str, token: int ) -> None: @@ -315,10 +329,18 @@ def on_remote_server_up(self, server: str) -> None: self.send_handler.wake_destination(server) async def wait_for_stream_position( - self, instance_name: str, stream_name: str, position: int + self, + instance_name: str, + stream_name: str, + position: int, ) -> None: """Wait until this instance has received updates up to and including the given stream position. + + Args: + instance_name + stream_name + position """ if instance_name == self._instance_name: @@ -326,7 +348,7 @@ async def wait_for_stream_position( # anyway in that case we don't need to wait. return - current_position = self._streams[stream_name].current_token(self._instance_name) + current_position = self._streams[stream_name].current_token(instance_name) if position <= current_position: # We're already past the position return @@ -338,17 +360,32 @@ async def wait_for_stream_position( deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor ) - waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + waiting_list = self._streams_to_waiters.setdefault( + (stream_name, instance_name), [] + ) waiting_list.append((position, deferred)) waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): - logger.info("Waiting for repl stream %r to reach %s", stream_name, position) - await make_deferred_yieldable(deferred) logger.info( - "Finished waiting for repl stream %r to reach %s", stream_name, position + "Waiting for repl stream %r to reach %s (%s)", + stream_name, + position, + instance_name, + ) + try: + await make_deferred_yieldable(deferred) + except defer.TimeoutError: + logger.error("Timed out waiting for stream %s", stream_name) + return + + logger.info( + "Finished waiting for repl stream %r to reach %s (%s)", + stream_name, + position, + instance_name, ) def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: @@ -423,7 +460,11 @@ async def process_replication_rows( # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - hosts = {row.entity for row in rows if not row.entity.startswith("@")} + hosts = { + row.entity + for row in rows + if not row.entity.startswith("@") and not row.is_signature + } for host in hosts: self.federation_sender.send_device_messages(host, immediate=False) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 3e85c5d03a32..0ad31a05c8ad 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -58,7 +58,6 @@ PresenceStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -145,7 +144,7 @@ def __init__(self, hs: "HomeServer"): continue - if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + if isinstance(stream, AccountDataStream): # Only add AccountDataStream and TagAccountDataStream as a source on the # instance in charge of account_data persistence. if hs.get_instance_name() in hs.config.worker.writers.account_data: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 99f09669f00b..9d17eff71451 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -199,33 +199,28 @@ async def _run_notifier_loop(self) -> None: # The token has advanced but there is no data to # send, so we send a `POSITION` to inform other # workers of the updated position. - if stream.NAME == EventsStream.NAME: - # XXX: We only do this for the EventStream as it - # turns out that e.g. account data streams share - # their "current token" with each other, meaning - # that it is *not* safe to send a POSITION. - - # Note: `last_token` may not *actually* be the - # last token we sent out in a RDATA or POSITION. - # This can happen if we sent out an RDATA for - # position X when our current token was say X+1. - # Other workers will see RDATA for X and then a - # POSITION with last token of X+1, which will - # cause them to check if there were any missing - # updates between X and X+1. - logger.info( - "Sending position: %s -> %s", + + # Note: `last_token` may not *actually* be the + # last token we sent out in a RDATA or POSITION. + # This can happen if we sent out an RDATA for + # position X when our current token was say X+1. + # Other workers will see RDATA for X and then a + # POSITION with last token of X+1, which will + # cause them to check if there were any missing + # updates between X and X+1. + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( stream.NAME, + self._instance_name, + last_token, current_token, ) - self.command_handler.send_command( - PositionCommand( - stream.NAME, - self._instance_name, - last_token, - current_token, - ) - ) + ) continue # Some streams return multiple rows with the same stream IDs, diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 110f10aab9a5..9c67f661a362 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -35,10 +35,8 @@ PushRulesStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, - UserSignatureStream, ) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.federation import FederationStream @@ -62,9 +60,7 @@ DeviceListsStream, ToDeviceStream, FederationStream, - TagAccountDataStream, AccountDataStream, - UserSignatureStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, ) @@ -83,9 +79,7 @@ "CachesStream", "DeviceListsStream", "ToDeviceStream", - "TagAccountDataStream", "AccountDataStream", - "UserSignatureStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", ] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e01155ad597b..a4bdb48c0c99 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,8 +28,8 @@ import attr +from synapse.api.constants import AccountDataTypes from synapse.replication.http.streams import ReplicationGetStreamUpdates -from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -463,18 +463,67 @@ class DeviceListsStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: entity: str + # Indicates that a user has signed their own device with their user-signing key + is_signature: bool NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_device_list_changes_for_remotes, + current_token_without_instance(self.store.get_device_stream_token), + self._update_function, ) + async def _update_function( + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, + ) -> StreamUpdateResult: + ( + device_updates, + devices_to_token, + devices_limited, + ) = await self.store.get_all_device_list_changes_for_remotes( + instance_name, from_token, current_token, target_row_count + ) + + ( + signatures_updates, + signatures_to_token, + signatures_limited, + ) = await self.store.get_all_user_signature_changes_for_remotes( + instance_name, from_token, current_token, target_row_count + ) + + upper_limit_token = current_token + if devices_limited: + upper_limit_token = min(upper_limit_token, devices_to_token) + if signatures_limited: + upper_limit_token = min(upper_limit_token, signatures_to_token) + + device_updates = [ + (stream_id, (entity, False)) + for stream_id, (entity,) in device_updates + if stream_id <= upper_limit_token + ] + + signatures_updates = [ + (stream_id, (entity, True)) + for stream_id, (entity,) in signatures_updates + if stream_id <= upper_limit_token + ] + + updates = list( + heapq.merge(device_updates, signatures_updates, key=lambda row: row[0]) + ) + + return updates, upper_limit_token, devices_limited or signatures_limited + class ToDeviceStream(Stream): """New to_device messages for a client""" @@ -495,27 +544,6 @@ def __init__(self, hs: "HomeServer"): ) -class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class TagAccountDataStreamRow: - user_id: str - room_id: str - data: JsonDict - - NAME = "tag_account_data" - ROW_TYPE = TagAccountDataStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_max_account_data_stream_id), - store.get_all_updated_tags, - ) - - class AccountDataStream(Stream): """Global or per room account data was changed""" @@ -560,6 +588,19 @@ async def _update_function( to_token = room_results[-1][0] limited = True + tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags( + instance_name, + from_token, + to_token, + limit, + ) + + # again, if the tag results hit the limit, limit the global results to + # the same stream token. + if tags_limited: + to_token = tag_to_token + limited = True + # convert the global results to the right format, and limit them to the to_token # at the same time global_rows = ( @@ -568,11 +609,16 @@ async def _update_function( if stream_id <= to_token ) - # we know that the room_results are already limited to `to_token` so no need - # for a check on `stream_id` here. room_rows = ( (stream_id, (user_id, room_id, account_data_type)) for stream_id, user_id, room_id, account_data_type in room_results + if stream_id <= to_token + ) + + tag_rows = ( + (stream_id, (user_id, room_id, AccountDataTypes.TAG)) + for stream_id, user_id, room_id in tags + if stream_id <= to_token ) # We need to return a sorted list, so merge them together. @@ -582,24 +628,7 @@ async def _update_function( # leading to a comparison between the data tuples. The comparison could # fail due to attempting to compare the `room_id` which results in a # `TypeError` from comparing a `str` vs `None`. - updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) - return updates, to_token, limited - - -class UserSignatureStream(Stream): - """A user has signed their own device with their user-signing key""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class UserSignatureStreamRow: - user_id: str - - NAME = "user_signature" - ROW_TYPE = UserSignatureStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_user_signature_changes_for_remotes, + updates = list( + heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) ) + return updates, to_token, limited diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py index b5a2ae74b685..a8ce5ffd7289 100644 --- a/synapse/replication/tcp/streams/partial_state.py +++ b/synapse/replication/tcp/streams/partial_state.py @@ -16,7 +16,6 @@ import attr from synapse.replication.tcp.streams import Stream -from synapse.replication.tcp.streams._base import current_token_without_instance if TYPE_CHECKING: from synapse.server import HomeServer @@ -42,8 +41,7 @@ def __init__(self, hs: "HomeServer"): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - # TODO(faster_joins, multiple writers): we need to account for instance names - current_token_without_instance(store.get_un_partial_stated_rooms_token), + store.get_un_partial_stated_rooms_token, store.get_un_partial_stated_rooms_from_stream, ) @@ -70,7 +68,6 @@ def __init__(self, hs: "HomeServer"): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - # TODO(faster_joins, multiple writers): we need to account for instance names - current_token_without_instance(store.get_un_partial_stated_events_token), + store.get_un_partial_stated_events_token, store.get_un_partial_stated_events_from_stream, ) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index fb73886df061..79f22a59f195 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -152,7 +152,7 @@ async def on_POST( logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if not isinstance(ts, int): + if type(ts) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 6d634eef7081..a3beb74e2c3d 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -16,8 +16,9 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict @@ -60,7 +61,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - direction = parse_string(request, "dir", default="b") + direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS) user_id = parse_string(request, "user_id") room_id = parse_string(request, "room_id") @@ -78,13 +79,6 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: errcode=Codes.INVALID_PARAM, ) - if direction not in ("f", "b"): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, - ) - event_reports, total = await self.store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 023ed92144b4..e0ee55bd0eb6 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -15,9 +15,10 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.federation.transport.server import Authenticator -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.storage.databases.main.transactions import DestinationSortOrder @@ -79,7 +80,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: allowed_values=[dest.value for dest in DestinationSortOrder], ) - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) destinations, total = await self._store.get_destinations_paginate( start, limit, destination, order_by, direction @@ -192,7 +193,7 @@ async def on_GET( errcode=Codes.INVALID_PARAM, ) - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) rooms, total = await self._store.get_destination_rooms_paginate( destination, start, limit, direction diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 73470f09ae44..0d072c42a7aa 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,9 +17,16 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_boolean, + parse_enum, + parse_integer, + parse_string, +) from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, @@ -389,7 +396,7 @@ async def on_GET( # to newest media is on top for backward compatibility. if b"order_by" not in request.args and b"dir" not in request.args: order_by = MediaSortOrder.CREATED_TS.value - direction = "b" + direction = Direction.BACKWARDS else: order_by = parse_string( request, @@ -397,8 +404,8 @@ async def on_GET( default=MediaSortOrder.CREATED_TS.value, allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") + direction = parse_enum( + request, "dir", Direction, default=Direction.FORWARDS ) media, total = await self.store.get_local_media_by_user_paginate( @@ -447,7 +454,7 @@ async def on_DELETE( # to newest media is on top for backward compatibility. if b"order_by" not in request.args and b"dir" not in request.args: order_by = MediaSortOrder.CREATED_TS.value - direction = "b" + direction = Direction.BACKWARDS else: order_by = parse_string( request, @@ -455,8 +462,8 @@ async def on_DELETE( default=MediaSortOrder.CREATED_TS.value, allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") + direction = parse_enum( + request, "dir", Direction, default=Direction.FORWARDS ) media, _ = await self.store.get_local_media_by_user_paginate( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index af606e925216..95e751288b03 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -143,7 +143,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if not isinstance(length, int): + if type(length) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,8 +163,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None - or (isinstance(uses_allowed, int) and uses_allowed >= 0) + uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -173,13 +172,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) expiry_time = body.get("expiry_time", None) - if not isinstance(expiry_time, (int, type(None))): + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -284,7 +283,7 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (isinstance(uses_allowed, int) and uses_allowed >= 0) + or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -295,13 +294,13 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi if "expiry_time" in body: expiry_time = body["expiry_time"] - if not isinstance(expiry_time, (int, type(None))): + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 3b2b12114163..97d2da5dd4d2 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -19,13 +19,14 @@ from prometheus_client import Histogram -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, assert_params_in_dict, + parse_enum, parse_integer, parse_json_object_from_request, parse_string, @@ -232,15 +233,8 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: errcode=Codes.INVALID_PARAM, ) - direction = parse_string(request, "dir", default="f") - if direction not in ("f", "b"): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, - ) - - reverse_order = True if direction == "b" else False + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) + reverse_order = True if direction == Direction.BACKWARDS else False # Return list of rooms according to parameters rooms, total_rooms = await self.store.get_rooms_paginate( @@ -963,7 +957,7 @@ async def on_GET( await assert_user_is_admin(self._auth, requester) timestamp = parse_integer(request, "ts", required=True) - direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) ( event_id, diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 3b142b840206..9c45f4650dc3 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -16,8 +16,9 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.storage.databases.main.stats import UserSortOrder @@ -102,13 +103,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: errcode=Codes.INVALID_PARAM, ) - direction = parse_string(request, "dir", default="f") - if direction not in ("f", "b"): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, - ) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) users_media, total = await self.store.get_users_media_usage_paginate( start, limit, from_ts, until_ts, order_by, direction, search_term diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 61eb0cf0358e..5f15b92c60c8 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -18,12 +18,13 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from synapse.api.constants import UserTypes +from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_boolean, + parse_enum, parse_integer, parse_json_object_from_request, parse_string, @@ -121,7 +122,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ), ) - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) users, total = await self.store.get_users_paginate( start, @@ -975,7 +976,7 @@ async def on_POST( body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") - if valid_until_ms and not isinstance(valid_until_ms, int): + if type(valid_until_ms) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" ) @@ -1127,14 +1128,14 @@ async def on_POST( messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if not isinstance(messages_per_second, int) or messages_per_second < 0: + if type(messages_per_second) is not int or messages_per_second < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if not isinstance(burst_count, int) or burst_count < 0: + if type(burst_count) is not int or burst_count < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8191b4e32c34..ad5c10c99dd8 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -169,7 +169,7 @@ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDic raise UnrecognizedRequestError() -def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: +def _rule_spec_from_path(path: List[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 9dd59196d9a2..7456d6f50724 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -16,6 +16,7 @@ import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.api.constants import Direction from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -59,7 +60,7 @@ async def on_GET( requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request( - self._store, request, default_limit=5, default_dir="b" + self._store, request, default_limit=5, default_dir=Direction.BACKWARDS ) # The unstable version of this API returns an extra field for client diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 6e962a45321e..e2b410cf32a3 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -54,7 +54,7 @@ async def on_POST( "Param 'reason' must be a string", Codes.BAD_JSON, ) - if not isinstance(body.get("score", 0), int): + if type(body.get("score", 0)) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 4a0d09ce24f5..a63a87c40e71 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -26,7 +26,7 @@ from twisted.web.server import Request from synapse import event_auth -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -44,6 +44,7 @@ RestServlet, assert_params_in_dict, parse_boolean, + parse_enum, parse_integer, parse_json_object_from_request, parse_string, @@ -1297,7 +1298,7 @@ async def on_GET( await self._auth.check_user_in_room_or_world_readable(room_id, requester) timestamp = parse_integer(request, "ts", required=True) - direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) ( event_id, diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 61375651bc15..3f40f1874a8b 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -19,6 +19,7 @@ from typing_extensions import ParamSpec +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.web.server import Request @@ -90,7 +91,7 @@ def fetch_or_execute( fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> "Deferred[Tuple[int, JsonDict]]": """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index a3738a62507d..7592aa5d4767 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -200,7 +200,7 @@ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if val is not None and isinstance(val, int): + if type(val) is int: open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a48a4de92ae2..9480cc576332 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -77,7 +77,7 @@ def __init__(self, input_path: str): image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert isinstance(image_orientation, int) + assert type(image_orientation) is int self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF diff --git a/synapse/server.py b/synapse/server.py index 981d3c2b43b1..3cfb1b173b50 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -107,7 +107,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi -from synapse.notifier import Notifier +from synapse.notifier import Notifier, ReplicationNotifier from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler @@ -390,6 +390,10 @@ def get_federation_server(self) -> FederationServer: def get_notifier(self) -> Notifier: return Notifier(self) + @cache_in_self + def get_replication_notifier(self) -> ReplicationNotifier: + return ReplicationNotifier() + @cache_in_self def get_auth(self) -> Auth: return Auth(self) @@ -744,7 +748,7 @@ def get_oidc_handler(self) -> "OidcHandler": @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self.config.experimental.msc3925_inhibit_edit) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 26d79c6e62f2..52efd4a1719b 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -493,8 +493,6 @@ async def get_current_state_deltas( up to date. """ # FIXME(faster_joins): what do we do here? - # https://github.com/matrix-org/synapse/issues/12814 - # https://github.com/matrix-org/synapse/issues/12815 # https://github.com/matrix-org/synapse/issues/13008 return await self.stores.main.get_partial_current_state_deltas( @@ -571,10 +569,11 @@ async def get_current_hosts_in_room_or_partial_state_approximation( is arbitrary for rooms with partial state. """ # We have to read this list first to mitigate races with un-partial stating. - # This will be empty for rooms with full state. hosts_at_join = await self.stores.main.get_partial_state_servers_at_join( room_id ) + if hosts_at_join is None: + hosts_at_join = frozenset() hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 88479a16db0c..e20c5c5302fa 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1819,7 +1819,7 @@ async def simple_select_many_batch( keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, - ) -> List[Any]: + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 3b5d166f56ab..b421a57ba839 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast +from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import ( DatabasePool, @@ -169,7 +170,7 @@ async def get_users_paginate( guests: bool = True, deactivated: bool = False, order_by: str = UserSortOrder.NAME.value, - direction: str = "f", + direction: Direction = Direction.FORWARDS, approved: bool = True, appservice: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: @@ -200,7 +201,7 @@ def get_users_paginate_txn( # Set ordering order_by_column = UserSortOrder(order_by).value - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c298af6819a6..939a08aad162 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -28,7 +28,7 @@ ) from synapse.api.constants import AccountDataTypes -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( DatabasePool, @@ -80,6 +80,7 @@ def __init__( self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="account_data", instance_name=self._instance_name, tables=[ @@ -100,6 +101,7 @@ def __init__( # SQLite). self._account_data_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], @@ -459,9 +461,7 @@ def process_replication_rows( def process_replication_position( self, stream_name: str, instance_name: str, token: int ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index ed757539840b..2cddd3b94529 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,17 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Pattern, + Sequence, + Tuple, + cast, +) from synapse.appservice import ( ApplicationService, @@ -258,7 +268,7 @@ async def set_appservice_state( async def create_appservice_txn( self, service: ApplicationService, - events: List[EventBase], + events: Sequence[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], one_time_keys_count: TransactionOneTimeKeysCount, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index af6b217d5d03..14d5aa4f941c 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -75,6 +75,7 @@ def __init__( self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + notifier=hs.get_replication_notifier(), stream_name="caches", instance_name=hs.get_instance_name(), tables=[ diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 89b85ec9caa7..193dba320877 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -91,6 +91,7 @@ def __init__( MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], @@ -101,7 +102,7 @@ def __init__( else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" + db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id" ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 155b1a9bf6d8..6ed699f7824e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,7 @@ whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream +from synapse.replication.tcp.streams._base import DeviceListsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -92,12 +92,14 @@ def __init__( # class below that is used on the main process. self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ("device_lists_changes_in_room", "stream_id"), + ("device_lists_remote_pending", "stream_id"), ], is_writer=hs.config.worker.worker_app is None, ) @@ -163,9 +165,7 @@ def process_replication_rows( ) -> None: if stream_name == DeviceListsStream.NAME: self._invalidate_caches_for_devices(token, rows) - elif stream_name == UserSignatureStream.NAME: - for row in rows: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) def process_replication_position( @@ -173,14 +173,17 @@ def process_replication_position( ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) - elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: for row in rows: + if row.is_signature: + self._user_signature_stream_cache.entity_has_changed(row.entity, token) + continue + # NOTE: here we always tell both stream change caches, either about # the entity or just the known position. # The entities are either user IDs (starting with '@') whose devices diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4c691642e2b5..c4ac6c33ba54 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1181,7 +1181,10 @@ def __init__( super().__init__(database, db_conn, hs) self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" + db_conn, + hs.get_replication_notifier(), + "e2e_cross_signing_keys", + "stream_id", ) async def set_e2e_device_keys( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 08f5c9e77a7f..f4ce258ed6bf 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1651,7 +1651,7 @@ def _update_metadata_tables_txn( if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if isinstance(expiry_ts, int) and not event.is_state(): + if type(expiry_ts) is int and not event.is_state(): self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -2133,10 +2133,10 @@ def _store_retention_policy_for_room_txn( ): if ( "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), int) + and type(event.content["min_lifetime"]) is not int ) or ( "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), int) + and type(event.content["max_lifetime"]) is not int ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 9e31798ab198..b9d3c36d602c 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -69,6 +69,8 @@ class _BackgroundUpdates: EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" + EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" + @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -260,6 +262,16 @@ def __init__( self._background_events_populate_state_key_rejections, ) + # Add an index that would be useful for jumping to date using + # get_event_id_for_timestamp. + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.EVENTS_JUMP_TO_DATE_INDEX, + index_name="events_jump_to_date_idx", + table="events", + columns=["room_id", "origin_server_ts"], + where_clause="NOT outlier", + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 0c5d694f34c7..0c419b4715a3 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -39,7 +39,7 @@ from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import Direction, EventTypes from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, @@ -112,6 +112,10 @@ ) +class InvalidEventError(Exception): + """The event retrieved from the database is invalid and cannot be used.""" + + @attr.s(slots=True, auto_attribs=True) class EventCacheEntry: event: EventBase @@ -193,6 +197,7 @@ def __init__( self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="events", instance_name=hs.get_instance_name(), tables=[("events", "instance_name", "stream_ordering")], @@ -202,6 +207,7 @@ def __init__( self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="backfill", instance_name=hs.get_instance_name(), tables=[("events", "instance_name", "stream_ordering")], @@ -219,12 +225,14 @@ def __init__( # SQLite). self._stream_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "events", "stream_ordering", is_writer=hs.get_instance_name() in hs.config.worker.writers.events, ) self._backfill_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "events", "stream_ordering", step=-1, @@ -314,6 +322,7 @@ def get_chain_id_txn(txn: Cursor) -> int: self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="un_partial_stated_event_stream", instance_name=hs.get_instance_name(), tables=[ @@ -325,14 +334,18 @@ def get_chain_id_txn(txn: Cursor) -> int: ) else: self._un_partial_stated_events_stream_id_gen = StreamIdGenerator( - db_conn, "un_partial_stated_event_stream", "stream_id" + db_conn, + hs.get_replication_notifier(), + "un_partial_stated_event_stream", + "stream_id", ) - def get_un_partial_stated_events_token(self) -> int: - # TODO(faster_joins, multiple writers): This is inappropriate if there are multiple - # writers because workers that don't write often will hold all - # readers up. - return self._un_partial_stated_events_stream_id_gen.get_current_token() + def get_un_partial_stated_events_token(self, instance_name: str) -> int: + return ( + self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( + instance_name + ) + ) async def get_un_partial_stated_events_from_stream( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -422,6 +435,8 @@ def process_replication_position( self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) + elif stream_name == UnPartialStatedEventStream.NAME: + self._un_partial_stated_events_stream_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) async def have_censored_event(self, event_id: str) -> bool: @@ -1318,7 +1333,7 @@ async def _fetch_event_ids_and_get_outstanding_redactions( # invites, so just accept it for all membership events. # if d["type"] != EventTypes.Member: - raise Exception( + raise InvalidEventError( "Room %s for event %s is unknown" % (d["room_id"], event_id) ) @@ -1783,7 +1798,7 @@ def get_ex_outlier_stream_rows_txn( txn: LoggingTransaction, ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type," " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," " e.outlier" " FROM events AS e" @@ -1795,10 +1810,10 @@ def get_ex_outlier_stream_rows_txn( " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" + " WHERE ? < out.event_stream_ordering" + " AND out.event_stream_ordering <= ?" " AND out.instance_name = ?" - " ORDER BY event_stream_ordering ASC" + " ORDER BY out.event_stream_ordering ASC" ) txn.execute(sql, (last_id, current_id, instance_name)) @@ -2244,7 +2259,7 @@ def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: ) async def get_event_id_for_timestamp( - self, room_id: str, timestamp: int, direction: str + self, room_id: str, timestamp: int, direction: Direction ) -> Optional[str]: """Find the closest event to the given timestamp in the given direction. @@ -2252,14 +2267,14 @@ async def get_event_id_for_timestamp( room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: The closest event_id otherwise None if we can't find any event in the given direction. """ - if direction == "b": + if direction == Direction.BACKWARDS: # Find closest event *before* a given timestamp. We use descending # (which gives values largest to smallest) because we want the # largest possible timestamp *before* the given timestamp. @@ -2311,9 +2326,6 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: return None - if direction not in ("f", "b"): - raise ValueError("Unknown direction: %s" % (direction,)) - return await self.db_pool.runInteraction( "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 9b172a64d86d..b202c5eb87a6 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -26,6 +26,7 @@ cast, ) +from synapse.api.constants import Direction from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -176,7 +177,7 @@ async def get_local_media_by_user_paginate( limit: int, user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, - direction: str = "f", + direction: Direction = Direction.FORWARDS, ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -199,7 +200,7 @@ def get_local_media_by_user_paginate_txn( # Set ordering order_by_column = MediaSortOrder(order_by).value - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 7b60815043a6..beb210f8eefd 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -77,6 +77,7 @@ def __init__( self._presence_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="presence_stream", instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], @@ -85,7 +86,7 @@ def __init__( ) else: self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" + db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id" ) self.hs = hs diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 478eec8f9e9f..ced9e231f625 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -86,8 +86,11 @@ def _load_rules( filtered_rules = FilteredPushRules( push_rules, enabled_map, - msc3664_enabled=experimental_config.msc3664_enabled, msc1767_enabled=experimental_config.msc1767_enabled, + msc3664_enabled=experimental_config.msc3664_enabled, + msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, + msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions, + msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, ) return filtered_rules @@ -117,6 +120,7 @@ def __init__( # class below that is used on the main process. self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "push_rules_stream", "stream_id", is_writer=hs.config.worker.worker_app is None, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 7f24a3b6ec5e..df53e726e62a 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -62,6 +62,7 @@ def __init__( # class below that is used on the main process. self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b5cf2d8c3511..b2048fb167e1 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -75,6 +75,7 @@ def __init__( self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], @@ -93,6 +94,7 @@ def __init__( # SQLite). self._receipts_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "receipts_linearized", "stream_id", is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, @@ -965,10 +967,14 @@ async def _background_receipts_linearized_unique_index( receipts.""" def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + if isinstance(self.database_engine, PostgresEngine): + ROW_ID_NAME = "ctid" + else: + ROW_ID_NAME = "rowid" + # Identify any duplicate receipts arising from # https://github.com/matrix-org/synapse/issues/14406. - # We expect the following query to use the per-thread receipt index and take - # less than a minute. + # The following query takes less than a minute on matrix.org. sql = """ SELECT MAX(stream_id), room_id, receipt_type, user_id FROM receipts_linearized @@ -980,19 +986,33 @@ def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) # Then remove duplicate receipts, keeping the one with the highest - # `stream_id`. There should only be a single receipt with any given - # `stream_id`. - for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: - sql = """ + # `stream_id`. Since there might be duplicate rows with the same + # `stream_id`, we delete by the ctid instead. + for stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = f""" + SELECT {ROW_ID_NAME} + FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id = ? + LIMIT 1 + """ + txn.execute(sql, (room_id, receipt_type, user_id, stream_id)) + row_id = cast(Tuple[str], txn.fetchone())[0] + + sql = f""" DELETE FROM receipts_linearized WHERE room_id = ? AND receipt_type = ? AND user_id = ? AND thread_id IS NULL AND - stream_id < ? + {ROW_ID_NAME} != ? """ - txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + txn.execute(sql, (room_id, receipt_type, user_id, row_id)) await self.db_pool.runInteraction( self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index aea96e9d2478..0018d6f7abcb 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,7 @@ import attr -from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -40,9 +40,13 @@ LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.databases.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import ( + generate_next_token, + generate_pagination_bounds, + generate_pagination_where_clause, +) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -164,7 +168,7 @@ async def get_relations_for_event( relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: @@ -177,8 +181,8 @@ async def get_relations_for_event( relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. @@ -207,24 +211,23 @@ async def get_relations_for_event( where_clause.append("type = ?") where_args.append(event_type) + order, from_bound, to_bound = generate_pagination_bounds( + direction, + from_token.room_key if from_token else None, + to_token.room_key if to_token else None, + ) + pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.room_key.as_historical_tuple() - if from_token - else None, - to_token=to_token.room_key.as_historical_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) if pagination_clause: where_clause.append(pagination_clause) - if direction == "b": - order = "DESC" - else: - order = "ASC" - sql = """ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations @@ -266,16 +269,9 @@ def _get_recent_references_for_event_txn( topo_orderings = topo_orderings[:limit] stream_orderings = stream_orderings[:limit] - topo = topo_orderings[-1] - token = stream_orderings[-1] - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_key = RoomStreamToken(topo, token) + next_key = generate_next_token( + direction, topo_orderings[-1], stream_orderings[-1] + ) if from_token: next_token = from_token.copy_and_replace( @@ -292,6 +288,7 @@ def _get_recent_references_for_event_txn( to_device_key=0, device_list_key=0, groups_key=0, + un_partial_stated_rooms_key=0, ) return events[:limit], next_token diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 78906a5e1d9e..644bbb88788f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -18,6 +18,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Collection, @@ -25,7 +26,7 @@ List, Mapping, Optional, - Sequence, + Set, Tuple, Union, cast, @@ -34,6 +35,7 @@ import attr from synapse.api.constants import ( + Direction, EventContentFields, EventTypes, JoinRules, @@ -43,6 +45,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase +from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -58,9 +61,9 @@ MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID +from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.stringutils import MXC_REGEX if TYPE_CHECKING: @@ -106,7 +109,7 @@ class RoomSortOrder(Enum): @attr.s(slots=True, frozen=True, auto_attribs=True) class PartialStateResyncInfo: joined_via: Optional[str] - servers_in_room: List[str] = attr.ib(factory=list) + servers_in_room: Set[str] = attr.ib(factory=set) class RoomWorkerStore(CacheInvalidationWorkerStore): @@ -126,6 +129,7 @@ def __init__( self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="un_partial_stated_room_stream", instance_name=self._instance_name, tables=[ @@ -137,9 +141,19 @@ def __init__( ) else: self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator( - db_conn, "un_partial_stated_room_stream", "stream_id" + db_conn, + hs.get_replication_notifier(), + "un_partial_stated_room_stream", + "stream_id", ) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == UnPartialStatedRoomStream.NAME: + self._un_partial_stated_rooms_stream_id_gen.advance(instance_name, token) + return super().process_replication_position(stream_name, instance_name, token) + async def store_room( self, room_id: str, @@ -1179,21 +1193,35 @@ def get_rooms_for_retention_period_in_range_txn( get_rooms_for_retention_period_in_range_txn, ) - @cached(iterable=True) - async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]: - """Gets the list of servers in a partial state room at the time we joined it. + async def get_partial_state_servers_at_join( + self, room_id: str + ) -> Optional[AbstractSet[str]]: + """Gets the set of servers in a partial state room at the time we joined it. Returns: The `servers_in_room` list from the `/send_join` response for partial state rooms. May not be accurate or complete, as it comes from a remote homeserver. - An empty list for full state rooms. + `None` for full state rooms. """ - return await self.db_pool.simple_select_onecol( - "partial_state_rooms_servers", - keyvalues={"room_id": room_id}, - retcol="server_name", - desc="get_partial_state_servers_at_join", + servers_in_room = await self._get_partial_state_servers_at_join(room_id) + + if len(servers_in_room) == 0: + return None + + return servers_in_room + + @cached(iterable=True) + async def _get_partial_state_servers_at_join( + self, room_id: str + ) -> AbstractSet[str]: + return frozenset( + await self.db_pool.simple_select_onecol( + "partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + retcol="server_name", + desc="get_partial_state_servers_at_join", + ) ) async def get_partial_state_room_resync_info( @@ -1238,11 +1266,11 @@ async def get_partial_state_room_resync_info( # partial-joined between the two SELECTs, but this is unlikely to happen # in practice.) continue - entry.servers_in_room.append(server_name) + entry.servers_in_room.add(server_name) return room_servers - @cached() + @cached(max_entries=10000) async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. @@ -1261,6 +1289,27 @@ async def is_partial_state_room(self, room_id: str) -> bool: return entry is not None + @cachedList(cached_method_name="is_partial_state_room", list_name="room_ids") + async def is_partial_state_room_batched( + self, room_ids: StrCollection + ) -> Mapping[str, bool]: + """Checks if the given rooms have partial state. + + Returns true for "partial-state" rooms, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch( + table="partial_state_rooms", + column="room_id", + iterable=room_ids, + retcols=("room_id",), + desc="is_partial_state_room_batched", + ) + partial_state_rooms = {row_dict["room_id"] for row_dict in rows} + return {room_id: room_id in partial_state_rooms for room_id in room_ids} + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( self, room_id: str ) -> Tuple[str, int]: @@ -1277,18 +1326,49 @@ async def get_join_event_id_and_device_lists_stream_id_for_partial_state( ) return result["join_event_id"], result["device_lists_stream_id"] - def get_un_partial_stated_rooms_token(self) -> int: - # TODO(faster_joins, multiple writers): This is inappropriate if there - # are multiple writers because workers that don't write often will - # hold all readers up. - # (See `MultiWriterIdGenerator.get_persisted_upto_position` for an - # explanation.) - return self._un_partial_stated_rooms_stream_id_gen.get_current_token() + def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: + return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( + instance_name + ) + + async def get_un_partial_stated_rooms_between( + self, last_id: int, current_id: int, room_ids: Collection[str] + ) -> Set[str]: + """Get all rooms that got un partial stated between `last_id` exclusive and + `current_id` inclusive. + + Returns: + The list of room ids. + """ + + if last_id == current_id: + return set() + + def _get_un_partial_stated_rooms_between_txn( + txn: LoggingTransaction, + ) -> Set[str]: + sql = """ + SELECT DISTINCT room_id FROM un_partial_stated_room_stream + WHERE ? < stream_id AND stream_id <= ? AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + txn.execute(sql + clause, [last_id, current_id] + args) + + return {r[0] for r in txn} + + return await self.db_pool.runInteraction( + "get_un_partial_stated_rooms_between", + _get_un_partial_stated_rooms_between_txn, + ) async def get_un_partial_stated_rooms_from_stream( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: - """Get updates for caches replication stream. + """Get updates for un partial stated rooms replication stream. Args: instance_name: The writer we want to fetch updates from. Unused @@ -1876,7 +1956,7 @@ async def upsert_room_on_join( async def store_partial_state_room( self, room_id: str, - servers: Collection[str], + servers: AbstractSet[str], device_lists_stream_id: int, joined_via: str, ) -> None: @@ -1891,11 +1971,13 @@ async def store_partial_state_room( Args: room_id: the ID of the room - servers: other servers known to be in the room + servers: other servers known to be in the room. must include `joined_via`. device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. joined_via: the server name we requested a partial join from. """ + assert joined_via in servers + await self.db_pool.runInteraction( "store_partial_state_room", self._store_partial_state_room_txn, @@ -1909,7 +1991,7 @@ def _store_partial_state_room_txn( self, txn: LoggingTransaction, room_id: str, - servers: Collection[str], + servers: AbstractSet[str], device_lists_stream_id: int, joined_via: str, ) -> None: @@ -1932,7 +2014,7 @@ def _store_partial_state_room_txn( ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( - txn, self.get_partial_state_servers_at_join, (room_id,) + txn, self._get_partial_state_servers_at_join, (room_id,) ) async def write_partial_state_rooms_join_event_id( @@ -2139,7 +2221,7 @@ async def get_event_reports_paginate( self, start: int, limit: int, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, user_id: Optional[str] = None, room_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: @@ -2148,8 +2230,8 @@ async def get_event_reports_paginate( Args: start: event offset to begin the query from limit: number of rows to retrieve - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`) + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: @@ -2171,7 +2253,7 @@ def _get_event_reports_paginate_txn( filters.append("er.room_id LIKE ?") args.extend(["%" + room_id + "%"]) - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" @@ -2295,16 +2377,16 @@ async def unblock_room(self, room_id: str) -> None: (room_id,), ) - async def clear_partial_state_room(self, room_id: str) -> bool: + async def clear_partial_state_room(self, room_id: str) -> Optional[int]: """Clears the partial state flag for a room. Args: room_id: The room whose partial state flag is to be cleared. Returns: - `True` if the partial state flag has been cleared successfully. + The corresponding stream id for the un-partial-stated rooms stream. - `False` if the partial state flag could not be cleared because the room + `None` if the partial state flag could not be cleared because the room still contains events with partial state. """ try: @@ -2315,7 +2397,7 @@ async def clear_partial_state_room(self, room_id: str) -> bool: room_id, un_partial_state_room_stream_id, ) - return True + return un_partial_state_room_stream_id except self.db_pool.engine.module.IntegrityError as e: # Assume that any `IntegrityError`s are due to partial state events. logger.info( @@ -2323,7 +2405,7 @@ async def clear_partial_state_room(self, room_id: str) -> bool: room_id, e, ) - return False + return None def _clear_partial_state_room_txn( self, @@ -2343,7 +2425,7 @@ def _clear_partial_state_room_txn( ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( - txn, self.get_partial_state_servers_at_join, (room_id,) + txn, self._get_partial_state_servers_at_join, (room_id,) ) DatabasePool.simple_insert_txn( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 9c5adb8c221f..3abc4dcfef1a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -14,8 +14,10 @@ # limitations under the License. import logging import re +from itertools import chain from typing import ( TYPE_CHECKING, + AbstractSet, Collection, Dict, FrozenSet, @@ -48,7 +50,13 @@ ProfileInfo, RoomsForUser, ) -from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import ( + JsonDict, + PersistedEventPosition, + StateMap, + StrCollection, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -408,7 +416,7 @@ async def get_rooms_for_local_user_where_membership_is( self, user_id: str, membership_list: Collection[str], - excluded_rooms: Optional[List[str]] = None, + excluded_rooms: StrCollection = (), ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -435,10 +443,12 @@ async def get_rooms_for_local_user_where_membership_is( ) # Now we filter out forgotten and excluded rooms - rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id) if excluded_rooms is not None: - rooms_to_exclude.update(set(excluded_rooms)) + # Take a copy to avoid mutating the in-cache set + rooms_to_exclude = set(rooms_to_exclude) + rooms_to_exclude.update(excluded_rooms) return [room for room in rooms if room.room_id not in rooms_to_exclude] @@ -1145,12 +1155,33 @@ async def _get_joined_hosts( else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. + # + # We need to fetch all hosts joined to the room according to `state` by + # inspecting all join memberships in `state`. However, if the `state` is + # relatively recent then many of its events are likely to be held in + # the current state of the room, which is easily available and likely + # cached. + # + # We therefore compute the set of `state` events not in the + # current state and only fetch those. + current_memberships = ( + await self._get_approximate_current_memberships_in_room(room_id) + ) + unknown_state_events = {} + joined_users_in_current_state = [] + + for (type, state_key), event_id in state.items(): + if event_id not in current_memberships: + unknown_state_events[type, state_key] = event_id + elif current_memberships[event_id] == Membership.JOIN: + joined_users_in_current_state.append(state_key) + joined_user_ids = await self.get_joined_user_ids_from_state( - room_id, state + room_id, unknown_state_events ) cache.hosts_to_joined_users = {} - for user_id in joined_user_ids: + for user_id in chain(joined_user_ids, joined_users_in_current_state): host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) @@ -1161,6 +1192,26 @@ async def _get_joined_hosts( return frozenset(cache.hosts_to_joined_users) + async def _get_approximate_current_memberships_in_room( + self, room_id: str + ) -> Mapping[str, Optional[str]]: + """Build a map from event id to membership, for all events in the current state. + + The event ids of non-memberships events (e.g. `m.room.power_levels`) are present + in the result, mapped to values of `None`. + + The result is approximate for partially-joined rooms. It is fully accurate + for fully-joined rooms. + """ + + rows = await self.db_pool.simple_select_list( + "current_state_events", + keyvalues={"room_id": room_id}, + retcols=("event_id", "membership"), + desc="has_completed_background_updates", + ) + return {row["event_id"]: row["membership"] for row in rows} + @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache() @@ -1192,7 +1243,7 @@ def f(txn: LoggingTransaction) -> int: return count == 0 @cached() - async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: + async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]: """Gets all rooms the user has forgotten. Args: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index f32cbb2decd8..ba325d390b58 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -95,6 +95,7 @@ def process_replication_rows( for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) self._get_state_group_for_event.invalidate((row.event_id,)) + self.is_partial_state_event.invalidate((row.event_id,)) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -485,6 +486,7 @@ def _update_state_for_partial_state_event_txn( "rejection_status_changed": rejection_status_changed, }, ) + txn.call_after(self.hs.get_notifier().on_new_replication_data) class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 356d4ca78819..d7b7d0c3c909 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -22,13 +22,14 @@ from twisted.internet.defer import DeferredLock -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership from synapse.api.errors import StoreError from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.events_worker import InvalidEventError from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -554,7 +555,17 @@ def _fetch_current_state_stats( "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] + try: + state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] + except InvalidEventError as e: + # If an exception occurs fetching events then the room is broken; + # skip process it to avoid being stuck on a room. + logger.warning( + "Failed to fetch events for room %s, skipping stats calculation: %r.", + room_id, + e, + ) + return room_state: Dict[str, Union[None, bool, str]] = { "join_rules": None, @@ -652,7 +663,7 @@ async def get_users_media_usage_paginate( from_ts: Optional[int] = None, until_ts: Optional[int] = None, order_by: Optional[str] = UserSortOrder.USER_ID.value, - direction: Optional[str] = "f", + direction: Direction = Direction.FORWARDS, search_term: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users and their uploaded local media @@ -703,7 +714,7 @@ def get_users_media_usage_paginate_txn( 500, "Incorrect value for order_by provided: %s" % order_by ) - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 386e02552e19..5be3e192ddad 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -55,6 +55,7 @@ from twisted.internet import defer +from synapse.api.constants import Direction from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -67,7 +68,7 @@ make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached @@ -86,7 +87,6 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" - # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: @@ -104,7 +104,7 @@ class _EventsAround: def generate_pagination_where_clause( - direction: str, + direction: Direction, column_names: Tuple[str, str], from_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]], @@ -130,27 +130,26 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction: Whether we're paginating backwards("b") or forwards ("f"). + direction: Whether we're paginating backwards or forwards. column_names: The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. from_token: The start point for the pagination. This is an exclusive - minimum bound if direction is "f", and an inclusive maximum bound if - direction is "b". + minimum bound if direction is forwards, and an inclusive maximum bound if + direction is backwards. to_token: The endpoint point for the pagination. This is an inclusive - maximum bound if direction is "f", and an exclusive minimum bound if - direction is "b". + maximum bound if direction is forwards, and an exclusive minimum bound if + direction is backwards. engine: The database engine to generate the clauses for Returns: The sql expression """ - assert direction in ("b", "f") where_clause = [] if from_token: where_clause.append( _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", + bound=">=" if direction == Direction.BACKWARDS else "<", column_names=column_names, values=from_token, engine=engine, @@ -160,7 +159,7 @@ def generate_pagination_where_clause( if to_token: where_clause.append( _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", + bound="<" if direction == Direction.BACKWARDS else ">=", column_names=column_names, values=to_token, engine=engine, @@ -170,6 +169,104 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) +def generate_pagination_bounds( + direction: Direction, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], +) -> Tuple[ + str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]] +]: + """ + Generate a start and end point for this page of events. + + Args: + direction: Whether pagination is going forwards or backwards. + from_token: The token to start pagination at, or None to start at the first value. + to_token: The token to end pagination at, or None to not limit the end point. + + Returns: + A three tuple of: + + ASC or DESC for sorting of the query. + + The starting position as a tuple of ints representing + (topological position, stream position) or None if no from_token was + provided. The topological position may be None for live tokens. + + The end position in the same format as the starting position, or None + if no to_token was provided. + """ + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + from_bound: Optional[Tuple[Optional[int], int]] = None + if from_token: + if from_token.topological is not None: + from_bound = from_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound: Optional[Tuple[Optional[int], int]] = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + + return order, from_bound, to_bound + + +def generate_next_token( + direction: Direction, last_topo_ordering: int, last_stream_ordering: int +) -> RoomStreamToken: + """ + Generate the next room stream token based on the currently returned data. + + Args: + direction: Whether pagination is going forwards or backwards. + last_topo_ordering: The last topological ordering being returned. + last_stream_ordering: The last stream ordering being returned. + + Returns: + A new RoomStreamToken to return to the client. + """ + if direction == Direction.BACKWARDS: + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + last_stream_ordering -= 1 + return RoomStreamToken(last_topo_ordering, last_stream_ordering) + + def _make_generic_sql_bound( bound: str, column_names: Tuple[str, str], @@ -944,12 +1041,40 @@ async def get_current_topological_token(self, room_id: str, stream_key: int) -> room_id stream_key """ - sql = ( - "SELECT coalesce(MIN(topological_ordering), 0) FROM events" - " WHERE room_id = ? AND stream_ordering >= ?" - ) + if isinstance(self.database_engine, PostgresEngine): + min_function = "LEAST" + elif isinstance(self.database_engine, Sqlite3Engine): + min_function = "MIN" + else: + raise RuntimeError(f"Unknown database engine {self.database_engine}") + + # This query used to be + # SELECT COALESCE(MIN(topological_ordering), 0) FROM events + # WHERE room_id = ? and events.stream_ordering >= {stream_key} + # which returns 0 if the stream_key is newer than any event in + # the room. That's not wrong, but it seems to interact oddly with backfill, + # requiring a second call to /messages to actually backfill from a remote + # homeserver. + # + # Instead, rollback the stream ordering to that after the most recent event in + # this room. + sql = f""" + WITH fallback(max_stream_ordering) AS ( + SELECT MAX(stream_ordering) + FROM events + WHERE room_id = ? + ) + SELECT COALESCE(MIN(topological_ordering), 0) FROM events + WHERE + room_id = ? + AND events.stream_ordering >= {min_function}( + ?, + (SELECT max_stream_ordering FROM fallback) + ) + """ + row = await self.db_pool.execute( - "get_current_topological_token", None, sql, room_id, stream_key + "get_current_topological_token", None, sql, room_id, room_id, stream_key ) return row[0][0] if row else 0 @@ -1075,7 +1200,7 @@ def _get_events_around_txn( txn, room_id, before_token, - direction="b", + direction=Direction.BACKWARDS, limit=before_limit, event_filter=event_filter, ) @@ -1085,7 +1210,7 @@ def _get_events_around_txn( txn, room_id, after_token, - direction="f", + direction=Direction.FORWARDS, limit=after_limit, event_filter=event_filter, ) @@ -1248,7 +1373,7 @@ def _paginate_room_events_txn( room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: @@ -1259,8 +1384,8 @@ def _paginate_room_events_txn( room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. @@ -1272,47 +1397,11 @@ def _paginate_room_events_txn( `to_token`), or `limit` is zero. """ - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - # The bounds for the stream tokens are complicated by the fact - # that we need to handle the instance_map part of the tokens. We do this - # by fetching all events between the min stream token and the maximum - # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and - # then filtering the results. - if from_token.topological is not None: - from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() - elif direction == "b": - from_bound = ( - None, - from_token.get_max_stream_pos(), - ) - else: - from_bound = ( - None, - from_token.stream, - ) - - to_bound: Optional[Tuple[Optional[int], int]] = None - if to_token: - if to_token.topological is not None: - to_bound = to_token.as_historical_tuple() - elif direction == "b": - to_bound = ( - None, - to_token.stream, - ) - else: - to_bound = ( - None, - to_token.get_max_stream_pos(), - ) + order, from_bound, to_bound = generate_pagination_bounds( + direction, from_token, to_token + ) bounds = generate_pagination_where_clause( direction=direction, @@ -1399,8 +1488,12 @@ def _paginate_room_events_txn( _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token if direction == "b" else from_token, - upper_token=from_token if direction == "b" else to_token, + lower_token=to_token + if direction == Direction.BACKWARDS + else from_token, + upper_token=from_token + if direction == Direction.BACKWARDS + else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, @@ -1408,16 +1501,10 @@ def _paginate_room_events_txn( ][:limit] if rows: - topo = rows[-1].topological_ordering - token = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_token = RoomStreamToken(topo, token) + assert rows[-1].topological_ordering is not None + next_token = generate_next_token( + direction, rows[-1].topological_ordering, rows[-1].stream_ordering + ) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token @@ -1430,7 +1517,7 @@ async def paginate_room_events( room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: @@ -1440,8 +1527,8 @@ async def paginate_room_events( room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index e23c927e02f2..d5500cdd470c 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -17,7 +17,8 @@ import logging from typing import Any, Dict, Iterable, List, Tuple, cast -from synapse.replication.tcp.streams import TagAccountDataStream +from synapse.api.constants import AccountDataTypes +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -54,7 +55,7 @@ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict] async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, str, str]], int, bool]: """Get updates for tags replication stream. Args: @@ -73,7 +74,7 @@ async def get_all_updated_tags( The token returned can be used in a subsequent call to this function to get further updatees. - The updates are a list of 2-tuples of stream ID and the row data + The updates are a list of tuples of stream ID, user ID and room ID """ if last_id == current_id: @@ -96,38 +97,13 @@ def get_all_updated_tags_txn( "get_all_updated_tags", get_all_updated_tags_txn ) - def get_tag_content( - txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]] - ) -> List[Tuple[int, Tuple[str, str, str]]]: - sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" - results = [] - for stream_id, user_id, room_id in tag_ids: - txn.execute(sql, (user_id, room_id)) - tags = [] - for tag, content in txn: - tags.append(json_encoder.encode(tag) + ":" + content) - tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, (user_id, room_id, tag_json))) - - return results - - batch_size = 50 - results = [] - for i in range(0, len(tag_ids), batch_size): - tags = await self.db_pool.runInteraction( - "get_all_updated_tag_content", - get_tag_content, - tag_ids[i : i + batch_size], - ) - results.extend(tags) - limited = False upto_token = current_id - if len(results) >= limit: - upto_token = results[-1][0] + if len(tag_ids) >= limit: + upto_token = tag_ids[-1][0] limited = True - return results, upto_token, limited + return tag_ids, upto_token, limited async def get_updated_tags( self, user_id: str, stream_id: int @@ -299,20 +275,16 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) + if row.data_type == AccountDataTypes.TAG: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed( + row.user_id, token + ) super().process_replication_rows(stream_name, instance_name, token, rows) - def process_replication_position( - self, stream_name: str, instance_name: str, token: int - ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - super().process_replication_position(stream_name, instance_name, token) - class TagsStore(TagsWorkerStore): pass diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index f8c6877ee847..6b33d809b6cf 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -19,6 +19,7 @@ import attr from canonicaljson import encode_canonical_json +from synapse.api.constants import Direction from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -496,7 +497,7 @@ async def get_destinations_paginate( limit: int, destination: Optional[str] = None, order_by: str = DestinationSortOrder.DESTINATION.value, - direction: str = "f", + direction: Direction = Direction.FORWARDS, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of destinations. This will return a json list of destinations and the @@ -518,7 +519,7 @@ def get_destinations_paginate_txn( ) -> Tuple[List[JsonDict], int]: order_by_column = DestinationSortOrder(order_by).value - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" @@ -550,7 +551,11 @@ def get_destinations_paginate_txn( ) async def get_destination_rooms_paginate( - self, destination: str, start: int, limit: int, direction: str = "f" + self, + destination: str, + start: int, + limit: int, + direction: Direction = Direction.FORWARDS, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of destination's rooms. This will return a json list of rooms and the @@ -569,7 +574,7 @@ def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 70e594a68f09..0363cdc038f0 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -132,6 +132,10 @@ def executescript(cursor: CursorType, script: str) -> None: """Execute a chunk of SQL containing multiple semicolon-delimited statements. This is not provided by DBAPI2, and so needs engine-specific support. + + Any ongoing transaction is committed before executing the script in its own + transaction. The script transaction is left open and it is the responsibility of + the caller to commit it. """ ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index f9f562ea4544..b350f57ccb4a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -220,5 +220,9 @@ def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None: """Execute a chunk of SQL containing multiple semicolon-delimited statements. Psycopg2 seems happy to do this in DBAPI2's `execute()` function. + + For consistency with SQLite, any ongoing transaction is committed before + executing the script in its own transaction. The script transaction is + left open and it is the responsibility of the caller to commit it. """ - cursor.execute(script) + cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 14260442b6e4..28751e89a5a5 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -135,13 +135,16 @@ def executescript(cursor: sqlite3.Cursor, script: str) -> None: > than one statement with it, it will raise a Warning. Use executescript() if > you want to execute multiple SQL statements with one call. - Though the docs for `executescript` warn: + The script is prefixed with a `BEGIN TRANSACTION`, since the docs for + `executescript` warn: > If there is a pending transaction, an implicit COMMIT statement is executed > first. No other implicit transaction control is performed; any transaction > control must be added to sql_script. """ - cursor.executescript(script) + # The implementation of `executescript` can be found at + # https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035. + cursor.executescript(f"BEGIN TRANSACTION; {script}") # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 19dbf2da7fee..d3103a6c7a05 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 73 # remember to update the list below when updating +SCHEMA_VERSION = 74 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -78,7 +78,7 @@ - Unused column application_services_state.last_txn is dropped - Cache invalidation stream id sequence now begins at 2 to match code expectation. -Changes in SCHEMA_VERSION = 73; +Changes in SCHEMA_VERSION = 73: - thread_id column is added to event_push_actions, event_push_actions_staging event_push_summary, receipts_linearized, and receipts_graph. - Add table `event_failed_pull_attempts` to keep track when we fail to pull @@ -86,6 +86,11 @@ - Add indexes to various tables (`event_failed_pull_attempts`, `insertion_events`, `batch_events`) to make it easy to delete all associated rows when purging a room. - `inserted_ts` column is added to `event_push_actions_staging` table. + +Changes in SCHEMA_VERSION = 74: + - A query on `event_stream_ordering` column has now been disambiguated (i.e. the + codebase can handle the `current_state_events`, `local_current_memberships` and + `room_memberships` tables having an `event_stream_ordering` column). """ diff --git a/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql b/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql new file mode 100644 index 000000000000..67059909a1d2 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7324, 'events_jump_to_date_index', '{}'); diff --git a/synapse/storage/schema/main/delta/73/25drop_presence.sql b/synapse/storage/schema/main/delta/73/25drop_presence.sql new file mode 100644 index 000000000000..9f6ffa20b613 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/25drop_presence.sql @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this table is unused +DROP TABLE presence; diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0d7108f01b41..9adff3f4f523 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,6 +20,7 @@ from contextlib import contextmanager from types import TracebackType from typing import ( + TYPE_CHECKING, AsyncContextManager, ContextManager, Dict, @@ -49,6 +50,9 @@ from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator +if TYPE_CHECKING: + from synapse.notifier import ReplicationNotifier + logger = logging.getLogger(__name__) @@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator): def __init__( self, db_conn: LoggingDatabaseConnection, + notifier: "ReplicationNotifier", table: str, column: str, extra_tables: Iterable[Tuple[str, str]] = (), @@ -205,6 +210,8 @@ def __init__( # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() + self._notifier = notifier + def advance(self, instance_name: str, new_id: int) -> None: # Advance should never be called on a writer instance, only over replication if self._is_writer: @@ -227,6 +234,8 @@ def manager() -> Generator[int, None, None]: with self._lock: self._unfinished_ids.pop(next_id) + self._notifier.notify_replication() + return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: @@ -250,6 +259,8 @@ def manager() -> Generator[Sequence[int], None, None]: for next_id in next_ids: self._unfinished_ids.pop(next_id) + self._notifier.notify_replication() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: @@ -296,6 +307,7 @@ def __init__( self, db_conn: LoggingDatabaseConnection, db: DatabasePool, + notifier: "ReplicationNotifier", stream_name: str, instance_name: str, tables: List[Tuple[str, str, str]], @@ -304,6 +316,7 @@ def __init__( positive: bool = True, ) -> None: self._db = db + self._notifier = notifier self._stream_name = stream_name self._instance_name = instance_name self._positive = positive @@ -378,6 +391,12 @@ def __init__( self._current_positions.values(), default=1 ) + if not writers: + # If there have been no explicit writers given then any instance can + # write to the stream. In which case, let's pre-seed our own + # position with the current minimum. + self._current_positions[self._instance_name] = self._persisted_upto_position + def _load_current_ids( self, db_conn: LoggingDatabaseConnection, @@ -529,7 +548,9 @@ def get_next(self) -> AsyncContextManager[int]: # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, # controls the return type. If `None` or omitted, the context manager yields # a single integer stream_id; otherwise it yields a list of stream_ids. - return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) + return cast( + AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier) + ) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: # If we have a list of instances that are allowed to write to this @@ -538,7 +559,10 @@ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: raise Exception("Tried to allocate stream ID on non-writer") # Cast safety: see get_next. - return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) + return cast( + AsyncContextManager[List[int]], + _MultiWriterCtxManager(self, self._notifier, n), + ) def get_next_txn(self, txn: LoggingTransaction) -> int: """ @@ -557,6 +581,7 @@ def get_next_txn(self, txn: LoggingTransaction) -> int: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + txn.call_after(self._notifier.notify_replication) # Update the `stream_positions` table with newly updated stream # ID (unless self._writers is not set in which case we don't @@ -695,24 +720,22 @@ def _add_persisted_position(self, new_id: int) -> None: heapq.heappush(self._known_persisted_positions, new_id) - # If we're a writer and we don't have any active writes we update our - # current position to the latest position seen. This allows the instance - # to report a recent position when asked, rather than a potentially old - # one (if this instance hasn't written anything for a while). - our_current_position = self._current_positions.get(self._instance_name) - if ( - our_current_position - and not self._unfinished_ids - and not self._in_flight_fetches - ): - self._current_positions[self._instance_name] = max( - our_current_position, new_id - ) - # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values(), default=0) + our_current_position = self._current_positions.get(self._instance_name, 0) + min_curr = min( + ( + token + for name, token in self._current_positions.items() + if name != self._instance_name + ), + default=our_current_position, + ) + + if our_current_position and (self._unfinished_ids or self._in_flight_fetches): + min_curr = min(min_curr, our_current_position) + self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are @@ -783,6 +806,7 @@ class _MultiWriterCtxManager: """Async context manager returned by MultiWriterIdGenerator""" id_gen: MultiWriterIdGenerator + notifier: "ReplicationNotifier" multiple_ids: Optional[int] = None stream_ids: List[int] = attr.Factory(list) @@ -810,6 +834,8 @@ async def __aexit__( for i in self.stream_ids: self.id_gen._mark_id_as_finished(i) + self.notifier.notify_replication() + if exc_type is not None: return False diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 2dcd43d0a2f5..c6c8a0315c9b 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Generic, List, Optional, Tuple, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar -from synapse.types import UserID +from synapse.types import StrCollection, UserID # The key, this is either a stream token or int. K = TypeVar("K") @@ -28,7 +28,7 @@ async def get_new_events( user: UserID, from_key: K, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 6df2de919cda..a04428041015 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -16,8 +16,9 @@ import attr +from synapse.api.constants import Direction from synapse.api.errors import SynapseError -from synapse.http.servlet import parse_integer, parse_string +from synapse.http.servlet import parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.storage.databases.main import DataStore from synapse.types import StreamToken @@ -34,7 +35,7 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] - direction: str + direction: Direction limit: int @classmethod @@ -43,11 +44,9 @@ async def from_request( store: "DataStore", request: SynapseRequest, default_limit: int, - default_dir: str = "f", + default_dir: Direction = Direction.FORWARDS, ) -> "PaginationConfig": - direction = parse_string( - request, "dir", default=default_dir, allowed_values=["f", "b"] - ) + direction = parse_enum(request, "dir", Direction, default=default_dir) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 619eb7f601de..d7084d2358cc 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -53,11 +53,15 @@ def __init__(self, hs: "HomeServer"): *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) ) self.store = hs.get_datastores().main + self._instance_name = hs.get_instance_name() def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token( + self._instance_name + ) token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -70,6 +74,7 @@ def get_current_token(self) -> StreamToken: device_list_key=device_list_key, # Groups key is unused. groups_key=0, + un_partial_stated_rooms_key=un_partial_stated_rooms_key, ) return token @@ -107,5 +112,6 @@ async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: to_device_key=0, device_list_key=0, groups_key=0, + un_partial_stated_rooms_key=0, ) return token diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 0c725eb9677d..f82d1cfc298b 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -17,6 +17,7 @@ import string from typing import ( TYPE_CHECKING, + AbstractSet, Any, ClassVar, Dict, @@ -79,7 +80,7 @@ # Collection[str] that does not include str itself; str being a Sequence[str] # is very misleading and results in bugs. -StrCollection = Union[Tuple[str, ...], List[str], Set[str]] +StrCollection = Union[Tuple[str, ...], List[str], AbstractSet[str]] # Note that this seems to require inheriting *directly* from Interface in order @@ -604,6 +605,12 @@ async def to_string(self, store: "DataStore") -> str: elif self.instance_map: entries = [] for name, pos in self.instance_map.items(): + if pos <= self.stream: + # Ignore instances who are below the minimum stream position + # (we might know they've advanced without seeing a recent + # write from them). + continue + instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") @@ -627,6 +634,7 @@ class StreamKeyType: PUSH_RULES: Final = "push_rules_key" TO_DEVICE: Final = "to_device_key" DEVICE_LIST: Final = "device_list_key" + UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -634,7 +642,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -646,12 +654,13 @@ class StreamToken: 7. `to_device_key`: `274711` 8. `device_list_key`: `265584` 9. `groups_key`: `1` (note that this key is now unused) + 10. `un_partial_stated_rooms_key`: `379` You can see how many of these keys correspond to the various fields in a "/sync" response: ```json { - "next_batch": "s12_4_0_1_1_1_1_4_1", + "next_batch": "s12_4_0_1_1_1_1_4_1_1", "presence": { "events": [] }, @@ -663,7 +672,7 @@ class StreamToken: "!QrZlfIDQLNLdZHqTnt:hs1": { "timeline": { "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1", + "prev_batch": "s10_4_0_1_1_1_1_4_1_1", "limited": false }, "state": { @@ -699,6 +708,7 @@ class StreamToken: device_list_key: int # Note that the groups key is no longer used and may have bogus values. groups_key: int + un_partial_stated_rooms_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -737,6 +747,7 @@ async def to_string(self, store: "DataStore") -> str: # serialized so that there will not be confusion in the future # if additional tokens are added. str(self.groups_key), + str(self.un_partial_stated_rooms_key), ] ) @@ -769,7 +780,7 @@ def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": return attr.evolve(self, **{key: new_value}) -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 2aceb1a47fb6..f262bf95a0fd 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -364,12 +364,22 @@ def on_both(r: object) -> object: def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) - self.current_processing.discard(request_id) - try: - # start processing the next item on the queue. - _, deferred = self.ready_request_queue.popitem(last=False) - with PreserveLoggingContext(): - deferred.callback(None) - except KeyError: - pass + # When requests complete synchronously, we will recursively start the next + # request in the queue. To avoid stack exhaustion, we defer starting the next + # request until the next reactor tick. + + def start_next_request() -> None: + # We only remove the completed request from the list when we're about to + # start the next one, otherwise we can allow extra requests through. + self.current_processing.discard(request_id) + try: + # start processing the next item on the queue. + _, deferred = self.ready_request_queue.popitem(last=False) + + with PreserveLoggingContext(): + deferred.callback(None) + except KeyError: + pass + + self.clock.call_later(0.0, start_next_request) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e0f363555b3c..6e36e73f0d75 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -31,7 +31,7 @@ from synapse.appservice import ApplicationService from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester +from synapse.types import Requester, UserID from synapse.util import Clock from tests import unittest @@ -41,10 +41,12 @@ class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = Mock() - hs.datastores.main = self.store + # type-ignore: datastores is None until hs.setup() is called---but it'll + # have been called by the HomeserverTestCase machinery. + hs.datastores.main = self.store # type: ignore[union-attr] hs.get_auth_handler().store = self.store self.auth = Auth(hs) @@ -61,7 +63,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.store.insert_client_ip = simple_async_mock(None) self.store.is_support_user = simple_async_mock(False) - def test_get_user_by_req_user_valid_token(self): + def test_get_user_by_req_user_valid_token(self) -> None: user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) @@ -74,7 +76,7 @@ def test_get_user_by_req_user_valid_token(self): requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(requester.user.to_string(), self.test_user) - def test_get_user_by_req_user_bad_token(self): + def test_get_user_by_req_user_bad_token(self) -> None: self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) @@ -86,7 +88,7 @@ def test_get_user_by_req_user_bad_token(self): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") - def test_get_user_by_req_user_missing_token(self): + def test_get_user_by_req_user_missing_token(self) -> None: user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = simple_async_mock(user_info) @@ -98,7 +100,7 @@ def test_get_user_by_req_user_missing_token(self): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - def test_get_user_by_req_appservice_valid_token(self): + def test_get_user_by_req_appservice_valid_token(self) -> None: app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) @@ -112,7 +114,7 @@ def test_get_user_by_req_appservice_valid_token(self): requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(requester.user.to_string(), self.test_user) - def test_get_user_by_req_appservice_valid_token_good_ip(self): + def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: from netaddr import IPSet app_service = Mock( @@ -131,7 +133,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self): requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(requester.user.to_string(), self.test_user) - def test_get_user_by_req_appservice_valid_token_bad_ip(self): + def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: from netaddr import IPSet app_service = Mock( @@ -153,7 +155,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") - def test_get_user_by_req_appservice_bad_token(self): + def test_get_user_by_req_appservice_bad_token(self) -> None: self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = simple_async_mock(None) @@ -166,7 +168,7 @@ def test_get_user_by_req_appservice_bad_token(self): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") - def test_get_user_by_req_appservice_missing_token(self): + def test_get_user_by_req_appservice_missing_token(self) -> None: app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = simple_async_mock(None) @@ -179,7 +181,7 @@ def test_get_user_by_req_appservice_missing_token(self): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - def test_get_user_by_req_appservice_valid_token_valid_user_id(self): + def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None @@ -200,7 +202,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self): requester.user.to_string(), masquerading_user_id.decode("utf8") ) - def test_get_user_by_req_appservice_valid_token_bad_user_id(self): + def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None @@ -217,7 +219,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self): self.get_failure(self.auth.get_user_by_req(request), AuthError) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) - def test_get_user_by_req_appservice_valid_token_valid_device_id(self): + def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None: """ Tests that when an application service passes the device_id URL parameter with the ID of a valid device for the user in question, @@ -249,7 +251,7 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self): self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) - def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): + def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: """ Tests that when an application service passes the device_id URL parameter with an ID that is not a valid device ID for the user in question, @@ -279,7 +281,7 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): self.assertEqual(failure.value.code, 400) self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) - def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): + def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult( user_id="@baldrick:matrix.org", @@ -298,7 +300,7 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): self.get_success(self.auth.get_user_by_req(request)) self.store.insert_client_ip.assert_called_once() - def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self): + def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.auth._track_puppeted_user_ips = True self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult( @@ -318,7 +320,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self): self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(self.store.insert_client_ip.call_count, 2) - def test_get_user_from_macaroon(self): + def test_get_user_from_macaroon(self) -> None: self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" @@ -336,7 +338,7 @@ def test_get_user_from_macaroon(self): self.auth.get_user_by_access_token(serialized), InvalidClientTokenError ) - def test_get_guest_user_from_macaroon(self): + def test_get_guest_user_from_macaroon(self) -> None: self.store.get_user_by_id = simple_async_mock({"is_guest": True}) self.store.get_user_by_access_token = simple_async_mock(None) @@ -357,7 +359,7 @@ def test_get_guest_user_from_macaroon(self): self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) - def test_blocking_mau(self): + def test_blocking_mau(self) -> None: self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._max_mau_value = 50 lots_of_users = 100 @@ -381,7 +383,7 @@ def test_blocking_mau(self): self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) self.get_success(self.auth_blocking.check_auth_blocking()) - def test_blocking_mau__depending_on_user_type(self): + def test_blocking_mau__depending_on_user_type(self) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True @@ -400,7 +402,9 @@ def test_blocking_mau__depending_on_user_type(self): # Real users not allowed self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) - def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): + def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips( + self, + ) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = False @@ -418,7 +422,7 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): sender="@appservice:sender", ) requester = Requester( - user="@appservice:server", + user=UserID.from_string("@appservice:server"), access_token_id=None, device_id="FOOBAR", is_guest=False, @@ -428,7 +432,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): ) self.get_success(self.auth_blocking.check_auth_blocking(requester=requester)) - def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): + def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( + self, + ) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = True @@ -446,7 +452,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): sender="@appservice:sender", ) requester = Requester( - user="@appservice:server", + user=UserID.from_string("@appservice:server"), access_token_id=None, device_id="FOOBAR", is_guest=False, @@ -459,7 +465,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): ResourceLimitError, ) - def test_reserved_threepid(self): + def test_reserved_threepid(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 self.store.get_monthly_active_count = simple_async_mock(2) @@ -476,7 +482,7 @@ def test_reserved_threepid(self): self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid)) - def test_hs_disabled(self): + def test_hs_disabled(self) -> None: self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure( @@ -486,7 +492,7 @@ def test_hs_disabled(self): self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) - def test_hs_disabled_no_server_notices_user(self): + def test_hs_disabled_no_server_notices_user(self) -> None: """Check that 'hs_disabled_message' works correctly when there is no server_notices user. """ @@ -503,7 +509,7 @@ def test_hs_disabled_no_server_notices_user(self): self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) - def test_server_notices_mxid_special_cased(self): + def test_server_notices_mxid_special_cased(self) -> None: self.auth_blocking._hs_disabled = True user = "@user:server" self.auth_blocking._server_notices_mxid = user diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index d5524d296e38..0f45615160ed 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -14,40 +14,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List from unittest.mock import patch import jsonschema from frozendict import frozendict +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events import make_event_from_dict +from synapse.api.presence import UserPresenceState +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest +from tests.events.test_utils import MockEvent user_localpart = "test_user" -def MockEvent(**kwargs): - if "event_id" not in kwargs: - kwargs["event_id"] = "fake_event_id" - if "type" not in kwargs: - kwargs["type"] = "fake_type" - if "content" not in kwargs: - kwargs["content"] = {} - return make_event_from_dict(kwargs) - - class FilteringTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filtering = hs.get_filtering() self.datastore = hs.get_datastores().main - def test_errors_on_invalid_filters(self): + def test_errors_on_invalid_filters(self) -> None: # See USER_FILTER_SCHEMA for the filter schema. - invalid_filters = [ + invalid_filters: List[JsonDict] = [ # `account_data` must be a dictionary {"account_data": "Hello World"}, # `event_fields` entries must not contain backslashes @@ -63,10 +59,10 @@ def test_errors_on_invalid_filters(self): with self.assertRaises(SynapseError): self.filtering.check_valid_filter(filter) - def test_ignores_unknown_filter_fields(self): + def test_ignores_unknown_filter_fields(self) -> None: # For forward compatibility, we must ignore unknown filter fields. # See USER_FILTER_SCHEMA for the filter schema. - filters = [ + filters: List[JsonDict] = [ {"org.matrix.msc9999.future_option": True}, {"presence": {"org.matrix.msc9999.future_option": True}}, {"room": {"org.matrix.msc9999.future_option": True}}, @@ -76,8 +72,8 @@ def test_ignores_unknown_filter_fields(self): self.filtering.check_valid_filter(filter) # Must not raise. - def test_valid_filters(self): - valid_filters = [ + def test_valid_filters(self) -> None: + valid_filters: List[JsonDict] = [ { "room": { "timeline": {"limit": 20}, @@ -132,22 +128,22 @@ def test_valid_filters(self): except jsonschema.ValidationError as e: self.fail(e) - def test_limits_are_applied(self): + def test_limits_are_applied(self) -> None: # TODO pass - def test_definition_types_works_with_literals(self): + def test_definition_types_works_with_literals(self) -> None: definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_types_works_with_wildcards(self): + def test_definition_types_works_with_wildcards(self) -> None: definition = {"types": ["m.*", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_types_works_with_unknowns(self): + def test_definition_types_works_with_unknowns(self) -> None: definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent( sender="@foo:bar", @@ -156,24 +152,24 @@ def test_definition_types_works_with_unknowns(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_works_with_literals(self): + def test_definition_not_types_works_with_literals(self) -> None: definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_works_with_wildcards(self): + def test_definition_not_types_works_with_wildcards(self) -> None: definition = {"not_types": ["m.room.message", "org.matrix.*"]} event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_works_with_unknowns(self): + def test_definition_not_types_works_with_unknowns(self) -> None: definition = {"not_types": ["m.*", "org.*"]} event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_takes_priority_over_types(self): + def test_definition_not_types_takes_priority_over_types(self) -> None: definition = { "not_types": ["m.*", "org.*"], "types": ["m.room.message", "m.room.topic"], @@ -181,35 +177,35 @@ def test_definition_not_types_takes_priority_over_types(self): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_senders_works_with_literals(self): + def test_definition_senders_works_with_literals(self) -> None: definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_senders_works_with_unknowns(self): + def test_definition_senders_works_with_unknowns(self) -> None: definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_senders_works_with_literals(self): + def test_definition_not_senders_works_with_literals(self) -> None: definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_senders_works_with_unknowns(self): + def test_definition_not_senders_works_with_unknowns(self) -> None: definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_not_senders_takes_priority_over_senders(self): + def test_definition_not_senders_takes_priority_over_senders(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets", "@misspiggy:muppets"], @@ -219,14 +215,14 @@ def test_definition_not_senders_takes_priority_over_senders(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_rooms_works_with_literals(self): + def test_definition_rooms_works_with_literals(self) -> None: definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_rooms_works_with_unknowns(self): + def test_definition_rooms_works_with_unknowns(self) -> None: definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", @@ -235,7 +231,7 @@ def test_definition_rooms_works_with_unknowns(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_rooms_works_with_literals(self): + def test_definition_not_rooms_works_with_literals(self) -> None: definition = {"not_rooms": ["!anothersecretbase:unknown"]} event = MockEvent( sender="@foo:bar", @@ -244,7 +240,7 @@ def test_definition_not_rooms_works_with_literals(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_rooms_works_with_unknowns(self): + def test_definition_not_rooms_works_with_unknowns(self) -> None: definition = {"not_rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", @@ -253,7 +249,7 @@ def test_definition_not_rooms_works_with_unknowns(self): ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_not_rooms_takes_priority_over_rooms(self): + def test_definition_not_rooms_takes_priority_over_rooms(self) -> None: definition = { "not_rooms": ["!secretbase:unknown"], "rooms": ["!secretbase:unknown"], @@ -263,7 +259,7 @@ def test_definition_not_rooms_takes_priority_over_rooms(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event(self): + def test_definition_combined_event(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -279,7 +275,7 @@ def test_definition_combined_event(self): ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event_bad_sender(self): + def test_definition_combined_event_bad_sender(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -295,7 +291,7 @@ def test_definition_combined_event_bad_sender(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event_bad_room(self): + def test_definition_combined_event_bad_room(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -311,7 +307,7 @@ def test_definition_combined_event_bad_room(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event_bad_type(self): + def test_definition_combined_event_bad_type(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -327,7 +323,7 @@ def test_definition_combined_event_bad_type(self): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_filter_labels(self): + def test_filter_labels(self) -> None: definition = {"org.matrix.labels": ["#fun"]} event = MockEvent( sender="@foo:bar", @@ -356,7 +352,7 @@ def test_filter_labels(self): ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_filter_not_labels(self): + def test_filter_not_labels(self) -> None: definition = {"org.matrix.not_labels": ["#fun"]} event = MockEvent( sender="@foo:bar", @@ -377,7 +373,7 @@ def test_filter_not_labels(self): self.assertTrue(Filter(self.hs, definition)._check(event)) @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) - def test_filter_rel_type(self): + def test_filter_rel_type(self) -> None: definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} event = MockEvent( sender="@foo:bar", @@ -407,7 +403,7 @@ def test_filter_rel_type(self): self.assertTrue(Filter(self.hs, definition)._check(event)) @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) - def test_filter_not_rel_type(self): + def test_filter_not_rel_type(self) -> None: definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} event = MockEvent( sender="@foo:bar", @@ -436,15 +432,25 @@ def test_filter_not_rel_type(self): self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_filter_presence_match(self): - user_filter_json = {"presence": {"types": ["m.*"]}} + def test_filter_presence_match(self) -> None: + """Check that filter_presence return events which matches the filter.""" + user_filter_json = {"presence": {"senders": ["@foo:bar"]}} filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) ) - event = MockEvent(sender="@foo:bar", type="m.profile") - events = [event] + presence_states = [ + UserPresenceState( + user_id="@foo:bar", + state="unavailable", + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ), + ] user_filter = self.get_success( self.filtering.get_user_filter( @@ -452,23 +458,29 @@ def test_filter_presence_match(self): ) ) - results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEqual(events, results) + results = self.get_success(user_filter.filter_presence(presence_states)) + self.assertEqual(presence_states, results) - def test_filter_presence_no_match(self): - user_filter_json = {"presence": {"types": ["m.*"]}} + def test_filter_presence_no_match(self) -> None: + """Check that filter_presence does not return events rejected by the filter.""" + user_filter_json = {"presence": {"not_senders": ["@foo:bar"]}} filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart + "2", user_filter=user_filter_json ) ) - event = MockEvent( - event_id="$asdasd:localhost", - sender="@foo:bar", - type="custom.avatar.3d.crazy", - ) - events = [event] + presence_states = [ + UserPresenceState( + user_id="@foo:bar", + state="unavailable", + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ), + ] user_filter = self.get_success( self.filtering.get_user_filter( @@ -476,10 +488,10 @@ def test_filter_presence_no_match(self): ) ) - results = self.get_success(user_filter.filter_presence(events=events)) + results = self.get_success(user_filter.filter_presence(presence_states)) self.assertEqual([], results) - def test_filter_room_state_match(self): + def test_filter_room_state_match(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( @@ -498,7 +510,7 @@ def test_filter_room_state_match(self): results = self.get_success(user_filter.filter_room_state(events=events)) self.assertEqual(events, results) - def test_filter_room_state_no_match(self): + def test_filter_room_state_no_match(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( @@ -519,7 +531,7 @@ def test_filter_room_state_no_match(self): results = self.get_success(user_filter.filter_room_state(events)) self.assertEqual([], results) - def test_filter_rooms(self): + def test_filter_rooms(self) -> None: definition = { "rooms": ["!allowed:example.com", "!excluded:example.com"], "not_rooms": ["!excluded:example.com"], @@ -535,7 +547,7 @@ def test_filter_rooms(self): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - def test_filter_relations(self): + def test_filter_relations(self) -> None: events = [ # An event without a relation. MockEvent( @@ -551,9 +563,8 @@ def test_filter_relations(self): type="org.matrix.custom.event", room_id="!foo:bar", ), - # Non-EventBase objects get passed through. - {}, ] + jsondicts: List[JsonDict] = [{}] # For the following tests we patch the datastore method (intead of injecting # events). This is a bit cheeky, but tests the logic of _check_event_relations. @@ -561,7 +572,7 @@ def test_filter_relations(self): # Filter for a particular sender. definition = {"related_by_senders": ["@foo:bar"]} - async def events_have_relations(*args, **kwargs): + async def events_have_relations(*args: object, **kwargs: object) -> List[str]: return ["$with_relation"] with patch.object( @@ -572,9 +583,17 @@ async def events_have_relations(*args, **kwargs): Filter(self.hs, definition)._check_event_relations(events) ) ) + # Non-EventBase objects get passed through. + filtered_jsondicts = list( + self.get_success( + Filter(self.hs, definition)._check_event_relations(jsondicts) + ) + ) + self.assertEqual(filtered_events, events[1:]) + self.assertEqual(filtered_jsondicts, [{}]) - def test_add_filter(self): + def test_add_filter(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( @@ -595,7 +614,7 @@ def test_add_filter(self): ), ) - def test_get_filter(self): + def test_get_filter(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index c86f783c5bd4..fa6c1c02ce95 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -6,9 +6,12 @@ class TestRatelimiter(unittest.HomeserverTestCase): - def test_allowed_via_can_do_action(self): + def test_allowed_via_can_do_action(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -28,9 +31,9 @@ def test_allowed_via_can_do_action(self): self.assertTrue(allowed) self.assertEqual(20.0, time_allowed) - def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None: appservice = ApplicationService( - None, + token="fake_token", id="foo", rate_limited=True, sender="@as:example.com", @@ -38,7 +41,10 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -58,9 +64,9 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): self.assertTrue(allowed) self.assertEqual(20.0, time_allowed) - def test_allowed_appservice_via_can_requester_do_action(self): + def test_allowed_appservice_via_can_requester_do_action(self) -> None: appservice = ApplicationService( - None, + token="fake_token", id="foo", rate_limited=False, sender="@as:example.com", @@ -68,7 +74,10 @@ def test_allowed_appservice_via_can_requester_do_action(self): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -88,9 +97,12 @@ def test_allowed_appservice_via_can_requester_do_action(self): self.assertTrue(allowed) self.assertEqual(-1, time_allowed) - def test_allowed_via_ratelimit(self): + def test_allowed_via_ratelimit(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) # Shouldn't raise @@ -108,13 +120,16 @@ def test_allowed_via_ratelimit(self): limiter.ratelimit(None, key="test_id", _time_now_s=10) ) - def test_allowed_via_can_do_action_and_overriding_parameters(self): + def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None: """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) # First attempt should be allowed @@ -154,13 +169,16 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self): self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) - def test_allowed_via_ratelimit_and_overriding_parameters(self): + def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None: """Test that we can override options of the ratelimit method that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) # First attempt should be allowed @@ -186,9 +204,12 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) ) - def test_pruning(self): + def test_pruning(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=1, ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -202,7 +223,7 @@ def test_pruning(self): self.assertNotIn("test_id_1", limiter.actions) - def test_db_user_override(self): + def test_db_user_override(self) -> None: """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ @@ -223,15 +244,18 @@ def test_db_user_override(self): ) ) - limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) # Shouldn't raise for _ in range(20): self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) - def test_multiple_actions(self): + def test_multiple_actions(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=3, ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -295,7 +319,10 @@ def test_rate_limit_burst_only_given_once(self) -> None: extra tokens by timing requests. """ limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=3, ) def consume_at(time: float) -> bool: @@ -317,7 +344,10 @@ def consume_at(time: float) -> bool: def test_record_action_which_doesnt_fill_bucket(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=3, ) # Observe two actions, leaving room in the bucket for one more. @@ -337,7 +367,10 @@ def test_record_action_which_doesnt_fill_bucket(self) -> None: def test_record_action_which_fills_bucket(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=3, ) # Observe three actions, filling up the bucket. @@ -363,7 +396,10 @@ def test_record_action_which_fills_bucket(self) -> None: def test_record_action_which_overfills_bucket(self) -> None: limiter = Ratelimiter( - store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, + clock=self.clock, + rate_hz=0.1, + burst_count=3, ) # Observe four actions, exceeding the bucket. diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py index cbcada04517e..788c93553715 100644 --- a/tests/app/test_homeserver_start.py +++ b/tests/app/test_homeserver_start.py @@ -19,7 +19,7 @@ class HomeserverAppStartTestCase(ConfigFileTestCase): - def test_wrong_start_caught(self): + def test_wrong_start_caught(self) -> None: # Generate a config with a worker_app self.generate_config() # Add a blank line as otherwise the next addition ends up on a line with a comment diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 8d03da7f96a6..5d89ba94ad8f 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -11,26 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from unittest.mock import Mock, patch from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from synapse.config.server import parse_listener_def +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.server import make_request from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs - def default_config(self): + def default_config(self) -> JsonDict: conf = super().default_config() # we're using FederationReaderServer, which uses a SlavedStore, so we # have to tell the FederationHandler not to try to access stuff that is only @@ -47,7 +53,7 @@ def default_config(self): (["openid"], "auth_fail"), ] ) - def test_openid_listener(self, names, expectation): + def test_openid_listener(self, names: List[str], expectation: str) -> None: """ Test different openid listener configurations. @@ -81,7 +87,7 @@ def test_openid_listener(self, names, expectation): @patch("synapse.app.homeserver.KeyResource", new=Mock()) class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( federation_http_client=None, homeserver_to_use=SynapseHomeServer ) @@ -95,7 +101,7 @@ def make_homeserver(self, reactor, clock): (["openid"], "auth_fail"), ] ) - def test_openid_listener(self, names, expectation): + def test_openid_listener(self, names: List[str], expectation: str) -> None: """ Test different openid listener configurations. diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index df731eb599cc..a860eedbcfef 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,8 +1,11 @@ import synapse from synapse.app.phone_stats_home import start_phone_stats_home from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest +from tests.server import ThreadedMemoryReactorClock from tests.unittest import HomeserverTestCase FIVE_MINUTES_IN_SECONDS = 300 @@ -19,7 +22,7 @@ class PhoneHomeTestCase(HomeserverTestCase): # Override the retention time for the user_ips table because otherwise it # gets pruned too aggressively for our R30 test. @unittest.override_config({"user_ips_max_age": "365d"}) - def test_r30_minimum_usage(self): + def test_r30_minimum_usage(self) -> None: """ Tests the minimum amount of interaction necessary for the R30 metric to consider a user 'retained'. @@ -68,7 +71,7 @@ def test_r30_minimum_usage(self): r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) - def test_r30_minimum_usage_using_default_config(self): + def test_r30_minimum_usage_using_default_config(self) -> None: """ Tests the minimum amount of interaction necessary for the R30 metric to consider a user 'retained'. @@ -122,7 +125,7 @@ def test_r30_minimum_usage_using_default_config(self): r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) - def test_r30_user_must_be_retained_for_at_least_a_month(self): + def test_r30_user_must_be_retained_for_at_least_a_month(self) -> None: """ Tests that a newly-registered user must be retained for a whole month before appearing in the R30 statistic, even if they post every day @@ -164,12 +167,14 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): login.register_servlets, ] - def _advance_to(self, desired_time_secs: float): + def _advance_to(self, desired_time_secs: float) -> None: now = self.hs.get_clock().time() assert now < desired_time_secs self.reactor.advance(desired_time_secs - now) - def make_homeserver(self, reactor, clock): + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) # We don't want our tests to actually report statistics, so check @@ -181,7 +186,7 @@ def make_homeserver(self, reactor, clock): start_phone_stats_home(hs) return hs - def test_r30v2_minimum_usage(self): + def test_r30v2_minimum_usage(self) -> None: """ Tests the minimum amount of interaction necessary for the R30v2 metric to consider a user 'retained'. @@ -250,7 +255,7 @@ def test_r30v2_minimum_usage(self): r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} ) - def test_r30v2_user_must_be_retained_for_at_least_a_month(self): + def test_r30v2_user_must_be_retained_for_at_least_a_month(self) -> None: """ Tests that a newly-registered user must be retained for a whole month before appearing in the R30v2 statistic, even if they post every day @@ -316,7 +321,7 @@ def test_r30v2_user_must_be_retained_for_at_least_a_month(self): r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0} ) - def test_r30v2_returning_dormant_users_not_counted(self): + def test_r30v2_returning_dormant_users_not_counted(self) -> None: """ Tests that dormant users (users inactive for a long time) do not contribute to R30v2 when they return for just a single day. diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 89ee79396fb5..9d183b733ec2 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -29,7 +29,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.api = hs.get_application_service_api() self.service = ApplicationService( id="unique_identifier", @@ -39,7 +39,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): hs_token=TOKEN, ) - def test_query_3pe_authenticates_token(self): + def test_query_3pe_authenticates_token(self) -> None: """ Tests that 3pe queries to the appservice are authenticated with the appservice's token. diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index d4dccfc2f070..dee976356faa 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +from typing import Generator from unittest.mock import Mock from twisted.internet import defer @@ -27,7 +28,7 @@ def _regex(regex: str, exclusive: bool = True) -> Namespace: class ApplicationServiceTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.service = ApplicationService( id="unique_identifier", sender="@as:test", @@ -46,7 +47,9 @@ def setUp(self): self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks - def test_regex_user_id_prefix_match(self): + def test_regex_user_id_prefix_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( @@ -60,7 +63,9 @@ def test_regex_user_id_prefix_match(self): ) @defer.inlineCallbacks - def test_regex_user_id_prefix_no_match(self): + def test_regex_user_id_prefix_no_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( @@ -74,7 +79,9 @@ def test_regex_user_id_prefix_no_match(self): ) @defer.inlineCallbacks - def test_regex_room_member_is_checked(self): + def test_regex_room_member_is_checked( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" @@ -90,7 +97,9 @@ def test_regex_room_member_is_checked(self): ) @defer.inlineCallbacks - def test_regex_room_id_match(self): + def test_regex_room_id_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) @@ -106,7 +115,9 @@ def test_regex_room_id_match(self): ) @defer.inlineCallbacks - def test_regex_room_id_no_match(self): + def test_regex_room_id_no_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) @@ -122,7 +133,9 @@ def test_regex_room_id_no_match(self): ) @defer.inlineCallbacks - def test_regex_alias_match(self): + def test_regex_alias_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -140,44 +153,46 @@ def test_regex_alias_match(self): ) ) - def test_non_exclusive_alias(self): + def test_non_exclusive_alias(self) -> None: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) - def test_non_exclusive_room(self): + def test_non_exclusive_room(self) -> None: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org")) - def test_non_exclusive_user(self): + def test_non_exclusive_user(self) -> None: self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org")) - def test_exclusive_alias(self): + def test_exclusive_alias(self) -> None: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) - def test_exclusive_user(self): + def test_exclusive_user(self) -> None: self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org")) - def test_exclusive_room(self): + def test_exclusive_room(self) -> None: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org")) @defer.inlineCallbacks - def test_regex_alias_no_match(self): + def test_regex_alias_no_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -196,7 +211,9 @@ def test_regex_alias_no_match(self): ) @defer.inlineCallbacks - def test_regex_multiple_matches(self): + def test_regex_multiple_matches( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -215,7 +232,9 @@ def test_regex_multiple_matches(self): ) @defer.inlineCallbacks - def test_interested_in_self(self): + def test_interested_in_self( + self, + ) -> Generator["defer.Deferred[object]", object, None]: # make sure invites get through self.service.sender = "@appservice:name" self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) @@ -233,7 +252,9 @@ def test_interested_in_self(self): ) @defer.inlineCallbacks - def test_member_list_match(self): + def test_member_list_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. self.store.get_local_users_in_room = simple_async_mock( diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0a1ae83a2bbf..febcc1499d06 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -11,20 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast from unittest.mock import Mock +from typing_extensions import TypeAlias + from twisted.internet import defer -from synapse.appservice import ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeysCount, + TransactionUnusedFallbackKeys, +) from synapse.appservice.scheduler import ( ApplicationServiceScheduler, _Recoverer, _TransactionController, ) +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable from synapse.server import HomeServer -from synapse.types import DeviceListUpdates +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock from tests import unittest @@ -37,18 +45,18 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.store = Mock() self.as_api = Mock() self.recoverer = Mock() self.recoverer_fn = Mock(return_value=self.recoverer) self.txnctrl = _TransactionController( - clock=self.clock, store=self.store, as_api=self.as_api + clock=cast(Clock, self.clock), store=self.store, as_api=self.as_api ) self.txnctrl.RECOVERER_CLASS = self.recoverer_fn - def test_single_service_up_txn_sent(self): + def test_single_service_up_txn_sent(self) -> None: # Test: The AS is up and the txn is successfully sent. service = Mock() events = [Mock(), Mock()] @@ -76,7 +84,7 @@ def test_single_service_up_txn_sent(self): self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed - def test_single_service_down(self): + def test_single_service_down(self) -> None: # Test: The AS is down so it shouldn't push; Recoverers will do it. # It should still make a transaction though. service = Mock() @@ -103,7 +111,7 @@ def test_single_service_down(self): self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.complete.call_count) # or completed - def test_single_service_up_txn_not_sent(self): + def test_single_service_up_txn_not_sent(self) -> None: # Test: The AS is up and the txn is not sent. A Recoverer is made and # started. service = Mock() @@ -139,26 +147,28 @@ def test_single_service_up_txn_not_sent(self): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.as_api = Mock() self.store = Mock() self.service = Mock() self.callback = simple_async_mock() self.recoverer = _Recoverer( - clock=self.clock, + clock=cast(Clock, self.clock), as_api=self.as_api, store=self.store, service=self.service, callback=self.callback, ) - def test_recover_single_txn(self): + def test_recover_single_txn(self) -> None: txn = Mock() # return one txn to send, then no more old txns txns = [txn, None] - def take_txn(*args, **kwargs): + def take_txn( + *args: object, **kwargs: object + ) -> "defer.Deferred[Optional[Mock]]": return defer.succeed(txns.pop(0)) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) @@ -177,12 +187,14 @@ def take_txn(*args, **kwargs): self.callback.assert_called_once_with(self.recoverer) self.assertEqual(self.recoverer.service, self.service) - def test_recover_retry_txn(self): + def test_recover_retry_txn(self) -> None: txn = Mock() txns = [txn, None] pop_txn = False - def take_txn(*args, **kwargs): + def take_txn( + *args: object, **kwargs: object + ) -> "defer.Deferred[Optional[Mock]]": if pop_txn: return defer.succeed(txns.pop(0)) else: @@ -214,8 +226,24 @@ def take_txn(*args, **kwargs): self.callback.assert_called_once_with(self.recoverer) +# Corresponds to synapse.appservice.scheduler._TransactionController.send +TxnCtrlArgs: TypeAlias = """ +defer.Deferred[ + Tuple[ + ApplicationService, + Sequence[EventBase], + Optional[List[JsonDict]], + Optional[List[JsonDict]], + Optional[TransactionOneTimeKeysCount], + Optional[TransactionUnusedFallbackKeys], + Optional[DeviceListUpdates], + ] +] +""" + + class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer): + def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() self.txn_ctrl.send = simple_async_mock() @@ -224,7 +252,7 @@ def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer): self.scheduler.txn_ctrl = self.txn_ctrl self.scheduler.queuer.txn_ctrl = self.txn_ctrl - def test_send_single_event_no_queue(self): + def test_send_single_event_no_queue(self) -> None: # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() @@ -233,8 +261,8 @@ def test_send_single_event_no_queue(self): service, [event], [], [], None, None, DeviceListUpdates() ) - def test_send_single_event_with_queue(self): - d = defer.Deferred() + def test_send_single_event_with_queue(self) -> None: + d: TxnCtrlArgs = defer.Deferred() self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") @@ -257,22 +285,22 @@ def test_send_single_event_with_queue(self): ) self.assertEqual(2, self.txn_ctrl.send.call_count) - def test_multiple_service_queues(self): + def test_multiple_service_queues(self) -> None: # Tests that each service has its own queue, and that they don't block # on each other. srv1 = Mock(id=4) - srv_1_defer = defer.Deferred() + srv_1_defer: "defer.Deferred[EventBase]" = defer.Deferred() srv_1_event = Mock(event_id="srv1a") srv_1_event2 = Mock(event_id="srv1b") srv2 = Mock(id=6) - srv_2_defer = defer.Deferred() + srv_2_defer: "defer.Deferred[EventBase]" = defer.Deferred() srv_2_event = Mock(event_id="srv2a") srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - def do_send(*args, **kwargs): + def do_send(*args: object, **kwargs: object) -> "defer.Deferred[EventBase]": return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -297,12 +325,12 @@ def do_send(*args, **kwargs): ) self.assertEqual(3, self.txn_ctrl.send.call_count) - def test_send_large_txns(self): - srv_1_defer = defer.Deferred() - srv_2_defer = defer.Deferred() + def test_send_large_txns(self) -> None: + srv_1_defer: "defer.Deferred[EventBase]" = defer.Deferred() + srv_2_defer: "defer.Deferred[EventBase]" = defer.Deferred() send_return_list = [srv_1_defer, srv_2_defer] - def do_send(*args, **kwargs): + def do_send(*args: object, **kwargs: object) -> "defer.Deferred[EventBase]": return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -328,7 +356,7 @@ def do_send(*args, **kwargs): ) self.assertEqual(3, self.txn_ctrl.send.call_count) - def test_send_single_ephemeral_no_queue(self): + def test_send_single_ephemeral_no_queue(self) -> None: # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] @@ -337,7 +365,7 @@ def test_send_single_ephemeral_no_queue(self): service, [], event_list, [], None, None, DeviceListUpdates() ) - def test_send_multiple_ephemeral_no_queue(self): + def test_send_multiple_ephemeral_no_queue(self) -> None: # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] @@ -346,8 +374,8 @@ def test_send_multiple_ephemeral_no_queue(self): service, [], event_list, [], None, None, DeviceListUpdates() ) - def test_send_single_ephemeral_with_queue(self): - d = defer.Deferred() + def test_send_single_ephemeral_with_queue(self) -> None: + d: TxnCtrlArgs = defer.Deferred() self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] @@ -377,8 +405,8 @@ def test_send_single_ephemeral_with_queue(self): ) self.assertEqual(2, self.txn_ctrl.send.call_count) - def test_send_large_txns_ephemeral(self): - d = defer.Deferred() + def test_send_large_txns_ephemeral(self) -> None: + d: TxnCtrlArgs = defer.Deferred() self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) # Expect the event to be sent immediately. service = Mock(id=4, name="service") diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index b703e4472e9e..a9893def7422 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -16,6 +16,8 @@ import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction @@ -23,11 +25,13 @@ from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, presence, room +from synapse.server import HomeServer from synapse.types import JsonDict, StreamToken, create_requester +from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config from tests.test_utils import simple_async_mock -from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, override_config @attr.s @@ -49,9 +53,7 @@ async def get_users_for_states( } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -71,9 +73,14 @@ def parse_config(config_dict: dict) -> PresenceRouterTestConfig: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -96,9 +103,7 @@ async def get_users_for_states( } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -118,9 +123,14 @@ def parse_config(config_dict: dict) -> PresenceRouterTestConfig: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client.send_transaction = simple_async_mock({}) @@ -153,7 +163,9 @@ def make_homeserver(self, reactor, clock): return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.sync_handler = self.hs.get_sync_handler() self.module_api = homeserver.get_module_api() @@ -176,7 +188,7 @@ def default_config(self) -> JsonDict: }, } ) - def test_receiving_all_presence_legacy(self): + def test_receiving_all_presence_legacy(self) -> None: self.receiving_all_presence_test_body() @override_config( @@ -193,10 +205,10 @@ def test_receiving_all_presence_legacy(self): ], } ) - def test_receiving_all_presence(self): + def test_receiving_all_presence(self) -> None: self.receiving_all_presence_test_body() - def receiving_all_presence_test_body(self): + def receiving_all_presence_test_body(self) -> None: """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -302,7 +314,7 @@ def receiving_all_presence_test_body(self): }, } ) - def test_send_local_online_presence_to_with_module_legacy(self): + def test_send_local_online_presence_to_with_module_legacy(self) -> None: self.send_local_online_presence_to_with_module_test_body() @override_config( @@ -321,10 +333,10 @@ def test_send_local_online_presence_to_with_module_legacy(self): ], } ) - def test_send_local_online_presence_to_with_module(self): + def test_send_local_online_presence_to_with_module(self) -> None: self.send_local_online_presence_to_with_module_test_body() - def send_local_online_presence_to_with_module_test_body(self): + def send_local_online_presence_to_with_module_test_body(self) -> None: """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ @@ -447,18 +459,18 @@ def send_local_online_presence_to_with_module_test_body(self): continue # EDUs can contain multiple presence updates - for presence_update in edu["content"]["push"]: + for presence_edu in edu["content"]["push"]: # Check for presence updates that contain the user IDs we're after - found_users.add(presence_update["user_id"]) + found_users.add(presence_edu["user_id"]) # Ensure that no offline states are being sent out - self.assertNotEqual(presence_update["presence"], "offline") + self.assertNotEqual(presence_edu["presence"], "offline") self.assertEqual(found_users, expected_users) def send_presence_update( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, access_token: str, presence_state: str, @@ -479,7 +491,7 @@ def send_presence_update( def sync_presence( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, since_token: Optional[StreamToken] = None, ) -> Tuple[List[UserPresenceState], StreamToken]: @@ -500,7 +512,7 @@ def sync_presence( requester = create_requester(user_id) sync_config = generate_sync_config(requester.user.to_string()) sync_result = testcase.get_success( - testcase.sync_handler.wait_for_sync_for_user( + testcase.hs.get_sync_handler().wait_for_sync_for_user( requester, sync_config, since_token ) ) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 8ddce83b830d..6687c28e8fea 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils.event_injection import create_event @@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -35,7 +40,7 @@ def prepare(self, reactor, clock, hs): self.user_tok = self.login("u1", "pass") self.room_id = self.helper.create_room_as(tok=self.user_tok) - def test_serialize_deserialize_msg(self): + def test_serialize_deserialize_msg(self) -> None: """Test that an EventContext for a message event is the same after serialize/deserialize. """ @@ -51,7 +56,7 @@ def test_serialize_deserialize_msg(self): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_no_prev(self): + def test_serialize_deserialize_state_no_prev(self) -> None: """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ @@ -67,7 +72,7 @@ def test_serialize_deserialize_state_no_prev(self): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_prev(self): + def test_serialize_deserialize_state_prev(self) -> None: """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ @@ -84,7 +89,9 @@ def test_serialize_deserialize_state_prev(self): self._check_serialize_deserialize(event, context) - def _check_serialize_deserialize(self, event, context): + def _check_serialize_deserialize( + self, event: EventBase, context: EventContext + ) -> None: serialized = self.get_success(context.serialize(event, self.store)) d_context = EventContext.deserialize(self._storage_controllers, serialized) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index a79256846fb3..4174a237ec84 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,25 +13,30 @@ # limitations under the License. import unittest as stdlib_unittest +from typing import Any, List, Mapping, Optional from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import ( + PowerLevelsContent, SerializeEventConfig, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, serialize_event, ) +from synapse.types import JsonDict from synapse.util.frozenutils import freeze -def MockEvent(**kwargs): +def MockEvent(**kwargs: Any) -> EventBase: if "event_id" not in kwargs: kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -60,7 +65,7 @@ def test_update_not_okay_leaves_original_value(self) -> None: class PruneEventTestCase(stdlib_unittest.TestCase): - def run_test(self, evdict, matchdict, **kwargs): + def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None: """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. @@ -74,7 +79,7 @@ def run_test(self, evdict, matchdict, **kwargs): prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) - def test_minimal(self): + def test_minimal(self) -> None: self.run_test( {"type": "A", "event_id": "$test:domain"}, { @@ -86,7 +91,7 @@ def test_minimal(self): }, ) - def test_basic_keys(self): + def test_basic_keys(self) -> None: """Ensure that the keys that should be untouched are kept.""" # Note that some of the values below don't really make sense, but the # pruning of events doesn't worry about the values of any fields (with @@ -138,7 +143,7 @@ def test_basic_keys(self): room_version=RoomVersions.MSC2176, ) - def test_unsigned(self): + def test_unsigned(self) -> None: """Ensure that unsigned properties get stripped (except age_ts and replaces_state).""" self.run_test( { @@ -159,7 +164,7 @@ def test_unsigned(self): }, ) - def test_content(self): + def test_content(self) -> None: """The content dictionary should be stripped in most cases.""" self.run_test( {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, @@ -194,7 +199,7 @@ def test_content(self): }, ) - def test_create(self): + def test_create(self) -> None: """Create events are partially redacted until MSC2176.""" self.run_test( { @@ -223,7 +228,7 @@ def test_create(self): room_version=RoomVersions.MSC2176, ) - def test_power_levels(self): + def test_power_levels(self) -> None: """Power level events keep a variety of content keys.""" self.run_test( { @@ -273,7 +278,7 @@ def test_power_levels(self): room_version=RoomVersions.MSC2176, ) - def test_alias_event(self): + def test_alias_event(self) -> None: """Alias events have special behavior up through room version 6.""" self.run_test( { @@ -302,7 +307,7 @@ def test_alias_event(self): room_version=RoomVersions.V6, ) - def test_redacts(self): + def test_redacts(self) -> None: """Redaction events have no special behaviour until MSC2174/MSC2176.""" self.run_test( @@ -328,7 +333,7 @@ def test_redacts(self): room_version=RoomVersions.MSC2176, ) - def test_join_rules(self): + def test_join_rules(self) -> None: """Join rules events have changed behavior starting with MSC3083.""" self.run_test( { @@ -371,7 +376,7 @@ def test_join_rules(self): room_version=RoomVersions.V8, ) - def test_member(self): + def test_member(self) -> None: """Member events have changed behavior starting with MSC3375.""" self.run_test( { @@ -417,12 +422,12 @@ def test_member(self): class SerializeEventTestCase(stdlib_unittest.TestCase): - def serialize(self, ev, fields): + def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) ) - def test_event_fields_works_with_keys(self): + def test_event_fields_works_with_keys(self) -> None: self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] @@ -430,7 +435,7 @@ def test_event_fields_works_with_keys(self): {"room_id": "!foo:bar"}, ) - def test_event_fields_works_with_nested_keys(self): + def test_event_fields_works_with_nested_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -443,7 +448,7 @@ def test_event_fields_works_with_nested_keys(self): {"content": {"body": "A message"}}, ) - def test_event_fields_works_with_dot_keys(self): + def test_event_fields_works_with_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -456,7 +461,7 @@ def test_event_fields_works_with_dot_keys(self): {"content": {"key.with.dots": {}}}, ) - def test_event_fields_works_with_nested_dot_keys(self): + def test_event_fields_works_with_nested_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -472,7 +477,7 @@ def test_event_fields_works_with_nested_dot_keys(self): {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) - def test_event_fields_nops_with_unknown_keys(self): + def test_event_fields_nops_with_unknown_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -485,7 +490,7 @@ def test_event_fields_nops_with_unknown_keys(self): {"content": {"foo": "bar"}}, ) - def test_event_fields_nops_with_non_dict_keys(self): + def test_event_fields_nops_with_non_dict_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -498,7 +503,7 @@ def test_event_fields_nops_with_non_dict_keys(self): {}, ) - def test_event_fields_nops_with_array_keys(self): + def test_event_fields_nops_with_array_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -511,7 +516,7 @@ def test_event_fields_nops_with_array_keys(self): {}, ) - def test_event_fields_all_fields_if_empty(self): + def test_event_fields_all_fields_if_empty(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -531,16 +536,16 @@ def test_event_fields_all_fields_if_empty(self): }, ) - def test_event_fields_fail_if_fields_not_str(self): + def test_event_fields_fail_if_fields_not_str(self) -> None: with self.assertRaises(TypeError): self.serialize( - MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] + MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item] ) class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: - self.test_content = { + self.test_content: PowerLevelsContent = { "ban": 50, "events": {"m.room.name": 100, "m.room.power_levels": 100}, "events_default": 0, @@ -553,10 +558,11 @@ def setUp(self) -> None: "users_default": 0, } - def _test(self, input): + def _test(self, input: PowerLevelsContent) -> None: a = copy_and_fixup_power_levels_contents(input) self.assertEqual(a["ban"], 50) + assert isinstance(a["events"], Mapping) self.assertEqual(a["events"]["m.room.name"], 100) # make sure that changing the copy changes the copy and not the orig @@ -564,18 +570,19 @@ def _test(self, input): a["events"]["m.room.power_levels"] = 20 self.assertEqual(input["ban"], 50) + assert isinstance(input["events"], Mapping) self.assertEqual(input["events"]["m.room.power_levels"], 100) - def test_unfrozen(self): + def test_unfrozen(self) -> None: self._test(self.test_content) - def test_frozen(self): + def test_frozen(self) -> None: input = freeze(self.test_content) self._test(input) - def test_stringy_integers(self): + def test_stringy_integers(self) -> None: """String representations of decimal integers are converted to integers.""" - input = { + input: PowerLevelsContent = { "a": "100", "b": { "foo": 99, @@ -603,9 +610,9 @@ def test_strings_that_dont_represent_decimal_integers(self) -> None: def test_invalid_types_raise_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[arg-type] - copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[arg-type] + copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[dict-item] + copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[dict-item] def test_invalid_nesting_raises_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) + copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item] diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9f1115dd23b8..d667dd27bf40 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -17,7 +17,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.rest import admin from synapse.rest.client import login, room -from synapse.types import UserID +from synapse.types import JsonDict, UserID from tests import unittest from tests.test_utils import make_awaitable @@ -31,12 +31,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config - def test_complexity_simple(self): + def test_complexity_simple(self) -> None: u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") @@ -66,7 +66,7 @@ def test_complexity_simple(self): complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) - def test_join_too_large(self): + def test_join_too_large(self) -> None: u1 = self.register_user("u1", "pass") @@ -95,7 +95,7 @@ def test_join_too_large(self): self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_join_too_large_admin(self): + def test_join_too_large_admin(self) -> None: # Check whether an admin can join if option "admins_can_join" is undefined, # this option defaults to false, so the join should fail. @@ -126,7 +126,7 @@ def test_join_too_large_admin(self): self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_join_too_large_once_joined(self): + def test_join_too_large_once_joined(self) -> None: u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") @@ -180,7 +180,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config["limit_remote_rooms"] = { "enabled": True, @@ -189,7 +189,7 @@ def default_config(self): } return config - def test_join_too_large_no_admin(self): + def test_join_too_large_no_admin(self) -> None: # A user which is not an admin should not be able to join a remote room # which is too complex. @@ -220,7 +220,7 @@ def test_join_too_large_no_admin(self): self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_join_too_large_admin(self): + def test_join_too_large_admin(self) -> None: # An admin should be able to join rooms where a complexity check fails. u1 = self.register_user("u1", "pass", admin=True) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index b8fee7289838..a986b15f0a87 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,13 +1,17 @@ -from typing import List, Tuple +from typing import Callable, List, Optional, Tuple from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager -from synapse.federation.units import Edu +from synapse.federation.units import Edu, Transaction from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable @@ -28,23 +32,25 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # stub out get_current_hosts_in_room - state_handler = hs.get_state_handler() + state_storage_controller = hs.get_storage_controllers().state # This mock is crucial for destination_rooms to be populated. - state_handler.get_current_hosts_in_room = Mock( - return_value=make_awaitable(["test", "host2"]) + # TODO: this seems to no longer be the case---tests pass with this mock + # commented out. + state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment] + return_value=make_awaitable({"test", "host2"}) ) # whenever send_transaction is called, record the pdu data - self.pdus = [] - self.failed_pdus = [] + self.pdus: List[JsonDict] = [] + self.failed_pdus: List[JsonDict] = [] self.is_online = True self.hs.get_federation_transport_client().send_transaction.side_effect = ( self.record_transaction @@ -55,8 +61,13 @@ def default_config(self) -> JsonDict: config["federation_sender_instances"] = None return config - async def record_transaction(self, txn, json_cb): - if self.is_online: + async def record_transaction( + self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] + ) -> JsonDict: + if json_cb is None: + # The tests seem to expect that this method raises in this situation. + raise Exception("Blank json_cb") + elif self.is_online: data = json_cb() self.pdus.extend(data["pdus"]) return {} @@ -92,7 +103,7 @@ def get_destination_room(self, room: str, destination: str = "host2") -> dict: )[0] return {"event_id": event_id, "stream_ordering": stream_ordering} - def test_catch_up_destination_rooms_tracking(self): + def test_catch_up_destination_rooms_tracking(self) -> None: """ Tests that we populate the `destination_rooms` table as needed. """ @@ -117,7 +128,7 @@ def test_catch_up_destination_rooms_tracking(self): self.assertEqual(row_2["event_id"], event_id_2) self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) - def test_catch_up_last_successful_stream_ordering_tracking(self): + def test_catch_up_last_successful_stream_ordering_tracking(self) -> None: """ Tests that we populate the `destination_rooms` table as needed. """ @@ -174,7 +185,7 @@ def test_catch_up_last_successful_stream_ordering_tracking(self): "Send succeeded but not marked as last_successful_stream_ordering", ) - def test_catch_up_from_blank_state(self): + def test_catch_up_from_blank_state(self) -> None: """ Runs an overall test of federation catch-up from scratch. Further tests will focus on more narrow aspects and edge-cases, but I @@ -261,16 +272,15 @@ async def fake_send( destination_tm: str, pending_pdus: List[EventBase], _pending_edus: List[Edu], - ) -> bool: + ) -> None: assert destination == destination_tm results_list.extend(pending_pdus) - return True # success! - transaction_manager.send_new_transaction = fake_send + transaction_manager.send_new_transaction = fake_send # type: ignore[assignment] return per_dest_queue, results_list - def test_catch_up_loop(self): + def test_catch_up_loop(self) -> None: """ Tests the behaviour of _catch_up_transmission_loop. """ @@ -334,7 +344,7 @@ def test_catch_up_loop(self): event_5.internal_metadata.stream_ordering, ) - def test_catch_up_on_synapse_startup(self): + def test_catch_up_on_synapse_startup(self) -> None: """ Tests the behaviour of get_catch_up_outstanding_destinations and _wake_destinations_needing_catchup. @@ -412,7 +422,7 @@ def test_catch_up_on_synapse_startup(self): # patch wake_destination to just count the destinations instead woken = [] - def wake_destination_track(destination): + def wake_destination_track(destination: str) -> None: woken.append(destination) self.hs.get_federation_sender().wake_destination = wake_destination_track @@ -432,7 +442,7 @@ def wake_destination_track(destination): # - all destinations are woken exactly once; they appear once in woken. self.assertCountEqual(woken, server_names[:-1]) - def test_not_latest_event(self): + def test_not_latest_event(self) -> None: """Test that we send the latest event in the room even if its not ours.""" per_dest_queue, sent_pdus = self.make_fake_destination_queue() diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index e67f4058260f..86e1236501fd 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -36,7 +36,9 @@ class FederationClientTest(FederatingHomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: super().prepare(reactor, clock, homeserver) # mock out the Agent used by the federation client, which is easier than @@ -51,7 +53,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): self.creator = f"@creator:{self.OTHER_SERVER_NAME}" self.test_room_id = "!room_id" - def test_get_room_state(self): + def test_get_room_state(self) -> None: # mock up some events to use in the response. # In real life, these would have things in `prev_events` and `auth_events`, but that's # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. @@ -140,7 +142,7 @@ def test_get_room_state(self): ["m.room.create", "m.room.member", "m.room.power_levels"], ) - def test_get_pdu_returns_nothing_when_event_does_not_exist(self): + def test_get_pdu_returns_nothing_when_event_does_not_exist(self) -> None: """No event should be returned when the event does not exist""" pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( @@ -151,11 +153,11 @@ def test_get_pdu_returns_nothing_when_event_does_not_exist(self): ) self.assertEqual(pulled_pdu_info, None) - def test_get_pdu(self): + def test_get_pdu(self) -> None: """Test to make sure an event is returned by `get_pdu()`""" self._get_pdu_once() - def test_get_pdu_event_from_cache_is_pristine(self): + def test_get_pdu_event_from_cache_is_pristine(self) -> None: """Test that modifications made to events returned by `get_pdu()` do not propagate back to to the internal cache (events returned should be a copy). diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 8692d8190f66..ddeffe1ad53b 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -11,18 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Callable, FrozenSet, List, Optional, Set from unittest.mock import Mock from signedjson import key, sign from signedjson.types import BaseKey, SigningKey from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms +from synapse.federation.units import Transaction from synapse.rest import admin from synapse.rest.client import login +from synapse.server import HomeServer from synapse.types import JsonDict, ReadReceipt +from synapse.util import Clock from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase @@ -36,16 +40,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): re-enabled for the main process. """ - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] return_value=make_awaitable({"test", "host2"}) ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] hs.get_storage_controllers().state.get_current_hosts_in_room ) @@ -56,7 +60,7 @@ def default_config(self) -> JsonDict: config["federation_sender_instances"] = None return config - def test_send_receipts(self): + def test_send_receipts(self) -> None: mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -98,7 +102,7 @@ def test_send_receipts(self): ], ) - def test_send_receipts_thread(self): + def test_send_receipts_thread(self) -> None: mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -174,7 +178,7 @@ def test_send_receipts_thread(self): ], ) - def test_send_receipts_with_backoff(self): + def test_send_receipts_with_backoff(self) -> None: """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" mock_send_transaction = ( @@ -272,51 +276,55 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver( federation_transport_client=Mock( spec=["send_transaction", "query_user_devices"] ), ) - def default_config(self): + def default_config(self) -> JsonDict: c = super().default_config() # Enable federation sending on the main process. c["federation_sender_instances"] = None return c - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: test_room_id = "!room:host1" # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the # server thinks the user shares a room with `@user2:host2` - def get_rooms_for_user(user_id): - return defer.succeed({test_room_id}) + def get_rooms_for_user(user_id: str) -> "defer.Deferred[FrozenSet[str]]": + return defer.succeed(frozenset({test_room_id})) - hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user + hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user # type: ignore[assignment] - async def get_current_hosts_in_room(room_id): + async def get_current_hosts_in_room(room_id: str) -> Set[str]: if room_id == test_room_id: - return ["host2"] - - # TODO: We should fail the test when we encounter an unxpected room ID. - # We can't just use `self.fail(...)` here because the app code is greedy - # with `Exception` and will catch it before the test can see it. + return {"host2"} + else: + # TODO: We should fail the test when we encounter an unxpected room ID. + # We can't just use `self.fail(...)` here because the app code is greedy + # with `Exception` and will catch it before the test can see it. + return set() - hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room + hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment] # whenever send_transaction is called, record the edu data - self.edus = [] + self.edus: List[JsonDict] = [] self.hs.get_federation_transport_client().send_transaction.side_effect = ( self.record_transaction ) - def record_transaction(self, txn, json_cb): + def record_transaction( + self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None + ) -> "defer.Deferred[JsonDict]": + assert json_cb is not None data = json_cb() self.edus.extend(data["edus"]) return defer.succeed({}) - def test_send_device_updates(self): + def test_send_device_updates(self) -> None: """Basic case: each device update should result in an EDU""" # create a device u1 = self.register_user("user", "pass") @@ -340,7 +348,7 @@ def test_send_device_updates(self): self.assertEqual(len(self.edus), 1) self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) - def test_dont_send_device_updates_for_remote_users(self): + def test_dont_send_device_updates_for_remote_users(self) -> None: """Check that we don't send device updates for remote users""" # Send the server a device list EDU for the other user, this will cause @@ -379,7 +387,7 @@ def test_dont_send_device_updates_for_remote_users(self): ) self.assertIn("D1", devices) - def test_upload_signatures(self): + def test_upload_signatures(self) -> None: """Uploading signatures on some devices should produce updates for that user""" e2e_handler = self.hs.get_e2e_keys_handler() @@ -391,7 +399,7 @@ def test_upload_signatures(self): # expect two edus self.assertEqual(len(self.edus), 2) - stream_id = None + stream_id: Optional[int] = None stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) @@ -473,13 +481,13 @@ def test_upload_signatures(self): self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] if stream_id is not None: - self.assertEqual(c["prev_id"], [stream_id]) + self.assertEqual(c["prev_id"], [stream_id]) # type: ignore[unreachable] self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2"}, devices) - def test_delete_devices(self): + def test_delete_devices(self) -> None: """If devices are deleted, that should result in EDUs too""" # create devices @@ -521,7 +529,7 @@ def test_delete_devices(self): devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) - def test_unreachable_server(self): + def test_unreachable_server(self) -> None: """If the destination server is unreachable, all the updates should get sent on recovery """ @@ -555,7 +563,7 @@ def test_unreachable_server(self): # for each device, there should be a single update self.assertEqual(len(self.edus), 3) - stream_id = None + stream_id: Optional[int] = None for edu in self.edus: self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] @@ -566,7 +574,7 @@ def test_unreachable_server(self): devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) - def test_prune_outbound_device_pokes1(self): + def test_prune_outbound_device_pokes1(self) -> None: """If a destination is unreachable, and the updates are pruned, we should get a single update. @@ -615,7 +623,7 @@ def test_prune_outbound_device_pokes1(self): # synapse uses an empty prev_id list to indicate "needs a full resync". self.assertEqual(c["prev_id"], []) - def test_prune_outbound_device_pokes2(self): + def test_prune_outbound_device_pokes2(self) -> None: """If a destination is unreachable, and the updates are pruned, we should get a single update. @@ -741,7 +749,7 @@ def encode_pubkey(sk: SigningKey) -> str: return key.encode_verify_key_base64(key.get_verify_key(sk)) -def build_device_dict(user_id: str, device_id: str, sk: SigningKey): +def build_device_dict(user_id: str, device_id: str, sk: SigningKey) -> JsonDict: """Build a dict representing the given device""" return { "user_id": user_id, diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 177e5b5afce3..bba6469b559b 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -21,7 +21,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin from synapse.rest.client import login, room @@ -42,7 +42,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): ] @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)]) - def test_bad_request(self, query_content): + def test_bad_request(self, query_content: bytes) -> None: """ Querying with bad data returns a reasonable error code. """ @@ -64,7 +64,7 @@ def test_bad_request(self, query_content): class ServerACLsTestCase(unittest.TestCase): - def test_blacklisted_server(self): + def test_blacklisted_server(self) -> None: e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) logging.info("ACL event: %s", e.content) @@ -74,7 +74,7 @@ def test_blacklisted_server(self): self.assertTrue(server_matches_acl_event("evil.com.au", e)) self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) - def test_block_ip_literals(self): + def test_block_ip_literals(self) -> None: e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]}) logging.info("ACL event: %s", e.content) @@ -83,7 +83,7 @@ def test_block_ip_literals(self): self.assertFalse(server_matches_acl_event("[1:2::]", e)) self.assertTrue(server_matches_acl_event("1:2:3:4", e)) - def test_wildcard_matching(self): + def test_wildcard_matching(self) -> None: e = _create_acl_event({"allow": ["good*.com"]}) self.assertTrue( server_matches_acl_event("good.com", e), @@ -110,7 +110,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def test_needs_to_be_in_room(self): + def test_needs_to_be_in_room(self) -> None: """/v1/state/ requires the server to be in the room""" u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") @@ -131,7 +131,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self._storage_controllers = hs.get_storage_controllers() @@ -157,7 +157,7 @@ def _make_join(self, user_id: str) -> JsonDict: self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body - def test_send_join(self): + def test_send_join(self) -> None: """happy-path test of send_join""" joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME join_result = self._make_join(joining_user) @@ -211,9 +211,8 @@ def test_send_join(self): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") - @override_config({"experimental_features": {"msc3706_enabled": True}}) def test_send_join_partial_state(self) -> None: - """When MSC3706 support is enabled, /send_join should return partial state""" + """/send_join should return partial state, if requested""" joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME join_result = self._make_join(joining_user) @@ -224,7 +223,7 @@ def test_send_join_partial_state(self) -> None: ) channel = self.make_signed_federation_request( "PUT", - f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", + f"/_matrix/federation/v2/send_join/{self._room_id}/x?omit_members=true", content=join_event_dict, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) @@ -325,7 +324,7 @@ def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> N # is probably sufficient to reassure that the bucket is updated. -def _create_acl_event(content): +def _create_acl_event(content: JsonDict) -> EventBase: return make_event_from_dict( { "room_id": "!a:b", diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py index e88e5d8bb355..55655de8628e 100644 --- a/tests/federation/transport/server/test__base.py +++ b/tests/federation/transport/server/test__base.py @@ -15,6 +15,8 @@ from http import HTTPStatus from typing import Dict, List, Tuple +from twisted.web.resource import Resource + from synapse.api.errors import Codes from synapse.federation.transport.server import BaseFederationServlet from synapse.federation.transport.server._base import Authenticator, _parse_auth_header @@ -62,7 +64,7 @@ class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCa path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}" - def create_test_resource(self): + def create_test_resource(self) -> Resource: """Overrides `HomeserverTestCase.create_test_resource`.""" resource = JsonResource(self.hs) diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index b84c74fc0eca..3d61b1e8a9ca 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -13,12 +13,14 @@ # limitations under the License. import json +from typing import List, Optional from unittest.mock import Mock import ijson.common from synapse.api.room_versions import RoomVersions from synapse.federation.transport.client import SendJoinParser +from synapse.types import JsonDict from synapse.util import ExceptionBundle from tests.unittest import TestCase @@ -66,38 +68,73 @@ def test_two_writes(self) -> None: self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) - self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertFalse(parsed_response.members_omitted, parsed_response) self.assertEqual(parsed_response.servers_in_room, None, parsed_response) def test_partial_state(self) -> None: - """Check that the partial_state flag is correctly parsed""" - parser = SendJoinParser(RoomVersions.V1, False) - response = { - "org.matrix.msc3706.partial_state": True, - } + """Check that the members_omitted flag is correctly parsed""" - serialised_response = json.dumps(response).encode() + def parse(response: JsonDict) -> bool: + parser = SendJoinParser(RoomVersions.V1, False) + serialised_response = json.dumps(response).encode() - # Send data to the parser - parser.write(serialised_response) + # Send data to the parser + parser.write(serialised_response) - # Retrieve and check the parsed SendJoinResponse - parsed_response = parser.finish() - self.assertTrue(parsed_response.partial_state) + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + return parsed_response.members_omitted - def test_servers_in_room(self) -> None: - """Check that the servers_in_room field is correctly parsed""" - parser = SendJoinParser(RoomVersions.V1, False) - response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + self.assertTrue(parse({"members_omitted": True})) + self.assertTrue(parse({"org.matrix.msc3706.partial_state": True})) - serialised_response = json.dumps(response).encode() + self.assertFalse(parse({"members_omitted": False})) + self.assertFalse(parse({"org.matrix.msc3706.partial_state": False})) - # Send data to the parser - parser.write(serialised_response) + # If there's a conflict, the stable field wins. + self.assertTrue( + parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False}) + ) + self.assertFalse( + parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True}) + ) - # Retrieve and check the parsed SendJoinResponse - parsed_response = parser.finish() - self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + + def parse(response: JsonDict) -> Optional[List[str]]: + parser = SendJoinParser(RoomVersions.V1, False) + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + return parsed_response.servers_in_room + + self.assertEqual( + parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}), + ["hs1", "hs2"], + ) + self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"]) + + # If both are provided, the stable identifier should win + self.assertEqual( + parse( + { + "org.matrix.msc3706.servers_in_room": ["old"], + "servers_in_room": ["new"], + } + ), + ["new"], + ) + + # And lastly, we should be able to tell if neither field was present. + self.assertEqual( + parse({}), + None, + ) def test_errors_closing_coroutines(self) -> None: """Check we close all coroutines, even if closing the first raises an Exception. diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index d21c11b716cd..70209ab09011 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -12,21 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from typing import Dict, List +from typing import Any, Dict, List, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.room_versions import RoomVersions -from synapse.events import builder +from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.events import EventBase, builder +from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import RoomAlias +from synapse.util import Clock from tests.test_utils import event_injection -from tests.unittest import FederatingHomeserverTestCase, TestCase +from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase -class KnockingStrippedStateEventHelperMixin(TestCase): +class KnockingStrippedStateEventHelperMixin(HomeserverTestCase): def send_example_state_events_to_room( self, hs: "HomeServer", @@ -49,7 +53,7 @@ def send_example_state_events_to_room( # To set a canonical alias, we'll need to point an alias at the room first. canonical_alias = "#fancy_alias:test" self.get_success( - self.store.create_room_alias_association( + self.hs.get_datastores().main.create_room_alias_association( RoomAlias.from_string(canonical_alias), room_id, ["test"] ) ) @@ -197,7 +201,9 @@ class FederationKnockingTestCase( login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main # We're not going to be properly signing events as our remote homeserver is fake, @@ -205,23 +211,29 @@ def prepare(self, reactor, clock, homeserver): # Note that these checks are not relevant to this test case. # Have this homeserver auto-approve all event signature checking. - async def approve_all_signature_checking(_, pdu): + async def approve_all_signature_checking( + room_version: RoomVersion, + pdu: EventBase, + record_failure_callback: Any = None, + ) -> EventBase: return pdu - homeserver.get_federation_server()._check_sigs_and_hash = ( + homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment] approve_all_signature_checking ) # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth(origin, event, context, *args, **kwargs): - return context + async def _check_event_auth( + origin: Optional[str], event: EventBase, context: EventContext + ) -> None: + pass - homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] return super().prepare(reactor, clock, homeserver) - def test_room_state_returned_when_knocking(self): + def test_room_state_returned_when_knocking(self) -> None: """ Tests that specific, stripped state events from a room are returned after a remote homeserver successfully knocks on a local room. diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index cfd550a04bcc..c4231f4aa972 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -20,7 +20,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"allow_public_rooms_over_federation": False}) - def test_blocked_public_room_list_over_federation(self): + def test_blocked_public_room_list_over_federation(self) -> None: """Test that unauthenticated requests to the public rooms directory 403 when allow_public_rooms_over_federation is False. """ @@ -31,7 +31,7 @@ def test_blocked_public_room_list_over_federation(self): self.assertEqual(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) - def test_open_public_room_list_over_federation(self): + def test_open_public_room_list_over_federation(self) -> None: """Test that unauthenticated requests to the public rooms directory 200 when allow_public_rooms_over_federation is True. """ @@ -42,7 +42,7 @@ def test_open_public_room_list_over_federation(self): self.assertEqual(200, channel.code) @DEBUG - def test_edu_debugging_doesnt_explode(self): + def test_edu_debugging_doesnt_explode(self) -> None: """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled. Remove this when we strip out issue_8631_logger. diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index c1579dac610f..6f300b8e1119 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -38,6 +38,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_handler = hs.get_admin_handler() + self._store = hs.get_datastores().main self.user1 = self.register_user("user1", "password") self.token1 = self.login("user1", "password") @@ -236,3 +237,62 @@ def test_knock(self) -> None: self.assertEqual(args[0], room_id) self.assertEqual(args[1].content["membership"], "knock") self.assertTrue(args[2]) # Assert there is at least one bit of state + + def test_profile(self) -> None: + """Tests that user profile get exported.""" + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_profile.assert_called_once() + + # check only a few values, not all available + args = writer.write_profile.call_args[0] + self.assertEqual(args[0]["name"], self.user2) + self.assertIn("displayname", args[0]) + self.assertIn("avatar_url", args[0]) + self.assertIn("threepids", args[0]) + self.assertIn("external_ids", args[0]) + self.assertIn("creation_ts", args[0]) + + def test_devices(self) -> None: + """Tests that user devices get exported.""" + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_devices.assert_called_once() + + args = writer.write_devices.call_args[0] + self.assertEqual(len(args[0]), 1) + self.assertEqual(args[0][0]["user_id"], self.user2) + self.assertIn("device_id", args[0][0]) + self.assertIsNone(args[0][0]["display_name"]) + self.assertIsNone(args[0][0]["last_seen_user_agent"]) + self.assertIsNone(args[0][0]["last_seen_ts"]) + self.assertIsNone(args[0][0]["last_seen_ip"]) + + def test_connections(self) -> None: + """Tests that user sessions / connections get exported.""" + # Insert a user IP + self.get_success( + self._store.insert_client_ip( + self.user2, "access_token", "ip", "user_agent", "MY_DEVICE" + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_connections.assert_called_once() + + args = writer.write_connections.call_args[0] + self.assertEqual(len(args[0]), 1) + self.assertEqual(args[0][0]["ip"], "ip") + self.assertEqual(args[0][0]["user_agent"], "user_agent") + self.assertGreater(args[0][0]["last_seen"], 0) + self.assertNotIn("access_token", args[0][0]) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index cedbb9fafcfa..57675fa407e4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import cast +from typing import Collection, Optional, cast from unittest import TestCase from unittest.mock import Mock, patch +from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes @@ -655,7 +656,7 @@ def test_failed_partial_join_is_clean(self) -> None: EVENT_INVITATION_MEMBERSHIP, ], partial_state=True, - servers_in_room=["example.com"], + servers_in_room={"example.com"}, ) ) ) @@ -679,3 +680,112 @@ def test_failed_partial_join_is_clean(self) -> None: f"Stale partial-stated room flag left over for {room_id} after a" f" failed do_invite_join!", ) + + def test_duplicate_partial_state_room_syncs(self) -> None: + """ + Tests that concurrent partial state syncs are not started for the same room. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Try to start another partial state sync. + # Nothing should happen. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # End the partial state sync + is_partial_state = False + end_sync.callback(None) + + # The partial state sync should not be restarted. + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # The next attempt to start the partial state sync should work. + is_partial_state = True + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + def test_partial_state_room_sync_restart(self) -> None: + """ + Tests that partial state syncs are restarted when a second partial state sync + was deduplicated and the first partial state sync fails. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Fail the partial state sync. + # The partial state sync should not be restarted. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Start the partial state sync again. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Deduplicate another partial state sync. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Fail the partial state sync. + # It should restart with the latest parameters. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 3) + mock_sync_partial_state_room.assert_called_with( + initial_destination="hs3", + other_destinations=["hs2"], + room_id="room_id", + ) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 6bbfd5dc843f..6a38893b688a 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -171,7 +171,7 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None: state=[create_event], auth_chain=[create_event], partial_state=False, - servers_in_room=[], + servers_in_room=frozenset(), ) ) ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index efbb5a8dbbaf..1fe9563c98f3 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,21 +14,22 @@ import json -from typing import Dict +from typing import Dict, List, Set from unittest.mock import ANY, Mock, call -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer +from synapse.handlers.typing import TypingWriterHandler from synapse.server import HomeServer from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from tests import unittest +from tests.server import ThreadedMemoryReactorClock from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -62,7 +63,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes: class TypingNotificationsTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + def make_homeserver( + self, + reactor: ThreadedMemoryReactorClock, + clock: Clock, + ) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) @@ -75,8 +80,9 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) + self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( - notifier=Mock(), + notifier=self.mock_hs_notifier, federation_http_client=mock_federation_client, keyring=mock_keyring, replication_streams={}, @@ -90,32 +96,38 @@ def create_resource_dict(self) -> Dict[str, Resource]: return d def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - mock_notifier = hs.get_notifier() - self.on_new_event = mock_notifier.on_new_event + self.on_new_event = self.mock_hs_notifier.on_new_event - self.handler = hs.get_typing_handler() + # hs.get_typing_handler will return a TypingWriterHandler when calling it + # from the main process, and a FollowerTypingHandler on workers. + # We rely on methods only available on the former, so assert we have the + # correct type here. We have to assign self.handler after the assert, + # otherwise mypy will treat it as a FollowerTypingHandler + handler = hs.get_typing_handler() + assert isinstance(handler, TypingWriterHandler) + self.handler = handler self.event_source = hs.get_event_sources().sources.typing self.datastore = hs.get_datastores().main + self.datastore.get_destination_retry_timings = Mock( return_value=make_awaitable(None) ) - self.datastore.get_device_updates_by_remote = Mock( + self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] return_value=make_awaitable((0, [])) ) - self.datastore.get_destination_last_successful_stream_ordering = Mock( + self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] return_value=make_awaitable(None) ) - def get_received_txn_response(*args): - return defer.succeed(None) - - self.datastore.get_received_txn_response = get_received_txn_response + self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) + ) - self.room_members = [] + self.room_members: List[UserID] = [] async def check_user_in_room(room_id: str, requester: Requester) -> None: if requester.user.to_string() not in [ @@ -124,47 +136,54 @@ async def check_user_in_room(room_id: str, requester: Requester) -> None: raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = check_user_in_room + hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] + side_effect=check_user_in_room + ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = check_host_in_room + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] + side_effect=check_host_in_room + ) - async def get_current_hosts_in_room(room_id: str): + async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = ( - get_current_hosts_in_room + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( - get_current_hosts_in_room + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] + side_effect=get_current_hosts_in_room ) - async def get_users_in_room(room_id: str): + async def get_users_in_room(room_id: str) -> Set[str]: return {str(u) for u in self.room_members} - self.datastore.get_users_in_room = get_users_in_room + self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = Mock( + self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] side_effect=( - # we deliberately return a non-None stream pos to avoid doing an initial_spam + # we deliberately return a non-None stream pos to avoid + # doing an initial_sync lambda: make_awaitable(1) ) ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] - self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = ( - lambda *args, **kargs: make_awaitable(([], 0)) + self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] + side_effect=lambda: 0 + ) + self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(([], 0)) ) - self.datastore.delete_device_msgs_for_remote = ( - lambda *args, **kargs: make_awaitable(None) + self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(None) ) - self.datastore.set_received_txn_response = ( - lambda *args, **kwargs: make_awaitable(None) + self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kwargs: make_awaitable(None) ) def test_started_typing_local(self) -> None: @@ -186,7 +205,7 @@ def test_started_typing_local(self) -> None: self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -257,7 +276,7 @@ def test_started_typing_remote_recv(self) -> None: self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -298,7 +317,7 @@ def test_started_typing_remote_recv_not_in_room(self) -> None: self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[OTHER_ROOM_ID], is_guest=False, ) @@ -351,7 +370,7 @@ def test_stopped_typing(self) -> None: self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -387,7 +406,7 @@ def test_typing_timeout(self) -> None: self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) @@ -412,7 +431,7 @@ def test_typing_timeout(self) -> None: self.event_source.get_new_events( user=U_APPLE, from_key=1, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) @@ -447,7 +466,7 @@ def test_typing_timeout(self) -> None: self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 093537adef52..528cdee34b8b 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -19,13 +19,15 @@ from OpenSSL import SSL from OpenSSL.SSL import Connection +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.protocols.tls import TLSMemoryBIOProtocol from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 -def get_test_https_policy(): +def get_test_https_policy() -> BrowserLikePolicyForHTTPS: """Get a test IPolicyForHTTPS which trusts the test CA cert Returns: @@ -39,7 +41,7 @@ def get_test_https_policy(): return BrowserLikePolicyForHTTPS(trustRoot=trust_root) -def get_test_ca_cert_file(): +def get_test_ca_cert_file() -> str: """Get the path to the test CA cert The keypair is generated with: @@ -51,7 +53,7 @@ def get_test_ca_cert_file(): return os.path.join(os.path.dirname(__file__), "ca.crt") -def get_test_key_file(): +def get_test_key_file() -> str: """get the path to the test key The key file is made with: @@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory: """An SSL connection creator which returns connections which present a certificate signed by our test CA.""" - def __init__(self, sanlist): + def __init__(self, sanlist: List[bytes]): """ Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert """ self._cert_file = create_test_cert_file(sanlist) - def serverConnectionForTLS(self, tlsProtocol): + def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection: ctx = SSL.Context(SSL.SSLv23_METHOD) ctx.use_certificate_file(self._cert_file) ctx.use_privatekey_file(get_test_key_file()) return Connection(ctx, None) + + +# A dummy address, useful for tests that use FakeTransport and don't care about where +# packets are going to/coming from. +dummy_address = IPv4Address("TCP", "127.0.0.1", 80) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 992d8f94fd70..acfdcd3bca25 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -14,7 +14,7 @@ import base64 import logging import os -from typing import Iterable, Optional +from typing import Any, Awaitable, Callable, Generator, List, Optional, cast from unittest.mock import Mock, patch import treq @@ -24,14 +24,19 @@ from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions -from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.defer import Deferred +from twisted.internet.endpoints import _WrappingProtocol +from twisted.internet.interfaces import ( + IOpenSSLClientConnectionCreator, + IProtocolFactory, +) from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers -from twisted.web.iweb import IPolicyForHTTPS +from twisted.web.iweb import IPolicyForHTTPS, IResponse from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS @@ -42,11 +47,21 @@ WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + LoggingContextOrSentinel, + current_context, +) +from synapse.types import ISynapseReactor from synapse.util.caches.ttlcache import TTLCache from tests import unittest -from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.http import ( + TestServerTLSConnectionFactory, + dummy_address, + get_test_ca_cert_file, +) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.utils import default_config @@ -54,15 +69,17 @@ # Once Async Mocks or lambdas are supported this can go away. -def generate_resolve_service(result): - async def resolve_service(_): +def generate_resolve_service( + result: List[Server], +) -> Callable[[Any], Awaitable[List[Server]]]: + async def resolve_service(_: Any) -> List[Server]: return result return resolve_service class MatrixFederationAgentTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() @@ -75,8 +92,12 @@ def setUp(self): self.tls_factory = FederationPolicyForHTTPS(config) - self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) - self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache( + "test_cache", timer=self.reactor.seconds + ) + self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache( + "test_cache", timer=self.reactor.seconds + ) self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), @@ -89,8 +110,8 @@ def _make_connection( self, client_factory: IProtocolFactory, ssl: bool = True, - expected_sni: bytes = None, - tls_sanlist: Optional[Iterable[bytes]] = None, + expected_sni: Optional[bytes] = None, + tls_sanlist: Optional[List[bytes]] = None, ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection Args: @@ -116,8 +137,8 @@ def _make_connection( if ssl: server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) - server_protocol = server_factory.buildProtocol(None) - + server_protocol = server_factory.buildProtocol(dummy_address) + assert server_protocol is not None # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # HTTP11ClientProtocol) and wire the output of said protocol up to the server via @@ -125,7 +146,8 @@ def _make_connection( # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) + client_protocol = client_factory.buildProtocol(dummy_address) + assert isinstance(client_protocol, _WrappingProtocol) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -136,6 +158,7 @@ def _make_connection( ) if ssl: + assert isinstance(server_protocol, TLSMemoryBIOProtocol) # fish the test server back out of the server-side TLS protocol. http_protocol = server_protocol.wrappedProtocol # grab a hold of the TLS connection, in case it gets torn down @@ -144,6 +167,7 @@ def _make_connection( http_protocol = server_protocol tls_connection = None + assert isinstance(http_protocol, HTTPChannel) # give the reactor a pump to get the TLS juices flowing (if needed) self.reactor.advance(0) @@ -159,12 +183,14 @@ def _make_connection( return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri: bytes): + def _make_get_request( + self, uri: bytes + ) -> Generator["Deferred[object]", object, IResponse]: """ Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b"GET", uri) + fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -172,8 +198,9 @@ def _make_get_request(self, uri: bytes): # should have reset logcontext to the sentinel _check_logcontext(SENTINEL_CONTEXT) + fetch_res: IResponse try: - fetch_res = yield fetch_d + fetch_res = yield fetch_d # type: ignore[misc, assignment] return fetch_res except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) @@ -216,7 +243,7 @@ def _send_well_known_response( request: Request, content: bytes, headers: Optional[dict] = None, - ): + ) -> None: """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -237,16 +264,16 @@ def _make_agent(self) -> MatrixFederationAgent: because it is created too early during setUp """ return MatrixFederationAgent( - reactor=self.reactor, + reactor=cast(ISynapseReactor, self.reactor), tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) - def test_get(self): + def test_get(self) -> None: """happy-path test of a GET request with an explicit port""" self._do_get() @@ -254,11 +281,11 @@ def test_get(self): os.environ, {"https_proxy": "proxy.com", "no_proxy": "testserv"}, ) - def test_get_bypass_proxy(self): + def test_get_bypass_proxy(self) -> None: """test of a GET request with an explicit port and bypass proxy""" self._do_get() - def _do_get(self): + def _do_get(self) -> None: """test of a GET request with an explicit port""" self.agent = self._make_agent() @@ -318,7 +345,7 @@ def _do_get(self): @patch.dict( os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} ) - def test_get_via_http_proxy(self): + def test_get_via_http_proxy(self) -> None: """test for federation request through a http proxy""" self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) @@ -326,7 +353,7 @@ def test_get_via_http_proxy(self): os.environ, {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, ) - def test_get_via_http_proxy_with_auth(self): + def test_get_via_http_proxy_with_auth(self) -> None: """test for federation request through a http proxy with authentication""" self._do_get_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" @@ -335,7 +362,7 @@ def test_get_via_http_proxy_with_auth(self): @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) - def test_get_via_https_proxy(self): + def test_get_via_https_proxy(self) -> None: """test for federation request through a https proxy""" self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) @@ -343,7 +370,7 @@ def test_get_via_https_proxy(self): os.environ, {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, ) - def test_get_via_https_proxy_with_auth(self): + def test_get_via_https_proxy_with_auth(self) -> None: """test for federation request through a https proxy with authentication""" self._do_get_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" @@ -353,7 +380,7 @@ def _do_get_via_proxy( self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a https federation request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: @@ -418,10 +445,12 @@ def _do_get_via_proxy( # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() - ).buildProtocol(None) + ).buildProtocol(dummy_address) + assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport + assert proxy_server_transport is not None server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the @@ -451,6 +480,7 @@ def _do_get_via_proxy( # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -491,7 +521,7 @@ def _do_get_via_proxy( json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) - def test_get_ip_address(self): + def test_get_ip_address(self) -> None: """ Test the behaviour when the server name contains an explicit IP (with no port) """ @@ -526,7 +556,7 @@ def test_get_ip_address(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_ipv6_address(self): + def test_get_ipv6_address(self) -> None: """ Test the behaviour when the server name contains an explicit IPv6 address (with no port) @@ -562,7 +592,7 @@ def test_get_ipv6_address(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_ipv6_address_with_port(self): + def test_get_ipv6_address_with_port(self) -> None: """ Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) @@ -598,7 +628,7 @@ def test_get_ipv6_address_with_port(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_hostname_bad_cert(self): + def test_get_hostname_bad_cert(self) -> None: """ Test the behaviour when the certificate on the server doesn't match the hostname """ @@ -651,7 +681,7 @@ def test_get_hostname_bad_cert(self): failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) - def test_get_ip_address_bad_cert(self): + def test_get_ip_address_bad_cert(self) -> None: """ Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it @@ -684,7 +714,7 @@ def test_get_ip_address_bad_cert(self): failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) - def test_get_no_srv_no_well_known(self): + def test_get_no_srv_no_well_known(self) -> None: """ Test the behaviour when the server name has no port, no SRV, and no well-known """ @@ -740,7 +770,7 @@ def test_get_no_srv_no_well_known(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known(self): + def test_get_well_known(self) -> None: """Test the behaviour when the .well-known delegates elsewhere""" self.agent = self._make_agent() @@ -802,7 +832,7 @@ def test_get_well_known(self): self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) - def test_get_well_known_redirect(self): + def test_get_well_known_redirect(self) -> None: """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ @@ -892,7 +922,7 @@ def test_get_well_known_redirect(self): self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) - def test_get_invalid_well_known(self): + def test_get_invalid_well_known(self) -> None: """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ @@ -945,7 +975,7 @@ def test_get_invalid_well_known(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known_unsigned_cert(self): + def test_get_well_known_unsigned_cert(self) -> None: """Test the behaviour when the .well-known server presents a cert not signed by a CA """ @@ -969,7 +999,7 @@ def test_get_well_known_unsigned_cert(self): ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( - self.reactor, + cast(ISynapseReactor, self.reactor), Agent(self.reactor, contextFactory=tls_factory), b"test-agent", well_known_cache=self.well_known_cache, @@ -999,7 +1029,7 @@ def test_get_well_known_unsigned_cert(self): b"_matrix._tcp.testserv" ) - def test_get_hostname_srv(self): + def test_get_hostname_srv(self) -> None: """ Test the behaviour when there is a single SRV record """ @@ -1041,7 +1071,7 @@ def test_get_hostname_srv(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known_srv(self): + def test_get_well_known_srv(self) -> None: """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ @@ -1101,7 +1131,7 @@ def test_get_well_known_srv(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_idna_servername(self): + def test_idna_servername(self) -> None: """test the behaviour when the server name has idna chars in""" self.agent = self._make_agent() @@ -1163,7 +1193,7 @@ def test_idna_servername(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_idna_srv_target(self): + def test_idna_srv_target(self) -> None: """test the behaviour when the target of a SRV record has idna chars""" self.agent = self._make_agent() @@ -1206,7 +1236,7 @@ def test_idna_srv_target(self): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_well_known_cache(self): + def test_well_known_cache(self) -> None: self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = defer.ensureDeferred( @@ -1262,7 +1292,7 @@ def test_well_known_cache(self): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"other-server") - def test_well_known_cache_with_temp_failure(self): + def test_well_known_cache_with_temp_failure(self) -> None: """Test that we refetch well-known before the cache expires, and that it ignores transient errors. """ @@ -1341,7 +1371,7 @@ def test_well_known_cache_with_temp_failure(self): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) - def test_well_known_too_large(self): + def test_well_known_too_large(self) -> None: """A well-known query that returns a result which is too large should be rejected.""" self.reactor.lookups["testserv"] = "1.2.3.4" @@ -1367,7 +1397,7 @@ def test_well_known_too_large(self): r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) - def test_srv_fallbacks(self): + def test_srv_fallbacks(self) -> None: """Test that other SRV results are tried if the first one fails.""" self.agent = self._make_agent() @@ -1427,7 +1457,7 @@ def test_srv_fallbacks(self): class TestCachePeriodFromHeaders(unittest.TestCase): - def test_cache_control(self): + def test_cache_control(self) -> None: # uppercase self.assertEqual( _cache_period_from_headers( @@ -1464,7 +1494,7 @@ def test_cache_control(self): 0, ) - def test_expires(self): + def test_expires(self) -> None: self.assertEqual( _cache_period_from_headers( Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), @@ -1491,14 +1521,14 @@ def test_expires(self): self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) -def _check_logcontext(context): +def _check_logcontext(context: LoggingContextOrSentinel) -> None: current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _wrap_server_factory_for_tls( - factory: IProtocolFactory, sanlist: Iterable[bytes] = None + factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None ) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate @@ -1537,7 +1567,7 @@ def _get_test_protocol_factory() -> IProtocolFactory: return server_factory -def _log_request(request: str): +def _log_request(request: str) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info(f"Completed request {request}") @@ -1547,6 +1577,8 @@ class TrustingTLSPolicyForHTTPS: """An IPolicyForHTTPS which checks that the certificate belongs to the right server, but doesn't check the certificate chain.""" - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 77ce8432ac53..7748f56ee6e2 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Dict, Generator, List, Tuple, cast from unittest.mock import Mock from twisted.internet import defer @@ -20,7 +20,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error -from synapse.http.federation.srv_resolver import SrvResolver +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.logging.context import LoggingContext, current_context from tests import unittest @@ -28,7 +28,7 @@ class SrvResolverTestCase(unittest.TestCase): - def test_resolve(self): + def test_resolve(self) -> None: dns_client_mock = Mock() service_name = b"test_service.example.com" @@ -38,18 +38,19 @@ def test_resolve(self): type=dns.SRV, payload=dns.Record_SRV(target=host_name) ) - result_deferred = Deferred() + result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock.lookupService.return_value = result_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @defer.inlineCallbacks - def do_lookup(): + def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - result = yield defer.ensureDeferred(resolve_d) + result: List[Server] + result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment] # should have restored our context self.assertIs(current_context(), ctx) @@ -70,7 +71,9 @@ def do_lookup(): self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks - def test_from_cache_expired_and_dns_fail(self): + def test_from_cache_expired_and_dns_fail( + self, + ) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) @@ -81,10 +84,13 @@ def test_from_cache_expired_and_dns_fail(self): entry.priority = 0 entry.weight = 0 - cache = {service_name: [entry]} + cache = {service_name: [cast(Server, entry)]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -92,7 +98,7 @@ def test_from_cache_expired_and_dns_fail(self): self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks - def test_from_cache(self): + def test_from_cache(self) -> Generator["Deferred[object]", object, None]: clock = MockClock() dns_client_mock = Mock(spec_set=["lookupService"]) @@ -105,12 +111,15 @@ def test_from_cache(self): entry.priority = 0 entry.weight = 0 - cache = {service_name: [entry]} + cache = {service_name: [cast(Server, entry)]} resolver = SrvResolver( dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] self.assertFalse(dns_client_mock.lookupService.called) @@ -118,45 +127,48 @@ def test_from_cache(self): self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks - def test_empty_cache(self): + def test_empty_cache(self) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) service_name = b"test_service.example.com" - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks - def test_name_error(self): + def test_name_error(self) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) service_name = b"test_service.example.com" - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] self.assertEqual(len(servers), 0) self.assertEqual(len(cache), 0) - def test_disabled_service(self): + def test_disabled_service(self) -> None: """ test the behaviour when there is a single record which is ".". """ service_name = b"test_service.example.com" - lookup_deferred = Deferred() + lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) # Old versions of Twisted don't have an ensureDeferred in failureResultOf. @@ -173,16 +185,16 @@ def test_disabled_service(self): self.failureResultOf(resolve_d, ConnectError) - def test_non_srv_answer(self): + def test_non_srv_answer(self) -> None: """ test the behaviour when the dns server gives us a spurious non-SRV response """ service_name = b"test_service.example.com" - lookup_deferred = Deferred() + lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) # Old versions of Twisted don't have an ensureDeferred in successResultOf. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 5071f835745e..36472e57a839 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str: return method_name -def _hash_stack(stack: List[inspect.FrameInfo]): +def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]: """Turns a stack into a hashable value that can be put into a set.""" return tuple(_format_stack_frame(frame) for frame in stack) diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 391196425c38..ec6aacf235ae 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -11,28 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from twisted.web.server import Request from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase class _AsyncTestCustomEndpoint: - def __init__(self, config, module_api): + def __init__(self, config: JsonDict, module_api: Any) -> None: pass - async def handle_request(self, request): + async def handle_request(self, request: Request) -> None: + assert isinstance(request, SynapseRequest) respond_with_json(request, 200, {"some_key": "some_value_async"}) class _SyncTestCustomEndpoint: - def __init__(self, config, module_api): + def __init__(self, config: JsonDict, module_api: Any) -> None: pass - async def handle_request(self, request): + async def handle_request(self, request: Request) -> None: + assert isinstance(request, SynapseRequest) respond_with_json(request, 200, {"some_key": "some_value_sync"}) @@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase): and async handlers. """ - def test_async(self): + def test_async(self) -> None: handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) @@ -52,7 +58,7 @@ def test_async(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) - def test_sync(self): + def test_sync(self) -> None: handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 7e2f2a01cc07..9cfe1ad0de0a 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -13,10 +13,12 @@ # limitations under the License. from io import BytesIO +from typing import Tuple, Union from unittest.mock import Mock from netaddr import IPSet +from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol @@ -28,6 +30,7 @@ BlacklistingAgentWrapper, BlacklistingReactorWrapper, BodyExceededMaxSize, + _DiscardBodyWithMaxSizeProtocol, read_body_with_max_size, ) @@ -36,7 +39,9 @@ class ReadBodyWithMaxSizeTests(TestCase): - def _build_response(self, length=UNKNOWN_LENGTH): + def _build_response( + self, length: Union[int, str] = UNKNOWN_LENGTH + ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: """Start reading the body, returns the response, result and proto""" response = Mock(length=length) result = BytesIO() @@ -48,23 +53,27 @@ def _build_response(self, length=UNKNOWN_LENGTH): return result, deferred, protocol - def _assert_error(self, deferred, protocol): + def _assert_error( + self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol + ) -> None: """Ensure that the expected error is received.""" - self.assertIsInstance(deferred.result, Failure) + assert isinstance(deferred.result, Failure) self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) - protocol.transport.abortConnection.assert_called_once() + assert protocol.transport is not None + # type-ignore: presumably abortConnection has been replaced with a Mock. + protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined] - def _cleanup_error(self, deferred): + def _cleanup_error(self, deferred: "Deferred[int]") -> None: """Ensure that the error in the Deferred is handled gracefully.""" called = [False] - def errback(f): + def errback(f: Failure) -> None: called[0] = True deferred.addErrback(errback) self.assertTrue(called[0]) - def test_no_error(self): + def test_no_error(self) -> None: """A response that is NOT too large.""" result, deferred, protocol = self._build_response() @@ -76,7 +85,7 @@ def test_no_error(self): self.assertEqual(result.getvalue(), b"12345") self.assertEqual(deferred.result, 5) - def test_too_large(self): + def test_too_large(self) -> None: """A response which is too large raises an exception.""" result, deferred, protocol = self._build_response() @@ -87,7 +96,7 @@ def test_too_large(self): self._assert_error(deferred, protocol) self._cleanup_error(deferred) - def test_multiple_packets(self): + def test_multiple_packets(self) -> None: """Data should be accumulated through mutliple packets.""" result, deferred, protocol = self._build_response() @@ -100,7 +109,7 @@ def test_multiple_packets(self): self.assertEqual(result.getvalue(), b"1234") self.assertEqual(deferred.result, 4) - def test_additional_data(self): + def test_additional_data(self) -> None: """A connection can receive data after being closed.""" result, deferred, protocol = self._build_response() @@ -115,7 +124,7 @@ def test_additional_data(self): self._assert_error(deferred, protocol) self._cleanup_error(deferred) - def test_content_length(self): + def test_content_length(self) -> None: """The body shouldn't be read (at all) if the Content-Length header is too large.""" result, deferred, protocol = self._build_response(length=10) @@ -132,7 +141,7 @@ def test_content_length(self): class BlacklistingAgentTest(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, self.clock = get_clock() self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" @@ -151,7 +160,7 @@ def setUp(self): self.ip_whitelist = IPSet([self.allowed_ip.decode()]) self.ip_blacklist = IPSet(["5.0.0.0/8"]) - def test_reactor(self): + def test_reactor(self) -> None: """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" agent = Agent( BlacklistingReactorWrapper( @@ -197,7 +206,7 @@ def test_reactor(self): response = self.successResultOf(d) self.assertEqual(response.code, 200) - def test_agent(self): + def test_agent(self) -> None: """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" agent = BlacklistingAgentWrapper( Agent(self.reactor), diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index a801f002a004..8c18e5688177 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -17,7 +17,7 @@ class ServerNameTestCase(unittest.TestCase): - def test_parse_server_name(self): + def test_parse_server_name(self) -> None: test_data = { "localhost": ("localhost", None), "my-example.com:1234": ("my-example.com", 1234), @@ -32,7 +32,7 @@ def test_parse_server_name(self): for i, o in test_data.items(): self.assertEqual(parse_server_name(i), o) - def test_validate_bad_server_names(self): + def test_validate_bad_server_names(self) -> None: test_data = [ "", # empty "localhost:http", # non-numeric port diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index be9eaf34e8e7..fdd22a8e9437 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Generator from unittest.mock import Mock from netaddr import IPSet from parameterized import parameterized from twisted.internet import defer -from twisted.internet.defer import TimeoutError +from twisted.internet.defer import Deferred, TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from twisted.web.client import ResponseNeverReceived from twisted.web.http import HTTPChannel @@ -30,34 +30,43 @@ MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + LoggingContextOrSentinel, + current_context, +) +from synapse.server import HomeServer +from synapse.util import Clock from tests.server import FakeTransport from tests.unittest import HomeserverTestCase -def check_logcontext(context): +def check_logcontext(context: LoggingContextOrSentinel) -> None: current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4" - def test_client_get(self): + def test_client_get(self) -> None: """ happy-path test of a GET request """ @defer.inlineCallbacks - def do_request(): + def do_request() -> Generator["Deferred[object]", object, object]: with LoggingContext("one") as context: fetch_d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar") @@ -119,7 +128,7 @@ def do_request(): # check the response is as expected self.assertEqual(res, {"a": 1}) - def test_dns_error(self): + def test_dns_error(self) -> None: """ If the DNS lookup returns an error, it will bubble up. """ @@ -132,7 +141,7 @@ def test_dns_error(self): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) - def test_client_connection_refused(self): + def test_client_connection_refused(self) -> None: d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) ) @@ -156,7 +165,7 @@ def test_client_connection_refused(self): self.assertIsInstance(f.value, RequestSendFailed) self.assertIs(f.value.inner_exception, e) - def test_client_never_connect(self): + def test_client_never_connect(self) -> None: """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. @@ -188,7 +197,7 @@ def test_client_never_connect(self): f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) - def test_client_connect_no_response(self): + def test_client_connect_no_response(self) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. @@ -222,7 +231,7 @@ def test_client_connect_no_response(self): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) - def test_client_ip_range_blacklist(self): + def test_client_ip_range_blacklist(self) -> None: """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist @@ -292,7 +301,7 @@ def test_client_ip_range_blacklist(self): f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) - def test_client_gets_headers(self): + def test_client_gets_headers(self) -> None: """ Once the client gets the headers, _request returns successfully. """ @@ -319,7 +328,7 @@ def test_client_gets_headers(self): self.assertEqual(r.code, 200) @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) - def test_timeout_reading_body(self, method_name: str): + def test_timeout_reading_body(self, method_name: str) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a RequestSendFailed with can_retry. @@ -351,7 +360,7 @@ def test_timeout_reading_body(self, method_name: str): self.assertTrue(f.value.can_retry) self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) - def test_client_requires_trailing_slashes(self): + def test_client_requires_trailing_slashes(self) -> None: """ If a connection is made to a client but the client rejects it due to requiring a trailing slash. We need to retry the request with a @@ -405,7 +414,7 @@ def test_client_requires_trailing_slashes(self): r = self.successResultOf(d) self.assertEqual(r, {}) - def test_client_does_not_retry_on_400_plus(self): + def test_client_does_not_retry_on_400_plus(self) -> None: """ Another test for trailing slashes but now test that we don't retry on trailing slashes on a non-400/M_UNRECOGNIZED response. @@ -450,7 +459,7 @@ def test_client_does_not_retry_on_400_plus(self): # We should get a 404 failure response self.failureResultOf(d) - def test_client_sends_body(self): + def test_client_sends_body(self) -> None: defer.ensureDeferred( self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} @@ -474,7 +483,7 @@ def test_client_sends_body(self): content = request.content.read() self.assertEqual(content, b'{"a":"b"}') - def test_closes_connection(self): + def test_closes_connection(self) -> None: """Check that the client closes unused HTTP connections""" d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) @@ -514,7 +523,7 @@ def test_closes_connection(self): self.assertTrue(conn.disconnecting) @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)]) - def test_json_error(self, return_value): + def test_json_error(self, return_value: bytes) -> None: """ Test what happens if invalid JSON is returned from the remote endpoint. """ @@ -560,7 +569,7 @@ def test_json_error(self, return_value): f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) - def test_too_big(self): + def test_too_big(self) -> None: """ Test what happens if a huge response is returned from the remote endpoint. """ diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 2db77c6a7345..a817940730fd 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -14,7 +14,7 @@ import base64 import logging import os -from typing import Iterable, Optional +from typing import List, Optional from unittest.mock import patch import treq @@ -22,7 +22,11 @@ from parameterized import parameterized from twisted.internet import interfaces # noqa: F401 -from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint +from twisted.internet.endpoints import ( + HostnameEndpoint, + _WrapperEndpoint, + _WrappingProtocol, +) from twisted.internet.interfaces import IProtocol, IProtocolFactory from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol @@ -32,7 +36,11 @@ from synapse.http.connectproxyclient import ProxyCredentials from synapse.http.proxyagent import ProxyAgent, parse_proxy -from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.http import ( + TestServerTLSConnectionFactory, + dummy_address, + get_test_https_policy, +) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase @@ -183,7 +191,7 @@ def test_parse_proxy( expected_hostname: bytes, expected_port: int, expected_credentials: Optional[bytes], - ): + ) -> None: """ Tests that a given proxy URL will be broken into the components. Args: @@ -209,7 +217,7 @@ def test_parse_proxy( class MatrixFederationAgentTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() def _make_connection( @@ -218,7 +226,7 @@ def _make_connection( server_factory: IProtocolFactory, ssl: bool = False, expected_sni: Optional[bytes] = None, - tls_sanlist: Optional[Iterable[bytes]] = None, + tls_sanlist: Optional[List[bytes]] = None, ) -> IProtocol: """Builds a test server, and completes the outgoing client connection @@ -244,7 +252,8 @@ def _make_connection( if ssl: server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) - server_protocol = server_factory.buildProtocol(None) + server_protocol = server_factory.buildProtocol(dummy_address) + assert server_protocol is not None # now, tell the client protocol factory to build the client protocol, # and wire the output of said protocol up to the server via @@ -252,7 +261,8 @@ def _make_connection( # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) + client_protocol = client_factory.buildProtocol(dummy_address) + assert client_protocol is not None client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -263,6 +273,7 @@ def _make_connection( ) if ssl: + assert isinstance(server_protocol, TLSMemoryBIOProtocol) http_protocol = server_protocol.wrappedProtocol tls_connection = server_protocol._tlsConnection else: @@ -288,7 +299,7 @@ def _test_request_direct_connection( scheme: bytes, hostname: bytes, path: bytes, - ): + ) -> None: """Runs a test case for a direct connection not going through a proxy. Args: @@ -319,6 +330,7 @@ def _test_request_direct_connection( ssl=is_https, expected_sni=hostname if is_https else None, ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -339,34 +351,34 @@ def _test_request_direct_connection( body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - def test_http_request(self): + def test_http_request(self) -> None: agent = ProxyAgent(self.reactor) self._test_request_direct_connection(agent, b"http", b"test.com", b"") - def test_https_request(self): + def test_https_request(self) -> None: agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") - def test_http_request_use_proxy_empty_environment(self): + def test_http_request_use_proxy_empty_environment(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"}) - def test_http_request_via_uppercase_no_proxy(self): + def test_http_request_via_uppercase_no_proxy(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"} ) - def test_http_request_via_no_proxy(self): + def test_http_request_via_no_proxy(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"} ) - def test_https_request_via_no_proxy(self): + def test_https_request_via_no_proxy(self) -> None: agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -375,12 +387,12 @@ def test_https_request_via_no_proxy(self): self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) - def test_http_request_via_no_proxy_star(self): + def test_http_request_via_no_proxy_star(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) - def test_https_request_via_no_proxy_star(self): + def test_https_request_via_no_proxy_star(self) -> None: agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -389,7 +401,7 @@ def test_https_request_via_no_proxy_star(self): self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) - def test_http_request_via_proxy(self): + def test_http_request_via_proxy(self) -> None: """ Tests that requests can be made through a proxy. """ @@ -401,7 +413,7 @@ def test_http_request_via_proxy(self): os.environ, {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, ) - def test_http_request_via_proxy_with_auth(self): + def test_http_request_via_proxy_with_auth(self) -> None: """ Tests that authenticated requests can be made through a proxy. """ @@ -412,7 +424,7 @@ def test_http_request_via_proxy_with_auth(self): @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) - def test_http_request_via_https_proxy(self): + def test_http_request_via_https_proxy(self) -> None: self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=None ) @@ -424,13 +436,13 @@ def test_http_request_via_https_proxy(self): "no_proxy": "unused.com", }, ) - def test_http_request_via_https_proxy_with_auth(self): + def test_http_request_via_https_proxy_with_auth(self) -> None: self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): + def test_https_request_via_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=None @@ -440,7 +452,7 @@ def test_https_request_via_proxy(self): os.environ, {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, ) - def test_https_request_via_proxy_with_auth(self): + def test_https_request_via_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" @@ -449,7 +461,7 @@ def test_https_request_via_proxy_with_auth(self): @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) - def test_https_request_via_https_proxy(self): + def test_https_request_via_https_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=None @@ -459,7 +471,7 @@ def test_https_request_via_https_proxy(self): os.environ, {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, ) - def test_https_request_via_https_proxy_with_auth(self): + def test_https_request_via_https_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" @@ -469,7 +481,7 @@ def _do_http_request_via_proxy( self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: @@ -501,6 +513,7 @@ def _do_http_request_via_proxy( tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -542,7 +555,7 @@ def _do_https_request_via_proxy( self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: @@ -606,10 +619,12 @@ def _do_https_request_via_proxy( # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() - ).buildProtocol(None) + ).buildProtocol(dummy_address) + assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport + assert proxy_server_transport is not None server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the @@ -644,6 +659,7 @@ def _do_https_request_via_proxy( # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -667,7 +683,7 @@ def _do_https_request_via_proxy( self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) - def test_http_request_via_proxy_with_blacklist(self): + def test_http_request_via_proxy_with_blacklist(self) -> None: # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper( @@ -691,6 +707,7 @@ def test_http_request_via_proxy_with_blacklist(self): http_server = self._make_connection( client_factory, _get_test_protocol_factory() ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -712,7 +729,7 @@ def test_http_request_via_proxy_with_blacklist(self): self.assertEqual(body, b"result") @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"}) - def test_https_request_via_uppercase_proxy_with_blacklist(self): + def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None: # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper( @@ -737,11 +754,15 @@ def test_https_request_via_uppercase_proxy_with_blacklist(self): proxy_server = self._make_connection( client_factory, _get_test_protocol_factory() ) + assert isinstance(proxy_server, HTTPChannel) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport + assert isinstance(s2c_transport, FakeTransport) client_protocol = s2c_transport.other + assert isinstance(client_protocol, _WrappingProtocol) c2s_transport = client_protocol.transport + assert isinstance(c2s_transport, FakeTransport) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -762,8 +783,10 @@ def test_https_request_via_uppercase_proxy_with_blacklist(self): # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) - ssl_protocol = ssl_factory.buildProtocol(None) + ssl_protocol = ssl_factory.buildProtocol(dummy_address) + assert isinstance(ssl_protocol, TLSMemoryBIOProtocol) http_server = ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol) @@ -797,28 +820,28 @@ def test_https_request_via_uppercase_proxy_with_blacklist(self): self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) - def test_proxy_with_no_scheme(self): + def test_proxy_with_no_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) - def test_proxy_with_unsupported_scheme(self): + def test_proxy_with_unsupported_scheme(self) -> None: with self.assertRaises(ValueError): ProxyAgent(self.reactor, use_proxy=True) @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) - def test_proxy_with_http_scheme(self): + def test_proxy_with_http_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) - def test_proxy_with_https_scheme(self): + def test_proxy_with_https_scheme(self) -> None: https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) + assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) self.assertEqual( https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" ) @@ -828,7 +851,7 @@ def test_proxy_with_https_scheme(self): def _wrap_server_factory_for_tls( - factory: IProtocolFactory, sanlist: Iterable[bytes] = None + factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None ) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory @@ -865,6 +888,6 @@ def _get_test_protocol_factory() -> IProtocolFactory: return server_factory -def _log_request(request: str): +def _log_request(request: str) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info(f"Completed request {request}") diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 46166292fed9..c8d215b6dcfe 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -14,7 +14,7 @@ import json from http import HTTPStatus from io import BytesIO -from typing import Tuple +from typing import Tuple, Union from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -33,7 +33,7 @@ from tests.http.server._base import test_disconnect -def make_request(content): +def make_request(content: Union[bytes, JsonDict]) -> Mock: """Make an object that acts enough like a request.""" request = Mock(spec=["method", "uri", "content"]) @@ -47,7 +47,7 @@ def make_request(content): class TestServletUtils(unittest.TestCase): - def test_parse_json_value(self): + def test_parse_json_value(self) -> None: """Basic tests for parse_json_value_from_request.""" # Test round-tripping. obj = {"foo": 1} @@ -78,7 +78,7 @@ def test_parse_json_value(self): with self.assertRaises(SynapseError): parse_json_value_from_request(make_request(b'{"foo": Infinity}')) - def test_parse_json_object(self): + def test_parse_json_object(self) -> None: """Basic tests for parse_json_object_from_request.""" # Test empty. result = parse_json_object_from_request( diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py index c85a3665c127..010601da4b93 100644 --- a/tests/http/test_simple_client.py +++ b/tests/http/test_simple_client.py @@ -17,22 +17,24 @@ from twisted.internet import defer from twisted.internet.error import DNSLookupError +from twisted.test.proto_helpers import MemoryReactor from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class SimpleHttpClientTests(HomeserverTestCase): - def prepare(self, reactor, clock, hs: "HomeServer"): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None: # Add a DNS entry for a test server self.reactor.lookups["testserv"] = "1.2.3.4" self.cl = hs.get_simple_http_client() - def test_dns_error(self): + def test_dns_error(self) -> None: """ If the DNS lookup returns an error, it will bubble up. """ @@ -42,7 +44,7 @@ def test_dns_error(self): f = self.failureResultOf(d) self.assertIsInstance(f.value, DNSLookupError) - def test_client_connection_refused(self): + def test_client_connection_refused(self) -> None: d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) self.pump() @@ -63,7 +65,7 @@ def test_client_connection_refused(self): self.assertIs(f.value, e) - def test_client_never_connect(self): + def test_client_never_connect(self) -> None: """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. @@ -90,7 +92,7 @@ def test_client_never_connect(self): self.assertIsInstance(f.value, RequestTimedOutError) - def test_client_connect_no_response(self): + def test_client_connect_no_response(self) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. @@ -121,7 +123,7 @@ def test_client_connect_no_response(self): self.assertIsInstance(f.value, RequestTimedOutError) - def test_client_ip_range_blacklist(self): + def test_client_ip_range_blacklist(self) -> None: """Ensure that Synapse does not try to connect to blacklisted IPs""" # Add some DNS entries we'll blacklist diff --git a/tests/http/test_site.py b/tests/http/test_site.py index b2dbf76d33b1..9a78fede9294 100644 --- a/tests/http/test_site.py +++ b/tests/http/test_site.py @@ -13,18 +13,20 @@ # limitations under the License. from twisted.internet.address import IPv6Address -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from synapse.app.homeserver import SynapseHomeServer +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class SynapseRequestTestCase(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer) - def test_large_request(self): + def test_large_request(self) -> None: """overlarge HTTP requests should be rejected""" self.hs.start_listening() diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py index 1acf5666a856..1c5de95a809f 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. import logging +from tests.unittest import TestCase -class LoggerCleanupMixin: - def get_logger(self, handler): + +class LoggerCleanupMixin(TestCase): + def get_logger(self, handler: logging.Handler) -> logging.Logger: """ Attach a handler to a logger and add clean-ups to remove revert this. """ diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 0917e478a5b0..e28ba84cc2b7 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -153,7 +153,7 @@ def test_overlapping_spans(self) -> None: scopes = [] - async def task(i: int): + async def task(i: int) -> None: scope = start_active_span( f"task{i}", tracer=self._tracer, @@ -165,7 +165,7 @@ async def task(i: int): self.assertEqual(self._tracer.active_span, scope.span) scope.close() - async def root(): + async def root() -> None: with start_active_span("root span", tracer=self._tracer) as root_scope: self.assertEqual(self._tracer.active_span, root_scope.span) scopes.append(root_scope) diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index b0d046fe0079..c08954d887cb 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.test.proto_helpers import AccumulatingProtocol +from typing import Tuple + +from twisted.internet.protocol import Protocol +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from synapse.logging import RemoteHandler @@ -20,7 +23,9 @@ from tests.unittest import TestCase -def connect_logging_client(reactor, client_id): +def connect_logging_client( + reactor: MemoryReactorClock, client_id: int +) -> Tuple[Protocol, AccumulatingProtocol]: # This is essentially tests.server.connect_client, but disabling autoflush on # the client transport. This is necessary to avoid an infinite loop due to # sending of data via the logging transport causing additional logs to be @@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id): class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, _ = get_clock() - def test_log_output(self): + def test_log_output(self) -> None: """ The remote handler delivers logs over TCP. """ @@ -51,6 +56,7 @@ def test_log_output(self): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent + assert isinstance(client.transport, FakeTransport) client.transport.flush() # One log message, with a single trailing newline @@ -61,7 +67,7 @@ def test_log_output(self): # Ensure the data passed through properly. self.assertEqual(logs[0], "Hello there, wally!") - def test_log_backpressure_debug(self): + def test_log_backpressure_debug(self) -> None: """ When backpressure is hit, DEBUG logs will be shed. """ @@ -83,6 +89,7 @@ def test_log_backpressure_debug(self): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # Only the 7 infos made it through, the debugs were elided @@ -90,7 +97,7 @@ def test_log_backpressure_debug(self): self.assertEqual(len(logs), 7) self.assertNotIn(b"debug", server.data) - def test_log_backpressure_info(self): + def test_log_backpressure_info(self) -> None: """ When backpressure is hit, DEBUG and INFO logs will be shed. """ @@ -116,6 +123,7 @@ def test_log_backpressure_info(self): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The 10 warnings made it through, the debugs and infos were elided @@ -124,7 +132,7 @@ def test_log_backpressure_info(self): self.assertNotIn(b"debug", server.data) self.assertNotIn(b"info", server.data) - def test_log_backpressure_cut_middle(self): + def test_log_backpressure_cut_middle(self) -> None: """ When backpressure is hit, and no more DEBUG and INFOs cannot be culled, it will cut the middle messages out. @@ -140,6 +148,7 @@ def test_log_backpressure_cut_middle(self): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The first five and last five warnings made it through, the debugs and @@ -151,7 +160,7 @@ def test_log_backpressure_cut_middle(self): logs, ) - def test_cancel_connection(self): + def test_cancel_connection(self) -> None: """ Gracefully handle the connection being cancelled. """ diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index ac1aacf36871..1dddcd209e88 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -14,24 +14,28 @@ import json import logging from io import BytesIO, StringIO +from typing import cast from unittest.mock import Mock, patch +from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.http.site import SynapseRequest from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging.context import LoggingContext, LoggingContextFilter +from synapse.types import JsonDict from tests.logging import LoggerCleanupMixin -from tests.server import FakeChannel +from tests.server import FakeChannel, get_clock from tests.unittest import TestCase class TerseJsonTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.output = StringIO() + self.reactor, _ = get_clock() - def get_log_line(self): + def get_log_line(self) -> JsonDict: # One log message, with a single trailing newline. data = self.output.getvalue() logs = data.splitlines() @@ -39,7 +43,7 @@ def get_log_line(self): self.assertEqual(data.count("\n"), 1) return json.loads(logs[0]) - def test_terse_json_output(self): + def test_terse_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -61,7 +65,7 @@ def test_terse_json_output(self): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_extra_data(self): + def test_extra_data(self) -> None: """ Additional information can be included in the structured logging. """ @@ -93,7 +97,7 @@ def test_extra_data(self): self.assertEqual(log["int"], 3) self.assertIs(log["bool"], True) - def test_json_output(self): + def test_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -114,7 +118,7 @@ def test_json_output(self): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_with_context(self): + def test_with_context(self) -> None: """ The logging context should be added to the JSON response. """ @@ -139,7 +143,7 @@ def test_with_context(self): self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "name") - def test_with_request_context(self): + def test_with_request_context(self) -> None: """ Information from the logging context request should be added to the JSON response. """ @@ -154,11 +158,13 @@ def test_with_request_context(self): site.server_version_string = "Server v1" site.reactor = Mock() site.experimental_cors_msc3886 = False - request = SynapseRequest(FakeChannel(site, None), site) + request = SynapseRequest( + cast(HTTPChannel, FakeChannel(site, self.reactor)), site + ) # Call requestReceived to finish instantiating the object. request.content = BytesIO() - # Partially skip some of the internal processing of SynapseRequest. - request._started_processing = Mock() + # Partially skip some internal processing of SynapseRequest. + request._started_processing = Mock() # type: ignore[assignment] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") @@ -201,7 +207,7 @@ def test_with_request_context(self): self.assertEqual(log["protocol"], "1.1") self.assertEqual(log["user_agent"], "") - def test_with_exception(self): + def test_with_exception(self) -> None: """ The logging exception type & value should be added to the JSON response. """ diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index b0f3f4374da8..8f88c0117d78 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -110,6 +110,24 @@ def test_can_set_admin(self): self.assertEqual(found_user.user_id.to_string(), user_id) self.assertIdentical(found_user.is_admin, True) + def test_can_set_displayname(self): + localpart = "alice_wants_a_new_displayname" + user_id = self.register_user( + localpart, "1234", displayname="Alice", admin=False + ) + found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + + self.get_success( + self.module_api.set_displayname( + found_userinfo.user_id, "Bob", deactivation=False + ) + ) + found_profile = self.get_success( + self.module_api.get_profile_for_user(localpart) + ) + + self.assertEqual(found_profile.display_name, "Bob") + def test_get_userinfo_by_id(self): user_id = self.register_user("alice", "1234") found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) @@ -386,6 +404,9 @@ def test_send_local_online_presence_to_federation(self): self.module_api.send_local_online_presence_to([remote_user_id]) ) + # We don't always send out federation immediately, so we advance the clock. + self.reactor.advance(1000) + # Check that a presence update was sent as part of a federation transaction found_update = False calls = ( diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 1cd453248ec2..7567756135b7 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -1,10 +1,32 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional from unittest.mock import patch +from parameterized import parameterized + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin from synapse.rest.client import login, register, room -from synapse.types import create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, create_requester +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -19,54 +41,109 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): register.register_servlets, ] - def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: - """We should convert floats and strings to integers before passing to Rust. + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Create a new user and room. + self.alice = self.register_user("alice", "pass") + self.token = self.login(self.alice, "pass") + self.requester = create_requester(self.alice) + + self.room_id = self.helper.create_room_as( + # This is deliberately set to V9, because we want to test the logic which + # handles stringy power levels. Stringy power levels were outlawed in V10. + self.alice, + room_version=RoomVersions.V9.identifier, + tok=self.token, + ) + + self.event_creation_handler = self.hs.get_event_creation_handler() + + @parameterized.expand( + [ + # The historically-permitted bad values. Alice's notification should be + # allowed if this threshold is at or below her power level (60) + ("100", False), + ("0", True), + (12.34, True), + (60.0, True), + (67.89, False), + # Values that int(...) would not successfully cast should be ignored. + # The room notification level should then default to 50, per the spec, so + # Alice's notification is allowed. + (None, True), + # We haven't seen `"room": []` or `"room": {}` in the wild (yet), but + # let's check them for paranoia's sake. + ([], True), + ({}, True), + ] + ) + def test_action_for_event_by_user_handles_noninteger_room_power_levels( + self, bad_room_level: object, should_permit: bool + ) -> None: + """We should convert strings in `room` to integers before passing to Rust. + + Test this as follows: + - Create a room as Alice and invite two other users Bob and Charlie. + - Set PLs so that Alice has PL 60 and `notifications.room` is set to a bad value. + - Have Alice create a message notifying @room. + - Evaluate notification actions for that message. This should not raise. + - Look in the DB to see if that message triggered a highlight for Bob. + + The test is parameterised with two arguments: + - the bad power level value for "room", before JSON serisalistion + - whether Bob should expect the message to be highlighted Reproduces #14060. A lack of validation: the gift that keeps on giving. """ - # Create a new user and room. - alice = self.register_user("alice", "pass") - token = self.login(alice, "pass") + # Join another user to the room, so that there is someone to see Alice's + # @room notification. + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + self.helper.join(self.room_id, bob, tok=bob_token) - room_id = self.helper.create_room_as( - alice, room_version=RoomVersions.V9.identifier, tok=token - ) - - # Alter the power levels in that room to include stringy and floaty levels. - # We need to suppress the validation logic or else it will reject these dodgy - # values. (Presumably this validation was not always present.) - event_creation_handler = self.hs.get_event_creation_handler() - requester = create_requester(alice) + # Alter the power levels in that room to include the bad @room notification + # level. We need to suppress + # + # - canonicaljson validation, because canonicaljson forbids floats; + # - the event jsonschema validation, because it will forbid bad values; and + # - the auth rules checks, because they stop us from creating power levels + # with `"room": null`. (We want to test this case, because we have seen it + # in the wild.) + # + # We have seen stringy and null values for "room" in the wild, so presumably + # some of this validation was missing in the past. with patch("synapse.events.validator.validate_canonicaljson"), patch( "synapse.events.validator.jsonschema.validate" - ): - self.helper.send_state( - room_id, + ), patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"): + pl_event_id = self.helper.send_state( + self.room_id, "m.room.power_levels", { - "users": {alice: "100"}, # stringy - "notifications": {"room": 100.0}, # float + "users": {self.alice: 60}, + "notifications": {"room": bad_room_level}, }, - token, + self.token, state_key="", - ) + )["event_id"] # Create a new message event, and try to evaluate it under the dodgy # power level event. event, context = self.get_success( - event_creation_handler.create_event( - requester, + self.event_creation_handler.create_event( + self.requester, { "type": "m.room.message", - "room_id": room_id, + "room_id": self.room_id, "content": { "msgtype": "m.text", - "body": "helo", + "body": "helo @room", }, - "sender": alice, + "sender": self.alice, }, + prev_event_ids=[pl_event_id], ) ) @@ -74,42 +151,262 @@ def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: # should not raise self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + # Did Bob see Alice's @room notification? + highlighted_actions = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + "user_id": bob, + "highlight": 1, + }, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + self.assertEqual(len(highlighted_actions), int(should_permit)) + @override_config({"push": {"enabled": False}}) def test_action_for_event_by_user_disabled_by_config(self) -> None: """Ensure that push rules are not calculated when disabled in the config""" - # Create a new user and room. - alice = self.register_user("alice", "pass") - token = self.login(alice, "pass") - - room_id = self.helper.create_room_as( - alice, room_version=RoomVersions.V9.identifier, tok=token - ) - - # Alter the power levels in that room to include stringy and floaty levels. - # We need to suppress the validation logic or else it will reject these dodgy - # values. (Presumably this validation was not always present.) - event_creation_handler = self.hs.get_event_creation_handler() - requester = create_requester(alice) - # Create a new message event, and try to evaluate it under the dodgy - # power level event. + # Create a new message event which should cause a notification. event, context = self.get_success( - event_creation_handler.create_event( - requester, + self.event_creation_handler.create_event( + self.requester, { "type": "m.room.message", - "room_id": room_id, + "room_id": self.room_id, "content": { "msgtype": "m.text", "body": "helo", }, - "sender": alice, + "sender": self.alice, }, ) ) bulk_evaluator = BulkPushRuleEvaluator(self.hs) + # Mock the method which calculates push rules -- we do this instead of + # e.g. checking the results in the database because we want to ensure + # that code isn't even running. bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] - # should not raise + + # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) bulk_evaluator._action_for_event_by_user.assert_not_called() + + def _create_and_process( + self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None + ) -> bool: + """Returns true iff the `mentions` trigger an event push action.""" + # Create a new message event which should cause a notification. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "test", + "room_id": self.room_id, + "content": content or {}, + "sender": f"@bob:{self.hs.hostname}", + }, + ) + ) + + # Execute the push rule machinery. + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + + # If any actions are generated for this event, return true. + result = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + return len(result) > 0 + + @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + def test_user_mentions(self) -> None: + """Test the behavior of an event which includes invalid user mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + # Not including the mentions field should not notify. + self.assertFalse(self._create_and_process(bulk_evaluator)) + # An empty mentions field should not notify. + self.assertFalse( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {}} + ) + ) + + # Non-dict mentions should be ignored. + mentions: Any + for mentions in (None, True, False, 1, "foo", []): + self.assertFalse( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: mentions} + ) + ) + + # A non-list should be ignored. + for mentions in (None, True, False, 1, "foo", {}): + self.assertFalse( + self._create_and_process( + bulk_evaluator, + {EventContentFields.MSC3952_MENTIONS: {"user_ids": mentions}}, + ) + ) + + # The Matrix ID appearing anywhere in the list should notify. + self.assertTrue( + self._create_and_process( + bulk_evaluator, + {EventContentFields.MSC3952_MENTIONS: {"user_ids": [self.alice]}}, + ) + ) + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": ["@another:test", self.alice] + } + }, + ) + ) + + # Duplicate user IDs should notify. + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": [self.alice, self.alice] + } + }, + ) + ) + + # Invalid entries in the list are ignored. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": [None, True, False, {}, []] + } + }, + ) + ) + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": [None, True, False, {}, [], self.alice] + } + }, + ) + ) + + # The legacy push rule should not mention if the mentions field exists. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "body": self.alice, + "msgtype": "m.text", + EventContentFields.MSC3952_MENTIONS: {}, + }, + ) + ) + + @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + def test_room_mentions(self) -> None: + """Test the behavior of an event which includes invalid room mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + # Room mentions from those without power should not notify. + self.assertFalse( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}} + ) + ) + + # Room mentions from those with power should notify. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"notifications": {"room": 0}}, + self.token, + state_key="", + ) + self.assertTrue( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}} + ) + ) + + # Invalid data should not notify. + mentions: Any + for mentions in (None, False, 1, "foo", [], {}): + self.assertFalse( + self._create_and_process( + bulk_evaluator, + {EventContentFields.MSC3952_MENTIONS: {"room": mentions}}, + ) + ) + + # The legacy push rule should not mention if the mentions field exists. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "body": "@room", + "msgtype": "m.text", + EventContentFields.MSC3952_MENTIONS: {}, + }, + ) + ) + + @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}}) + def test_suppress_edits(self) -> None: + """Under the default push rules, event edits should not generate notifications.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + # Create & persist an event to use as the parent of the relation. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "m.room.message", + "room_id": self.room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": self.alice, + }, + ) + ) + self.get_success( + self.event_creation_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) + ) + + # Room mentions from those without power should not notify. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "body": self.alice, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event.event_id, + }, + }, + ) + ) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 57b2f0536e4a..ab8bb417e759 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -13,25 +13,28 @@ # limitations under the License. import email.message import os -from typing import Dict, List, Sequence, Tuple +from typing import Any, Dict, List, Sequence, Tuple import attr import pkg_resources from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase -@attr.s +@attr.s(auto_attribs=True) class _User: "Helper wrapper for user ID and access token" - id = attr.ib() - token = attr.ib() + id: str + token: str class EmailPusherTests(HomeserverTestCase): @@ -41,10 +44,9 @@ class EmailPusherTests(HomeserverTestCase): room.register_servlets, login.register_servlets, ] - user_id = True hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["email"] = { @@ -72,17 +74,17 @@ def make_homeserver(self, reactor, clock): # List[Tuple[Deferred, args, kwargs]] self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = [] - def sendmail(*args, **kwargs): + def sendmail(*args: Any, **kwargs: Any) -> Deferred: # This mocks out synapse.reactor.send_email._sendmail. - d = Deferred() + d: Deferred = Deferred() self.email_attempts.append((d, args, kwargs)) return d - hs.get_send_email_handler()._sendmail = sendmail + hs.get_send_email_handler()._sendmail = sendmail # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Register the user who gets notified self.user_id = self.register_user("user", "pass") self.access_token = self.login("user", "pass") @@ -129,7 +131,7 @@ def prepare(self, reactor, clock, hs): self.auth_handler = hs.get_auth_handler() self.store = hs.get_datastores().main - def test_need_validated_email(self): + def test_need_validated_email(self) -> None: """Test that we can only add an email pusher if the user has validated their email. """ @@ -151,7 +153,7 @@ def test_need_validated_email(self): self.assertEqual(400, cm.exception.code) self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode) - def test_simple_sends_email(self): + def test_simple_sends_email(self) -> None: # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( @@ -171,7 +173,7 @@ def test_simple_sends_email(self): self._check_for_mail() - def test_invite_sends_email(self): + def test_invite_sends_email(self) -> None: # Create a room and invite the user to it room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) self.helper.invite( @@ -184,7 +186,7 @@ def test_invite_sends_email(self): # We should get emailed about the invite self._check_for_mail() - def test_invite_to_empty_room_sends_email(self): + def test_invite_to_empty_room_sends_email(self) -> None: # Create a room and invite the user to it room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) self.helper.invite( @@ -200,7 +202,7 @@ def test_invite_to_empty_room_sends_email(self): # We should get emailed about the invite self._check_for_mail() - def test_multiple_members_email(self): + def test_multiple_members_email(self) -> None: # We want to test multiple notifications, so we pause processing of push # while we send messages. self.pusher._pause_processing() @@ -227,7 +229,7 @@ def test_multiple_members_email(self): # We should get emailed about those messages self._check_for_mail() - def test_multiple_rooms(self): + def test_multiple_rooms(self) -> None: # We want to test multiple notifications from multiple rooms, so we pause # processing of push while we send messages. self.pusher._pause_processing() @@ -257,7 +259,7 @@ def test_multiple_rooms(self): # We should get emailed about those messages self._check_for_mail() - def test_room_notifications_include_avatar(self): + def test_room_notifications_include_avatar(self) -> None: # Create a room and set its avatar. room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.send_state( @@ -290,7 +292,7 @@ def test_room_notifications_include_avatar(self): ) self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html) - def test_empty_room(self): + def test_empty_room(self) -> None: """All users leaving a room shouldn't cause the pusher to break.""" # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) @@ -309,7 +311,7 @@ def test_empty_room(self): # We should get emailed about that message self._check_for_mail() - def test_empty_room_multiple_messages(self): + def test_empty_room_multiple_messages(self) -> None: """All users leaving a room shouldn't cause the pusher to break.""" # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) @@ -329,7 +331,7 @@ def test_empty_room_multiple_messages(self): # We should get emailed about that message self._check_for_mail() - def test_encrypted_message(self): + def test_encrypted_message(self) -> None: room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id @@ -342,7 +344,7 @@ def test_encrypted_message(self): # We should get emailed about that message self._check_for_mail() - def test_no_email_sent_after_removed(self): + def test_no_email_sent_after_removed(self) -> None: # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( @@ -379,7 +381,7 @@ def test_no_email_sent_after_removed(self): pushers = list(pushers) self.assertEqual(len(pushers), 0) - def test_remove_unlinked_pushers_background_job(self): + def test_remove_unlinked_pushers_background_job(self) -> None: """Checks that all existing pushers associated with unlinked email addresses are removed upon running the remove_deleted_email_pushers background update. """ diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d89c3ac313c6..1e4f97406a10 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -46,7 +46,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: m = Mock() - def post_json_get_json(url, body): + def post_json_get_json(url: str, body: JsonDict) -> Deferred: d: Deferred = Deferred() self.push_attempts.append((d, url, body)) return make_deferred_yieldable(d) diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py index aff563919d43..d37f8ce26236 100644 --- a/tests/push/test_presentable_names.py +++ b/tests/push/test_presentable_names.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, cast from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import EventBase, FrozenEvent from synapse.push.presentable_names import calculate_room_name from synapse.types import StateKey, StateMap @@ -51,13 +51,15 @@ def __init__(self, events: Iterable[Tuple[StateKey, dict]]): ) async def get_event( - self, event_id: StateKey, allow_none: bool = False + self, event_id: str, allow_none: bool = False ) -> Optional[FrozenEvent]: assert allow_none, "Mock not configured for allow_none = False" - return self._events.get(event_id) + # Decode the state key from the event ID. + state_key = cast(Tuple[str, str], tuple(event_id.split("|", 1))) + return self._events.get(state_key) - async def get_events(self, event_ids: Iterable[StateKey]): + async def get_events(self, event_ids: Iterable[StateKey]) -> StateMap[EventBase]: # This is cheating since it just returns all events. return self._events @@ -68,17 +70,17 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase): def _calculate_room_name( self, - events: StateMap[dict], + events: Iterable[Tuple[Tuple[str, str], dict]], user_id: str = "", fallback_to_members: bool = True, fallback_to_single_member: bool = True, - ): - # This isn't 100% accurate, but works with MockDataStore. - room_state_ids = {k[0]: k[0] for k in events} + ) -> Optional[str]: + # Encode the state key into the event ID. + room_state_ids = {k[0]: "|".join(k[0]) for k in events} return self.get_success( calculate_room_name( - MockDataStore(events), + MockDataStore(events), # type: ignore[arg-type] room_state_ids, user_id or self.USER_ID, fallback_to_members, @@ -86,9 +88,9 @@ def _calculate_room_name( ) ) - def test_name(self): + def test_name(self) -> None: """A room name event should be used.""" - events = [ + events: List[Tuple[Tuple[str, str], dict]] = [ ((EventTypes.Name, ""), {"name": "test-name"}), ] self.assertEqual("test-name", self._calculate_room_name(events)) @@ -100,9 +102,9 @@ def test_name(self): events = [((EventTypes.Name, ""), {"name": 1})] self.assertEqual(1, self._calculate_room_name(events)) - def test_canonical_alias(self): + def test_canonical_alias(self) -> None: """An canonical alias should be used.""" - events = [ + events: List[Tuple[Tuple[str, str], dict]] = [ ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}), ] self.assertEqual("#test-name:test", self._calculate_room_name(events)) @@ -114,9 +116,9 @@ def test_canonical_alias(self): events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})] self.assertEqual("Empty Room", self._calculate_room_name(events)) - def test_invite(self): + def test_invite(self) -> None: """An invite has special behaviour.""" - events = [ + events: List[Tuple[Tuple[str, str], dict]] = [ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}), ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}), ] @@ -140,9 +142,9 @@ def test_invite(self): ] self.assertEqual("Room Invite", self._calculate_room_name(events)) - def test_no_members(self): + def test_no_members(self) -> None: """Behaviour of an empty room.""" - events = [] + events: List[Tuple[Tuple[str, str], dict]] = [] self.assertEqual("Empty Room", self._calculate_room_name(events)) # Note that events with invalid (or missing) membership are ignored. @@ -152,7 +154,7 @@ def test_no_members(self): ] self.assertEqual("Empty Room", self._calculate_room_name(events)) - def test_no_other_members(self): + def test_no_other_members(self) -> None: """Behaviour of a room with no other members in it.""" events = [ ( @@ -185,7 +187,7 @@ def test_no_other_members(self): self._calculate_room_name(events, user_id=self.OTHER_USER_ID), ) - def test_one_other_member(self): + def test_one_other_member(self) -> None: """Behaviour of a room with a single other member.""" events = [ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}), @@ -209,7 +211,7 @@ def test_one_other_member(self): ] self.assertEqual("@user:test", self._calculate_room_name(events)) - def test_other_members(self): + def test_other_members(self) -> None: """Behaviour of a room with multiple other members.""" # Two other members. events = [ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index e48d2f8d1910..297430d9e582 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union, cast import frozendict @@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, make_event_from_dict from synapse.push.bulk_push_rule_evaluator import _flatten_dict from synapse.push.httppusher import tweaks_for_actions from synapse.rest import admin @@ -30,16 +30,107 @@ from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator, PushRules -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from synapse.util import Clock from tests import unittest from tests.test_utils.event_injection import create_event, inject_member_event +class FlattenDictTestCase(unittest.TestCase): + def test_simple(self) -> None: + """Test a dictionary that isn't modified.""" + input = {"foo": "abc"} + self.assertEqual(input, _flatten_dict(input)) + + def test_nested(self) -> None: + """Nested dictionaries become dotted paths.""" + input = {"foo": {"bar": "abc"}} + self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input)) + + def test_non_string(self) -> None: + """Non-string items are dropped.""" + input: Dict[str, Any] = { + "woo": "woo", + "foo": True, + "bar": 1, + "baz": None, + "fuzz": [], + "boo": {}, + } + self.assertEqual({"woo": "woo"}, _flatten_dict(input)) + + def test_event(self) -> None: + """Events can also be flattened.""" + event = make_event_from_dict( + { + "room_id": "!test:test", + "type": "m.room.message", + "sender": "@alice:test", + "content": { + "msgtype": "m.text", + "body": "Hello world!", + "format": "org.matrix.custom.html", + "formatted_body": "

Hello world!

", + }, + }, + room_version=RoomVersions.V8, + ) + expected = { + "content.msgtype": "m.text", + "content.body": "hello world!", + "content.format": "org.matrix.custom.html", + "content.formatted_body": "

hello world!

", + "room_id": "!test:test", + "sender": "@alice:test", + "type": "m.room.message", + } + self.assertEqual(expected, _flatten_dict(event)) + + def test_extensible_events(self) -> None: + """Extensible events has compatibility behaviour.""" + event_dict = { + "room_id": "!test:test", + "type": "m.room.message", + "sender": "@alice:test", + "content": { + "org.matrix.msc1767.markup": [ + {"mimetype": "text/plain", "body": "Hello world!"}, + {"mimetype": "text/html", "body": "

Hello world!

"}, + ] + }, + } + + # For a current room version, there's no special behavior. + event = make_event_from_dict(event_dict, room_version=RoomVersions.V8) + expected = { + "room_id": "!test:test", + "sender": "@alice:test", + "type": "m.room.message", + } + self.assertEqual(expected, _flatten_dict(event)) + + # For a room version with extensible events, they parse out the text/plain + # to a content.body property. + event = make_event_from_dict(event_dict, room_version=RoomVersions.MSC1767v10) + expected = { + "content.body": "hello world!", + "room_id": "!test:test", + "sender": "@alice:test", + "type": "m.room.message", + } + self.assertEqual(expected, _flatten_dict(event)) + + class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( - self, content: JsonDict, related_events=None + self, + content: JsonMapping, + *, + has_mentions: bool = False, + user_mentions: Optional[Set[str]] = None, + room_mention: bool = False, + related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -57,22 +148,23 @@ def _get_evaluator( power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluator( _flatten_dict(event), + has_mentions, + user_mentions or set(), + room_mention, room_member_count, sender_power_level, - power_levels.get("notifications", {}), + cast(Dict[str, int], power_levels.get("notifications", {})), {} if related_events is None else related_events, - True, - event.room_version.msc3931_push_features, - True, + related_event_match_enabled=True, + room_version_feature_flags=event.room_version.msc3931_push_features, + msc3931_enabled=True, ) def test_display_name(self) -> None: """Check for a matching display name in the body of the event.""" evaluator = self._get_evaluator({"body": "foo bar baz"}) - condition = { - "kind": "contains_display_name", - } + condition = {"kind": "contains_display_name"} # Blank names are skipped. self.assertFalse(evaluator.matches(condition, "@user:test", "")) @@ -92,8 +184,55 @@ def test_display_name(self) -> None: # A display name with spaces should work fine. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + def test_user_mentions(self) -> None: + """Check for user mentions.""" + condition = {"kind": "org.matrix.msc3952.is_user_mention"} + + # No mentions shouldn't match. + evaluator = self._get_evaluator({}, has_mentions=True) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # An empty set shouldn't match + evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set()) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # The Matrix ID appearing anywhere in the mentions list should match + evaluator = self._get_evaluator( + {}, has_mentions=True, user_mentions={"@user:test"} + ) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + evaluator = self._get_evaluator( + {}, has_mentions=True, user_mentions={"@another:test", "@user:test"} + ) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + + def test_room_mentions(self) -> None: + """Check for room mentions.""" + condition = {"kind": "org.matrix.msc3952.is_room_mention"} + + # No room mention shouldn't match. + evaluator = self._get_evaluator({}, has_mentions=True) + self.assertFalse(evaluator.matches(condition, None, None)) + + # Room mention should match. + evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True) + self.assertTrue(evaluator.matches(condition, None, None)) + + # A room mention and user mention is valid. + evaluator = self._get_evaluator( + {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True + ) + self.assertTrue(evaluator.matches(condition, None, None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + def _assert_matches( - self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None + self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) @@ -287,7 +426,7 @@ def test_tweaks_for_actions(self) -> None: This tests the behaviour of tweaks_for_actions. """ - actions = [ + actions: List[Union[Dict[str, str], str]] = [ {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}, "notify", @@ -298,7 +437,7 @@ def test_tweaks_for_actions(self) -> None: {"sound": "default", "highlight": True}, ) - def test_related_event_match(self): + def test_related_event_match(self) -> None: evaluator = self._get_evaluator( { "m.relates_to": { @@ -310,7 +449,7 @@ def test_related_event_match(self): }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", @@ -397,7 +536,7 @@ def test_related_event_match(self): ) ) - def test_related_event_match_with_fallback(self): + def test_related_event_match_with_fallback(self) -> None: evaluator = self._get_evaluator( { "m.relates_to": { @@ -410,7 +549,7 @@ def test_related_event_match_with_fallback(self): }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", @@ -469,7 +608,7 @@ def test_related_event_match_with_fallback(self): ) ) - def test_related_event_match_no_related_event(self): + def test_related_event_match_no_related_event(self) -> None: evaluator = self._get_evaluator( {"msgtype": "m.text", "body": "Message without related event"} ) @@ -518,7 +657,9 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: # Define an application service so that we can register appservice users self._service_token = "some_token" self._service = ApplicationService( diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 6a7174b333c9..46a8e2013e41 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -16,7 +16,9 @@ from typing import Any, Dict, List, Optional, Set, Tuple from twisted.internet.address import IPv4Address -from twisted.internet.protocol import Protocol +from twisted.internet.protocol import Protocol, connectionDone +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer @@ -30,6 +32,7 @@ ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeTransport @@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): if not hiredis: skip = "Requires hiredis" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() @@ -92,8 +95,8 @@ def prepare(self, reactor, clock, hs): repl_handler, ) - self._client_transport = None - self._server_transport = None + self._client_transport: Optional[FakeTransport] = None + self._server_transport: Optional[FakeTransport] = None def create_resource_dict(self) -> Dict[str, Resource]: d = super().create_resource_dict() @@ -107,10 +110,10 @@ def _get_worker_hs_config(self) -> dict: config["worker_replication_http_port"] = "8765" return config - def _build_replication_data_handler(self): + def _build_replication_data_handler(self) -> "TestReplicationDataHandler": return TestReplicationDataHandler(self.worker_hs) - def reconnect(self): + def reconnect(self) -> None: if self._client_transport: self.client.close() @@ -123,7 +126,7 @@ def reconnect(self): self._server_transport = FakeTransport(self.client, self.reactor) self.server.makeConnection(self._server_transport) - def disconnect(self): + def disconnect(self) -> None: if self._client_transport: self._client_transport = None self.client.close() @@ -132,7 +135,7 @@ def disconnect(self): self._server_transport = None self.server.close() - def replicate(self): + def replicate(self) -> None: """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ @@ -168,7 +171,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest: requests: List[SynapseRequest] = [] real_request_factory = channel.requestFactory - def request_factory(*args, **kwargs): + def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest: request = real_request_factory(*args, **kwargs) requests.append(request) return request @@ -202,7 +205,7 @@ def request_factory(*args, **kwargs): def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str - ): + ) -> None: """Asserts that the given request is a HTTP replication request for fetching updates for given stream. """ @@ -244,7 +247,7 @@ def default_config(self) -> Dict[str, Any]: base["redis"] = {"enabled": True} return base - def setUp(self): + def setUp(self) -> None: super().setUp() # build a replication server @@ -287,7 +290,7 @@ def setUp(self): lambda: self._handle_http_replication_attempt(self.hs, 8765), ) - def create_test_resource(self): + def create_test_resource(self) -> ReplicationRestResource: """Overrides `HomeserverTestCase.create_test_resource`.""" # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all @@ -301,7 +304,7 @@ def create_test_resource(self): return resource def make_worker_hs( - self, worker_app: str, extra_config: Optional[dict] = None, **kwargs + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any ) -> HomeServer: """Make a new worker HS instance, correctly connecting replcation stream to the master HS. @@ -385,14 +388,14 @@ def _get_worker_hs_config(self) -> dict: config["worker_replication_http_port"] = "8765" return config - def replicate(self): + def replicate(self) -> None: """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self, hs, repl_port): + def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None: """Handles a connection attempt to the given HS replication HTTP listener on the given port. """ @@ -429,7 +432,7 @@ def _handle_http_replication_attempt(self, hs, repl_port): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. - def connect_any_redis_attempts(self): + def connect_any_redis_attempts(self) -> None: """If redis is enabled we need to deal with workers connecting to a redis server. We don't want to use a real Redis server so we use a fake one. @@ -440,8 +443,11 @@ def connect_any_redis_attempts(self): self.assertEqual(host, "localhost") self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_address = IPv4Address("TCP", "127.0.0.1", 6379) + client_protocol = client_factory.buildProtocol(client_address) + + server_address = IPv4Address("TCP", host, port) + server_protocol = self._redis_server.buildProtocol(server_address) client_to_server_transport = FakeTransport( server_protocol, self.reactor, client_protocol @@ -463,7 +469,9 @@ def __init__(self, hs: HomeServer): # list of received (stream_name, token, row) tuples self.received_rdata_rows: List[Tuple[str, int, Any]] = [] - async def on_rdata(self, stream_name, instance_name, token, rows): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ) -> None: await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) @@ -472,28 +480,30 @@ async def on_rdata(self, stream_name, instance_name, token, rows): class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" - def __init__(self): + def __init__(self) -> None: self._subscribers_by_channel: Dict[ bytes, Set["FakeRedisPubSubProtocol"] ] = defaultdict(set) - def add_subscriber(self, conn, channel: bytes): + def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None: """A connection has called SUBSCRIBE""" self._subscribers_by_channel[channel].add(conn) - def remove_subscriber(self, conn): + def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None: """A connection has lost connection""" for subscribers in self._subscribers_by_channel.values(): subscribers.discard(conn) - def publish(self, conn, channel: bytes, msg) -> int: + def publish( + self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object + ) -> int: """A connection want to publish a message to subscribers.""" for sub in self._subscribers_by_channel[channel]: sub.send(["message", channel, msg]) return len(self._subscribers_by_channel) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol": return FakeRedisPubSubProtocol(self) @@ -506,7 +516,7 @@ def __init__(self, server: FakeRedisPubSubServer): self._server = server self._reader = hiredis.Reader() - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self._reader.feed(data) # We might get multiple messages in one packet. @@ -523,7 +533,7 @@ def dataReceived(self, data): self.handle_command(msg[0], *msg[1:]) - def handle_command(self, command, *args): + def handle_command(self, command: bytes, *args: bytes) -> None: """Received a Redis command from the client.""" # We currently only support pub/sub. @@ -548,9 +558,9 @@ def handle_command(self, command, *args): self.send("PONG") else: - raise Exception(f"Unknown command: {command}") + raise Exception(f"Unknown command: {command!r}") - def send(self, msg): + def send(self, msg: object) -> None: """Send a message back to the client.""" assert self.transport is not None @@ -559,7 +569,7 @@ def send(self, msg): self.transport.write(raw) self.transport.flush() - def encode(self, obj): + def encode(self, obj: object) -> str: """Encode an object to its Redis format. Supports: strings/bytes, integers and list/tuples. @@ -581,5 +591,5 @@ def encode(self, obj): raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) - def connectionLost(self, reason): + def connectionLost(self, reason: Failure = connectionDone) -> None: self._server.remove_subscriber(self) diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py index 936ab4504a79..9be11ab80217 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py @@ -44,7 +44,7 @@ async def _serialize_payload() -> JsonDict: @cancellable async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: await self.clock.sleep(1.0) return HTTPStatus.OK, {"result": True} @@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint): NAME = "uncancellable_sleep" PATH_ARGS = () CACHE = False + WAIT_FOR_STREAMS = False def __init__(self, hs: HomeServer): super().__init__(hs) @@ -64,7 +65,7 @@ async def _serialize_payload() -> JsonDict: return {} async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: await self.clock.sleep(1.0) return HTTPStatus.OK, {"result": True} @@ -73,7 +74,7 @@ async def _handle_request( # type: ignore[override] class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase): """Tests for `ReplicationEndpoint` cancellation.""" - def create_test_resource(self): + def create_test_resource(self) -> JsonResource: """Overrides `HomeserverTestCase.create_test_resource`.""" resource = JsonResource(self.hs) @@ -85,7 +86,7 @@ def create_test_resource(self): def test_cancellable_disconnect(self) -> None: """Test that handlers with the `@cancellable` flag can be cancelled.""" path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" - channel = self.make_request("POST", path, await_result=False) + channel = self.make_request("POST", path, await_result=False, content={}) test_disconnect( self.reactor, channel, @@ -96,7 +97,7 @@ def test_cancellable_disconnect(self) -> None: def test_uncancellable_disconnect(self) -> None: """Test that handlers without the `@cancellable` flag cannot be cancelled.""" path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" - channel = self.make_request("POST", path, await_result=False) + channel = self.make_request("POST", path, await_result=False, content={}) test_disconnect( self.reactor, channel, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index c5705256e6fa..4c9b494344ae 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,35 +13,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Iterable, Optional from unittest.mock import Mock -from tests.replication._base import BaseStreamTestCase +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer +from synapse.util import Clock -class BaseSlavedStoreTestCase(BaseStreamTestCase): - def make_homeserver(self, reactor, clock): +from tests.replication._base import BaseStreamTestCase - hs = self.setup_test_homeserver(federation_client=Mock()) - return hs +class BaseSlavedStoreTestCase(BaseStreamTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + return self.setup_test_homeserver(federation_client=Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self.reconnect() self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self._storage_controllers = hs.get_storage_controllers() + persistence = hs.get_storage_controllers().persistence + assert persistence is not None + self.persistance = persistence - def replicate(self): + def replicate(self) -> None: """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ self.streamer.on_notifier_poke() self.pump(0.1) - def check(self, method, args, expected_result=None): + def check( + self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None + ) -> None: master_result = self.get_success(getattr(self.master_store, method)(*args)) slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) if expected_result is not None: diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index becbc88947a9..f4df28b193a8 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable, Optional +from typing import Any, Callable, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict +from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource +from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import ( NotifCounts, RoomNotifCounts, @@ -28,6 +32,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition +from synapse.util import Clock from tests.server import FakeTransport @@ -41,19 +46,19 @@ logger = logging.getLogger(__name__) -def dict_equals(self, other): +def dict_equals(self: EventBase, other: EventBase) -> bool: me = encode_canonical_json(self.get_pdu_json()) them = encode_canonical_json(other.get_pdu_json()) return me == them -def patch__eq__(cls): +def patch__eq__(cls: object) -> Callable[[], None]: eq = getattr(cls, "__eq__", None) - cls.__eq__ = dict_equals + cls.__eq__ = dict_equals # type: ignore[assignment] - def unpatch(): + def unpatch() -> None: if eq is not None: - cls.__eq__ = eq + cls.__eq__ = eq # type: ignore[assignment] return unpatch @@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): STORE_TYPE = EventsWorkerStore - def setUp(self): + def setUp(self) -> None: # Patch up the equality operator for events so that we can check # whether lists of events match using assertEqual - self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] - return super().setUp() + self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)] + super().setUp() - def prepare(self, *args, **kwargs): - super().prepare(*args, **kwargs) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) self.get_success( self.master_store.store_room( @@ -80,10 +85,10 @@ def prepare(self, *args, **kwargs): ) ) - def tearDown(self): + def tearDown(self) -> None: [unpatch() for unpatch in self.unpatches] - def test_get_latest_event_ids_in_room(self): + def test_get_latest_event_ids_in_room(self) -> None: create = self.persist(type="m.room.create", key="", creator=USER_ID) self.replicate() self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) @@ -97,7 +102,7 @@ def test_get_latest_event_ids_in_room(self): self.replicate() self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) - def test_redactions(self): + def test_redactions(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.member", key=USER_ID, membership="join") @@ -117,7 +122,7 @@ def test_redactions(self): ) self.check("get_event", [msg.event_id], redacted) - def test_backfilled_redactions(self): + def test_backfilled_redactions(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.member", key=USER_ID, membership="join") @@ -139,7 +144,7 @@ def test_backfilled_redactions(self): ) self.check("get_event", [msg.event_id], redacted) - def test_invites(self): + def test_invites(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") @@ -163,7 +168,7 @@ def test_invites(self): ) @parameterized.expand([(True,), (False,)]) - def test_push_actions_for_user(self, send_receipt: bool): + def test_push_actions_for_user(self, send_receipt: bool) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist( @@ -219,7 +224,7 @@ def test_push_actions_for_user(self, send_receipt: bool): ), ) - def test_get_rooms_for_user_with_stream_ordering(self): + def test_get_rooms_for_user_with_stream_ordering(self) -> None: """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated by rows in the events stream """ @@ -243,7 +248,9 @@ def test_get_rooms_for_user_with_stream_ordering(self): {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, ) - def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): + def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist( + self, + ) -> None: """Check that current_state invalidation happens correctly with multiple events in the persistence batch. @@ -283,11 +290,7 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success( - self._storage_controllers.persistence.persist_events( - [(j2, j2ctx), (msg, msgctx)] - ) - ) + self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)])) self.replicate() assert j2.internal_metadata.stream_ordering is not None @@ -339,7 +342,7 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): event_id = 0 - def persist(self, backfill=False, **kwargs) -> FrozenEvent: + def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase: """ Returns: The event that was persisted. @@ -348,32 +351,28 @@ def persist(self, backfill=False, **kwargs) -> FrozenEvent: if backfill: self.get_success( - self._storage_controllers.persistence.persist_events( - [(event, context)], backfilled=True - ) + self.persistance.persist_events([(event, context)], backfilled=True) ) else: - self.get_success( - self._storage_controllers.persistence.persist_event(event, context) - ) + self.get_success(self.persistance.persist_event(event, context)) return event def build_event( self, - sender=USER_ID, - room_id=ROOM_ID, - type="m.room.message", - key=None, + sender: str = USER_ID, + room_id: str = ROOM_ID, + type: str = "m.room.message", + key: Optional[str] = None, internal: Optional[dict] = None, - depth=None, - prev_events: Optional[list] = None, - auth_events: Optional[list] = None, - prev_state: Optional[list] = None, - redacts=None, + depth: Optional[int] = None, + prev_events: Optional[List[Tuple[str, dict]]] = None, + auth_events: Optional[List[str]] = None, + prev_state: Optional[List[str]] = None, + redacts: Optional[str] = None, push_actions: Iterable = frozenset(), - **content, - ): + **content: object, + ) -> Tuple[EventBase, EventContext]: prev_events = prev_events or [] auth_events = auth_events or [] prev_state = prev_state or [] diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index 50fbff5f324f..01df1be0473c 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -21,7 +21,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): - def test_update_function_room_account_data_limit(self): + def test_update_function_room_account_data_limit(self) -> None: """Test replication with many room account data updates""" store = self.hs.get_datastores().main @@ -67,7 +67,7 @@ def test_update_function_room_account_data_limit(self): self.assertEqual([], received_rows) - def test_update_function_global_account_data_limit(self): + def test_update_function_global_account_data_limit(self) -> None: """Test replication with many global account data updates""" store = self.hs.get_datastores().main diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 641a94133b1d..043dbe76afb0 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Any, List, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -25,6 +27,8 @@ ) from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests.replication._base import BaseStreamTestCase from tests.test_utils.event_injection import inject_event, inject_member_event @@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -47,7 +51,7 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() - def test_update_function_event_row_limit(self): + def test_update_function_event_row_limit(self) -> None: """Test replication with many non-state events Checks that all events are correctly replicated when there are lots of @@ -102,7 +106,7 @@ def test_update_function_event_row_limit(self): self.assertEqual([], received_rows) - def test_update_function_huge_state_change(self): + def test_update_function_huge_state_change(self) -> None: """Test replication with many state events Ensures that all events are correctly replicated when there are lots of @@ -256,7 +260,7 @@ def test_update_function_huge_state_change(self): # "None" indicates the state has been deleted self.assertIsNone(sr.event_id) - def test_update_function_state_row_limit(self): + def test_update_function_state_row_limit(self) -> None: """Test replication with many state events over several stream ids.""" # we want to generate lots of state changes, but for this test, we want to @@ -376,7 +380,7 @@ def test_update_function_state_row_limit(self): self.assertEqual([], received_rows) - def test_backwards_stream_id(self): + def test_backwards_stream_id(self) -> None: """ Test that RDATA that comes after the current position should be discarded. """ @@ -437,7 +441,7 @@ def test_backwards_stream_id(self): event_count = 0 def _inject_test_event( - self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any ) -> EventBase: if sender is None: sender = self.user_id diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index bcb82c9c8051..cdbdfaf057dc 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -26,7 +26,7 @@ def _get_worker_hs_config(self) -> dict: config["federation_sender_instances"] = ["federation_sender1"] return config - def test_catchup(self): + def test_catchup(self) -> None: """Basic test of catchup on reconnect Makes sure that updates sent while we are offline are received later. diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py index 2c10eab4dbcd..38b5020ce0e1 100644 --- a/tests/replication/tcp/streams/test_partial_state.py +++ b/tests/replication/tcp/streams/test_partial_state.py @@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): hijack_auth = True user_id = "@bob:test" - def setUp(self): + def setUp(self) -> None: super().setUp() self.store = self.hs.get_datastores().main diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 9a229dd23f4a..68de5d1cc280 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -27,10 +27,11 @@ class TypingStreamTestCase(BaseStreamTestCase): - def _build_replication_data_handler(self): - return Mock(wraps=super()._build_replication_data_handler()) + def _build_replication_data_handler(self) -> Mock: + self.mock_handler = Mock(wraps=super()._build_replication_data_handler()) + return self.mock_handler - def test_typing(self): + def test_typing(self) -> None: typing = self.hs.get_typing_handler() self.reconnect() @@ -43,8 +44,8 @@ def test_typing(self): request = self.handle_http_replication_attempt() self.assert_request_is_get_repl_stream_updates(request, "typing") - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row: TypingStream.TypingStreamRow = rdata_rows[0] @@ -54,11 +55,11 @@ def test_typing(self): # Now let's disconnect and insert some data. self.disconnect() - self.test_handler.on_rdata.reset_mock() + self.mock_handler.on_rdata.reset_mock() typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False) - self.test_handler.on_rdata.assert_not_called() + self.mock_handler.on_rdata.assert_not_called() self.reconnect() self.pump(0.1) @@ -71,15 +72,15 @@ def test_typing(self): assert request.args is not None self.assertEqual(int(request.args[b"from_token"][0]), token) - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([], row.user_ids) - def test_reset(self): + def test_reset(self) -> None: """ Test what happens when a typing stream resets. @@ -98,8 +99,8 @@ def test_reset(self): request = self.handle_http_replication_attempt() self.assert_request_is_get_repl_stream_updates(request, "typing") - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row: TypingStream.TypingStreamRow = rdata_rows[0] @@ -134,15 +135,15 @@ def test_reset(self): self.assert_request_is_get_repl_stream_updates(request, "typing") # Reset the test code. - self.test_handler.on_rdata.reset_mock() - self.test_handler.on_rdata.assert_not_called() + self.mock_handler.on_rdata.reset_mock() + self.mock_handler.on_rdata.assert_not_called() # Push additional data. typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) self.reactor.advance(0) - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py index cca7ebb7195d..5d6b72b16d84 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py @@ -21,12 +21,12 @@ class ParseCommandTestCase(TestCase): - def test_parse_one_word_command(self): + def test_parse_one_word_command(self) -> None: line = "REPLICATE" cmd = parse_command_from_line(line) self.assertIsInstance(cmd, ReplicateCommand) - def test_parse_rdata(self): + def test_parse_rdata(self) -> None: line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' cmd = parse_command_from_line(line) assert isinstance(cmd, RdataCommand) @@ -34,7 +34,7 @@ def test_parse_rdata(self): self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.token, 6287863) - def test_parse_rdata_batch(self): + def test_parse_rdata_batch(self) -> None: line = 'RDATA presence master batch ["@foo:example.com", "online"]' cmd = parse_command_from_line(line) assert isinstance(cmd, RdataCommand) diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index 1e299d2d67ea..6e4055cc2102 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.replication.tcp.commands import PositionCommand + from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -71,3 +75,68 @@ def test_non_background_worker_not_subscribed_to_user_ip(self) -> None: self.assertEqual( len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 ) + + def test_wait_for_stream_position(self) -> None: + """Check that wait for stream position correctly waits for an update from the + correct instance. + """ + store = self.hs.get_datastores().main + cmd_handler = self.hs.get_replication_command_handler() + data_handler = self.hs.get_replication_data_handler() + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + cache_id_gen = worker1.get_datastores().main._cache_id_gen + assert cache_id_gen is not None + + self.replicate() + + # First, make sure the master knows that `worker1` exists. + initial_token = cache_id_gen.get_current_token() + cmd_handler.send_command( + PositionCommand("caches", "worker1", initial_token, initial_token) + ) + self.replicate() + + # Next send out a normal RDATA, and check that waiting for that stream + # ID returns immediately. + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + self.get_success( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + + # `wait_for_stream_position` should only return once master receives a + # notification that `next_token` has persisted. + ctx_worker1 = cache_id_gen.get_next() + next_token = self.get_success(ctx_worker1.__aenter__()) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... updating the cache ID gen on the master still shouldn't cause the + # deferred to wake up. + ctx = store._cache_id_gen.get_next() + self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... but worker1 finishing (and so sending an update) should. + self.get_success(ctx_worker1.__aexit__(None, None, None)) + + self.assertTrue(d.called) diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py index 545f11acd1bf..b75fc05fd55b 100644 --- a/tests/replication/tcp/test_remote_server_up.py +++ b/tests/replication/tcp/test_remote_server_up.py @@ -16,15 +16,17 @@ from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IProtocol -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class RemoteServerUpTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.factory = ReplicationStreamProtocolFactory(hs) def _make_client(self) -> Tuple[IProtocol, StringTransport]: @@ -40,7 +42,7 @@ def _make_client(self) -> Tuple[IProtocol, StringTransport]: return proto, transport - def test_relay(self): + def test_relay(self) -> None: """Test that Synapse will relay REMOTE_SERVER_UP commands to all other connections, but not the one that sent it. """ diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 5d7a89e0c702..98602371e467 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,7 +13,11 @@ # limitations under the License. import logging +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest.client import register +from synapse.server import HomeServer +from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, make_request @@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): servlets = [register.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # This isn't a real configuration option but is used to provide the main # homeserver and worker homeserver different options. @@ -77,7 +81,7 @@ def _test_register(self) -> FakeChannel: {"auth": {"session": session, "type": "m.login.dummy"}}, ) - def test_no_auth(self): + def test_no_auth(self) -> None: """With no authentication the request should finish.""" channel = self._test_register() self.assertEqual(channel.code, 200) @@ -86,7 +90,7 @@ def test_no_auth(self): self.assertEqual(channel.json_body["user_id"], "@user:test") @override_config({"main_replication_secret": "my-secret"}) - def test_missing_auth(self): + def test_missing_auth(self) -> None: """If the main process expects a secret that is not provided, an error results.""" channel = self._test_register() self.assertEqual(channel.code, 500) @@ -97,13 +101,13 @@ def test_missing_auth(self): "worker_replication_secret": "wrong-secret", } ) - def test_unauthorized(self): + def test_unauthorized(self) -> None: """If the main process receives the wrong secret, an error results.""" channel = self._test_register() self.assertEqual(channel.code, 500) @override_config({"worker_replication_secret": "my-secret"}) - def test_authorized(self): + def test_authorized(self) -> None: """The request should finish when the worker provides the authentication header.""" channel = self._test_register() self.assertEqual(channel.code, 200) diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index eb5b376534b9..eca503376110 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -33,7 +33,7 @@ def _get_worker_hs_config(self) -> dict: config["worker_replication_http_port"] = "8765" return config - def test_register_single_worker(self): + def test_register_single_worker(self) -> None: """Test that registration works when using a single generic worker.""" worker_hs = self.make_worker_hs("synapse.app.generic_worker") site = self._hs_to_site[worker_hs] @@ -63,7 +63,7 @@ def test_register_single_worker(self): # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") - def test_register_multi_worker(self): + def test_register_multi_worker(self) -> None: """Test that registration works when using multiple generic workers.""" worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker") worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker") diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 63b1dd40b5b9..12668b34c5aa 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -14,10 +14,14 @@ from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.app.generic_worker import GenericWorkerServer from synapse.replication.tcp.commands import FederationAckCommand from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams.federation import FederationStream +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,12 +34,10 @@ def default_config(self) -> dict: config["federation_sender_instances"] = ["federation_sender1"] return config - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) - - return hs + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) - def test_federation_ack_sent(self): + def test_federation_ack_sent(self) -> None: """A FEDERATION_ACK should be sent back after each RDATA federation This test checks that the federation sender is correctly sending back diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index c28073b8f797..89380e25b59a 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): room.register_servlets, ] - def test_send_event_single_sender(self): + def test_send_event_single_sender(self) -> None: """Test that using a single federation sender worker correctly sends a new event. """ @@ -71,7 +71,7 @@ def test_send_event_single_sender(self): self.assertEqual(mock_client.put_json.call_args[0][0], "other_server") self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus")) - def test_send_event_sharded(self): + def test_send_event_sharded(self) -> None: """Test that using two federation sender workers correctly sends new events. """ @@ -138,7 +138,7 @@ def test_send_event_sharded(self): self.assertTrue(sent_on_1) self.assertTrue(sent_on_2) - def test_send_typing_sharded(self): + def test_send_typing_sharded(self) -> None: """Test that using two federation sender workers correctly sends new typing EDUs. """ @@ -215,7 +215,9 @@ def test_send_typing_sharded(self): self.assertTrue(sent_on_1) self.assertTrue(sent_on_2) - def create_room_with_remote_server(self, user, token, remote_server="other_server"): + def create_room_with_remote_server( + self, user: str, token: str, remote_server: str = "other_server" + ) -> str: room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py index b93cae67d3c4..9c4fbda71b87 100644 --- a/tests/replication/test_module_cache_invalidation.py +++ b/tests/replication/test_module_cache_invalidation.py @@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): synapse.rest.admin.register_servlets, ] - def test_module_cache_full_invalidation(self): + def test_module_cache_full_invalidation(self) -> None: main_cache = TestCache() self.hs.get_module_api().register_cached_function(main_cache.cached_function) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 96cdf2c45b16..1527b4a82d84 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -18,12 +18,14 @@ from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol +from twisted.test.proto_helpers import MemoryReactor from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin from synapse.rest.client import login from synapse.server import HomeServer +from synapse.util import Clock from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "pass") self.access_token = self.login("user", "pass") self.reactor.lookups["example.com"] = "1.2.3.4" - def default_config(self): + def default_config(self) -> dict: conf = super().default_config() conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] return conf @@ -122,7 +124,7 @@ def _get_media_req( return channel, request - def test_basic(self): + def test_basic(self) -> None: """Test basic fetching of remote media from a single worker.""" hs1 = self.make_worker_hs("synapse.app.generic_worker") @@ -138,7 +140,7 @@ def test_basic(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"Hello!") - def test_download_simple_file_race(self): + def test_download_simple_file_race(self) -> None: """Test that fetching remote media from two different processes at the same time works. """ @@ -177,7 +179,7 @@ def test_download_simple_file_race(self): # We expect only one new file to have been persisted. self.assertEqual(start_count + 1, self._count_remote_media()) - def test_download_image_race(self): + def test_download_image_race(self) -> None: """Test that fetching remote *images* from two different processes at the same time works. @@ -229,7 +231,7 @@ def _count_remote_thumbnails(self) -> int: return sum(len(files) for _, _, files in os.walk(path)) -def get_connection_factory(): +def get_connection_factory() -> TestServerTLSConnectionFactory: # this needs to happen once, but not until we are ready to run the first test global test_server_connection_factory if test_server_connection_factory is None: @@ -263,6 +265,6 @@ def _build_test_server( return server_tls_factory.buildProtocol(None) -def _log_request(request): +def _log_request(request: Request) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info("Completed request %s", request) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index ca18ad655394..9345cfbeb266 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -15,9 +15,12 @@ from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Register a user who sends a message that we'll get notified about self.other_user_id = self.register_user("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass") - def _create_pusher_and_send_msg(self, localpart): + def _create_pusher_and_send_msg(self, localpart: str) -> str: # Create a user that will get push notifications user_id = self.register_user(localpart, "pass") access_token = self.login(localpart, "pass") @@ -79,7 +82,7 @@ def _create_pusher_and_send_msg(self, localpart): return event_id - def test_send_push_single_worker(self): + def test_send_push_single_worker(self) -> None: """Test that registration works when using a pusher worker.""" http_client_mock = Mock(spec_set=["post_json_get_json"]) http_client_mock.post_json_get_json.side_effect = ( @@ -109,7 +112,7 @@ def test_send_push_single_worker(self): ], ) - def test_send_push_multiple_workers(self): + def test_send_push_multiple_workers(self) -> None: """Test that registration works when using sharded pusher workers.""" http_client_mock1 = Mock(spec_set=["post_json_get_json"]) http_client_mock1.post_json_get_json.side_effect = ( diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 541d3902860c..7f9cc67e735a 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -14,9 +14,13 @@ import logging from unittest.mock import patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request @@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Register a user who sends a message that we'll get notified about self.other_user_id = self.register_user("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass") @@ -42,7 +46,7 @@ def prepare(self, reactor, clock, hs): self.room_creator = self.hs.get_room_creation_handler() self.store = hs.get_datastores().main - def default_config(self): + def default_config(self) -> dict: conf = super().default_config() conf["stream_writers"] = {"events": ["worker1", "worker2"]} conf["instance_map"] = { @@ -51,7 +55,7 @@ def default_config(self): } return conf - def _create_room(self, room_id: str, user_id: str, tok: str): + def _create_room(self, room_id: str, user_id: str, tok: str) -> None: """Create a room with given room_id""" # We control the room ID generation by patching out the @@ -62,7 +66,7 @@ def _create_room(self, room_id: str, user_id: str, tok: str): mock.side_effect = lambda: room_id self.helper.create_room_as(user_id, tok=tok) - def test_basic(self): + def test_basic(self) -> None: """Simple test to ensure that multiple rooms can be created and joined, and that different rooms get handled by different instances. """ @@ -112,7 +116,7 @@ def test_basic(self): self.assertTrue(persisted_on_1) self.assertTrue(persisted_on_2) - def test_vector_clock_token(self): + def test_vector_clock_token(self) -> None: """Tests that using a stream token with a vector clock component works correctly with basic /sync and /messages usage. """ diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 8a4e5c3f777b..233eba351690 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -280,7 +280,10 @@ def test_invalid_search_order(self) -> None: self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - self.assertEqual("Unknown direction: bar", channel.json_body["error"]) + self.assertEqual( + "Query parameter 'dir' must be one of ['b', 'f']", + channel.json_body["error"], + ) def test_limit_is_negative(self) -> None: """ diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index e0f5d54abab0..453a6e979c02 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1831,7 +1831,7 @@ def test_timestamp_to_event(self) -> None: def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -1845,7 +1845,7 @@ def test_topo_token_is_accepted(self) -> None: def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index b86f341ff5b4..c8a6911d5ee5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -30,6 +30,7 @@ from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event +from tests.unittest import override_config class BaseRelationsTestCase(unittest.HomeserverTestCase): @@ -355,30 +356,67 @@ def test_ignore_invalid_room(self) -> None: self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def _assert_edit_bundle( + self, event_json: JsonDict, edit_event_id: str, edit_event_content: JsonDict + ) -> None: + """ + Assert that the given event has a correctly-serialised edit event in its + bundled aggregations + + Args: + event_json: the serialised event to be checked + edit_event_id: the ID of the edit event that we expect to be bundled + edit_event_content: the content of that event, excluding the 'm.relates_to` + property + """ + relations_dict = event_json["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in [ + "event_id", + "sender", + "origin_server_ts", + "content", + "type", + "unsigned", + ]: + self.assertIn(key, m_replace_dict) + + expected_edit_content = { + "m.relates_to": { + "event_id": event_json["event_id"], + "rel_type": "m.replace", + } + } + expected_edit_content.update(edit_event_content) + + self.assert_dict( + { + "event_id": edit_event_id, + "sender": self.user_id, + "content": expected_edit_content, + "type": "m.room.message", + }, + m_replace_dict, + ) + def test_edit(self) -> None: """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": new_body, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + content=edit_event_content, ) edit_event_id = channel.json_body["event_id"] - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - # /event should return the *original* event channel = self.make_request( "GET", @@ -389,7 +427,7 @@ def assert_bundle(event_json: JsonDict) -> None: self.assertEqual( channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"} ) - assert_bundle(channel.json_body) + self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) # Request the room messages. channel = self.make_request( @@ -398,7 +436,11 @@ def assert_bundle(event_json: JsonDict) -> None: access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + self._assert_edit_bundle( + self._find_event_in_chunk(channel.json_body["chunk"]), + edit_event_id, + edit_event_content, + ) # Request the room context. # /context should return the edited event. @@ -408,7 +450,9 @@ def assert_bundle(event_json: JsonDict) -> None: access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content + ) self.assertEqual(channel.json_body["event"]["content"], new_body) # Request sync, but limit the timeline so it becomes limited (and includes @@ -420,7 +464,11 @@ def assert_bundle(event_json: JsonDict) -> None: self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + self._assert_edit_bundle( + self._find_event_in_chunk(room_timeline["events"]), + edit_event_id, + edit_event_content, + ) # Request search. channel = self.make_request( @@ -437,7 +485,45 @@ def assert_bundle(event_json: JsonDict) -> None: "results" ] ] - assert_bundle(self._find_event_in_chunk(chunk)) + self._assert_edit_bundle( + self._find_event_in_chunk(chunk), + edit_event_id, + edit_event_content, + ) + + @override_config({"experimental_features": {"msc3925_inhibit_edit": True}}) + def test_edit_inhibit_replace(self) -> None: + """ + If msc3925_inhibit_edit is enabled, then the original event should not be + replaced. + """ + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": new_body, + } + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content=edit_event_content, + ) + edit_event_id = channel.json_body["event_id"] + + # /context should return the *original* event. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"} + ) + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content + ) def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who @@ -455,10 +541,15 @@ def test_multi_edit(self) -> None: ) new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": new_body, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + content=edit_event_content, ) edit_event_id = channel.json_body["event_id"] @@ -480,16 +571,8 @@ def test_multi_edit(self) -> None: self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["event"]["content"], new_body) - - relations_dict = channel.json_body["event"]["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content ) def test_edit_reply(self) -> None: @@ -502,11 +585,15 @@ def test_edit_reply(self) -> None: ) reply = channel.json_body["event_id"] - new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "I've been edited!"}, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + content=edit_event_content, parent_id=reply, ) edit_event_id = channel.json_body["event_id"] @@ -549,28 +636,22 @@ def test_edit_reply(self) -> None: # We expect that the edit relation appears in the unsigned relations # section. - relations_dict = result_event_dict["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict, desc) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict, desc) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + self._assert_edit_bundle( + result_event_dict, edit_event_id, edit_event_content ) def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} + edit_event_content = { + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": new_body, - }, + content=edit_event_content, ) edit_event_id = channel.json_body["event_id"] @@ -599,8 +680,7 @@ def test_edit_edit(self) -> None: ) # The relations information should not include the edit to the edit. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) # /context should return the event updated for the *first* edit # (The edit to the edit should be ignored.) @@ -611,13 +691,8 @@ def test_edit_edit(self) -> None: ) self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["event"]["content"], new_body) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content ) # Directly requesting the edit should not have the edit to the edit applied. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b4daace55617..9222cab19801 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -1987,7 +1987,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -1998,7 +1998,7 @@ def test_topo_token_is_accepted(self) -> None: self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2728,7 +2728,7 @@ def test_messages_filter_labels(self) -> None: """Test that we can filter by a label on a /messages request.""" self._send_labelled_messages_in_room() - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" @@ -2745,7 +2745,7 @@ def test_messages_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /messages request.""" self._send_labelled_messages_in_room() - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" @@ -2768,7 +2768,7 @@ def test_messages_filter_labels_not_labels(self) -> None: """ self._send_labelled_messages_in_room() - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 49d93849e630..2cb624ce7b57 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -294,9 +294,7 @@ def test_sync_backwards_typing(self) -> None: self.make_request("GET", sync_url % (access_token, next_batch)) -class SyncKnockTestCase( - unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin -): +class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1269,7 +1267,9 @@ def prepare( # We need to manually append the room ID, because we can't know the ID before # creating the room, and we can't set the config after starting the homeserver. - self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id) + self.hs.get_sync_handler().rooms_to_exclude_globally.append( + self.excluded_room_id + ) def test_join_leave(self) -> None: """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 21a1ca2a6885..3086e1b5650b 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -13,18 +13,22 @@ # limitations under the License. from http import HTTPStatus +from typing import Any, Generator, Tuple, cast from unittest.mock import Mock, call -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as _reactor from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable from tests.utils import MockClock +reactor = cast(ISynapseReactor, _reactor) + class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self) -> None: @@ -34,11 +38,13 @@ def setUp(self) -> None: self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) self.mock_key = "foo" @defer.inlineCallbacks - def test_executes_given_function(self): + def test_executes_given_function( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" @@ -47,7 +53,9 @@ def test_executes_given_function(self): self.assertEqual(res, self.mock_http_response) @defer.inlineCallbacks - def test_deduplicates_based_on_key(self): + def test_deduplicates_based_on_key( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( @@ -58,18 +66,20 @@ def test_deduplicates_based_on_key(self): cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0) @defer.inlineCallbacks - def test_logcontexts_with_async_result(self): + def test_logcontexts_with_async_result( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks - def cb(): + def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: yield Clock(reactor).sleep(0) - return "yay" + return 1, {} @defer.inlineCallbacks - def test(): + def test() -> Generator["defer.Deferred[Any]", object, None]: with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertIs(current_context(), c1) - self.assertEqual(res, "yay") + self.assertEqual(res, (1, {})) # run the test twice in parallel d = defer.gatherResults([test(), test()]) @@ -78,13 +88,15 @@ def test(): self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks - def test_does_not_cache_exceptions(self): + def test_does_not_cache_exceptions( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback throws an exception, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -104,13 +116,15 @@ def cb(): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_does_not_cache_failures(self): + def test_does_not_cache_failures( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback returns a failure, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -130,7 +144,7 @@ def cb(): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_cleans_up(self): + def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index d48f33d1f978..04c352b66666 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -200,9 +200,15 @@ def test_power_levels_tombstone(self) -> None: def test_stringy_power_levels(self) -> None: """The room upgrade converts stringy power levels to proper integers.""" + # Create a room on room version < 10. + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_token, room_version="9" + ) + self.helper.join(room_id, self.other, tok=self.other_token) + # Retrieve the room's current power levels. power_levels = self.helper.get_state( - self.room_id, + room_id, "m.room.power_levels", tok=self.creator_token, ) @@ -218,14 +224,14 @@ def test_stringy_power_levels(self) -> None: # conscience, we ought to ensure it's upgrading from a sufficiently old # version of room. self.helper.send_state( - self.room_id, + room_id, "m.room.power_levels", body=power_levels, tok=self.creator_token, ) # Upgrade the room. Check the homeserver reports success. - channel = self._upgrade_room() + channel = self._upgrade_room(room_id=room_id) self.assertEqual(200, channel.code, channel.result) # Extract the new room ID. diff --git a/tests/server.py b/tests/server.py index b1730fcc8dd5..237bcad8ba6f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -70,7 +70,7 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine -from synapse.types import JsonDict +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests.utils import ( @@ -401,7 +401,9 @@ def make_request( return channel -@implementer(IReactorPluggableNameResolver) +# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly +# marking this as an implementer of the latter seems to keep mypy-zope happier. +@implementer(IReactorPluggableNameResolver, ISynapseReactor) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 71bdfe68cdbb..be1be8d85939 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -70,7 +70,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(1000) ) - self._rlsn._server_notices_manager.send_notice = Mock( + self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment] return_value=make_awaitable(Mock()) ) self._send_notice = self._rlsn._server_notices_manager.send_notice @@ -83,8 +83,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( return_value=make_awaitable("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) + self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): @@ -363,9 +363,10 @@ def _trigger_notice_and_join(self) -> Tuple[str, str, str]: tok: The access token of the user that joined the room. room_id: The ID of the room that's been joined. """ - user_id = None - tok = None - invites = [] + # We need at least one user to process + self.assertGreater(self.hs.config.server.max_mau_value, 0) + + invites = {} # Register as many users as the MAU limit allows. for i in range(self.hs.config.server.max_mau_value): diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index 68026e283046..ac77aec003b1 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -168,7 +168,9 @@ def test_background_receipts_linearized_unique_index(self) -> None: {"stream_id": 6, "event_id": "$some_event"}, ], (self.other_room_id, "m.read", self.user_id): [ - {"stream_id": 7, "event_id": "$some_event"} + # It is possible for stream IDs to be duplicated. + {"stream_id": 7, "event_id": "$some_event"}, + {"stream_id": 7, "event_id": "$some_event"}, ], }, expected_unique_receipts={ diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7d961fac649c..3108ca344423 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -40,9 +40,23 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.token = self.login("foo", "pass") def _generate_room(self) -> str: - room_id = self.helper.create_room_as(self.user_id, tok=self.token) + """Create a room and return the room ID.""" + return self.helper.create_room_as(self.user_id, tok=self.token) - return room_id + def run_background_updates(self, update_name: str) -> None: + """Insert and run the background update.""" + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": update_name, "progress_json": "{}"}, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() def test_background_populate_rooms_creator_column(self) -> None: """Test that the background update to populate the rooms creator column @@ -71,22 +85,7 @@ def test_background_populate_rooms_creator_column(self) -> None: ) self.assertEqual(room_creator_before, None) - # Insert and run the background update. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, - "progress_json": "{}", - }, - ) - ) - - # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db_pool.updates._all_done = False - - # Now let's actually drive the updates to completion - self.wait_for_background_updates() + self.run_background_updates(_BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN) # Make sure the background update filled in the room creator room_creator_after = self.get_success( @@ -137,22 +136,7 @@ def test_background_add_room_type_column(self) -> None: ) ) - # Insert and run the background update - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, - "progress_json": "{}", - }, - ) - ) - - # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db_pool.updates._all_done = False - - # Now let's actually drive the updates to completion - self.wait_for_background_updates() + self.run_background_updates(_BackgroundUpdates.ADD_ROOM_TYPE_COLUMN) # Make sure the background update filled in the room type room_type_after = self.get_success( @@ -164,3 +148,39 @@ def test_background_add_room_type_column(self) -> None: ) ) self.assertEqual(room_type_after, RoomTypes.SPACE) + + def test_populate_stats_broken_rooms(self) -> None: + """Ensure that re-populating room stats skips broken rooms.""" + + # Create a good room. + good_room_id = self._generate_room() + + # Create a room and then break it by having no room version. + room_id = self._generate_room() + self.get_success( + self.store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"room_version": None}, + desc="test", + ) + ) + + # Nuke any current stats in the database. + self.get_success( + self.store.db_pool.simple_delete( + table="room_stats_state", keyvalues={"1": 1}, desc="test" + ) + ) + + self.run_background_updates("populate_stats_process_rooms") + + # Only the good room appears in the stats tables. + results = self.get_success( + self.store.db_pool.simple_select_onecol( + table="room_stats_state", + keyvalues={}, + retcol="room_id", + ) + ) + self.assertEqual(results, [good_room_id]) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 543cce6b3e12..8cd7c89ca2f8 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -22,6 +22,7 @@ from synapse.server import HomeServer from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -37,6 +38,101 @@ def test_native_tuple_comparison(self) -> None: self.assertEqual(args, [1, 2]) +class ExecuteScriptTestCase(unittest.HomeserverTestCase): + """Tests for `BaseDatabaseEngine.executescript` implementations.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + self.get_success( + self.db_pool.runInteraction( + "create", + lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"), + ) + ) + + def test_transaction(self) -> None: + """Test that all statements are run in a single transaction.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_transaction") + self.db_pool.engine.executescript( + cur, + ";".join( + [ + "INSERT INTO foo (name) VALUES ('transaction test')", + # This next statement will fail. When `executescript` is not + # transactional, the previous row will be observed later. + "INSERT INTO foo (name) VALUES ('transaction test')", + ] + ), + ) + + self.get_failure( + self.db_pool.runWithConnection(run), + self.db_pool.engine.module.IntegrityError, + ) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "transaction test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not running statements inside a transaction", + ) + + def test_commit(self) -> None: + """Test that the script transaction remains open and can be committed.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_commit") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('commit test')" + ) + cur.execute("COMMIT") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNotNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "commit test"}, + retcol="name", + allow_none=True, + ) + ), + ) + + def test_rollback(self) -> None: + """Test that the script transaction remains open and can be rolled back.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_rollback") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('rollback test')" + ) + cur.execute("ROLLBACK") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "rollback test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not leaving the script transaction open", + ) + + class CallbacksTestCase(unittest.HomeserverTestCase): """Tests for transaction callbacks.""" diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5fa8bd2d98ce..76c06a9d1e72 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -154,7 +154,7 @@ def test_count_aggregation(self) -> None: # Create a user to receive notifications and send receipts. user_id, token, _, other_token, room_id = self._create_users_and_room() - last_event_id: str + last_event_id = "" def _assert_counts(notif_count: int, highlight_count: int) -> None: counts = self.get_success( @@ -289,7 +289,7 @@ def test_count_aggregation_threads(self) -> None: user_id, token, _, other_token, room_id = self._create_users_and_room() thread_id: str - last_event_id: str + last_event_id = "" def _assert_counts( notif_count: int, @@ -471,7 +471,7 @@ def test_count_aggregation_mixed(self) -> None: user_id, token, _, other_token, room_id = self._create_users_and_room() thread_id: str - last_event_id: str + last_event_id = "" def _assert_counts( notif_count: int, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index d6a2b8d2743e..9174fb096470 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -52,6 +52,7 @@ def _create_id_generator(self) -> StreamIdGenerator: def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: return StreamIdGenerator( db_conn=conn, + notifier=self.hs.get_replication_notifier(), table="foobar", column="stream_id", ) @@ -196,6 +197,7 @@ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], @@ -349,8 +351,8 @@ def test_multi_instance(self) -> None: # The first ID gen will notice that it can advance its token to 7 as it # has no in progress writes... - self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # ... but the second ID gen doesn't know that. @@ -366,8 +368,9 @@ async def _get_next_async() -> None: self.assertEqual(stream_id, 8) self.assertEqual( - first_id_gen.get_positions(), {"first": 7, "second": 7} + first_id_gen.get_positions(), {"first": 3, "second": 7} ) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.get_success(_get_next_async()) @@ -473,7 +476,7 @@ def test_get_persisted_upto_position_get_next(self) -> None: id_gen = self._create_id_generator("first", writers=["first", "second"]) - self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -629,6 +632,7 @@ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], @@ -720,7 +724,7 @@ async def _get_next_async2() -> None: self.get_success(_get_next_async2()) - self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2}) + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) @@ -765,6 +769,7 @@ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[ @@ -816,15 +821,12 @@ def test_load_existing_stream(self) -> None: first_id_gen = self._create_id_generator("first", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - # The first ID gen will notice that it can advance its token to 7 as it - # has no in progress writes... - self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - # ... but the second ID gen doesn't know that. self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + self.assertEqual(second_id_gen.get_persisted_upto_position(), 7) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 3ba896ecf384..f1ca523d23f9 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -28,6 +28,7 @@ from synapse.storage.roommember import ProfileInfo from synapse.util import Clock +from tests.server import ThreadedMemoryReactorClock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -138,7 +139,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", id="1234", diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index f4d9fba0a14c..0a7937f1cc72 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from typing import Collection, Dict, Iterable, List, Optional +from typing import Any, Collection, Dict, Iterable, List, Optional from parameterized import parameterized @@ -728,6 +728,36 @@ def test_room_v10_rejects_string_power_levels(self) -> None: pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event} ) + def test_room_v10_rejects_other_non_integer_power_levels(self) -> None: + """We should reject PLs that are non-integer, non-string JSON values. + + test_room_v10_rejects_string_power_levels above handles the string case. + """ + + def create_event(pl_event_content: Dict[str, Any]) -> EventBase: + return make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + contents: Iterable[Dict[str, Any]] = [ + {"notifications": {"room": None}}, + {"users": {"@alice:wonderland": []}}, + {"users_default": {}}, + ] + for content in contents: + event = create_event(content) + with self.assertRaises(SynapseError): + event_auth._check_power_levels(event.room_version, event, {}) + # helpers for making events TEST_DOMAIN = "example.com" diff --git a/tests/test_state.py b/tests/test_state.py index 504530b49a8e..b20a26e1ffe4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Optional, cast +from typing import ( + Any, + Collection, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + cast, +) from unittest.mock import Mock from twisted.internet import defer @@ -19,9 +31,11 @@ from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry +from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.macaroons import MacaroonGenerator @@ -33,14 +47,14 @@ def create_event( - name=None, - type=None, - state_key=None, - depth=2, - event_id=None, - prev_events: Optional[List[str]] = None, - **kwargs, -): + name: Optional[str] = None, + type: Optional[str] = None, + state_key: Optional[str] = None, + depth: int = 2, + event_id: Optional[str] = None, + prev_events: Optional[List[Tuple[str, dict]]] = None, + **kwargs: Any, +) -> EventBase: global _next_event_id if not event_id: @@ -67,21 +81,21 @@ def create_event( d.update(kwargs) - event = make_event_from_dict(d) - - return event + return make_event_from_dict(d) class _DummyStore: - def __init__(self): - self._event_to_state_group = {} - self._group_to_state = {} + def __init__(self) -> None: + self._event_to_state_group: Dict[str, int] = {} + self._group_to_state: Dict[int, MutableStateMap[str]] = {} - self._event_id_to_event = {} + self._event_id_to_event: Dict[str, EventBase] = {} self._next_group = 1 - async def get_state_groups_ids(self, room_id, event_ids): + async def get_state_groups_ids( + self, room_id: str, event_ids: Collection[str] + ) -> Dict[int, MutableStateMap[str]]: groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) @@ -90,16 +104,25 @@ async def get_state_groups_ids(self, room_id, event_ids): return groups - async def get_state_ids_for_group(self, state_group, state_filter=None): + async def get_state_ids_for_group( + self, state_group: int, state_filter: Optional[StateFilter] = None + ) -> MutableStateMap[str]: return self._group_to_state[state_group] async def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[StateMap[str]], + current_state_ids: Optional[StateMap[str]], + ) -> int: state_group = self._next_group self._next_group += 1 if current_state_ids is None: + assert prev_group is not None + assert delta_ids is not None current_state_ids = dict(self._group_to_state[prev_group]) current_state_ids.update(delta_ids) @@ -107,7 +130,9 @@ async def store_state_group( return state_group - async def get_events(self, event_ids, **kwargs): + async def get_events( + self, event_ids: Collection[str], **kwargs: Any + ) -> Dict[str, EventBase]: return { e_id: self._event_id_to_event[e_id] for e_id in event_ids @@ -119,31 +144,36 @@ async def get_partial_state_events( ) -> Dict[str, bool]: return {e: False for e in event_ids} - async def get_state_group_delta(self, name): + async def get_state_group_delta( + self, name: str + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: return None, None - def register_events(self, events): + def register_events(self, events: Iterable[EventBase]) -> None: for e in events: self._event_id_to_event[e.event_id] = e - def register_event_context(self, event, context): + def register_event_context(self, event: EventBase, context: EventContext) -> None: + assert context.state_group is not None self._event_to_state_group[event.event_id] = context.state_group - def register_event_id_state_group(self, event_id, state_group): + def register_event_id_state_group(self, event_id: str, state_group: int) -> None: self._event_to_state_group[event_id] = state_group - async def get_room_version_id(self, room_id): + async def get_room_version_id(self, room_id: str) -> str: return RoomVersions.V1.identifier async def get_state_group_for_events( - self, event_ids, await_full_state: bool = True - ): + self, event_ids: Collection[str], await_full_state: bool = True + ) -> Dict[str, int]: res = {} for event in event_ids: res[event] = self._event_to_state_group[event] return res - async def get_state_for_groups(self, groups): + async def get_state_for_groups( + self, groups: Collection[int] + ) -> Dict[int, MutableStateMap[str]]: res = {} for group in groups: state = self._group_to_state[group] @@ -152,21 +182,21 @@ async def get_state_for_groups(self, groups): class DictObj(dict): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(kwargs) self.__dict__ = self class Graph: - def __init__(self, nodes, edges): - events = {} - clobbered = set(events.keys()) + def __init__(self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]): + events: Dict[str, EventBase] = {} + clobbered: Set[str] = set() for event_id, fields in nodes.items(): refs = edges.get(event_id) if refs: clobbered.difference_update(refs) - prev_events = [(r, {}) for r in refs] + prev_events: List[Tuple[str, dict]] = [(r, {}) for r in refs] else: prev_events = [] @@ -177,15 +207,12 @@ def __init__(self, nodes, edges): self._leaves = clobbered self._events = sorted(events.values(), key=lambda e: e.depth) - def walk(self): + def walk(self) -> Iterator[EventBase]: return iter(self._events) - def get_leaves(self): - return (self._events[i] for i in self._leaves) - class StateTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dummy_store = _DummyStore() storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( @@ -220,7 +247,7 @@ def setUp(self): self.event_id = 0 @defer.inlineCallbacks - def test_branch_no_conflict(self): + def test_branch_no_conflict(self) -> Generator[defer.Deferred, Any, None]: graph = Graph( nodes={ "START": DictObj( @@ -248,6 +275,7 @@ def test_branch_no_conflict(self): ctx_c = context_store["C"] ctx_d = context_store["D"] + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) @@ -255,7 +283,9 @@ def test_branch_no_conflict(self): self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @defer.inlineCallbacks - def test_branch_basic_conflict(self): + def test_branch_basic_conflict( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: graph = Graph( nodes={ "START": DictObj( @@ -280,7 +310,7 @@ def test_branch_basic_conflict(self): self.dummy_store.register_events(graph.walk()) - context_store = {} + context_store: Dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( @@ -294,6 +324,7 @@ def test_branch_basic_conflict(self): ctx_c = context_store["C"] ctx_d = context_store["D"] + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) @@ -301,7 +332,9 @@ def test_branch_basic_conflict(self): self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @defer.inlineCallbacks - def test_branch_have_banned_conflict(self): + def test_branch_have_banned_conflict( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: graph = Graph( nodes={ "START": DictObj( @@ -338,7 +371,7 @@ def test_branch_have_banned_conflict(self): self.dummy_store.register_events(graph.walk()) - context_store = {} + context_store: Dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( @@ -353,13 +386,16 @@ def test_branch_have_banned_conflict(self): ctx_c = context_store["C"] ctx_e = context_store["E"] + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @defer.inlineCallbacks - def test_branch_have_perms_conflict(self): + def test_branch_have_perms_conflict( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: userid1 = "@user_id:example.com" userid2 = "@user_id2:example.com" @@ -413,7 +449,7 @@ def test_branch_have_perms_conflict(self): self.dummy_store.register_events(graph.walk()) - context_store = {} + context_store: Dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( @@ -428,14 +464,17 @@ def test_branch_have_perms_conflict(self): ctx_b = context_store["B"] ctx_d = context_store["D"] + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) - def _add_depths(self, nodes, edges): - def _get_depth(ev): + def _add_depths( + self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]] + ) -> None: + def _get_depth(ev: str) -> int: node = nodes[ev] if "depth" not in node: prevs = edges[ev] @@ -447,7 +486,9 @@ def _get_depth(ev): _get_depth(n) @defer.inlineCallbacks - def test_annotate_with_old_message(self): + def test_annotate_with_old_message( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: event = create_event(type="test_message", name="event") old_state = [ @@ -456,6 +497,7 @@ def test_annotate_with_old_message(self): create_event(type="test2", state_key=""), ] + context: EventContext context = yield defer.ensureDeferred( self.state.compute_event_context( event, @@ -466,9 +508,11 @@ def test_annotate_with_old_message(self): ) ) + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) + current_state_ids: StateMap[str] current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() @@ -478,7 +522,9 @@ def test_annotate_with_old_message(self): self.assertEqual(context.state_group_before_event, context.state_group) @defer.inlineCallbacks - def test_annotate_with_old_state(self): + def test_annotate_with_old_state( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: event = create_event(type="state", state_key="", name="event") old_state = [ @@ -487,6 +533,7 @@ def test_annotate_with_old_state(self): create_event(type="test2", state_key=""), ] + context: EventContext context = yield defer.ensureDeferred( self.state.compute_event_context( event, @@ -497,9 +544,11 @@ def test_annotate_with_old_state(self): ) ) + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) + current_state_ids: StateMap[str] current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() @@ -511,7 +560,9 @@ def test_annotate_with_old_state(self): self.assertEqual({("state", ""): event.event_id}, context.delta_ids) @defer.inlineCallbacks - def test_trivial_annotate_message(self): + def test_trivial_annotate_message( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: prev_event_id = "prev_event_id" event = create_event( type="test_message", name="event2", prev_events=[(prev_event_id, {})] @@ -534,8 +585,10 @@ def test_trivial_annotate_message(self): ) self.dummy_store.register_event_id_state_group(prev_event_id, group_name) + context: EventContext context = yield defer.ensureDeferred(self.state.compute_event_context(event)) + current_state_ids: StateMap[str] current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( @@ -545,7 +598,9 @@ def test_trivial_annotate_message(self): self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks - def test_trivial_annotate_state(self): + def test_trivial_annotate_state( + self, + ) -> Generator["defer.Deferred[object]", Any, None]: prev_event_id = "prev_event_id" event = create_event( type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})] @@ -568,8 +623,10 @@ def test_trivial_annotate_state(self): ) self.dummy_store.register_event_id_state_group(prev_event_id, group_name) + context: EventContext context = yield defer.ensureDeferred(self.state.compute_event_context(event)) + prev_state_ids: StateMap[str] prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) @@ -577,7 +634,9 @@ def test_trivial_annotate_state(self): self.assertIsNotNone(context.state_group) @defer.inlineCallbacks - def test_resolve_message_conflict(self): + def test_resolve_message_conflict( + self, + ) -> Generator["defer.Deferred[Any]", Any, None]: prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( @@ -605,10 +664,12 @@ def test_resolve_message_conflict(self): self.dummy_store.register_events(old_state_1) self.dummy_store.register_events(old_state_2) + context: EventContext context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) + current_state_ids: StateMap[str] current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -616,7 +677,9 @@ def test_resolve_message_conflict(self): self.assertIsNotNone(context.state_group) @defer.inlineCallbacks - def test_resolve_state_conflict(self): + def test_resolve_state_conflict( + self, + ) -> Generator["defer.Deferred[Any]", Any, None]: prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( @@ -645,12 +708,14 @@ def test_resolve_state_conflict(self): store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events + self.dummy_store.get_events = store.get_events # type: ignore[assignment] + context: EventContext context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) + current_state_ids: StateMap[str] current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -658,7 +723,9 @@ def test_resolve_state_conflict(self): self.assertIsNotNone(context.state_group) @defer.inlineCallbacks - def test_standard_depth_conflict(self): + def test_standard_depth_conflict( + self, + ) -> Generator["defer.Deferred[Any]", Any, None]: prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( @@ -700,12 +767,14 @@ def test_standard_depth_conflict(self): store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events + self.dummy_store.get_events = store.get_events # type: ignore[assignment] + context: EventContext context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) + current_state_ids: StateMap[str] current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -740,8 +809,14 @@ def test_standard_depth_conflict(self): @defer.inlineCallbacks def _get_context( - self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 - ): + self, + event: EventBase, + prev_event_id_1: str, + old_state_1: Collection[EventBase], + prev_event_id_2: str, + old_state_2: Collection[EventBase], + ) -> Generator["defer.Deferred[object]", Any, EventContext]: + sg1: int sg1 = yield defer.ensureDeferred( self.dummy_store.store_state_group( prev_event_id_1, @@ -753,6 +828,7 @@ def _get_context( ) self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1) + sg2: int sg2 = yield defer.ensureDeferred( self.dummy_store.store_state_group( prev_event_id_2, @@ -767,7 +843,7 @@ def _get_context( result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result - def test_make_state_cache_entry(self): + def test_make_state_cache_entry(self) -> None: "Test that calculating a prev_group and delta is correct" new_state = { diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index abd7459a8cb7..52424aa08713 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -14,9 +14,12 @@ from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.internet.interfaces import IReactorTime +from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock from synapse.rest.client.register import register_servlets +from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -25,7 +28,7 @@ class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config.update( { @@ -40,17 +43,21 @@ def default_config(self): ) return config - def prepare(self, reactor, clock, hs): - self.clock = MemoryReactorClock() + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # type-ignore: mypy-zope doesn't seem to recognise that MemoryReactorClock + # implements IReactorTime, via inheritance from twisted.internet.testing.Clock + self.clock: IReactorTime = MemoryReactorClock() # type: ignore[assignment] self.hs_clock = Clock(self.clock) self.url = "/_matrix/client/r0/register" self.registration_handler = Mock() self.auth_handler = Mock() self.device_handler = Mock() - def test_ui_auth(self): + def test_ui_auth(self) -> None: # Do a UI auth request - request_data = {"username": "kermit", "password": "monkey"} + request_data: JsonDict = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.code, 401, channel.result) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 9228454c9e78..304c7b98c5c9 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -17,6 +17,7 @@ import twisted.logger from synapse.logging.context import LoggingContextFilter +from synapse.synapse_rust import reset_logging_config class ToTwistedHandler(logging.Handler): @@ -52,3 +53,5 @@ def setup_logging(): log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") root_logger.setLevel(log_level) + + reset_logging_config() diff --git a/tests/unittest.py b/tests/unittest.py index a120c2976ccd..fa92dd94eb6a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -75,6 +75,7 @@ from tests.server import ( CustomHeaderType, FakeChannel, + ThreadedMemoryReactorClock, get_clock, make_request, setup_test_homeserver, @@ -360,7 +361,7 @@ def wait_for_background_updates(self) -> None: store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor: MemoryReactor, clock: Clock): + def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): """ Make and return a homeserver. diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 5b327b390ef2..fe4961dcf3ab 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.config.homeserver import HomeServerConfig @@ -57,6 +58,7 @@ def test_concurrent_limit(self) -> None: # ... until we complete an earlier request cm2.__exit__(None, None, None) + reactor.advance(0.0) self.successResultOf(d3) def test_sleep_limit(self) -> None: @@ -81,6 +83,43 @@ def test_sleep_limit(self) -> None: sleep_time = _await_resolution(reactor, d3) self.assertAlmostEqual(sleep_time, 500, places=3) + def test_lots_of_queued_things(self) -> None: + """Tests lots of synchronous things queued up behind a slow thing. + + The stack should *not* explode when the slow thing completes. + """ + reactor, clock = get_clock() + rc_config = build_rc_config( + { + "rc_federation": { + "sleep_limit": 1000000000, # never sleep + "reject_limit": 1000000000, # never reject requests + "concurrent": 1, + } + } + ) + ratelimiter = FederationRateLimiter(clock, rc_config) + + with ratelimiter.ratelimit("testhost") as d: + # shouldn't block + self.successResultOf(d) + + async def task() -> None: + with ratelimiter.ratelimit("testhost") as d: + await d + + for _ in range(1, 100): + defer.ensureDeferred(task()) + + last_task = defer.ensureDeferred(task()) + + # Upon exiting the context manager, all the synchronous things will resume. + # If a stack overflow occurs, the final task will not complete. + + # Wait for all the things to complete. + reactor.advance(0.0) + self.successResultOf(last_task) + def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float: """advance the clock until the deferred completes.