diff --git a/.asf.yaml b/.asf.yaml index 848391d29d32..defb99e8c37f 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -46,6 +46,66 @@ github: merge: true rebase: false + protected_branches: + master: {} + release-2.49.0: {} + release-2.48.0: {} + release-2.47.0: {} + release-2.46.0: {} + release-2.45.0: {} + release-2.44.0: {} + release-2.43.0: {} + release-2.42.0: {} + release-2.41.0: {} + release-2.40.0: {} + release-2.39.0: {} + release-2.38.0: {} + release-2.37.0: {} + release-2.36.0: {} + release-2.35.0: {} + release-2.34.0: {} + release-2.33.0: {} + release-2.32.0: {} + release-2.31.0: {} + release-2.30.0: {} + release-2.29.0: {} + release-2.28.0: {} + release-2.27.0: {} + release-2.26.0: {} + release-2.25.0: {} + release-2.24.0: {} + release-2.23.0: {} + release-2.22.0: {} + release-2.21.0: {} + release-2.20.0: {} + release-2.19.0: {} + release-2.18.0: {} + release-2.17.0: {} + release-2.16.0: {} + release-2.15.0: {} + release-2.14.0: {} + release-2.13.0: {} + release-2.12.0: {} + release-2.11.0: {} + release-2.10.0: {} + release-2.8.0: {} + release-2.8.0: {} + release-2.7.0: {} + release-2.6.0: {} + release-2.5.0: {} + release-2.4.0: {} + release-2.3.0: {} + release-2.2.0: {} + release-2.1.1: {} + release-2.1.0: {} + release-0.6.0: {} + release-0.5.0: {} + release-0.4.0: {} + release-0.4.0-incubating: {} + release-0.3.0-incubating: {} + release-0.2.0-incubating: {} + release-0.1.0-incubating: {} + notifications: commits: commits@beam.apache.org issues: github@beam.apache.org diff --git a/.github/workflows/README.md b/.github/workflows/README.md index ffbb49c63670..083a1326ae76 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -20,6 +20,7 @@ Please note that jobs with matrix need to have matrix element in the comment. Ex |:-------------:|:------:|:--------------:|:-----------:| | [ Go PreCommit ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Go.yml) | N/A |`Run Go PreCommit`| [![Go PreCommit](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Go.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Go.yml) | | [ Python PreCommit Docker ](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_DockerBuild.yml) | ['3.8','3.9','3.10','3.11'] | `Run PythonDocker PreCommit (matrix_element)`| [![.github/workflows/job_PreCommit_Python_DockerBuild.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_DockerBuild.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_DockerBuild.yml) | +| [ Python PreCommit Docs ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_PythonDocs.yml) | N/A | `Run PythonDocs PreCommit`| [![.github/workflows/beam_PreCommit_PythonDocs.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_PythonDocs.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_PythonDocs.yml) | | [ Python PreCommit Formatter ](https://github.com/apache/beam/actions/workflows/job_PreCommit_PythonAutoformatter.yml) | N/A | `Run PythonFormatter PreCommit`| [![.github/workflows/job_PreCommit_PythonAutoformatter.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_PythonAutoformatter.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_PythonAutoformatter.yml) | | [ Python PreCommit Coverage ](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Coverage.yml) | N/A | `Run Python_Coverage PreCommit`| [![.github/workflows/job_PreCommit_Python_Coverage.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Coverage.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Coverage.yml) | | [ Python PreCommit Dataframes ](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Dataframes.yml) | ['3.8','3.9','3.10','3.11'] | `Run Python_Dataframes PreCommit (matrix_element)`| [![.github/workflows/job_PreCommit_Python_Dataframes.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Dataframes.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Dataframes.yml) | @@ -28,4 +29,5 @@ Please note that jobs with matrix need to have matrix element in the comment. Ex | [ Python PreCommit Lint ](https://github.com/apache/beam/actions/workflows/job_PreCommit_PythonLint.yml) | N/A | `Run PythonLint PreCommit` | [![.github/workflows/job_PreCommit_PythonLint.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_PythonLint.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_PythonLint.yml) | | [ Python PreCommit Transforms ](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Transforms.yml) | ['3.8','3.9','3.10','3.11'] | `Run Python_Transforms PreCommit (matrix_element)`| [![.github/workflows/job_PreCommit_Python_Transforms.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Transforms.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python_Transforms.yml) | | [ Python PreCommit ](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python.yml) | ['3.8','3.9','3.10','3.11'] | `Run Python PreCommit (matrix_element)` | [![.github/workflows/job_PreCommit_Python.yml](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/job_PreCommit_Python.yml) | +| [ PreCommit Python Integration](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Python_Integration.yml) | ['3.8','3.11'] | `Run Python_Integration PreCommit (matrix_element)` | [![.github/workflows/beam_PreCommit_Python_Integration.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Python_Integration.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Python_Integration.yml) | | [ RAT PreCommit ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_RAT.yml) | N/A | `Run RAT PreCommit` | [![.github/workflows/beam_PreCommit_RAT.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_RAT.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_RAT.yml) | diff --git a/.github/workflows/beam_PreCommit_CommunityMetrics.yml b/.github/workflows/beam_PreCommit_CommunityMetrics.yml new file mode 100644 index 000000000000..31d145a09163 --- /dev/null +++ b/.github/workflows/beam_PreCommit_CommunityMetrics.yml @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Community Metrics + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: ['.test-infra/metrics/**', '.github/workflows/beam_PreCommit_CommunityMetrics.yml'] + pull_request_target: + branches: ['master', 'release-*'] + paths: ['.test-infra/metrics/**'] + issue_comment: + types: [created] + schedule: + - cron: '* */6 * * *' + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +jobs: + beam_PreCommit_CommunityMetrics: + name: beam_PreCommit_CommunityMetrics + runs-on: [self-hosted, ubuntu-20.04, main] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + github.event_name == 'schedule' || + github.event.comment.body == 'Run CommunityMetrics PreCommit' + steps: + - uses: actions/checkout@v3 + - name: Install Java + uses: actions/setup-java@v3.8.0 + with: + distribution: 'zulu' + java-version: '8' + - name: Setup Gradle + uses: gradle/gradle-build-action@v2 + with: + cache-read-only: false + - name: Rerun on comment + if: github.event.comment.body == 'Run CommunityMetrics PreCommit' + uses: ./.github/actions/rerun-job-action + with: + pull_request_url: ${{ github.event.issue.pull_request.url }} + github_repository: ${{ github.repository }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ github.job }} + github_current_run_id: ${{ github.run_id }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Remove default github maven configuration + run: rm ~/.m2/settings.xml + - name: Install docker compose + run: | + sudo curl -L https://github.com/docker/compose/releases/download/1.22.0/docker-compose-$(uname -s)-$(uname -m) -o /usr/local/bin/docker-compose + sudo chmod +x /usr/local/bin/docker-compose + - name: Authenticate on GCP + uses: google-github-actions/setup-gcloud@v0 + with: + service_account_email: ${{ secrets.GCP_SA_EMAIL }} + service_account_key: ${{ secrets.GCP_SA_KEY }} + project_id: ${{ secrets.GCP_PROJECT_ID }} + - name: Install gcloud Kubectl + run: gcloud components install kubectl + - name: run Community Metrics PreCommit script + run: ./gradlew :communityMetricsPreCommit -PKUBE_CONFIG_PATH='$HOME/.kube/config' \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Go.yml b/.github/workflows/beam_PreCommit_Go.yml index 3e32873d6d82..211910c47f49 100644 --- a/.github/workflows/beam_PreCommit_Go.yml +++ b/.github/workflows/beam_PreCommit_Go.yml @@ -28,6 +28,10 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true jobs: beam_PreCommit_Go: diff --git a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml index 04461af058f9..af3b3106cc15 100644 --- a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml +++ b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml @@ -43,6 +43,11 @@ on: schedule: - cron: '* */6 * * *' +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + #Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event permissions: actions: write diff --git a/.github/workflows/beam_PreCommit_PythonDocs.yml b/.github/workflows/beam_PreCommit_PythonDocs.yml new file mode 100644 index 000000000000..2d7036fd5a24 --- /dev/null +++ b/.github/workflows/beam_PreCommit_PythonDocs.yml @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Python PreCommit Docs + +on: + pull_request_target: + branches: [ "master", "release-*" ] + paths: ["sdks/python/**"] + issue_comment: + types: [created] + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: ["sdks/python/**",".github/workflows/beam_PreCommit_PythonDocs.yml"] + schedule: + - cron: '* */6 * * *' +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +jobs: + beam_PreCommit_PythonDocs: + name: beam_PreCommit_PythonDocs + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + github.event.comment.body == 'Run PythonDocs PreCommit' || + github.event_name == 'schedule' + runs-on: [self-hosted, ubuntu-20.04, main] + steps: + - name: Check out repository code + uses: actions/checkout@v3 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Rerun on comment + if: github.event.comment.body == 'Run PythonDocs PreCommit' + uses: ./.github/actions/rerun-job-action + with: + pull_request_url: ${{ github.event.issue.pull_request.url }} + github_repository: ${{ github.repository }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ github.job }} + github_current_run_id: ${{ github.run_id }} + - name: Install Java + uses: actions/setup-java@v3.8.0 + with: + distribution: 'zulu' + java-version: '8' + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Setup Gradle + uses: gradle/gradle-build-action@v2 + with: + cache-read-only: false + - name: run PythonDocsPrecommit script + run: ./gradlew :pythonDocsPreCommit diff --git a/.github/workflows/beam_PreCommit_Python_Integration.yml b/.github/workflows/beam_PreCommit_Python_Integration.yml new file mode 100644 index 000000000000..4316192349b4 --- /dev/null +++ b/.github/workflows/beam_PreCommit_Python_Integration.yml @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Python Integration + +on: + pull_request_target: + branches: [ "master", "release-*" ] + paths: ["model/**", "sdks/python/**", "release/**"] + issue_comment: + types: [created] + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: ["model/**", "sdks/python/**", "release/**", ".github/workflows/beam_PreCommit_Python_Integration.yml"] + schedule: + - cron: '* */6 * * *' + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + +jobs: + beam_PreCommit_Python_Integration: + name: beam_PreCommit_Python_Integration + strategy: + fail-fast: false + matrix: + python_version: ['3.8', '3.11'] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + startsWith(github.event.comment.body, 'Run Python_Integration PreCommit') || + github.event_name == 'schedule' + runs-on: [self-hosted, ubuntu-20.04, main] + steps: + - name: Check out repository code + uses: actions/checkout@v3 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Set comment body with matrix + id: set_comment_body + run: | + echo "comment_body=Run Python_Integration PreCommit (${{ matrix.python_version }})" >> $GITHUB_OUTPUT + - name: Rerun on comment + if: github.event.comment.body == steps.set_comment_body.outputs.comment_body + uses: ./.github/actions/rerun-job-action + with: + pull_request_url: ${{ github.event.issue.pull_request.url }} + github_repository: ${{ github.repository }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: "${{ github.job }} (${{ matrix.python_version }})" + github_current_run_id: ${{ github.run_id }} + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + - name: Install Java + uses: actions/setup-java@v3.8.0 + with: + distribution: 'zulu' + java-version: '8' + - name: Setup Gradle + uses: gradle/gradle-build-action@v2 + with: + cache-read-only: false + - name: run Python Integration PreCommit batch script + run: | + PY_VER=${{ matrix.python_version }} + ./gradlew :sdks:python:test-suites:dataflow:py${PY_VER//.}:preCommitIT_batch -PuseWheelDistribution -PpythonVersion=${PY_VER} + - name: run Python Integration PreCommit streaming script + run: | + PY_VER=${{ matrix.python_version }} + ./gradlew :sdks:python:test-suites:dataflow:py${PY_VER//.}:preCommitIT_streaming -PuseWheelDistribution -PpythonVersion=${PY_VER} + - name: Archive code coverage results + uses: actions/upload-artifact@v3 + with: + name: python-code-coverage-report + path: "**/pytest*.xml" diff --git a/.github/workflows/beam_PreCommit_Typescript.yml b/.github/workflows/beam_PreCommit_Typescript.yml index 6d118525d1cb..c415da70a0d0 100644 --- a/.github/workflows/beam_PreCommit_Typescript.yml +++ b/.github/workflows/beam_PreCommit_Typescript.yml @@ -29,6 +29,11 @@ on: types: [created] schedule: - cron: '* */6 * * *' + + # This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true #Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event permissions: diff --git a/.github/workflows/beam_PreCommit_Website.yml b/.github/workflows/beam_PreCommit_Website.yml index a2a59c294319..88441fcb1b4d 100644 --- a/.github/workflows/beam_PreCommit_Website.yml +++ b/.github/workflows/beam_PreCommit_Website.yml @@ -43,6 +43,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Website: if: | diff --git a/.github/workflows/beam_PreCommit_Whitespace.yml b/.github/workflows/beam_PreCommit_Whitespace.yml index edb548fd6bc1..1d85fd64cf3b 100644 --- a/.github/workflows/beam_PreCommit_Whitespace.yml +++ b/.github/workflows/beam_PreCommit_Whitespace.yml @@ -42,6 +42,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Whitespace: if: | diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 3befac33f223..3df0281059f7 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -269,7 +269,7 @@ jobs: # TODO: https://github.com/apache/beam/issues/23048 CIBW_SKIP: "*-musllinux_*" CIBW_ENVIRONMENT: "SETUPTOOLS_USE_DISTUTILS=stdlib" - CIBW_BEFORE_BUILD: pip install cython numpy && pip install --upgrade setuptools + CIBW_BEFORE_BUILD: pip install cython==0.29.36 numpy && pip install --upgrade setuptools run: cibuildwheel --print-build-identifiers && cibuildwheel --output-dir wheelhouse shell: bash - name: install sha512sum on MacOS @@ -295,7 +295,7 @@ jobs: # TODO: https://github.com/apache/beam/issues/23048 CIBW_SKIP: "*-musllinux_*" CIBW_ENVIRONMENT: "SETUPTOOLS_USE_DISTUTILS=stdlib" - CIBW_BEFORE_BUILD: pip install cython numpy && pip install --upgrade setuptools + CIBW_BEFORE_BUILD: pip install cython==0.29.36 numpy && pip install --upgrade setuptools run: cibuildwheel --print-build-identifiers && cibuildwheel --output-dir wheelhouse shell: bash - name: Add RC checksums diff --git a/.github/workflows/cut_release_branch.yml b/.github/workflows/cut_release_branch.yml index f026c41ca9c3..4e104d78a445 100644 --- a/.github/workflows/cut_release_branch.yml +++ b/.github/workflows/cut_release_branch.yml @@ -100,6 +100,7 @@ jobs: MASTER_BRANCH: master NEXT_RELEASE: ${{ github.event.inputs.NEXT_VERSION }} SCRIPT_DIR: ./release/src/main/scripts + RELEASE: ${{ github.event.inputs.RELEASE_VERSION }} steps: - name: Mask Jenkins token run: | @@ -146,10 +147,13 @@ jobs: fi done - cat /tmp/result | sort | uniq | grep -i -E 'precommit|postcommit|validates|vr|example|test|gradle build' | grep -v -i -E 'load|perf|website' >> release/src/main/scripts/jenkins_jobs.txt + cat /tmp/result | sort | uniq | grep -i -E 'precommit|postcommit|validates|vr|example|test' | grep -v -i -E 'load|perf|website' >> release/src/main/scripts/jenkins_jobs.txt env: JENKINS_USERNAME: ${{ github.event.inputs.JENKINS_USERNAME }} JENKINS_TOKEN: ${{ github.event.inputs.JENKINS_TOKEN }} + - name: Update .asf.yaml to protect new release branch from force push + run: | + sed -i -e "s/master: {}/master: {}\n release-${RELEASE}: {}/g" .asf.yaml - name: Update master branch run: | bash "${SCRIPT_DIR}/set_version.sh" "${NEXT_VERSION_IN_BASE_BRANCH}" @@ -159,6 +163,7 @@ jobs: - name: Commit and Push to master branch files with Next Version run: | git add * + git add .asf.yaml git commit -m "Moving to ${NEXT_VERSION_IN_BASE_BRANCH}-SNAPSHOT on master branch." git push origin ${MASTER_BRANCH} @@ -170,6 +175,7 @@ jobs: REMOTE_NAME: remote_repo REMOTE_URL: ${{ github.server_url }}/${{ github.repository }} BRANCH_NAME: snapshot_build-${{ github.event.inputs.RELEASE_VERSION }} + RELEASE_BRANCH: release-${{ github.event.inputs.RELEASE_VERSION }} steps: - name: Install Hub run: | @@ -187,7 +193,7 @@ jobs: run: | git remote add ${REMOTE_NAME} ${REMOTE_URL} git checkout -b ${BRANCH_NAME} - touch empty_file.txt + touch empty_file.json git add -A git commit -m "Add empty file in order to create PR" git push -f ${REMOTE_NAME} @@ -195,7 +201,7 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - hub pull-request -F- <<<"[DO NOT MERGE]Start snapshot build for release process + hub pull-request -b apache:${RELEASE_BRANCH} -F- <<<"[DO NOT MERGE]Start snapshot build for release process Run Gradle Publish" diff --git a/.github/workflows/job_PreCommit_Python.yml b/.github/workflows/job_PreCommit_Python.yml index 49d095961b2a..7e2c4a914ac8 100644 --- a/.github/workflows/job_PreCommit_Python.yml +++ b/.github/workflows/job_PreCommit_Python.yml @@ -41,6 +41,12 @@ permissions: repository-projects: read security-events: read statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Python: strategy: diff --git a/.github/workflows/job_PreCommit_PythonAutoformatter.yml b/.github/workflows/job_PreCommit_PythonAutoformatter.yml index f41e5be05d6e..73c5a2202b21 100644 --- a/.github/workflows/job_PreCommit_PythonAutoformatter.yml +++ b/.github/workflows/job_PreCommit_PythonAutoformatter.yml @@ -41,6 +41,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_PythonFormatter: if: | diff --git a/.github/workflows/job_Precommit_PythonLint.yml b/.github/workflows/job_PreCommit_PythonLint.yml similarity index 92% rename from .github/workflows/job_Precommit_PythonLint.yml rename to .github/workflows/job_PreCommit_PythonLint.yml index f75c860af190..2b3586bd37ba 100644 --- a/.github/workflows/job_Precommit_PythonLint.yml +++ b/.github/workflows/job_PreCommit_PythonLint.yml @@ -41,6 +41,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_PythonLint: if: | diff --git a/.github/workflows/job_Precommit_Python_Coverage.yml b/.github/workflows/job_PreCommit_Python_Coverage.yml similarity index 93% rename from .github/workflows/job_Precommit_Python_Coverage.yml rename to .github/workflows/job_PreCommit_Python_Coverage.yml index c4821622322b..8f4448cb79fc 100644 --- a/.github/workflows/job_Precommit_Python_Coverage.yml +++ b/.github/workflows/job_PreCommit_Python_Coverage.yml @@ -42,6 +42,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Python_Coverage: if: | diff --git a/.github/workflows/job_PreCommit_Python_Dataframes.yml b/.github/workflows/job_PreCommit_Python_Dataframes.yml index 967bb65749ff..d33294148446 100644 --- a/.github/workflows/job_PreCommit_Python_Dataframes.yml +++ b/.github/workflows/job_PreCommit_Python_Dataframes.yml @@ -40,6 +40,12 @@ permissions: repository-projects: read security-events: read statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Python_Dataframes: strategy: diff --git a/.github/workflows/job_PreCommit_Python_DockerBuild.yml b/.github/workflows/job_PreCommit_Python_DockerBuild.yml index 2b0fe8a4100e..beaf9a8dfd98 100644 --- a/.github/workflows/job_PreCommit_Python_DockerBuild.yml +++ b/.github/workflows/job_PreCommit_Python_DockerBuild.yml @@ -41,6 +41,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_PythonDocker: strategy: diff --git a/.github/workflows/job_PreCommit_Python_Examples.yml b/.github/workflows/job_PreCommit_Python_Examples.yml index 9e26188c4dbd..8cce583cfecb 100644 --- a/.github/workflows/job_PreCommit_Python_Examples.yml +++ b/.github/workflows/job_PreCommit_Python_Examples.yml @@ -41,6 +41,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Python_Examples: strategy: diff --git a/.github/workflows/job_PreCommit_Python_Runners.yml b/.github/workflows/job_PreCommit_Python_Runners.yml index 5325f32f3c1c..394b9bf092c1 100644 --- a/.github/workflows/job_PreCommit_Python_Runners.yml +++ b/.github/workflows/job_PreCommit_Python_Runners.yml @@ -42,6 +42,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Python_Runners: strategy: diff --git a/.github/workflows/job_PreCommit_Python_Transforms.yml b/.github/workflows/job_PreCommit_Python_Transforms.yml index 3266ca6f3cc2..893d8a9a730b 100644 --- a/.github/workflows/job_PreCommit_Python_Transforms.yml +++ b/.github/workflows/job_PreCommit_Python_Transforms.yml @@ -42,6 +42,11 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + jobs: beam_PreCommit_Python_Transforms: strategy: diff --git a/.github/workflows/run_rc_validation.yml b/.github/workflows/run_rc_validation.yml index 7a98942eed1f..720150b57450 100644 --- a/.github/workflows/run_rc_validation.yml +++ b/.github/workflows/run_rc_validation.yml @@ -92,8 +92,8 @@ jobs: - name: Create Pull Request run: | git checkout -b ${{env.WORKING_BRANCH}} ${{ env.RC_TAG }} --quiet - touch empty_file.txt - git add empty_file.txt + touch empty_file.json + git add empty_file.json git commit -m "Add empty file in order to create PR" --quiet git push origin ${{env.WORKING_BRANCH}} --quiet GITHUB_PR_URL=$(gh pr create -B ${{env.RELEASE_BRANCH}} -H ${{env.WORKING_BRANCH}} -t"[DO NOT MERGE] Run Python RC Validation Tests" -b "PR to run Python ReleaseCandidate Jenkins Job.") diff --git a/.test-infra/jenkins/README.md b/.test-infra/jenkins/README.md index 2c04fec63f52..02cddfdc65c7 100644 --- a/.test-infra/jenkins/README.md +++ b/.test-infra/jenkins/README.md @@ -295,7 +295,6 @@ Beam Jenkins overview page: [link](https://ci-beam.apache.org/) | beam_Publish_Beam_SDK_Snapshots | [cron](https://ci-beam.apache.org/job/beam_Publish_Beam_SDK_Snapshots/)| N/A | [![Build Status](https://ci-beam.apache.org/job/beam_Publish_Beam_SDK_Snapshots/badge/icon)](https://ci-beam.apache.org/job/beam_Publish_Beam_SDK_Snapshots/) | | beam_Publish_Docker_Snapshots | [cron](https://ci-beam.apache.org/job/beam_Publish_Docker_Snapshots/)| N/A | [![Build Status](https://ci-beam.apache.org/job/beam_Publish_Docker_Snapshots/badge/icon)](https://ci-beam.apache.org/job/beam_Publish_Docker_Snapshots/) | | beam_PostRelease_Python_Candidate | [cron](https://ci-beam.apache.org/job/beam_PostRelease_Python_Candidate/)| `Run Python ReleaseCandidate` | [![Build Status](https://ci-beam.apache.org/job/beam_PostRelease_Python_Candidate/badge/icon)](https://ci-beam.apache.org/job/beam_PostRelease_Python_Candidate/) | -| beam_Release_Gradle_Build | [cron](https://ci-beam.apache.org/job/beam_Release_Gradle_Build/) | `Run Release Gradle Build` | [![Build Status](https://ci-beam.apache.org/job/beam_Release_Gradle_Build/badge/icon)](https://ci-beam.apache.org/job/beam_Release_Gradle_Build/) ### Notes: diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy index 13e78617d2c8..db052a0046ce 100644 --- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy +++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy @@ -36,8 +36,8 @@ PostcommitJobBuilder.postCommitJob('beam_PostCommit_Py_VR_Dataflow', 'Run Python steps { gradle { rootBuildScriptDir(commonJobProperties.checkoutDir) - tasks(':sdks:python:test-suites:dataflow:validatesRunnerBatchTestsV2') - tasks(':sdks:python:test-suites:dataflow:validatesRunnerStreamingTestsV2') + tasks(':sdks:python:test-suites:dataflow:validatesRunnerBatchTests') + tasks(':sdks:python:test-suites:dataflow:validatesRunnerStreamingTests') switches('-PuseWheelDistribution') commonJobProperties.setGradleSwitches(delegate) } diff --git a/.test-infra/jenkins/job_Release_Gradle_Build.groovy b/.test-infra/jenkins/job_Release_Gradle_Build.groovy deleted file mode 100644 index ba3efeca85fe..000000000000 --- a/.test-infra/jenkins/job_Release_Gradle_Build.groovy +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import CommonJobProperties as commonJobProperties - -// This job runs complete cycle of Gradle build against the official release -// version. Release manager should use this job to verify a release branch -// after cut. -job('beam_Release_Gradle_Build') { - description('Verify Gradle build against the official release version.') - - // Set common parameters. - commonJobProperties - .setTopLevelMainJobProperties( - delegate, - defaultBranch='master', - defaultTimeout=300) - - // Allows triggering this build against pull requests. - commonJobProperties.enablePhraseTriggeringFromPullRequest( - delegate, - 'Release_Build ("Run Release Gradle Build")', - 'Run Release Gradle Build') - - steps { - gradle { - rootBuildScriptDir(commonJobProperties.checkoutDir) - tasks('build') - commonJobProperties.setGradleSwitches(delegate) - switches('-PisRelease') - switches('--stacktrace') - } - } -} diff --git a/.test-infra/metrics/build.gradle b/.test-infra/metrics/build.gradle index 679ecd35735d..ad7dbdb5fd96 100644 --- a/.test-infra/metrics/build.gradle +++ b/.test-infra/metrics/build.gradle @@ -27,7 +27,7 @@ ext { cluster = 'metrics' zone='us-central1-a' projectName='apache-beam-testing' - kubeConfigPath='/home/jenkins/.kube/config' + kubeConfigPath = project.properties['KUBE_CONFIG_PATH'] ?: '/home/jenkins/.kube/config' } applyGroovyNature() diff --git a/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/README.md b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/README.md new file mode 100644 index 000000000000..19c1a5fe7db6 --- /dev/null +++ b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/README.md @@ -0,0 +1,46 @@ + + +# Overview + +This directory sets up the Kubernetes environment for subsequent modules. + +# Usage + +Follow terraform workflow convention to apply this module. +The following assumes the working directory is at +[.test-infra/pipelines/infrastructure/03.io/api-overuse-study](..). + +## Terraform Init + +Initialize the terraform workspace. + +``` +DIR=01.setup +terraform -chdir=$DIR init +``` + +## Terraform Apply + +Apply the terraform module. + +``` +DIR=01.setup +terraform -chdir=$DIR apply -var-file=common.tfvars +``` \ No newline at end of file diff --git a/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/common.tfvars b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/common.tfvars new file mode 100644 index 000000000000..f71b496b5a20 --- /dev/null +++ b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/common.tfvars @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace = "api-overuse-study" \ No newline at end of file diff --git a/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/namespace.tf b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/namespace.tf new file mode 100644 index 000000000000..f72e08dc08db --- /dev/null +++ b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/namespace.tf @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Provisions the namespace shared by all resources. +resource "kubernetes_namespace" "default" { + metadata { + name = var.namespace + } +} \ No newline at end of file diff --git a/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/provider.tf b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/provider.tf new file mode 100644 index 000000000000..1846a8717469 --- /dev/null +++ b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/provider.tf @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +provider "kubernetes" { + config_path = "~/.kube/config" +} \ No newline at end of file diff --git a/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/variables.tf b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/variables.tf new file mode 100644 index 000000000000..2ae6cc65410d --- /dev/null +++ b/.test-infra/pipelines/infrastructure/03.io/api-overuse-study/variables.tf @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +variable "namespace" { + type = string + description = "The Kubernetes namespace to provision resources" +} \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md index 2e1db58cdb9a..ec1c112ff4df 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -56,6 +56,7 @@ * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). * New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). +* Spark 3.2.2 is used as default version for Spark runner ([#23804](https://github.com/apache/beam/issues/23804)). ## I/Os @@ -70,7 +71,7 @@ ## Breaking Changes -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). +* Legacy runner support removed from Dataflow, all pipelines must use runner v2. ## Deprecations @@ -80,6 +81,7 @@ * Fixed DirectRunner bug in Python SDK where GroupByKey gets empty PCollection and fails when pipeline option `direct_num_workers!=1`. ([#27373](https://github.com/apache/beam/pull/27373)) * Fixed BigQuery I/O bug when estimating size on queries that utilize row-level security ([#27474](https://github.com/apache/beam/pull/27474)) +* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Known Issues diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index 8b2956f98c95..50a146b1f5cd 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -58,7 +58,7 @@ dependencies { runtimeOnly("com.avast.gradle:gradle-docker-compose-plugin:0.16.12") // Enable docker compose tasks runtimeOnly("ca.cutterslade.gradle:gradle-dependency-analyze:1.8.3") // Enable dep analysis runtimeOnly("gradle.plugin.net.ossindex:ossindex-gradle-plugin:0.4.11") // Enable dep vulnerability analysis - runtimeOnly("org.checkerframework:checkerframework-gradle-plugin:0.6.19") // Enable enhanced static checking plugin + runtimeOnly("org.checkerframework:checkerframework-gradle-plugin:0.6.29") // Enable enhanced static checking plugin } // Because buildSrc is built and tested automatically _before_ gradle diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 9637f17dbf8b..852edfe62e41 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -559,7 +559,7 @@ class BeamModulePlugin implements Plugin { def singlestore_jdbc_version = "1.1.4" def slf4j_version = "1.7.30" def spark2_version = "2.4.8" - def spark3_version = "3.1.2" + def spark3_version = "3.2.2" def spotbugs_version = "4.0.6" def testcontainers_version = "1.17.3" def arrow_version = "5.0.0" diff --git a/examples/java/src/main/java/org/apache/beam/examples/complete/game/injector/InjectorUtils.java b/examples/java/src/main/java/org/apache/beam/examples/complete/game/injector/InjectorUtils.java index dbefed2b7cc3..4713d15a05c2 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/complete/game/injector/InjectorUtils.java +++ b/examples/java/src/main/java/org/apache/beam/examples/complete/game/injector/InjectorUtils.java @@ -47,7 +47,7 @@ public static Pubsub getClient(final HttpTransport httpTransport, final JsonFact } if (credential.getClientAuthentication() != null) { System.out.println( - "\n***Warning! You are not using service account credentials to " + "\n***Error! You are not using service account credentials to " + "authenticate.\nYou need to use service account credentials for this example," + "\nsince user-level credentials do not have enough pubsub quota,\nand so you will run " + "out of PubSub quota very quickly.\nSee " diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index efa37ea996e3..66d144ab2310 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -46,6 +46,52 @@ import "google/protobuf/struct.proto"; import "google/protobuf/timestamp.proto"; import "google/protobuf/duration.proto"; + +// Describes transforms necessary to execute Beam over the FnAPI but are +// implementation details rather than part of the core model. +message FnApiTransforms { + enum Runner { + // DataSource is a Root Transform, and a source of data for downstream + // transforms in the same ProcessBundleDescriptor. + // It represents a stream of values coming in from an external source/over + // a data channel, typically from the runner. It's not the PCollection itself + // but a description of how to get the portion of the PCollection for a given + // bundle. + // + // The DataSource transform is implemented in each SDK and not explicitly + // provided during pipeline construction. A runner inserts the transform + // in ProcessBundleDescriptors to indicate where the bundle + // can retrieve data for an associated ProcessBundleRequest. + // Data for the same request will be retrieved with the matching instruction ID, + // and transform ID determined by the runner. + // + // The DataSource transform will take a stream of bytes from the remote + // source for the matching instruction ID and decode them as windowed + // values using the provided coder ID, which must be a windowed value coder. + // + // Payload: RemoteGrpcPort + DATA_SOURCE = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:runner:source:v1"]; + + // DataSink is a transform that sends PCollection elements to a remote + // port using the Data API. + // + // The DataSink transform is implemented in each SDK and not explicitly + // provided during pipeline construction. A runner inserts the transform in + // ProcessBundleDescriptors to indicate where the bundle can send + // data for each associated ProcessBundleRequest. Data for the same + // request will be sent with the matching instruction ID and transform ID. + // Each PCollection that exits the ProcessBundleDescriptor subgraph will have + // it's own DataSink, keyed by a transform ID determined by the runner. + // + // The DataSink will take in a stream of elements for a given instruction ID + // and encode them for transmission to the remote sink. The coder ID must be + // for a windowed value coder. + // + // Payload: RemoteGrpcPort + DATA_SINK = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:runner:sink:v1"]; + } +} + // A descriptor for connecting to a remote port using the Beam Fn Data API. // Allows for communication between two environments (for example between the // runner and the SDK). diff --git a/playground/README.md b/playground/README.md index 6f69a59d0551..ec324a1bb7d2 100644 --- a/playground/README.md +++ b/playground/README.md @@ -45,7 +45,7 @@ build, test, and deploy the frontend and backend services. **Ubuntu 22.04 and newer:** ```shell - sudo apt install golang` + sudo apt install golang ``` **Other Linux variants:** Follow manual at https://go.dev/doc/install @@ -224,4 +224,4 @@ Several directories in this repository are used for the Beam Playground project. # Contribution guide - Backend: see [backend/README.md](/playground/backend/README.md) and [backend/CONTRIBUTE.md](/playground/backend/CONTRIBUTE.md) -- Frontend: see [frontend/README.md](/playground/frontend/README.md) and [frontend/CONTRIBUTE.md](/playground/frontend/CONTRIBUTE.md) \ No newline at end of file +- Frontend: see [frontend/README.md](/playground/frontend/README.md) and [frontend/CONTRIBUTE.md](/playground/frontend/CONTRIBUTE.md) diff --git a/release/src/main/scripts/jenkins_jobs.txt b/release/src/main/scripts/jenkins_jobs.txt index 00db48bee486..c32b47f1dbf6 100644 --- a/release/src/main/scripts/jenkins_jobs.txt +++ b/release/src/main/scripts/jenkins_jobs.txt @@ -139,7 +139,6 @@ Run Python_Transforms PreCommit,beam_PreCommit_Python_Transforms_Phrase Run Python_Xlang_Gcp_Dataflow PostCommit,beam_PostCommit_Python_Xlang_Gcp_Dataflow_PR Run Python_Xlang_Gcp_Direct PostCommit,beam_PostCommit_Python_Xlang_Gcp_Direct_PR Run RAT PreCommit,beam_PreCommit_RAT_Phrase -Run Release Gradle Build,beam_Release_Gradle_Build Run SQL PostCommit,beam_PostCommit_SQL_PR Run SQL PreCommit,beam_PreCommit_SQL_Phrase Run SQL_Java11 PreCommit,beam_PreCommit_SQL_Java11_Phrase diff --git a/release/src/main/scripts/publish_docker_images.sh b/release/src/main/scripts/publish_docker_images.sh index b189c9c1bfdc..d7c1d1b9599e 100755 --- a/release/src/main/scripts/publish_docker_images.sh +++ b/release/src/main/scripts/publish_docker_images.sh @@ -37,18 +37,27 @@ echo "Which release candidate will be the source of final docker images? (ex: 1) read RC_NUM RC_VERSION="rc${RC_NUM}" -echo "================Confirming Release and RC version===========" +echo "================Pull RC Containers from DockerHub===========" +IMAGES=$(docker search ${DOCKER_IMAGE_DEFAULT_REPO_ROOT}/${DOCKER_IMAGE_DEFAULT_REPO_PREFIX} --format "{{.Name}}" --limit 100) +KNOWN_IMAGES=() echo "We are using ${RC_VERSION} to push docker images for ${RELEASE}." +while read IMAGE; do + # Try pull verified RC from dockerhub. + if docker pull "${IMAGE}:${RELEASE}${RC_VERSION}" 2>/dev/null ; then + KNOWN_IMAGES+=( $IMAGE ) + fi +done < <(echo "${IMAGES}") + +echo "================Confirming Release and RC version===========" echo "Publishing the following images:" -IMAGES=$(docker images --filter "reference=apache/beam_*:${RELEASE}${RC_VERSION}" --format "{{.Repository}}") -echo "${IMAGES}" +# Sort by name for easy examination +IFS=$'\n' KNOWN_IMAGES=($(sort <<<"${KNOWN_IMAGES[*]}")) +unset IFS +printf "%s\n" ${KNOWN_IMAGES[@]} echo "Do you want to proceed? [y|N]" read confirmation if [[ $confirmation = "y" ]]; then - echo "${IMAGES}" | while read IMAGE; do - # Pull verified RC from dockerhub. - docker pull "${IMAGE}:${RELEASE}${RC_VERSION}" - + for IMAGE in "${KNOWN_IMAGES[@]}"; do # Tag with ${RELEASE} and push to dockerhub. docker tag "${IMAGE}:${RELEASE}${RC_VERSION}" "${IMAGE}:${RELEASE}" docker push "${IMAGE}:${RELEASE}" @@ -58,4 +67,4 @@ if [[ $confirmation = "y" ]]; then docker push "${IMAGE}:latest" done -fi \ No newline at end of file +fi diff --git a/release/src/main/scripts/run_rc_validation.sh b/release/src/main/scripts/run_rc_validation.sh index 4ad8af16723b..7f32c2979660 100755 --- a/release/src/main/scripts/run_rc_validation.sh +++ b/release/src/main/scripts/run_rc_validation.sh @@ -271,8 +271,8 @@ echo "This task will create a PR against apache/beam, trigger a jenkins job to r echo "1. Python quickstart validations(batch & streaming)" echo "2. Python MobileGame validations(UserScore, HourlyTeamScore)" if [[ "$python_quickstart_mobile_game" = true && ! -z `which hub` ]]; then - touch empty_file.txt - git add empty_file.txt + touch empty_file.json + git add empty_file.json git commit -m "Add empty file in order to create PR" --quiet git push -f ${GITHUB_USERNAME} --quiet # Create a test PR diff --git a/release/src/main/scripts/verify_release_build.sh b/release/src/main/scripts/verify_release_build.sh index 0f9921c2bf86..214c65cc9ef6 100755 --- a/release/src/main/scripts/verify_release_build.sh +++ b/release/src/main/scripts/verify_release_build.sh @@ -125,8 +125,8 @@ hub version echo "" -echo "==================== 3 Run Gradle Release Build & PostCommit Tests on Jenkins ===================" -echo "[Current Task] Run Gradle release build and all PostCommit Tests against Release Branch on Jenkins." +echo "==================== 3 Run PostCommit Tests on Jenkins ===================" +echo "[Current Task] Run all PostCommit Tests against Release Branch on Jenkins." echo "This task will create a PR against apache/beam." echo "After PR created, you need to comment phrases listed in description in the created PR:" diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/renderer/PipelineDotRendererTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/renderer/PipelineDotRendererTest.java index eca57b9e8b19..8d7e02d70c8d 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/renderer/PipelineDotRendererTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/renderer/PipelineDotRendererTest.java @@ -52,7 +52,10 @@ public void testEmptyPipeline() { @Test public void testCompositePipeline() { - p.apply(Create.timestamped(TimestampedValue.of(KV.of(1, 1), new Instant(1)))) + p.apply( + Create.timestamped( + TimestampedValue.of(KV.of(1, 1), new Instant(1)), + TimestampedValue.of(KV.of(2, 2), new Instant(2)))) .apply(Window.into(FixedWindows.of(Duration.millis(10)))) .apply(Sum.integersPerKey()); assertEquals( diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 74ab21507045..2611a5845135 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -129,6 +129,7 @@ import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.Combine.GroupedValues; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.Impulse; @@ -495,6 +496,23 @@ protected DataflowRunner(DataflowPipelineOptions options) { this.ptransformViewsWithNonDeterministicKeyCoders = new HashSet<>(); } + private static class AlwaysCreateViaRead + implements PTransformOverrideFactory, Create.Values> { + @Override + public PTransformOverrideFactory.PTransformReplacement> + getReplacementTransform( + AppliedPTransform, Create.Values> appliedTransform) { + return PTransformOverrideFactory.PTransformReplacement.of( + appliedTransform.getPipeline().begin(), appliedTransform.getTransform().alwaysUseRead()); + } + + @Override + public final Map, ReplacementOutput> mapOutputs( + Map, PCollection> outputs, PCollection newOutput) { + return ReplacementOutputs.singleton(outputs, newOutput); + } + } + private List getOverrides(boolean streaming) { ImmutableList.Builder overridesBuilder = ImmutableList.builder(); @@ -509,6 +527,13 @@ private List getOverrides(boolean streaming) { PTransformOverride.of( PTransformMatchers.emptyFlatten(), EmptyFlattenAsCreateFactory.instance())); + if (streaming) { + // For update compatibility, always use a Read for Create in streaming mode. + overridesBuilder.add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(Create.Values.class), new AlwaysCreateViaRead())); + } + // By default Dataflow runner replaces single-output ParDo with a ParDoSingle override. // However, we want a different expansion for single-output splittable ParDo. overridesBuilder @@ -661,6 +686,29 @@ private List getOverrides(boolean streaming) { PTransformOverride.of( PTransformMatchers.classEqualTo(ParDo.SingleOutput.class), new PrimitiveParDoSingleFactory())); + + if (streaming) { + // For update compatibility, always use a Read for Create in streaming mode. + overridesBuilder + .add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(Create.Values.class), new AlwaysCreateViaRead())) + // Create is implemented in terms of BoundedRead. + .add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(Read.Bounded.class), + new StreamingBoundedReadOverrideFactory())) + // Streaming Bounded Read is implemented in terms of Streaming Unbounded Read. + .add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(Read.Unbounded.class), + new StreamingUnboundedReadOverrideFactory())) + .add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(ParDo.SingleOutput.class), + new PrimitiveParDoSingleFactory())); + } + return overridesBuilder.build(); } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index 27b98a9c7591..77c0a28981df 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -723,7 +723,7 @@ public void testTaggedNamesOverridden() throws Exception { PCollectionTuple outputs = pipeline - .apply(Create.of(3)) + .apply(Create.of(3, 4)) .apply( ParDo.of( new DoFn() { @@ -775,7 +775,7 @@ public void testBatchStatefulParDoTranslation() throws Exception { TupleTag mainOutputTag = new TupleTag() {}; pipeline - .apply(Create.of(KV.of(1, 1))) + .apply(Create.of(KV.of(1, 1), KV.of(2, 3))) .apply( ParDo.of( new DoFn, Integer>() { @@ -917,7 +917,7 @@ public void processElement(ProcessContext c) { // No need to actually check the pipeline as the ValidatesRunner tests // ensure translation is correct. This is just a quick check to see that translation // does not crash. - assertEquals(24, steps.size()); + assertEquals(25, steps.size()); } /** Smoke test to fail fast if translation of a splittable ParDo in streaming breaks. */ @@ -1057,7 +1057,7 @@ public void testToSingletonTranslationWithIsmSideInput() throws Exception { assertAllStepOutputsHaveUniqueIds(job); List steps = job.getSteps(); - assertEquals(9, steps.size()); + assertEquals(10, steps.size()); @SuppressWarnings("unchecked") List> toIsmRecordOutputs = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java index 0dae356b754b..087dd624ae78 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java @@ -39,6 +39,7 @@ public enum StreamingSystemCounterNames { JAVA_HARNESS_USED_MEMORY("dataflow_java_harness_used_memory"), JAVA_HARNESS_MAX_MEMORY("dataflow_java_harness_max_memory"), JAVA_HARNESS_RESTARTS("dataflow_java_harness_restarts"), + TIME_AT_MAX_ACTIVE_THREADS("dataflow_time_at_max_active_threads"), WINDMILL_QUOTA_THROTTLING("dataflow_streaming_engine_throttled_msecs"), MEMORY_THRASHING("dataflow_streaming_engine_user_worker_thrashing"); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ea5065260a97..04e83d8fee69 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -433,6 +433,7 @@ public int getSize() { // Built-in cumulative counters. private final Counter javaHarnessUsedMemory; private final Counter javaHarnessMaxMemory; + private final Counter timeAtMaxActiveThreads; private final Counter windmillMaxObservedWorkItemCommitBytes; private final Counter memoryThrashing; private ScheduledExecutorService refreshWorkTimer; @@ -611,6 +612,9 @@ public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( this.javaHarnessMaxMemory = pendingCumulativeCounters.longSum( StreamingSystemCounterNames.JAVA_HARNESS_MAX_MEMORY.counterName()); + this.timeAtMaxActiveThreads = + pendingCumulativeCounters.longSum( + StreamingSystemCounterNames.TIME_AT_MAX_ACTIVE_THREADS.counterName()); this.windmillMaxObservedWorkItemCommitBytes = pendingCumulativeCounters.intMax( StreamingSystemCounterNames.WINDMILL_MAX_WORK_ITEM_COMMIT_BYTES.counterName()); @@ -2020,9 +2024,15 @@ private void updateVMMetrics() { javaHarnessMaxMemory.addValue(maxMemory); } + private void updateThreadMetrics() { + timeAtMaxActiveThreads.getAndReset(); + timeAtMaxActiveThreads.addValue(workUnitExecutor.allThreadsActiveTime()); + } + @VisibleForTesting public void reportPeriodicWorkerUpdates() { updateVMMetrics(); + updateThreadMetrics(); try { sendWorkerUpdatesToDataflowService(pendingDeltaCounters, pendingCumulativeCounters); } catch (IOException e) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java index 908221973fae..446934b69589 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java @@ -19,9 +19,9 @@ import java.io.IOException; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; -import org.apache.beam.runners.dataflow.worker.windmill.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; +import org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcWindmillServer; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; @@ -207,7 +207,7 @@ public WindmillServerStub create(PipelineOptions options) { || streamingOptions.isEnableStreamingEngine() || streamingOptions.getLocalWindmillHostport().startsWith("grpc:")) { try { - return new GrpcWindmillServer(streamingOptions); + return GrpcWindmillServer.create(streamingOptions); } catch (IOException e) { throw new RuntimeException("Failed to create GrpcWindmillServer: ", e); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index 1784bbf8e3ba..79a95399d813 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -21,6 +21,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Monitor.Guard; @@ -36,6 +37,9 @@ public class BoundedQueueExecutor { private final Monitor monitor = new Monitor(); private int elementsOutstanding = 0; private long bytesOutstanding = 0; + private final AtomicInteger activeCount = new AtomicInteger(); + private long startTimeMaxActiveThreadsUsed; + private long totalTimeMaxActiveThreadsUsed; public BoundedQueueExecutor( int maximumPoolSize, @@ -51,7 +55,29 @@ public BoundedQueueExecutor( keepAliveTime, unit, new LinkedBlockingQueue<>(), - threadFactory); + threadFactory) { + @Override + protected void beforeExecute(Thread t, Runnable r) { + super.beforeExecute(t, r); + synchronized (this) { + if (activeCount.getAndIncrement() >= maximumPoolSize - 1) { + startTimeMaxActiveThreadsUsed = System.currentTimeMillis(); + } + } + } + + @Override + protected void afterExecute(Runnable r, Throwable t) { + super.afterExecute(r, t); + synchronized (this) { + if (activeCount.getAndDecrement() == maximumPoolSize) { + totalTimeMaxActiveThreadsUsed += + (System.currentTimeMillis() - startTimeMaxActiveThreadsUsed); + startTimeMaxActiveThreadsUsed = 0; + } + } + } + }; executor.allowCoreThreadTimeOut(true); this.maximumElementsOutstanding = maximumElementsOutstanding; this.maximumBytesOutstanding = maximumBytesOutstanding; @@ -89,6 +115,10 @@ public boolean executorQueueIsEmpty() { return executor.getQueue().isEmpty(); } + public long allThreadsActiveTime() { + return totalTimeMaxActiveThreadsUsed; + } + public String summaryHtml() { monitor.enter(); try { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java new file mode 100644 index 000000000000..1f01a8cc09d3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillStream; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Status; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.StatusRuntimeException; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for persistent streams connecting to Windmill. + * + *

This class handles the underlying gRPC StreamObservers, and automatically reconnects the + * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on a + * broken stream. + * + *

Subclasses should override onResponse to handle responses from the server, and onNewStream to + * perform any work that must be done when a new stream is created, such as sending headers or + * retrying requests. + * + *

send and startStream should not be called from onResponse; use executor() instead. + * + *

Synchronization on this is used to synchronize the gRpc stream state and internal data + * structures. Since grpc channel operations may block, synchronization on this stream may also + * block. This is generally not a problem since streams are used in a single-threaded manner. + * However, some accessors used for status page and other debugging need to take care not to require + * synchronizing on this. + */ +public abstract class AbstractWindmillStream implements WindmillStream { + protected static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; + // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce + // per-chunk overhead, and small enough that we can still perform granular flow-control. + protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; + + private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class); + + protected final AtomicBoolean clientClosed; + + private final Executor executor; + private final BackOff backoff; + // Indicates if the current stream in requestObserver is closed by calling close() method + private final AtomicBoolean streamClosed; + private final AtomicLong startTimeMs; + private final AtomicLong lastSendTimeMs; + private final AtomicLong lastResponseTimeMs; + private final AtomicInteger errorCount; + private final AtomicReference lastError; + private final AtomicLong sleepUntil; + private final CountDownLatch finishLatch; + private final Set> streamRegistry; + private final int logEveryNStreamFailures; + private final Supplier> requestObserverSupplier; + private @Nullable StreamObserver requestObserver; + + protected AbstractWindmillStream( + Function, StreamObserver> clientFactory, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures) { + this.executor = + Executors.newSingleThreadExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("WindmillStream-thread") + .build()); + this.backoff = backoff; + this.streamRegistry = streamRegistry; + this.logEveryNStreamFailures = logEveryNStreamFailures; + this.clientClosed = new AtomicBoolean(); + this.streamClosed = new AtomicBoolean(); + this.startTimeMs = new AtomicLong(); + this.lastSendTimeMs = new AtomicLong(); + this.lastResponseTimeMs = new AtomicLong(); + this.errorCount = new AtomicInteger(); + this.lastError = new AtomicReference<>(); + this.sleepUntil = new AtomicLong(); + this.finishLatch = new CountDownLatch(1); + this.requestObserverSupplier = + () -> + streamObserverFactory.from( + clientFactory, new AbstractWindmillStream.ResponseObserver()); + } + + private static long debugDuration(long nowMs, long startMs) { + if (startMs <= 0) { + return -1; + } + return Math.max(0, nowMs - startMs); + } + + /** Called on each response from the server. */ + protected abstract void onResponse(ResponseT response); + + /** Called when a new underlying stream to the server has been opened. */ + protected abstract void onNewStream(); + + /** Returns whether there are any pending requests that should be retried on a stream break. */ + protected abstract boolean hasPendingRequests(); + + /** + * Called when the stream is throttled due to resource exhausted errors. Will be called for each + * resource exhausted error not just the first. onResponse() must stop throttling on receipt of + * the first good message. + */ + protected abstract void startThrottleTimer(); + + private StreamObserver requestObserver() { + if (requestObserver == null) { + throw new NullPointerException( + "requestObserver cannot be null. Missing a call to startStream() to initialize."); + } + + return requestObserver; + } + + /** Send a request to the server. */ + protected final void send(RequestT request) { + lastSendTimeMs.set(Instant.now().getMillis()); + synchronized (this) { + if (streamClosed.get()) { + throw new IllegalStateException("Send called on a client closed stream."); + } + + requestObserver().onNext(request); + } + } + + /** Starts the underlying stream. */ + protected final void startStream() { + // Add the stream to the registry after it has been fully constructed. + streamRegistry.add(this); + while (true) { + try { + synchronized (this) { + startTimeMs.set(Instant.now().getMillis()); + lastResponseTimeMs.set(0); + streamClosed.set(false); + // lazily initialize the requestObserver. Gets reset whenever the stream is reopened. + requestObserver = requestObserverSupplier.get(); + onNewStream(); + if (clientClosed.get()) { + close(); + } + return; + } + } catch (Exception e) { + LOG.error("Failed to create new stream, retrying: ", e); + try { + long sleep = backoff.nextBackOffMillis(); + sleepUntil.set(Instant.now().getMillis() + sleep); + Thread.sleep(sleep); + } catch (InterruptedException | IOException i) { + // Keep trying to create the stream. + } + } + } + } + + protected final Executor executor() { + return executor; + } + + public final synchronized void maybeSendHealthCheck(Instant lastSendThreshold) { + if (lastSendTimeMs.get() < lastSendThreshold.getMillis() && !clientClosed.get()) { + try { + sendHealthCheck(); + } catch (RuntimeException e) { + LOG.debug("Received exception sending health check.", e); + } + } + } + + protected abstract void sendHealthCheck(); + + // Care is taken that synchronization on this is unnecessary for all status page information. + // Blocking sends are made beneath this stream object's lock which could block status page + // rendering. + public final void appendSummaryHtml(PrintWriter writer) { + appendSpecificHtml(writer); + if (errorCount.get() > 0) { + writer.format(", %d errors, last error [ %s ]", errorCount.get(), lastError.get()); + } + if (clientClosed.get()) { + writer.write(", client closed"); + } + long nowMs = Instant.now().getMillis(); + long sleepLeft = sleepUntil.get() - nowMs; + if (sleepLeft > 0) { + writer.format(", %dms backoff remaining", sleepLeft); + } + writer.format( + ", current stream is %dms old, last send %dms, last response %dms, closed: %s", + debugDuration(nowMs, startTimeMs.get()), + debugDuration(nowMs, lastSendTimeMs.get()), + debugDuration(nowMs, lastResponseTimeMs.get()), + streamClosed.get()); + } + + // Don't require synchronization on stream, see the appendSummaryHtml comment. + protected abstract void appendSpecificHtml(PrintWriter writer); + + @Override + public final synchronized void close() { + // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. + clientClosed.set(true); + requestObserver().onCompleted(); + streamClosed.set(true); + } + + @Override + public final boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { + return finishLatch.await(time, unit); + } + + @Override + public final Instant startTime() { + return new Instant(startTimeMs.get()); + } + + private class ResponseObserver implements StreamObserver { + @Override + public void onNext(ResponseT response) { + try { + backoff.reset(); + } catch (IOException e) { + // Ignore. + } + lastResponseTimeMs.set(Instant.now().getMillis()); + onResponse(response); + } + + @Override + public void onError(Throwable t) { + onStreamFinished(t); + } + + @Override + public void onCompleted() { + onStreamFinished(null); + } + + private void onStreamFinished(@Nullable Throwable t) { + synchronized (this) { + if (clientClosed.get() && !hasPendingRequests()) { + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + return; + } + } + if (t != null) { + Status status = null; + if (t instanceof StatusRuntimeException) { + status = ((StatusRuntimeException) t).getStatus(); + } + String statusError = status == null ? "" : status.toString(); + lastError.set(statusError); + if (errorCount.getAndIncrement() % logEveryNStreamFailures == 0) { + long nowMillis = Instant.now().getMillis(); + String responseDebug; + if (lastResponseTimeMs.get() == 0) { + responseDebug = "never received response"; + } else { + responseDebug = + "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; + } + LOG.debug( + "{} streaming Windmill RPC errors for {}, last was: {} with status {}." + + " created {}ms ago, {}. This is normal with autoscaling.", + AbstractWindmillStream.this.getClass(), + errorCount.get(), + t, + statusError, + nowMillis - startTimeMs.get(), + responseDebug); + } + // If the stream was stopped due to a resource exhausted error then we are throttled. + if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { + startThrottleTimer(); + } + + try { + long sleep = backoff.nextBackOffMillis(); + sleepUntil.set(Instant.now().getMillis() + sleep); + Thread.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + // Ignore. + } + } else { + errorCount.incrementAndGet(); + String error = + "Stream completed successfully but did not complete requested operations, " + + "recreating"; + LOG.warn(error); + lastError.set(error); + } + executor.execute(AbstractWindmillStream.this::startStream); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java deleted file mode 100644 index 9dcae93c8d1f..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java +++ /dev/null @@ -1,1852 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill; - -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import java.io.IOException; -import java.io.InputStream; -import java.io.PrintWriter; -import java.io.SequenceInputStream; -import java.net.URI; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Deque; -import java.util.EnumMap; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Random; -import java.util.Set; -import java.util.Timer; -import java.util.TimerTask; -import java.util.concurrent.BlockingDeque; -import java.util.concurrent.CancellationException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingDeque; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; -import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo.Event; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; -import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.CallCredentials; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Channel; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Status; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.StatusRuntimeException; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.auth.MoreCallCredentials; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.netty.GrpcSslContexts; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.netty.NegotiationType; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.netty.NettyChannelBuilder; -import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Verify; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.net.HostAndPort; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.DateTimeUtils.MillisProvider; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** gRPC client for communicating with Windmill Service. */ -// Very likely real potential for bugs - https://github.com/apache/beam/issues/19273 -// Very likely real potential for bugs - https://github.com/apache/beam/issues/19271 -@SuppressFBWarnings({"JLM_JSR166_UTILCONCURRENT_MONITORENTER", "IS2_INCONSISTENT_SYNC"}) -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class GrpcWindmillServer extends WindmillServerStub { - private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class); - - // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively - // high. - private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300; - private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; - - private static final Duration MIN_BACKOFF = Duration.millis(1); - private static final Duration MAX_BACKOFF = Duration.standardSeconds(30); - // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce - // per-chunk overhead, and small enough that we can still granularly flow-control. - private static final int COMMIT_STREAM_CHUNK_SIZE = 2 << 20; - private static final int GET_DATA_STREAM_CHUNK_SIZE = 2 << 20; - - private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE; - - private static final AtomicLong nextId = new AtomicLong(0); - - private final StreamingDataflowWorkerOptions options; - private final int streamingRpcBatchLimit; - private final List stubList = - new ArrayList<>(); - private final List - syncStubList = new ArrayList<>(); - private WindmillApplianceGrpc.WindmillApplianceBlockingStub syncApplianceStub = null; - private long unaryDeadlineSeconds = DEFAULT_UNARY_RPC_DEADLINE_SECONDS; - private long streamDeadlineSeconds = DEFAULT_STREAM_RPC_DEADLINE_SECONDS; - private ImmutableSet endpoints; - private int logEveryNStreamFailures = 20; - private Duration maxBackoff = MAX_BACKOFF; - private final ThrottleTimer getWorkThrottleTimer = new ThrottleTimer(); - private final ThrottleTimer getDataThrottleTimer = new ThrottleTimer(); - private final ThrottleTimer commitWorkThrottleTimer = new ThrottleTimer(); - private final Random rand = new Random(); - - private final Set> streamRegistry = - Collections.newSetFromMap(new ConcurrentHashMap, Boolean>()); - - private final Timer healthCheckTimer; - - public GrpcWindmillServer(StreamingDataflowWorkerOptions options) throws IOException { - this.options = options; - this.streamingRpcBatchLimit = options.getWindmillServiceStreamingRpcBatchLimit(); - this.logEveryNStreamFailures = options.getWindmillServiceStreamingLogEveryNStreamFailures(); - this.endpoints = ImmutableSet.of(); - if (options.getWindmillServiceEndpoint() != null) { - Set endpoints = new HashSet<>(); - for (String endpoint : Splitter.on(',').split(options.getWindmillServiceEndpoint())) { - endpoints.add( - HostAndPort.fromString(endpoint).withDefaultPort(options.getWindmillServicePort())); - } - initializeWindmillService(endpoints); - } else if (!streamingEngineEnabled() && options.getLocalWindmillHostport() != null) { - int portStart = options.getLocalWindmillHostport().lastIndexOf(':'); - String endpoint = options.getLocalWindmillHostport().substring(0, portStart); - assert ("grpc:localhost".equals(endpoint)); - int port = Integer.parseInt(options.getLocalWindmillHostport().substring(portStart + 1)); - this.endpoints = ImmutableSet.of(HostAndPort.fromParts("localhost", port)); - initializeLocalHost(port); - } - if (options.getWindmillServiceStreamingRpcHealthCheckPeriodMs() > 0) { - this.healthCheckTimer = new Timer("WindmillHealthCheckTimer"); - this.healthCheckTimer.schedule( - new TimerTask() { - @Override - public void run() { - Instant reportThreshold = - Instant.now() - .minus( - Duration.millis( - options.getWindmillServiceStreamingRpcHealthCheckPeriodMs())); - for (AbstractWindmillStream stream : streamRegistry) { - stream.maybeSendHealthCheck(reportThreshold); - } - } - }, - 0, - options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()); - } else { - this.healthCheckTimer = null; - } - } - - private GrpcWindmillServer(String name, boolean enableStreamingEngine) { - this.options = PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class); - this.streamingRpcBatchLimit = Integer.MAX_VALUE; - options.setProject("project"); - options.setJobId("job"); - options.setWorkerId("worker"); - if (enableStreamingEngine) { - List experiments = this.options.getExperiments(); - if (experiments == null) { - experiments = new ArrayList<>(); - } - experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT); - options.setExperiments(experiments); - } - this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel(name))); - this.healthCheckTimer = null; - } - - private boolean streamingEngineEnabled() { - return options.isEnableStreamingEngine(); - } - - @Override - public synchronized void setWindmillServiceEndpoints(Set endpoints) - throws IOException { - Preconditions.checkNotNull(endpoints); - if (endpoints.equals(this.endpoints)) { - // The endpoints are equal don't recreate the stubs. - return; - } - LOG.info("Creating a new windmill stub, endpoints: {}", endpoints); - if (this.endpoints != null) { - LOG.info("Previous windmill stub endpoints: {}", this.endpoints); - } - initializeWindmillService(endpoints); - } - - @Override - public synchronized boolean isReady() { - return !stubList.isEmpty(); - } - - private synchronized void initializeLocalHost(int port) throws IOException { - this.logEveryNStreamFailures = 1; - this.maxBackoff = Duration.millis(500); - this.unaryDeadlineSeconds = 10; // For local testing use short deadlines. - Channel channel = localhostChannel(port); - if (streamingEngineEnabled()) { - this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(channel)); - this.syncStubList.add(CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(channel)); - } else { - this.syncApplianceStub = WindmillApplianceGrpc.newBlockingStub(channel); - } - } - - /** - * Create a wrapper around credentials callback that delegates to the underlying vendored {@link - * com.google.auth.RequestMetadataCallback}. Note that this class should override every method - * that is not final and not static and call the delegate directly. - * - *

TODO: Replace this with an auto generated proxy which calls the underlying implementation - * delegate to reduce maintenance burden. - */ - private static class VendoredRequestMetadataCallbackAdapter - implements com.google.auth.RequestMetadataCallback { - private final org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.RequestMetadataCallback - callback; - - private VendoredRequestMetadataCallbackAdapter( - org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.RequestMetadataCallback callback) { - this.callback = callback; - } - - @Override - public void onSuccess(Map> metadata) { - callback.onSuccess(metadata); - } - - @Override - public void onFailure(Throwable exception) { - callback.onFailure(exception); - } - } - - /** - * Create a wrapper around credentials that delegates to the underlying {@link - * com.google.auth.Credentials}. Note that this class should override every method that is not - * final and not static and call the delegate directly. - * - *

TODO: Replace this with an auto generated proxy which calls the underlying implementation - * delegate to reduce maintenance burden. - */ - private static class VendoredCredentialsAdapter - extends org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.Credentials { - private final com.google.auth.Credentials credentials; - - private VendoredCredentialsAdapter(com.google.auth.Credentials credentials) { - this.credentials = credentials; - } - - @Override - public String getAuthenticationType() { - return credentials.getAuthenticationType(); - } - - @Override - public Map> getRequestMetadata() throws IOException { - return credentials.getRequestMetadata(); - } - - @Override - public void getRequestMetadata( - final URI uri, - Executor executor, - final org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.RequestMetadataCallback - callback) { - credentials.getRequestMetadata( - uri, executor, new VendoredRequestMetadataCallbackAdapter(callback)); - } - - @Override - public Map> getRequestMetadata(URI uri) throws IOException { - return credentials.getRequestMetadata(uri); - } - - @Override - public boolean hasRequestMetadata() { - return credentials.hasRequestMetadata(); - } - - @Override - public boolean hasRequestMetadataOnly() { - return credentials.hasRequestMetadataOnly(); - } - - @Override - public void refresh() throws IOException { - credentials.refresh(); - } - } - - private synchronized void initializeWindmillService(Set endpoints) - throws IOException { - LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", endpoints); - this.stubList.clear(); - this.syncStubList.clear(); - this.endpoints = ImmutableSet.copyOf(endpoints); - for (HostAndPort endpoint : this.endpoints) { - if ("localhost".equals(endpoint.getHost())) { - initializeLocalHost(endpoint.getPort()); - } else { - CallCredentials creds = - MoreCallCredentials.from(new VendoredCredentialsAdapter(options.getGcpCredential())); - this.stubList.add( - CloudWindmillServiceV1Alpha1Grpc.newStub(remoteChannel(endpoint)) - .withCallCredentials(creds)); - this.syncStubList.add( - CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(remoteChannel(endpoint)) - .withCallCredentials(creds)); - } - } - } - - @VisibleForTesting - static GrpcWindmillServer newTestInstance(String name, boolean enableStreamingEngine) { - return new GrpcWindmillServer(name, enableStreamingEngine); - } - - private Channel inProcessChannel(String name) { - return InProcessChannelBuilder.forName(name).directExecutor().build(); - } - - private Channel localhostChannel(int port) { - return NettyChannelBuilder.forAddress("localhost", port) - .maxInboundMessageSize(java.lang.Integer.MAX_VALUE) - .negotiationType(NegotiationType.PLAINTEXT) - .build(); - } - - private Channel remoteChannel(HostAndPort endpoint) throws IOException { - NettyChannelBuilder builder = - NettyChannelBuilder.forAddress(endpoint.getHost(), endpoint.getPort()); - int timeoutSec = options.getWindmillServiceRpcChannelAliveTimeoutSec(); - if (timeoutSec > 0) { - builder = - builder - .keepAliveTime(timeoutSec, TimeUnit.SECONDS) - .keepAliveTimeout(timeoutSec, TimeUnit.SECONDS) - .keepAliveWithoutCalls(true); - } - return builder - .flowControlWindow(10 * 1024 * 1024) - .maxInboundMessageSize(java.lang.Integer.MAX_VALUE) - .maxInboundMetadataSize(1024 * 1024) - .negotiationType(NegotiationType.TLS) - // Set ciphers(null) to not use GCM, which is disabled for Dataflow - // due to it being horribly slow. - .sslContext(GrpcSslContexts.forClient().ciphers(null).build()) - .build(); - } - - private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub() { - if (stubList.isEmpty()) { - throw new RuntimeException("windmillServiceEndpoint has not been set"); - } - if (stubList.size() == 1) { - return stubList.get(0); - } - return stubList.get(rand.nextInt(stubList.size())); - } - - private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub - syncStub() { - if (syncStubList.isEmpty()) { - throw new RuntimeException("windmillServiceEndpoint has not been set"); - } - if (syncStubList.size() == 1) { - return syncStubList.get(0); - } - return syncStubList.get(rand.nextInt(syncStubList.size())); - } - - @Override - public void appendSummaryHtml(PrintWriter writer) { - writer.write("Active Streams:
"); - for (AbstractWindmillStream stream : streamRegistry) { - stream.appendSummaryHtml(writer); - writer.write("
"); - } - } - - // Configure backoff to retry calls forever, with a maximum sane retry interval. - private BackOff grpcBackoff() { - return FluentBackoff.DEFAULT - .withInitialBackoff(MIN_BACKOFF) - .withMaxBackoff(maxBackoff) - .backoff(); - } - - private ResponseT callWithBackoff(Supplier function) { - BackOff backoff = grpcBackoff(); - int rpcErrors = 0; - while (true) { - try { - return function.get(); - } catch (StatusRuntimeException e) { - try { - if (++rpcErrors % 20 == 0) { - LOG.warn( - "Many exceptions calling gRPC. Last exception: {} with status {}", - e, - e.getStatus()); - } - if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) { - throw new WindmillServerStub.RpcException(e); - } - } catch (IOException | InterruptedException i) { - if (i instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - WindmillServerStub.RpcException rpcException = new WindmillServerStub.RpcException(e); - rpcException.addSuppressed(i); - throw rpcException; - } - } - } - } - - @Override - public GetWorkResponse getWork(GetWorkRequest request) { - if (syncApplianceStub == null) { - return callWithBackoff( - () -> - syncStub() - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .getWork( - request - .toBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build())); - } else { - return callWithBackoff( - () -> - syncApplianceStub - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .getWork(request)); - } - } - - @Override - public GetDataResponse getData(GetDataRequest request) { - if (syncApplianceStub == null) { - return callWithBackoff( - () -> - syncStub() - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .getData( - request - .toBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .build())); - } else { - return callWithBackoff( - () -> - syncApplianceStub - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .getData(request)); - } - } - - @Override - public CommitWorkResponse commitWork(CommitWorkRequest request) { - if (syncApplianceStub == null) { - return callWithBackoff( - () -> - syncStub() - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .commitWork( - request - .toBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .build())); - } else { - return callWithBackoff( - () -> - syncApplianceStub - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .commitWork(request)); - } - } - - @Override - public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { - return new GrpcGetWorkStream( - GetWorkRequest.newBuilder(request) - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build(), - receiver); - } - - @Override - public GetDataStream getDataStream() { - return new GrpcGetDataStream(); - } - - @Override - public CommitWorkStream commitWorkStream() { - return new GrpcCommitWorkStream(); - } - - @Override - public GetConfigResponse getConfig(GetConfigRequest request) { - if (syncApplianceStub == null) { - throw new RpcException( - new UnsupportedOperationException("GetConfig not supported with windmill service.")); - } else { - return callWithBackoff( - () -> - syncApplianceStub - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .getConfig(request)); - } - } - - @Override - public ReportStatsResponse reportStats(ReportStatsRequest request) { - if (syncApplianceStub == null) { - throw new RpcException( - new UnsupportedOperationException("ReportStats not supported with windmill service.")); - } else { - return callWithBackoff( - () -> - syncApplianceStub - .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) - .reportStats(request)); - } - } - - @Override - public long getAndResetThrottleTime() { - return getWorkThrottleTimer.getAndResetThrottleTime() - + getDataThrottleTimer.getAndResetThrottleTime() - + commitWorkThrottleTimer.getAndResetThrottleTime(); - } - - private JobHeader makeHeader() { - return JobHeader.newBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build(); - } - - /** Returns a long that is unique to this process. */ - private static long uniqueId() { - return nextId.incrementAndGet(); - } - - /** - * Base class for persistent streams connecting to Windmill. - * - *

This class handles the underlying gRPC StreamObservers, and automatically reconnects the - * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on - * a broken stream. - * - *

Subclasses should override onResponse to handle responses from the server, and onNewStream - * to perform any work that must be done when a new stream is created, such as sending headers or - * retrying requests. - * - *

send and startStream should not be called from onResponse; use executor() instead. - * - *

Synchronization on this is used to synchronize the gRpc stream state and internal data - * structures. Since grpc channel operations may block, synchronization on this stream may also - * block. This is generally not a problem since streams are used in a single-threaded manner. - * However some accessors used for status page and other debugging need to take care not to - * require synchronizing on this. - */ - private abstract class AbstractWindmillStream implements WindmillStream { - private final StreamObserverFactory streamObserverFactory = - StreamObserverFactory.direct( - streamDeadlineSeconds * 2, options.getWindmillMessagesBetweenIsReadyChecks()); - private final Function, StreamObserver> clientFactory; - private final Executor executor = - Executors.newSingleThreadExecutor( - new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat("WindmillStream-thread") - .build()); - - // The following should be protected by synchronizing on this, except for - // the atomics which may be read atomically for status pages. - private StreamObserver requestObserver; - // Indicates if the current stream in requestObserver is closed by calling close() method - private final AtomicBoolean streamClosed = new AtomicBoolean(); - private final AtomicLong startTimeMs = new AtomicLong(); - private final AtomicLong lastSendTimeMs = new AtomicLong(); - private final AtomicLong lastResponseTimeMs = new AtomicLong(); - private final AtomicInteger errorCount = new AtomicInteger(); - private final AtomicReference lastError = new AtomicReference<>(); - private final BackOff backoff = grpcBackoff(); - private final AtomicLong sleepUntil = new AtomicLong(); - protected final AtomicBoolean clientClosed = new AtomicBoolean(); - private final CountDownLatch finishLatch = new CountDownLatch(1); - - protected AbstractWindmillStream( - Function, StreamObserver> clientFactory) { - this.clientFactory = clientFactory; - } - - /** Called on each response from the server. */ - protected abstract void onResponse(ResponseT response); - /** Called when a new underlying stream to the server has been opened. */ - protected abstract void onNewStream(); - /** Returns whether there are any pending requests that should be retried on a stream break. */ - protected abstract boolean hasPendingRequests(); - /** - * Called when the stream is throttled due to resource exhausted errors. Will be called for each - * resource exhausted error not just the first. onResponse() must stop throttling on receipt of - * the first good message. - */ - protected abstract void startThrottleTimer(); - /** Send a request to the server. */ - protected final void send(RequestT request) { - lastSendTimeMs.set(Instant.now().getMillis()); - synchronized (this) { - if (streamClosed.get()) { - throw new IllegalStateException("Send called on a client closed stream."); - } - requestObserver.onNext(request); - } - } - - /** Starts the underlying stream. */ - protected final void startStream() { - // Add the stream to the registry after it has been fully constructed. - streamRegistry.add(this); - BackOff backoff = grpcBackoff(); - while (true) { - try { - synchronized (this) { - startTimeMs.set(Instant.now().getMillis()); - lastResponseTimeMs.set(0); - requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver()); - streamClosed.set(false); - onNewStream(); - if (clientClosed.get()) { - close(); - } - return; - } - } catch (Exception e) { - LOG.error("Failed to create new stream, retrying: ", e); - try { - long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException i) { - // Keep trying to create the stream. - } catch (IOException i) { - // Ignore. - } - } - } - } - - protected final Executor executor() { - return executor; - } - - public final synchronized void maybeSendHealthCheck(Instant lastSendThreshold) { - if (lastSendTimeMs.get() < lastSendThreshold.getMillis() && !clientClosed.get()) { - try { - sendHealthCheck(); - } catch (RuntimeException e) { - LOG.debug("Received exception sending health check.", e); - } - } - } - - protected abstract void sendHealthCheck(); - - protected final long debugDuration(long nowMs, long startMs) { - if (startMs <= 0) { - return -1; - } - return Math.max(0, nowMs - startMs); - } - - // Care is taken that synchronization on this is unnecessary for all status page information. - // Blocking sends are made beneath this stream object's lock which could block status page - // rendering. - public final void appendSummaryHtml(PrintWriter writer) { - appendSpecificHtml(writer); - if (errorCount.get() > 0) { - writer.format(", %d errors, last error [ %s ]", errorCount.get(), lastError.get()); - } - if (clientClosed.get()) { - writer.write(", client closed"); - } - long nowMs = Instant.now().getMillis(); - long sleepLeft = sleepUntil.get() - nowMs; - if (sleepLeft > 0) { - writer.format(", %dms backoff remaining", sleepLeft); - } - writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s", - debugDuration(nowMs, startTimeMs.get()), - debugDuration(nowMs, lastSendTimeMs.get()), - debugDuration(nowMs, lastResponseTimeMs.get()), - streamClosed.get()); - } - - // Don't require synchronization on stream, see the appendSummaryHtml comment. - protected abstract void appendSpecificHtml(PrintWriter writer); - - private class ResponseObserver implements StreamObserver { - @Override - public void onNext(ResponseT response) { - try { - backoff.reset(); - } catch (IOException e) { - // Ignore. - } - lastResponseTimeMs.set(Instant.now().getMillis()); - onResponse(response); - } - - @Override - public void onError(Throwable t) { - onStreamFinished(t); - } - - @Override - public void onCompleted() { - onStreamFinished(null); - } - - private void onStreamFinished(@Nullable Throwable t) { - synchronized (this) { - if (clientClosed.get() && !hasPendingRequests()) { - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - return; - } - } - if (t != null) { - Status status = null; - if (t instanceof StatusRuntimeException) { - status = ((StatusRuntimeException) t).getStatus(); - } - String statusError = status.toString(); - lastError.set(statusError); - if (errorCount.getAndIncrement() % logEveryNStreamFailures == 0) { - long nowMillis = Instant.now().getMillis(); - String responseDebug; - if (lastResponseTimeMs.get() == 0) { - responseDebug = "never received response"; - } else { - responseDebug = - "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; - } - LOG.debug( - "{} streaming Windmill RPC errors for {}, last was: {} with status {}." - + " created {}ms ago, {}. This is normal with autoscaling.", - AbstractWindmillStream.this.getClass(), - errorCount.get(), - t.toString(), - statusError, - nowMillis - startTimeMs.get(), - responseDebug); - } - // If the stream was stopped due to a resource exhausted error then we are throttled. - if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { - startThrottleTimer(); - } - - try { - long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (IOException e) { - // Ignore. - } - } else { - errorCount.incrementAndGet(); - String error = - "Stream completed successfully but did not complete requested operations, " - + "recreating"; - LOG.warn(error); - lastError.set(error); - } - executor.execute(AbstractWindmillStream.this::startStream); - } - } - - @Override - public final synchronized void close() { - // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. - clientClosed.set(true); - requestObserver.onCompleted(); - streamClosed.set(true); - } - - @Override - public final boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { - return finishLatch.await(time, unit); - } - - @Override - public final Instant startTime() { - return new Instant(startTimeMs.get()); - } - } - - static class GetWorkTimingInfosTracker { - private static class SumAndMaxDurations { - private Duration sum; - private Duration max; - - public SumAndMaxDurations(Duration sum, Duration max) { - this.sum = sum; - this.max = max; - } - } - - private Instant workItemCreationEndTime = Instant.EPOCH; - private Instant workItemLastChunkReceivedByWorkerTime = Instant.EPOCH; - - private LatencyAttribution workItemCreationLatency = null; - private final Map aggregatedGetWorkStreamLatencies; - - private final MillisProvider clock; - - public GetWorkTimingInfosTracker(MillisProvider clock) { - this.aggregatedGetWorkStreamLatencies = new EnumMap<>(State.class); - this.clock = clock; - } - - public void addTimingInfo(Collection infos) { - // We want to record duration for each stage and also be reflective on total work item - // processing time. It can be tricky because timings of different - // StreamingGetWorkResponseChunks can be interleaved. Current strategy is to record the - // sum duration in each transmission stage across different chunks, then divide the total - // duration (start from the chunk creation end in the windmill worker to the end of last chunk - // reception by the user worker) proportionally according the sum duration values across the - // many stages, the final latency is also capped by the corresponding stage maximum latency - // seen across multiple chunks. This should allow us to identify the slow stage meanwhile - // avoid confusions for comparing the stage duration to the total processing elapsed wall - // time. - Map getWorkStreamTimings = new HashMap<>(); - for (GetWorkStreamTimingInfo info : infos) { - getWorkStreamTimings.putIfAbsent( - info.getEvent(), Instant.ofEpochMilli(info.getTimestampUsec() / 1000)); - } - - // Record the difference between starting to get work and the first chunk being sent as the - // work creation time. - Instant workItemCreationStart = getWorkStreamTimings.get(Event.GET_WORK_CREATION_START); - Instant workItemCreationEnd = getWorkStreamTimings.get(Event.GET_WORK_CREATION_END); - if (workItemCreationStart != null - && workItemCreationEnd != null - && workItemCreationLatency == null) { - workItemCreationLatency = - LatencyAttribution.newBuilder() - .setState(State.GET_WORK_IN_WINDMILL_WORKER) - .setTotalDurationMillis( - new Duration(workItemCreationStart, workItemCreationEnd).getMillis()) - .build(); - } - // Record the work item creation end time as the start of transmission stages. - if (workItemCreationEnd != null && workItemCreationEnd.isAfter(workItemCreationEndTime)) { - workItemCreationEndTime = workItemCreationEnd; - } - - // Record the latency of each chunk between send on worker and arrival on dispatcher. - Instant receivedByDispatcherTiming = - getWorkStreamTimings.get(Event.GET_WORK_RECEIVED_BY_DISPATCHER); - if (workItemCreationEnd != null && receivedByDispatcherTiming != null) { - Duration newDuration = new Duration(workItemCreationEnd, receivedByDispatcherTiming); - aggregatedGetWorkStreamLatencies.compute( - State.GET_WORK_IN_TRANSIT_TO_DISPATCHER, - (stateKey, duration) -> { - if (duration == null) { - return new SumAndMaxDurations(newDuration, newDuration); - } - duration.max = newDuration.isLongerThan(duration.max) ? newDuration : duration.max; - duration.sum = duration.sum.plus(newDuration); - return duration; - }); - } - - // Record the latency of each chunk between send on dispatcher and arrival on worker. - Instant forwardedByDispatcherTiming = - getWorkStreamTimings.get(Event.GET_WORK_FORWARDED_BY_DISPATCHER); - Instant now = Instant.ofEpochMilli(clock.getMillis()); - if (forwardedByDispatcherTiming != null) { - Duration newDuration = new Duration(forwardedByDispatcherTiming, now); - aggregatedGetWorkStreamLatencies.compute( - State.GET_WORK_IN_TRANSIT_TO_USER_WORKER, - (stateKey, duration) -> { - if (duration == null) { - return new SumAndMaxDurations(newDuration, newDuration); - } - duration.max = newDuration.isLongerThan(duration.max) ? newDuration : duration.max; - duration.sum = duration.sum.plus(newDuration); - return duration; - }); - } - workItemLastChunkReceivedByWorkerTime = now; - } - - List getLatencyAttributions() { - if (workItemCreationLatency == null && aggregatedGetWorkStreamLatencies.isEmpty()) { - return Collections.emptyList(); - } - List latencyAttributions = - new ArrayList<>(aggregatedGetWorkStreamLatencies.size() + 1); - if (workItemCreationLatency != null) { - latencyAttributions.add(workItemCreationLatency); - } - if (workItemCreationEndTime.isAfter(workItemLastChunkReceivedByWorkerTime)) { - LOG.warn( - "Work item creation time {} is after the work received time {}, " - + "one or more GetWorkStream timing infos are missing.", - workItemCreationEndTime, - workItemLastChunkReceivedByWorkerTime); - return latencyAttributions; - } - long totalTransmissionDurationElapsedTime = - new Duration(workItemCreationEndTime, workItemLastChunkReceivedByWorkerTime).getMillis(); - long totalSumDurationTimeMills = 0; - for (SumAndMaxDurations duration : aggregatedGetWorkStreamLatencies.values()) { - totalSumDurationTimeMills += duration.sum.getMillis(); - } - final long finalTotalSumDurationTimeMills = totalSumDurationTimeMills; - - aggregatedGetWorkStreamLatencies.forEach( - (state, duration) -> { - long scaledDuration = - (long) - (((double) duration.sum.getMillis() / finalTotalSumDurationTimeMills) - * totalTransmissionDurationElapsedTime); - // Cap final duration by the max state duration across different chunks. This ensures - // the sum of final durations does not exceed the total elapsed time and the duration - // for each stage does not exceed the stage maximum. - long durationMills = Math.min(duration.max.getMillis(), scaledDuration); - latencyAttributions.add( - LatencyAttribution.newBuilder() - .setState(state) - .setTotalDurationMillis(durationMills) - .build()); - }); - return latencyAttributions; - } - - public void reset() { - this.aggregatedGetWorkStreamLatencies.clear(); - this.workItemCreationEndTime = Instant.EPOCH; - this.workItemLastChunkReceivedByWorkerTime = Instant.EPOCH; - this.workItemCreationLatency = null; - } - } - - private class GrpcGetWorkStream - extends AbstractWindmillStream - implements GetWorkStream { - private final GetWorkRequest request; - private final WorkItemReceiver receiver; - private final Map buffers = new ConcurrentHashMap<>(); - private final AtomicLong inflightMessages = new AtomicLong(); - private final AtomicLong inflightBytes = new AtomicLong(); - - private GrpcGetWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { - super( - responseObserver -> - stub() - .withDeadlineAfter(streamDeadlineSeconds, TimeUnit.SECONDS) - .getWorkStream(responseObserver)); - this.request = request; - this.receiver = receiver; - startStream(); - } - - @Override - protected synchronized void onNewStream() { - buffers.clear(); - inflightMessages.set(request.getMaxItems()); - inflightBytes.set(request.getMaxBytes()); - send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); - } - - @Override - protected boolean hasPendingRequests() { - return false; - } - - @Override - public void appendSpecificHtml(PrintWriter writer) { - // Number of buffers is same as distict workers that sent work on this stream. - writer.format( - "GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed", - buffers.size(), inflightMessages.intValue(), inflightBytes.intValue()); - } - - @Override - public void sendHealthCheck() { - send( - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(0) - .setMaxBytes(0) - .build()) - .build()); - } - - @Override - protected void onResponse(StreamingGetWorkResponseChunk chunk) { - getWorkThrottleTimer.stop(); - long id = chunk.getStreamId(); - - WorkItemBuffer buffer = buffers.computeIfAbsent(id, (Long l) -> new WorkItemBuffer()); - buffer.append(chunk); - - if (chunk.getRemainingBytesForWorkItem() == 0) { - long size = buffer.bufferedSize(); - buffer.runAndReset(); - - // Record the fact that there are now fewer outstanding messages and bytes on the stream. - long numInflight = inflightMessages.decrementAndGet(); - long bytesInflight = inflightBytes.addAndGet(-size); - - // If the outstanding items or bytes limit has gotten too low, top both off with a - // GetWorkExtension. The goal is to keep the limits relatively close to their maximum - // values without sending too many extension requests. - if (numInflight < request.getMaxItems() / 2 || bytesInflight < request.getMaxBytes() / 2) { - long moreItems = request.getMaxItems() - numInflight; - long moreBytes = request.getMaxBytes() - bytesInflight; - inflightMessages.getAndAdd(moreItems); - inflightBytes.getAndAdd(moreBytes); - final StreamingGetWorkRequest extension = - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(moreItems) - .setMaxBytes(moreBytes)) - .build(); - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); - } - } - } - - @Override - protected void startThrottleTimer() { - getWorkThrottleTimer.start(); - } - - private class WorkItemBuffer { - - private String computation; - private Instant inputDataWatermark; - private Instant synchronizedProcessingTime; - private ByteString data = ByteString.EMPTY; - private long bufferedSize = 0; - - private GetWorkTimingInfosTracker workTimingInfosTracker = - new GetWorkTimingInfosTracker(System::currentTimeMillis); - - private void setMetadata(Windmill.ComputationWorkItemMetadata metadata) { - this.computation = metadata.getComputationId(); - this.inputDataWatermark = - WindmillTimeUtils.windmillToHarnessWatermark(metadata.getInputDataWatermark()); - this.synchronizedProcessingTime = - WindmillTimeUtils.windmillToHarnessWatermark( - metadata.getDependentRealtimeInputWatermark()); - } - - public void append(StreamingGetWorkResponseChunk chunk) { - if (chunk.hasComputationMetadata()) { - setMetadata(chunk.getComputationMetadata()); - } - - this.data = data.concat(chunk.getSerializedWorkItem()); - this.bufferedSize += chunk.getSerializedWorkItem().size(); - workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList()); - } - - public long bufferedSize() { - return bufferedSize; - } - - public void runAndReset() { - try { - Windmill.WorkItem workItem = Windmill.WorkItem.parseFrom(data.newInput()); - List getWorkStreamLatencies = - workTimingInfosTracker.getLatencyAttributions(); - receiver.receiveWork( - computation, - inputDataWatermark, - synchronizedProcessingTime, - workItem, - getWorkStreamLatencies); - } catch (IOException e) { - LOG.error("Failed to parse work item from stream: ", e); - } - workTimingInfosTracker.reset(); - data = ByteString.EMPTY; - bufferedSize = 0; - } - } - } - - private class GrpcGetDataStream - extends AbstractWindmillStream - implements GetDataStream { - private class QueuedRequest { - public QueuedRequest(String computation, KeyedGetDataRequest request) { - this.id = uniqueId(); - this.globalDataRequest = null; - this.dataRequest = - ComputationGetDataRequest.newBuilder() - .setComputationId(computation) - .addRequests(request) - .build(); - this.byteSize = this.dataRequest.getSerializedSize(); - } - - public QueuedRequest(GlobalDataRequest request) { - this.id = uniqueId(); - this.globalDataRequest = request; - this.dataRequest = null; - this.byteSize = this.globalDataRequest.getSerializedSize(); - } - - final long id; - final long byteSize; - final GlobalDataRequest globalDataRequest; - final ComputationGetDataRequest dataRequest; - AppendableInputStream responseStream = null; - } - - private class QueuedBatch { - public QueuedBatch() {} - - final List requests = new ArrayList<>(); - long byteSize = 0; - boolean finalized = false; - final CountDownLatch sent = new CountDownLatch(1); - }; - - private final Deque batches = new ConcurrentLinkedDeque<>(); - private final Map pending = new ConcurrentHashMap<>(); - - @Override - public void appendSpecificHtml(PrintWriter writer) { - writer.format( - "GetDataStream: %d queued batches, %d pending requests [", - batches.size(), pending.size()); - for (Map.Entry entry : pending.entrySet()) { - writer.format("Stream %d ", entry.getKey()); - if (entry.getValue().cancelled.get()) { - writer.append("cancelled "); - } - if (entry.getValue().complete.get()) { - writer.append("complete "); - } - int queueSize = entry.getValue().queue.size(); - if (queueSize > 0) { - writer.format("%d queued responses ", queueSize); - } - long blockedMs = entry.getValue().blockedStartMs.get(); - if (blockedMs > 0) { - writer.format("blocked for %dms", Instant.now().getMillis() - blockedMs); - } - } - writer.append("]"); - } - - GrpcGetDataStream() { - super( - responseObserver -> - stub() - .withDeadlineAfter(streamDeadlineSeconds, TimeUnit.SECONDS) - .getDataStream(responseObserver)); - startStream(); - } - - @Override - protected synchronized void onNewStream() { - send(StreamingGetDataRequest.newBuilder().setHeader(makeHeader()).build()); - - if (clientClosed.get()) { - // We rely on close only occurring after all methods on the stream have returned. - // Since the requestKeyedData and requestGlobalData methods are blocking this - // means there should be no pending requests. - Verify.verify(!hasPendingRequests()); - } else { - for (AppendableInputStream responseStream : pending.values()) { - responseStream.cancel(); - } - } - } - - @Override - protected boolean hasPendingRequests() { - return !pending.isEmpty() || !batches.isEmpty(); - } - - @Override - protected void onResponse(StreamingGetDataResponse chunk) { - Preconditions.checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); - Preconditions.checkArgument( - chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); - getDataThrottleTimer.stop(); - - for (int i = 0; i < chunk.getRequestIdCount(); ++i) { - AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - Verify.verify(responseStream != null, "No pending response stream"); - responseStream.append(chunk.getSerializedResponse(i).newInput()); - if (chunk.getRemainingBytesForResponse() == 0) { - responseStream.complete(); - } - } - } - - @Override - protected void startThrottleTimer() { - getDataThrottleTimer.start(); - } - - @Override - public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { - return issueRequest(new QueuedRequest(computation, request), KeyedGetDataResponse::parseFrom); - } - - @Override - public GlobalData requestGlobalData(GlobalDataRequest request) { - return issueRequest(new QueuedRequest(request), GlobalData::parseFrom); - } - - @Override - public void refreshActiveWork(Map> active) { - long builderBytes = 0; - StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); - for (Map.Entry> entry : active.entrySet()) { - for (KeyedGetDataRequest request : entry.getValue()) { - // Calculate the bytes with some overhead for proto encoding. - long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; - if (builderBytes > 0 - && (builderBytes + bytes > GET_DATA_STREAM_CHUNK_SIZE - || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { - send(builder.build()); - builderBytes = 0; - builder.clear(); - } - builderBytes += bytes; - builder.addStateRequest( - ComputationGetDataRequest.newBuilder() - .setComputationId(entry.getKey()) - .addRequests(request)); - } - } - if (builderBytes > 0) { - send(builder.build()); - } - } - - @Override - public void sendHealthCheck() { - if (hasPendingRequests()) { - send(StreamingGetDataRequest.newBuilder().build()); - } - } - - private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { - while (true) { - request.responseStream = new AppendableInputStream(); - try { - queueRequestAndWait(request); - return parseFn.parse(request.responseStream); - } catch (CancellationException e) { - // Retry issuing the request since the response stream was cancelled. - continue; - } catch (IOException e) { - LOG.error("Parsing GetData response failed: ", e); - continue; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } finally { - pending.remove(request.id); - } - } - } - - private void queueRequestAndWait(QueuedRequest request) throws InterruptedException { - QueuedBatch batch; - boolean responsibleForSend = false; - CountDownLatch waitForSendLatch = null; - synchronized (batches) { - batch = batches.isEmpty() ? null : batches.getLast(); - if (batch == null - || batch.finalized - || batch.requests.size() >= streamingRpcBatchLimit - || batch.byteSize + request.byteSize > GET_DATA_STREAM_CHUNK_SIZE) { - if (batch != null) { - waitForSendLatch = batch.sent; - } - batch = new QueuedBatch(); - batches.addLast(batch); - responsibleForSend = true; - } - batch.requests.add(request); - batch.byteSize += request.byteSize; - } - if (responsibleForSend) { - if (waitForSendLatch == null) { - // If there was not a previous batch wait a little while to improve - // batching. - Thread.sleep(1); - } else { - waitForSendLatch.await(); - } - // Finalize the batch so that no additional requests will be added. Leave the batch in the - // queue so that a subsequent batch will wait for it's completion. - synchronized (batches) { - Verify.verify(batch == batches.peekFirst()); - batch.finalized = true; - } - sendBatch(batch.requests); - synchronized (batches) { - Verify.verify(batch == batches.pollFirst()); - } - // Notify all waiters with requests in this batch as well as the sender - // of the next batch (if one exists). - batch.sent.countDown(); - } else { - // Wait for this batch to be sent before parsing the response. - batch.sent.await(); - } - } - - private void sendBatch(List requests) { - StreamingGetDataRequest batchedRequest = flushToBatch(requests); - synchronized (this) { - // Synchronization of pending inserts is necessary with send to ensure duplicates are not - // sent on stream reconnect. - for (QueuedRequest request : requests) { - Verify.verify(pending.put(request.id, request.responseStream) == null); - } - try { - send(batchedRequest); - } catch (IllegalStateException e) { - // The stream broke before this call went through; onNewStream will retry the fetch. - LOG.warn("GetData stream broke before call started.", e); - } - } - } - - private StreamingGetDataRequest flushToBatch(List requests) { - // Put all global data requests first because there is only a single repeated field for - // request ids and the initial ids correspond to global data requests if they are present. - requests.sort( - (QueuedRequest r1, QueuedRequest r2) -> { - boolean r1gd = r1.globalDataRequest != null; - boolean r2gd = r2.globalDataRequest != null; - return r1gd == r2gd ? 0 : (r1gd ? -1 : 1); - }); - StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); - for (QueuedRequest request : requests) { - builder.addRequestId(request.id); - if (request.globalDataRequest == null) { - builder.addStateRequest(request.dataRequest); - } else { - builder.addGlobalDataRequest(request.globalDataRequest); - } - } - return builder.build(); - } - } - - private class GrpcCommitWorkStream - extends AbstractWindmillStream - implements CommitWorkStream { - private class PendingRequest { - private final String computation; - private final WorkItemCommitRequest request; - private final Consumer onDone; - - PendingRequest( - String computation, WorkItemCommitRequest request, Consumer onDone) { - this.computation = computation; - this.request = request; - this.onDone = onDone; - } - - long getBytes() { - return (long) request.getSerializedSize() + computation.length(); - } - } - - private final Map pending = new ConcurrentHashMap<>(); - - private class Batcher { - long queuedBytes = 0; - final Map queue = new HashMap<>(); - - boolean canAccept(PendingRequest request) { - return queue.isEmpty() - || (queue.size() < streamingRpcBatchLimit - && (request.getBytes() + queuedBytes) < COMMIT_STREAM_CHUNK_SIZE); - } - - void add(long id, PendingRequest request) { - assert (canAccept(request)); - queuedBytes += request.getBytes(); - queue.put(id, request); - } - - void flush() { - flushInternal(queue); - queuedBytes = 0; - queue.clear(); - } - } - - private final Batcher batcher = new Batcher(); - - GrpcCommitWorkStream() { - super( - responseObserver -> - stub() - .withDeadlineAfter(streamDeadlineSeconds, TimeUnit.SECONDS) - .commitWorkStream(responseObserver)); - startStream(); - } - - @Override - public void appendSpecificHtml(PrintWriter writer) { - writer.format("CommitWorkStream: %d pending", pending.size()); - } - - @Override - protected synchronized void onNewStream() { - send(StreamingCommitWorkRequest.newBuilder().setHeader(makeHeader()).build()); - Batcher resendBatcher = new Batcher(); - for (Map.Entry entry : pending.entrySet()) { - if (!resendBatcher.canAccept(entry.getValue())) { - resendBatcher.flush(); - } - resendBatcher.add(entry.getKey(), entry.getValue()); - } - resendBatcher.flush(); - } - - @Override - protected boolean hasPendingRequests() { - return !pending.isEmpty(); - } - - @Override - public void sendHealthCheck() { - if (hasPendingRequests()) { - StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); - builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); - send(builder.build()); - } - } - - @Override - protected void onResponse(StreamingCommitResponse response) { - commitWorkThrottleTimer.stop(); - - RuntimeException finalException = null; - for (int i = 0; i < response.getRequestIdCount(); ++i) { - long requestId = response.getRequestId(i); - if (requestId == HEARTBEAT_REQUEST_ID) { - continue; - } - PendingRequest done = pending.remove(requestId); - if (done == null) { - LOG.error("Got unknown commit request ID: {}", requestId); - } else { - try { - done.onDone.accept( - (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK); - } catch (RuntimeException e) { - // Catch possible exceptions to ensure that an exception for one commit does not prevent - // other commits from being processed. - LOG.warn("Exception while processing commit response.", e); - finalException = e; - } - } - } - if (finalException != null) { - throw finalException; - } - } - - @Override - protected void startThrottleTimer() { - commitWorkThrottleTimer.start(); - } - - @Override - public boolean commitWorkItem( - String computation, WorkItemCommitRequest commitRequest, Consumer onDone) { - PendingRequest request = new PendingRequest(computation, commitRequest, onDone); - if (!batcher.canAccept(request)) { - return false; - } - batcher.add(uniqueId(), request); - return true; - } - - @Override - public void flush() { - batcher.flush(); - } - - private void flushInternal(Map requests) { - if (requests.isEmpty()) { - return; - } - if (requests.size() == 1) { - Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request.getSerializedSize() > COMMIT_STREAM_CHUNK_SIZE) { - issueMultiChunkRequest(elem.getKey(), elem.getValue()); - } else { - issueSingleRequest(elem.getKey(), elem.getValue()); - } - } else { - issueBatchedRequest(requests); - } - } - - private void issueSingleRequest(final long id, PendingRequest pendingRequest) { - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - requestBuilder - .addCommitChunkBuilder() - .setComputationId(pendingRequest.computation) - .setRequestId(id) - .setShardingKey(pendingRequest.request.getShardingKey()) - .setSerializedWorkItemCommit(pendingRequest.request.toByteString()); - StreamingCommitWorkRequest chunk = requestBuilder.build(); - synchronized (this) { - pending.put(id, pendingRequest); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } - } - } - - private void issueBatchedRequest(Map requests) { - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - String lastComputation = null; - for (Map.Entry entry : requests.entrySet()) { - PendingRequest request = entry.getValue(); - StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computation)) { - chunkBuilder.setComputationId(request.computation); - lastComputation = request.computation; - } - chunkBuilder.setRequestId(entry.getKey()); - chunkBuilder.setShardingKey(request.request.getShardingKey()); - chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString()); - } - StreamingCommitWorkRequest request = requestBuilder.build(); - synchronized (this) { - pending.putAll(requests); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } - } - } - - private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { - Preconditions.checkNotNull(pendingRequest.computation); - final ByteString serializedCommit = pendingRequest.request.toByteString(); - - synchronized (this) { - pending.put(id, pendingRequest); - for (int i = 0; i < serializedCommit.size(); i += COMMIT_STREAM_CHUNK_SIZE) { - int end = i + COMMIT_STREAM_CHUNK_SIZE; - ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - - StreamingCommitRequestChunk.Builder chunkBuilder = - StreamingCommitRequestChunk.newBuilder() - .setRequestId(id) - .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computation) - .setShardingKey(pendingRequest.request.getShardingKey()); - int remaining = serializedCommit.size() - end; - if (remaining > 0) { - chunkBuilder.setRemainingBytesForWorkItem(remaining); - } - - StreamingCommitWorkRequest requestChunk = - StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - break; - } - } - } - } - } - - @FunctionalInterface - private interface ParseFn { - ResponseT parse(InputStream input) throws IOException; - } - - /** An InputStream that can be dynamically extended with additional InputStreams. */ - @SuppressWarnings("JdkObsolete") - private static class AppendableInputStream extends InputStream { - private static final InputStream POISON_PILL = ByteString.EMPTY.newInput(); - private final AtomicBoolean cancelled = new AtomicBoolean(false); - private final AtomicBoolean complete = new AtomicBoolean(false); - private final AtomicLong blockedStartMs = new AtomicLong(); - private final BlockingDeque queue = new LinkedBlockingDeque<>(10); - private final InputStream stream = - new SequenceInputStream( - new Enumeration() { - // The first stream is eagerly read on SequenceInputStream creation. For this reason - // we use an empty element as the first input to avoid blocking from the queue when - // creating the AppendableInputStream. - InputStream current = ByteString.EMPTY.newInput(); - - @Override - public boolean hasMoreElements() { - if (current != null) { - return true; - } - - try { - blockedStartMs.set(Instant.now().getMillis()); - current = queue.poll(180, TimeUnit.SECONDS); - if (current != null && current != POISON_PILL) { - return true; - } - if (cancelled.get()) { - throw new CancellationException(); - } - if (complete.get()) { - return false; - } - throw new IllegalStateException( - "Got poison pill or timeout but stream is not done."); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new CancellationException(); - } - } - - @Override - public InputStream nextElement() { - if (!hasMoreElements()) { - throw new NoSuchElementException(); - } - blockedStartMs.set(0); - InputStream next = current; - current = null; - return next; - } - }); - - /** Appends a new InputStream to the tail of this stream. */ - public synchronized void append(InputStream chunk) { - try { - queue.put(chunk); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.debug("interrupted append"); - } - } - - /** Cancels the stream. Future calls to InputStream methods will throw CancellationException. */ - public synchronized void cancel() { - cancelled.set(true); - try { - // Put the poison pill at the head of the queue to cancel as quickly as possible. - queue.clear(); - queue.putFirst(POISON_PILL); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.debug("interrupted cancel"); - } - } - - /** Signals that no new InputStreams will be added to this stream. */ - public synchronized void complete() { - complete.set(true); - try { - queue.put(POISON_PILL); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.debug("interrupted complete"); - } - } - - @Override - public int read() throws IOException { - if (cancelled.get()) { - throw new CancellationException(); - } - return stream.read(); - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - if (cancelled.get()) { - throw new CancellationException(); - } - return stream.read(b, off, len); - } - - @Override - public int available() throws IOException { - if (cancelled.get()) { - throw new CancellationException(); - } - return stream.available(); - } - - @Override - public void close() throws IOException { - stream.close(); - } - } - - /** - * A stopwatch used to track the amount of time spent throttled due to Resource Exhausted errors. - * Throttle time is cumulative for all three rpcs types but not for all streams. So if GetWork and - * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both - * blocked for x totalTime will be x. All methods are thread safe. - */ - private static class ThrottleTimer { - - // This is -1 if not currently being throttled or the time in - // milliseconds when throttling for this type started. - private long startTime = -1; - // This is the collected total throttle times since the last poll. Throttle times are - // reported as a delta so this is cleared whenever it gets reported. - private long totalTime = 0; - - /** - * Starts the timer if it has not been started and does nothing if it has already been started. - */ - public synchronized void start() { - if (!throttled()) { // This timer is not started yet so start it now. - startTime = Instant.now().getMillis(); - } - } - - /** Stops the timer if it has been started and does nothing if it has not been started. */ - public synchronized void stop() { - if (throttled()) { // This timer has been started already so stop it now. - totalTime += Instant.now().getMillis() - startTime; - startTime = -1; - } - } - - /** Returns if the specified type is currently being throttled. */ - public synchronized boolean throttled() { - return startTime != -1; - } - - /** Returns the combined total of all throttle times and resets those times to 0. */ - public synchronized long getAndResetThrottleTime() { - if (throttled()) { - stop(); - start(); - } - long toReturn = totalTime; - totalTime = 0; - return toReturn; - } - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java index 8907d86d92cd..9c791d349414 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java @@ -31,7 +31,7 @@ public class WindmillServerBase extends WindmillServerStub { /** Pointer to the underlying native windmill client object. */ - private long nativePointer; + private final long nativePointer; WindmillServerBase(String host) { this.nativePointer = create(host); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index be6c365ee33e..cd5a7953ceba 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -71,18 +71,6 @@ public abstract class WindmillServerStub implements StatusDataProvider { /** Report execution information to the server. */ public abstract Windmill.ReportStatsResponse reportStats(Windmill.ReportStatsRequest request); - /** Functional interface for receiving WorkItems. */ - @FunctionalInterface - public interface WorkItemReceiver { - - void receiveWork( - String computation, - @Nullable Instant inputDataWatermark, - Instant synchronizedProcessingTime, - Windmill.WorkItem workItem, - Collection getWorkStreamLatencies); - } - /** * Gets work to process, returned as a stream. * @@ -104,6 +92,18 @@ public abstract GetWorkStream getWorkStream( @Override public void appendSummaryHtml(PrintWriter writer) {} + /** Functional interface for receiving WorkItems. */ + @FunctionalInterface + public interface WorkItemReceiver { + + void receiveWork( + String computation, + @Nullable Instant inputDataWatermark, + @Nullable Instant synchronizedProcessingTime, + Windmill.WorkItem workItem, + Collection getWorkStreamLatencies); + } + /** Superclass for streams returned by streaming Windmill methods. */ @ThreadSafe public interface WindmillStream { @@ -162,13 +162,8 @@ boolean commitWorkItem( public static class StreamPool { private final Duration streamTimeout; - - private final class StreamData { - final S stream = supplier.get(); - int holds = 1; - }; - private final List streams; + private final Supplier supplier; private final HashMap holds; @@ -222,6 +217,11 @@ public void releaseStream(S stream) { stream.close(); } } + + private final class StreamData { + final S stream = supplier.get(); + int holds = 1; + } } /** Generic Exception type for implementors to use to represent errors while making RPCs. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/AppendableInputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/AppendableInputStream.java new file mode 100644 index 000000000000..dbd3613ee4c2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/AppendableInputStream.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.util.Enumeration; +import java.util.NoSuchElementException; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.CancellationException; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.Nullable; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** An InputStream that can be dynamically extended with additional InputStreams. */ +@SuppressWarnings("JdkObsolete") +final class AppendableInputStream extends InputStream { + private static final Logger LOG = LoggerFactory.getLogger(AppendableInputStream.class); + private static final int QUEUE_MAX_CAPACITY = 10; + private static final InputStream POISON_PILL = ByteString.EMPTY.newInput(); + + private final AtomicBoolean cancelled; + private final AtomicBoolean complete; + private final AtomicLong blockedStartMs; + private final BlockingDeque queue; + private final InputStream stream; + + AppendableInputStream() { + this.cancelled = new AtomicBoolean(false); + this.complete = new AtomicBoolean(false); + this.blockedStartMs = new AtomicLong(); + this.queue = new LinkedBlockingDeque<>(QUEUE_MAX_CAPACITY); + this.stream = new SequenceInputStream(new InputStreamEnumeration()); + } + + long getBlockedStartMs() { + return blockedStartMs.get(); + } + + boolean isComplete() { + return complete.get(); + } + + boolean isCancelled() { + return cancelled.get(); + } + + int size() { + return queue.size(); + } + + /** Appends a new InputStream to the tail of this stream. */ + synchronized void append(InputStream chunk) { + try { + queue.put(chunk); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.debug("interrupted append"); + } + } + + /** Cancels the stream. Future calls to InputStream methods will throw CancellationException. */ + synchronized void cancel() { + cancelled.set(true); + try { + // Put the poison pill at the head of the queue to cancel as quickly as possible. + queue.clear(); + queue.putFirst(POISON_PILL); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.debug("interrupted cancel"); + } + } + + /** Signals that no new InputStreams will be added to this stream. */ + synchronized void complete() { + complete.set(true); + try { + queue.put(POISON_PILL); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.debug("interrupted complete"); + } + } + + @Override + public int read() throws IOException { + if (cancelled.get()) { + throw new CancellationException(); + } + return stream.read(); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (cancelled.get()) { + throw new CancellationException(); + } + return stream.read(b, off, len); + } + + @Override + public int available() throws IOException { + if (cancelled.get()) { + throw new CancellationException(); + } + return stream.available(); + } + + @Override + public void close() throws IOException { + stream.close(); + } + + @SuppressWarnings("NullableProblems") + private class InputStreamEnumeration implements Enumeration { + // The first stream is eagerly read on SequenceInputStream creation. For this reason + // we use an empty element as the first input to avoid blocking from the queue when + // creating the AppendableInputStream. + private @Nullable InputStream current = POISON_PILL; + + @Override + public boolean hasMoreElements() { + if (current != null) { + return true; + } + + try { + blockedStartMs.set(Instant.now().getMillis()); + current = queue.poll(180, TimeUnit.SECONDS); + if (current != null && current != POISON_PILL) { + return true; + } + if (cancelled.get()) { + throw new CancellationException(); + } + if (complete.get()) { + return false; + } + throw new IllegalStateException("Got poison pill or timeout but stream is not done."); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CancellationException(); + } + } + + @SuppressWarnings("return") + @Override + public InputStream nextElement() { + if (!hasMoreElements()) { + throw new NoSuchElementException(); + } + blockedStartMs.set(0); + InputStream next = current; + current = null; + return next; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GetWorkTimingInfosTracker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GetWorkTimingInfosTracker.java new file mode 100644 index 000000000000..e6710993af9b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GetWorkTimingInfosTracker.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo.Event; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State; +import org.joda.time.DateTimeUtils.MillisProvider; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class GetWorkTimingInfosTracker { + + private static final Logger LOG = LoggerFactory.getLogger(GetWorkTimingInfosTracker.class); + + private final Map aggregatedGetWorkStreamLatencies; + private final MillisProvider clock; + private Instant workItemCreationEndTime; + private Instant workItemLastChunkReceivedByWorkerTime; + private @Nullable LatencyAttribution workItemCreationLatency; + + GetWorkTimingInfosTracker(MillisProvider clock) { + this.aggregatedGetWorkStreamLatencies = new EnumMap<>(State.class); + this.clock = clock; + this.workItemCreationEndTime = Instant.EPOCH; + workItemLastChunkReceivedByWorkerTime = Instant.EPOCH; + workItemCreationLatency = null; + } + + public void addTimingInfo(Collection infos) { + // We want to record duration for each stage and also be reflective on total work item + // processing time. It can be tricky because timings of different + // StreamingGetWorkResponseChunks can be interleaved. Current strategy is to record the + // sum duration in each transmission stage across different chunks, then divide the total + // duration (start from the chunk creation end in the windmill worker to the end of last chunk + // reception by the user worker) proportionally according the sum duration values across the + // many stages, the final latency is also capped by the corresponding stage maximum latency + // seen across multiple chunks. This should allow us to identify the slow stage meanwhile + // avoid confusions for comparing the stage duration to the total processing elapsed wall + // time. + Map getWorkStreamTimings = new HashMap<>(); + for (GetWorkStreamTimingInfo info : infos) { + getWorkStreamTimings.putIfAbsent( + info.getEvent(), Instant.ofEpochMilli(info.getTimestampUsec() / 1000)); + } + + // Record the difference between starting to get work and the first chunk being sent as the + // work creation time. + Instant workItemCreationStart = getWorkStreamTimings.get(Event.GET_WORK_CREATION_START); + Instant workItemCreationEnd = getWorkStreamTimings.get(Event.GET_WORK_CREATION_END); + if (workItemCreationStart != null + && workItemCreationEnd != null + && workItemCreationLatency == null) { + workItemCreationLatency = + LatencyAttribution.newBuilder() + .setState(State.GET_WORK_IN_WINDMILL_WORKER) + .setTotalDurationMillis( + new Duration(workItemCreationStart, workItemCreationEnd).getMillis()) + .build(); + } + // Record the work item creation end time as the start of transmission stages. + if (workItemCreationEnd != null && workItemCreationEnd.isAfter(workItemCreationEndTime)) { + workItemCreationEndTime = workItemCreationEnd; + } + + // Record the latency of each chunk between send on worker and arrival on dispatcher. + Instant receivedByDispatcherTiming = + getWorkStreamTimings.get(Event.GET_WORK_RECEIVED_BY_DISPATCHER); + if (workItemCreationEnd != null && receivedByDispatcherTiming != null) { + Duration newDuration = new Duration(workItemCreationEnd, receivedByDispatcherTiming); + aggregatedGetWorkStreamLatencies.compute( + State.GET_WORK_IN_TRANSIT_TO_DISPATCHER, + (stateKey, duration) -> { + if (duration == null) { + return new SumAndMaxDurations(newDuration, newDuration); + } + duration.max = newDuration.isLongerThan(duration.max) ? newDuration : duration.max; + duration.sum = duration.sum.plus(newDuration); + return duration; + }); + } + + // Record the latency of each chunk between send on dispatcher and arrival on worker. + Instant forwardedByDispatcherTiming = + getWorkStreamTimings.get(Event.GET_WORK_FORWARDED_BY_DISPATCHER); + Instant now = Instant.ofEpochMilli(clock.getMillis()); + if (forwardedByDispatcherTiming != null) { + Duration newDuration = new Duration(forwardedByDispatcherTiming, now); + aggregatedGetWorkStreamLatencies.compute( + State.GET_WORK_IN_TRANSIT_TO_USER_WORKER, + (stateKey, duration) -> { + if (duration == null) { + return new SumAndMaxDurations(newDuration, newDuration); + } + duration.max = newDuration.isLongerThan(duration.max) ? newDuration : duration.max; + duration.sum = duration.sum.plus(newDuration); + return duration; + }); + } + workItemLastChunkReceivedByWorkerTime = now; + } + + List getLatencyAttributions() { + if (workItemCreationLatency == null && aggregatedGetWorkStreamLatencies.isEmpty()) { + return Collections.emptyList(); + } + List latencyAttributions = + new ArrayList<>(aggregatedGetWorkStreamLatencies.size() + 1); + if (workItemCreationLatency != null) { + latencyAttributions.add(workItemCreationLatency); + } + if (workItemCreationEndTime.isAfter(workItemLastChunkReceivedByWorkerTime)) { + LOG.warn( + "Work item creation time {} is after the work received time {}, " + + "one or more GetWorkStream timing infos are missing.", + workItemCreationEndTime, + workItemLastChunkReceivedByWorkerTime); + return latencyAttributions; + } + long totalTransmissionDurationElapsedTime = + new Duration(workItemCreationEndTime, workItemLastChunkReceivedByWorkerTime).getMillis(); + long totalSumDurationTimeMills = 0; + for (SumAndMaxDurations duration : aggregatedGetWorkStreamLatencies.values()) { + totalSumDurationTimeMills += duration.sum.getMillis(); + } + final long finalTotalSumDurationTimeMills = totalSumDurationTimeMills; + + aggregatedGetWorkStreamLatencies.forEach( + (state, duration) -> { + long scaledDuration = + (long) + (((double) duration.sum.getMillis() / finalTotalSumDurationTimeMills) + * totalTransmissionDurationElapsedTime); + // Cap final duration by the max state duration across different chunks. This ensures + // the sum of final durations does not exceed the total elapsed time and the duration + // for each stage does not exceed the stage maximum. + long durationMills = Math.min(duration.max.getMillis(), scaledDuration); + latencyAttributions.add( + LatencyAttribution.newBuilder() + .setState(state) + .setTotalDurationMillis(durationMills) + .build()); + }); + return latencyAttributions; + } + + public void reset() { + this.aggregatedGetWorkStreamLatencies.clear(); + this.workItemCreationEndTime = Instant.EPOCH; + this.workItemLastChunkReceivedByWorkerTime = Instant.EPOCH; + this.workItemCreationLatency = null; + } + + private static class SumAndMaxDurations { + + private Duration sum; + private Duration max; + + public SumAndMaxDurations(Duration sum, Duration max) { + this.sum = sum; + this.max = max; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java new file mode 100644 index 000000000000..30f00a32ff1a --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import java.io.PrintWriter; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class GrpcCommitWorkStream + extends AbstractWindmillStream + implements CommitWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStream.class); + + private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE; + + private final Map pending; + private final Batcher batcher; + private final AtomicLong idGenerator; + private final JobHeader jobHeader; + private final ThrottleTimer commitWorkThrottleTimer; + private final int streamingRpcBatchLimit; + + private GrpcCommitWorkStream( + CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer commitWorkThrottleTimer, + JobHeader jobHeader, + AtomicLong idGenerator, + int streamingRpcBatchLimit) { + super( + responseObserver -> + stub.withDeadlineAfter( + AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) + .commitWorkStream(responseObserver), + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures); + pending = new ConcurrentHashMap<>(); + batcher = new Batcher(); + this.idGenerator = idGenerator; + this.jobHeader = jobHeader; + this.commitWorkThrottleTimer = commitWorkThrottleTimer; + this.streamingRpcBatchLimit = streamingRpcBatchLimit; + } + + static GrpcCommitWorkStream create( + CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer commitWorkThrottleTimer, + JobHeader jobHeader, + AtomicLong idGenerator, + int streamingRpcBatchLimit) { + GrpcCommitWorkStream commitWorkStream = + new GrpcCommitWorkStream( + stub, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit); + commitWorkStream.startStream(); + return commitWorkStream; + } + + @Override + public void appendSpecificHtml(PrintWriter writer) { + writer.format("CommitWorkStream: %d pending", pending.size()); + } + + @Override + protected synchronized void onNewStream() { + send(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + Batcher resendBatcher = new Batcher(); + for (Map.Entry entry : pending.entrySet()) { + if (!resendBatcher.canAccept(entry.getValue())) { + resendBatcher.flush(); + } + resendBatcher.add(entry.getKey(), entry.getValue()); + } + resendBatcher.flush(); + } + + @Override + protected boolean hasPendingRequests() { + return !pending.isEmpty(); + } + + @Override + public void sendHealthCheck() { + if (hasPendingRequests()) { + StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); + builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); + send(builder.build()); + } + } + + @Override + protected void onResponse(StreamingCommitResponse response) { + commitWorkThrottleTimer.stop(); + + RuntimeException finalException = null; + for (int i = 0; i < response.getRequestIdCount(); ++i) { + long requestId = response.getRequestId(i); + if (requestId == HEARTBEAT_REQUEST_ID) { + continue; + } + PendingRequest done = pending.remove(requestId); + if (done == null) { + LOG.error("Got unknown commit request ID: {}", requestId); + } else { + try { + done.onDone.accept( + (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK); + } catch (RuntimeException e) { + // Catch possible exceptions to ensure that an exception for one commit does not prevent + // other commits from being processed. + LOG.warn("Exception while processing commit response.", e); + finalException = e; + } + } + } + if (finalException != null) { + throw finalException; + } + } + + @Override + protected void startThrottleTimer() { + commitWorkThrottleTimer.start(); + } + + @Override + public boolean commitWorkItem( + String computation, WorkItemCommitRequest commitRequest, Consumer onDone) { + PendingRequest request = new PendingRequest(computation, commitRequest, onDone); + if (!batcher.canAccept(request)) { + return false; + } + batcher.add(idGenerator.incrementAndGet(), request); + return true; + } + + @Override + public void flush() { + batcher.flush(); + } + + private void flushInternal(Map requests) { + if (requests.isEmpty()) { + return; + } + if (requests.size() == 1) { + Map.Entry elem = requests.entrySet().iterator().next(); + if (elem.getValue().request.getSerializedSize() + > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + issueMultiChunkRequest(elem.getKey(), elem.getValue()); + } else { + issueSingleRequest(elem.getKey(), elem.getValue()); + } + } else { + issueBatchedRequest(requests); + } + } + + private void issueSingleRequest(final long id, PendingRequest pendingRequest) { + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + requestBuilder + .addCommitChunkBuilder() + .setComputationId(pendingRequest.computation) + .setRequestId(id) + .setShardingKey(pendingRequest.request.getShardingKey()) + .setSerializedWorkItemCommit(pendingRequest.request.toByteString()); + StreamingCommitWorkRequest chunk = requestBuilder.build(); + synchronized (this) { + pending.put(id, pendingRequest); + try { + send(chunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } + } + } + + private void issueBatchedRequest(Map requests) { + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + String lastComputation = null; + for (Map.Entry entry : requests.entrySet()) { + PendingRequest request = entry.getValue(); + StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); + if (lastComputation == null || !lastComputation.equals(request.computation)) { + chunkBuilder.setComputationId(request.computation); + lastComputation = request.computation; + } + chunkBuilder.setRequestId(entry.getKey()); + chunkBuilder.setShardingKey(request.request.getShardingKey()); + chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString()); + } + StreamingCommitWorkRequest request = requestBuilder.build(); + synchronized (this) { + pending.putAll(requests); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } + } + } + + private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { + Preconditions.checkNotNull(pendingRequest.computation); + final ByteString serializedCommit = pendingRequest.request.toByteString(); + + synchronized (this) { + pending.put(id, pendingRequest); + for (int i = 0; + i < serializedCommit.size(); + i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; + ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); + + StreamingCommitRequestChunk.Builder chunkBuilder = + StreamingCommitRequestChunk.newBuilder() + .setRequestId(id) + .setSerializedWorkItemCommit(chunk) + .setComputationId(pendingRequest.computation) + .setShardingKey(pendingRequest.request.getShardingKey()); + int remaining = serializedCommit.size() - end; + if (remaining > 0) { + chunkBuilder.setRemainingBytesForWorkItem(remaining); + } + + StreamingCommitWorkRequest requestChunk = + StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); + try { + send(requestChunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + break; + } + } + } + } + + private static class PendingRequest { + + private final String computation; + private final WorkItemCommitRequest request; + private final Consumer onDone; + + PendingRequest( + String computation, WorkItemCommitRequest request, Consumer onDone) { + this.computation = computation; + this.request = request; + this.onDone = onDone; + } + + long getBytes() { + return (long) request.getSerializedSize() + computation.length(); + } + } + + private class Batcher { + + final Map queue = new HashMap<>(); + long queuedBytes = 0; + + boolean canAccept(PendingRequest request) { + return queue.isEmpty() + || (queue.size() < streamingRpcBatchLimit + && (request.getBytes() + queuedBytes) < AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE); + } + + void add(long id, PendingRequest request) { + assert (canAccept(request)); + queuedBytes += request.getBytes(); + queue.put(id, request); + } + + void flush() { + flushInternal(queue); + queuedBytes = 0; + queue.clear(); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java new file mode 100644 index 000000000000..f8893f3bfcba --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedBatch; +import org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedRequest; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Verify; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class GrpcGetDataStream + extends AbstractWindmillStream + implements GetDataStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); + + private final Deque batches; + private final Map pending; + private final AtomicLong idGenerator; + private final ThrottleTimer getDataThrottleTimer; + private final JobHeader jobHeader; + private final int streamingRpcBatchLimit; + + private GrpcGetDataStream( + CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer getDataThrottleTimer, + JobHeader jobHeader, + AtomicLong idGenerator, + int streamingRpcBatchLimit) { + super( + responseObserver -> + stub.withDeadlineAfter( + AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) + .getDataStream(responseObserver), + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures); + this.idGenerator = idGenerator; + this.getDataThrottleTimer = getDataThrottleTimer; + this.jobHeader = jobHeader; + this.streamingRpcBatchLimit = streamingRpcBatchLimit; + this.batches = new ConcurrentLinkedDeque<>(); + this.pending = new ConcurrentHashMap<>(); + } + + static GrpcGetDataStream create( + CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer getDataThrottleTimer, + JobHeader jobHeader, + AtomicLong idGenerator, + int streamingRpcBatchLimit) { + GrpcGetDataStream getDataStream = + new GrpcGetDataStream( + stub, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit); + getDataStream.startStream(); + return getDataStream; + } + + @Override + protected synchronized void onNewStream() { + send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + if (clientClosed.get()) { + // We rely on close only occurring after all methods on the stream have returned. + // Since the requestKeyedData and requestGlobalData methods are blocking this + // means there should be no pending requests. + Verify.verify(!hasPendingRequests()); + } else { + for (AppendableInputStream responseStream : pending.values()) { + responseStream.cancel(); + } + } + } + + @Override + protected boolean hasPendingRequests() { + return !pending.isEmpty() || !batches.isEmpty(); + } + + @Override + @SuppressWarnings("dereference.of.nullable") + protected void onResponse(StreamingGetDataResponse chunk) { + Preconditions.checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); + Preconditions.checkArgument( + chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); + getDataThrottleTimer.stop(); + + for (int i = 0; i < chunk.getRequestIdCount(); ++i) { + AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); + Verify.verify(responseStream != null, "No pending response stream"); + responseStream.append(chunk.getSerializedResponse(i).newInput()); + if (chunk.getRemainingBytesForResponse() == 0) { + responseStream.complete(); + } + } + } + + @Override + protected void startThrottleTimer() { + getDataThrottleTimer.start(); + } + + private long uniqueId() { + return idGenerator.incrementAndGet(); + } + + @Override + public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { + return issueRequest( + QueuedRequest.forComputation(uniqueId(), computation, request), + KeyedGetDataResponse::parseFrom); + } + + @Override + public GlobalData requestGlobalData(GlobalDataRequest request) { + return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); + } + + @Override + public void refreshActiveWork(Map> active) { + long builderBytes = 0; + StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); + for (Map.Entry> entry : active.entrySet()) { + for (KeyedGetDataRequest request : entry.getValue()) { + // Calculate the bytes with some overhead for proto encoding. + long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; + if (builderBytes > 0 + && (builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE + || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { + send(builder.build()); + builderBytes = 0; + builder.clear(); + } + builderBytes += bytes; + builder.addStateRequest( + ComputationGetDataRequest.newBuilder() + .setComputationId(entry.getKey()) + .addRequests(request)); + } + } + if (builderBytes > 0) { + send(builder.build()); + } + } + + @Override + public void sendHealthCheck() { + if (hasPendingRequests()) { + send(StreamingGetDataRequest.newBuilder().build()); + } + } + + @Override + public void appendSpecificHtml(PrintWriter writer) { + writer.format( + "GetDataStream: %d queued batches, %d pending requests [", batches.size(), pending.size()); + for (Map.Entry entry : pending.entrySet()) { + writer.format("Stream %d ", entry.getKey()); + if (entry.getValue().isCancelled()) { + writer.append("cancelled "); + } + if (entry.getValue().isComplete()) { + writer.append("complete "); + } + int queueSize = entry.getValue().size(); + if (queueSize > 0) { + writer.format("%d queued responses ", queueSize); + } + long blockedMs = entry.getValue().getBlockedStartMs(); + if (blockedMs > 0) { + writer.format("blocked for %dms", Instant.now().getMillis() - blockedMs); + } + } + writer.append("]"); + } + + private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { + while (true) { + request.resetResponseStream(); + try { + queueRequestAndWait(request); + return parseFn.parse(request.getResponseStream()); + } catch (CancellationException e) { + // Retry issuing the request since the response stream was cancelled. + continue; + } catch (IOException e) { + LOG.error("Parsing GetData response failed: ", e); + continue; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } finally { + pending.remove(request.id()); + } + } + } + + private void queueRequestAndWait(QueuedRequest request) throws InterruptedException { + QueuedBatch batch; + boolean responsibleForSend = false; + CountDownLatch waitForSendLatch = null; + synchronized (batches) { + batch = batches.isEmpty() ? null : batches.getLast(); + if (batch == null + || batch.isFinalized() + || batch.requests().size() >= streamingRpcBatchLimit + || batch.byteSize() + request.byteSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + if (batch != null) { + waitForSendLatch = batch.getLatch(); + } + batch = new QueuedBatch(); + batches.addLast(batch); + responsibleForSend = true; + } + batch.addRequest(request); + } + if (responsibleForSend) { + if (waitForSendLatch == null) { + // If there was not a previous batch wait a little while to improve + // batching. + Thread.sleep(1); + } else { + waitForSendLatch.await(); + } + // Finalize the batch so that no additional requests will be added. Leave the batch in the + // queue so that a subsequent batch will wait for it's completion. + synchronized (batches) { + Verify.verify(batch == batches.peekFirst()); + batch.markFinalized(); + } + sendBatch(batch.requests()); + synchronized (batches) { + Verify.verify(batch == batches.pollFirst()); + } + // Notify all waiters with requests in this batch as well as the sender + // of the next batch (if one exists). + batch.countDown(); + } else { + // Wait for this batch to be sent before parsing the response. + batch.await(); + } + } + + @SuppressWarnings("NullableProblems") + private void sendBatch(List requests) { + StreamingGetDataRequest batchedRequest = flushToBatch(requests); + synchronized (this) { + // Synchronization of pending inserts is necessary with send to ensure duplicates are not + // sent on stream reconnect. + for (QueuedRequest request : requests) { + // Map#put returns null if there was no previous mapping for the key, meaning we have not + // seen it before. + Verify.verify(pending.put(request.id(), request.getResponseStream()) == null); + } + try { + send(batchedRequest); + } catch (IllegalStateException e) { + // The stream broke before this call went through; onNewStream will retry the fetch. + LOG.warn("GetData stream broke before call started.", e); + } + } + } + + @SuppressWarnings("argument") + private StreamingGetDataRequest flushToBatch(List requests) { + // Put all global data requests first because there is only a single repeated field for + // request ids and the initial ids correspond to global data requests if they are present. + requests.sort(QueuedRequest.globalRequestsFirst()); + StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); + for (QueuedRequest request : requests) { + request.addToStreamingGetDataRequest(builder); + } + return builder.build(); + } + + @FunctionalInterface + private interface ParseFn { + ResponseT parse(InputStream input) throws IOException; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStreamRequests.java new file mode 100644 index 000000000000..7da7b13958b9 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStreamRequests.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import com.google.auto.value.AutoOneOf; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; + +/** Utility data classes for {@link GrpcGetDataStream}. */ +final class GrpcGetDataStreamRequests { + private GrpcGetDataStreamRequests() {} + + static class QueuedRequest { + private final long id; + private final ComputationOrGlobalDataRequest dataRequest; + private AppendableInputStream responseStream; + + private QueuedRequest(long id, ComputationOrGlobalDataRequest dataRequest) { + this.id = id; + this.dataRequest = dataRequest; + responseStream = new AppendableInputStream(); + } + + static QueuedRequest forComputation( + long id, String computation, KeyedGetDataRequest keyedGetDataRequest) { + ComputationGetDataRequest computationGetDataRequest = + ComputationGetDataRequest.newBuilder() + .setComputationId(computation) + .addRequests(keyedGetDataRequest) + .build(); + return new QueuedRequest( + id, ComputationOrGlobalDataRequest.computation(computationGetDataRequest)); + } + + static QueuedRequest global(long id, GlobalDataRequest globalDataRequest) { + return new QueuedRequest(id, ComputationOrGlobalDataRequest.global(globalDataRequest)); + } + + static Comparator globalRequestsFirst() { + return (QueuedRequest r1, QueuedRequest r2) -> { + boolean r1gd = r1.dataRequest.isGlobal(); + boolean r2gd = r2.dataRequest.isGlobal(); + return r1gd == r2gd ? 0 : (r1gd ? -1 : 1); + }; + } + + long id() { + return id; + } + + long byteSize() { + return dataRequest.serializedSize(); + } + + AppendableInputStream getResponseStream() { + return responseStream; + } + + void resetResponseStream() { + this.responseStream = new AppendableInputStream(); + } + + void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder builder) { + builder.addRequestId(id); + if (dataRequest.isForComputation()) { + builder.addStateRequest(dataRequest.computation()); + } else { + builder.addGlobalDataRequest(dataRequest.global()); + } + } + } + + static class QueuedBatch { + private final List requests = new ArrayList<>(); + private final CountDownLatch sent = new CountDownLatch(1); + private long byteSize = 0; + private boolean finalized = false; + + CountDownLatch getLatch() { + return sent; + } + + List requests() { + return requests; + } + + long byteSize() { + return byteSize; + } + + boolean isFinalized() { + return finalized; + } + + void markFinalized() { + finalized = true; + } + + void addRequest(QueuedRequest request) { + requests.add(request); + byteSize += request.byteSize(); + } + + void countDown() { + sent.countDown(); + } + + void await() throws InterruptedException { + sent.await(); + } + } + + @AutoOneOf(ComputationOrGlobalDataRequest.Kind.class) + abstract static class ComputationOrGlobalDataRequest { + static ComputationOrGlobalDataRequest computation( + ComputationGetDataRequest computationGetDataRequest) { + return AutoOneOf_GrpcGetDataStreamRequests_ComputationOrGlobalDataRequest.computation( + computationGetDataRequest); + } + + static ComputationOrGlobalDataRequest global(GlobalDataRequest globalDataRequest) { + return AutoOneOf_GrpcGetDataStreamRequests_ComputationOrGlobalDataRequest.global( + globalDataRequest); + } + + abstract Kind getKind(); + + abstract ComputationGetDataRequest computation(); + + abstract GlobalDataRequest global(); + + boolean isGlobal() { + return getKind() == Kind.GLOBAL; + } + + boolean isForComputation() { + return getKind() == Kind.COMPUTATION; + } + + long serializedSize() { + switch (getKind()) { + case GLOBAL: + return global().getSerializedSize(); + case COMPUTATION: + return computation().getSerializedSize(); + // this will never happen since the switch is exhaustive. + default: + throw new UnsupportedOperationException("unknown dataRequest type."); + } + } + + enum Kind { + COMPUTATION, + GLOBAL + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java new file mode 100644 index 000000000000..d0edcc458289 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WorkItemReceiver; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class GrpcGetWorkStream + extends AbstractWindmillStream + implements GetWorkStream { + + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkStream.class); + + private final GetWorkRequest request; + private final WorkItemReceiver receiver; + private final ThrottleTimer getWorkThrottleTimer; + private final Map buffers; + private final AtomicLong inflightMessages; + private final AtomicLong inflightBytes; + + private GrpcGetWorkStream( + CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + GetWorkRequest request, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer getWorkThrottleTimer, + WorkItemReceiver receiver) { + super( + responseObserver -> + stub.withDeadlineAfter( + AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) + .getWorkStream(responseObserver), + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures); + this.request = request; + this.getWorkThrottleTimer = getWorkThrottleTimer; + this.receiver = receiver; + this.buffers = new ConcurrentHashMap<>(); + this.inflightMessages = new AtomicLong(); + this.inflightBytes = new AtomicLong(); + } + + static GrpcGetWorkStream create( + CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + GetWorkRequest request, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer getWorkThrottleTimer, + WorkItemReceiver receiver) { + GrpcGetWorkStream getWorkStream = + new GrpcGetWorkStream( + stub, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + receiver); + getWorkStream.startStream(); + return getWorkStream; + } + + private void sendRequestExtension(long moreItems, long moreBytes) { + final StreamingGetWorkRequest extension = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(moreItems) + .setMaxBytes(moreBytes)) + .build(); + + executor() + .execute( + () -> { + try { + send(extension); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); + } + + @Override + protected synchronized void onNewStream() { + buffers.clear(); + inflightMessages.set(request.getMaxItems()); + inflightBytes.set(request.getMaxBytes()); + send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); + } + + @Override + protected boolean hasPendingRequests() { + return false; + } + + @Override + public void appendSpecificHtml(PrintWriter writer) { + // Number of buffers is same as distinct workers that sent work on this stream. + writer.format( + "GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed", + buffers.size(), inflightMessages.intValue(), inflightBytes.intValue()); + } + + @Override + public void sendHealthCheck() { + send( + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build()) + .build()); + } + + @Override + protected void onResponse(StreamingGetWorkResponseChunk chunk) { + getWorkThrottleTimer.stop(); + + GrpcGetWorkStream.WorkItemBuffer buffer = + buffers.computeIfAbsent( + chunk.getStreamId(), unused -> new GrpcGetWorkStream.WorkItemBuffer()); + buffer.append(chunk); + + if (chunk.getRemainingBytesForWorkItem() == 0) { + long size = buffer.bufferedSize(); + buffer.runAndReset(); + + // Record the fact that there are now fewer outstanding messages and bytes on the stream. + long numInflight = inflightMessages.decrementAndGet(); + long bytesInflight = inflightBytes.addAndGet(-size); + + // If the outstanding items or bytes limit has gotten too low, top both off with a + // GetWorkExtension. The goal is to keep the limits relatively close to their maximum + // values without sending too many extension requests. + if (numInflight < request.getMaxItems() / 2 || bytesInflight < request.getMaxBytes() / 2) { + long moreItems = request.getMaxItems() - numInflight; + long moreBytes = request.getMaxBytes() - bytesInflight; + inflightMessages.getAndAdd(moreItems); + inflightBytes.getAndAdd(moreBytes); + sendRequestExtension(moreItems, moreBytes); + } + } + } + + @Override + protected void startThrottleTimer() { + getWorkThrottleTimer.start(); + } + + private class WorkItemBuffer { + private final GetWorkTimingInfosTracker workTimingInfosTracker; + private String computation; + @Nullable private Instant inputDataWatermark; + @Nullable private Instant synchronizedProcessingTime; + private ByteString data; + private long bufferedSize; + + @SuppressWarnings("initialization.fields.uninitialized") + WorkItemBuffer() { + workTimingInfosTracker = new GetWorkTimingInfosTracker(System::currentTimeMillis); + data = ByteString.EMPTY; + bufferedSize = 0; + } + + @SuppressWarnings("NullableProblems") + private void setMetadata(Windmill.ComputationWorkItemMetadata metadata) { + this.computation = metadata.getComputationId(); + this.inputDataWatermark = + WindmillTimeUtils.windmillToHarnessWatermark(metadata.getInputDataWatermark()); + this.synchronizedProcessingTime = + WindmillTimeUtils.windmillToHarnessWatermark( + metadata.getDependentRealtimeInputWatermark()); + } + + private void append(StreamingGetWorkResponseChunk chunk) { + if (chunk.hasComputationMetadata()) { + setMetadata(chunk.getComputationMetadata()); + } + + this.data = data.concat(chunk.getSerializedWorkItem()); + this.bufferedSize += chunk.getSerializedWorkItem().size(); + workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList()); + } + + private long bufferedSize() { + return bufferedSize; + } + + private void runAndReset() { + try { + Windmill.WorkItem workItem = Windmill.WorkItem.parseFrom(data.newInput()); + List getWorkStreamLatencies = + workTimingInfosTracker.getLatencyAttributions(); + receiver.receiveWork( + computation, + inputDataWatermark, + synchronizedProcessingTime, + workItem, + getWorkStreamLatencies); + } catch (IOException e) { + LOG.error("Failed to parse work item from stream: ", e); + } + workTimingInfosTracker.reset(); + data = ByteString.EMPTY; + bufferedSize = 0; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java new file mode 100644 index 000000000000..0aa4c53a10c6 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java @@ -0,0 +1,629 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.io.IOException; +import java.io.PrintWriter; +import java.net.URI; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions; +import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillApplianceGrpc; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Channel; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.StatusRuntimeException; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.auth.MoreCallCredentials; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.netty.GrpcSslContexts; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.netty.NegotiationType; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.netty.NettyChannelBuilder; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.net.HostAndPort; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** gRPC client for communicating with Windmill Service. */ +// Very likely real potential for bugs - https://github.com/apache/beam/issues/19273 +// Very likely real potential for bugs - https://github.com/apache/beam/issues/19271 +@SuppressFBWarnings({"JLM_JSR166_UTILCONCURRENT_MONITORENTER", "IS2_INCONSISTENT_SYNC"}) +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public final class GrpcWindmillServer extends WindmillServerStub { + private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class); + + // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively + // high. + private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300; + private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; + private static final int DEFAULT_LOG_EVERY_N_FAILURES = 20; + private static final String LOCALHOST = "localhost"; + private static final Duration MIN_BACKOFF = Duration.millis(1); + private static final Duration MAX_BACKOFF = Duration.standardSeconds(30); + private static final AtomicLong nextId = new AtomicLong(0); + private static final int NO_HEALTH_CHECK = -1; + + private final StreamingDataflowWorkerOptions options; + private final int streamingRpcBatchLimit; + private final List stubList; + private final List + syncStubList; + private final ThrottleTimer getWorkThrottleTimer; + private final ThrottleTimer getDataThrottleTimer; + private final ThrottleTimer commitWorkThrottleTimer; + private final Random rand; + private final Set> streamRegistry; + + private long unaryDeadlineSeconds; + private ImmutableSet endpoints; + private int logEveryNStreamFailures; + private Duration maxBackoff = MAX_BACKOFF; + private WindmillApplianceGrpc.WindmillApplianceBlockingStub syncApplianceStub = null; + + private GrpcWindmillServer(StreamingDataflowWorkerOptions options) { + this.options = options; + this.streamingRpcBatchLimit = options.getWindmillServiceStreamingRpcBatchLimit(); + this.stubList = new ArrayList<>(); + this.syncStubList = new ArrayList<>(); + this.logEveryNStreamFailures = options.getWindmillServiceStreamingLogEveryNStreamFailures(); + this.endpoints = ImmutableSet.of(); + this.getWorkThrottleTimer = new ThrottleTimer(); + this.getDataThrottleTimer = new ThrottleTimer(); + this.commitWorkThrottleTimer = new ThrottleTimer(); + this.rand = new Random(); + this.streamRegistry = Collections.newSetFromMap(new ConcurrentHashMap<>()); + this.unaryDeadlineSeconds = DEFAULT_UNARY_RPC_DEADLINE_SECONDS; + } + + private static StreamingDataflowWorkerOptions testOptions() { + StreamingDataflowWorkerOptions options = + PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class); + options.setProject("project"); + options.setJobId("job"); + options.setWorkerId("worker"); + List experiments = + options.getExperiments() == null ? new ArrayList<>() : options.getExperiments(); + experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT); + options.setExperiments(experiments); + + options.setWindmillServiceStreamingRpcBatchLimit(Integer.MAX_VALUE); + options.setWindmillServiceStreamingRpcHealthCheckPeriodMs(NO_HEALTH_CHECK); + options.setWindmillServiceStreamingLogEveryNStreamFailures(DEFAULT_LOG_EVERY_N_FAILURES); + + return options; + } + + /** Create new instance of {@link GrpcWindmillServer}. */ + public static GrpcWindmillServer create(StreamingDataflowWorkerOptions workerOptions) + throws IOException { + GrpcWindmillServer grpcWindmillServer = new GrpcWindmillServer(workerOptions); + if (workerOptions.getWindmillServiceEndpoint() != null) { + grpcWindmillServer.configureWindmillServiceEndpoints(); + } else if (!workerOptions.isEnableStreamingEngine() + && workerOptions.getLocalWindmillHostport() != null) { + grpcWindmillServer.configureLocalHost(); + } + + if (workerOptions.getWindmillServiceStreamingRpcHealthCheckPeriodMs() > 0) { + grpcWindmillServer.scheduleHealthCheckTimer( + workerOptions, () -> grpcWindmillServer.streamRegistry); + } + + return grpcWindmillServer; + } + + @VisibleForTesting + static GrpcWindmillServer newTestInstance(String name) { + GrpcWindmillServer testServer = new GrpcWindmillServer(testOptions()); + testServer.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel(name))); + return testServer; + } + + private static Channel inProcessChannel(String name) { + return InProcessChannelBuilder.forName(name).directExecutor().build(); + } + + private static Channel localhostChannel(int port) { + return NettyChannelBuilder.forAddress(LOCALHOST, port) + .maxInboundMessageSize(Integer.MAX_VALUE) + .negotiationType(NegotiationType.PLAINTEXT) + .build(); + } + + private void scheduleHealthCheckTimer( + StreamingDataflowWorkerOptions options, Supplier>> streams) { + new Timer("WindmillHealthCheckTimer") + .schedule( + new HealthCheckTimerTask(options, streams), + 0, + options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()); + } + + private void configureWindmillServiceEndpoints() throws IOException { + Set endpoints = new HashSet<>(); + for (String endpoint : Splitter.on(',').split(options.getWindmillServiceEndpoint())) { + endpoints.add( + HostAndPort.fromString(endpoint).withDefaultPort(options.getWindmillServicePort())); + } + initializeWindmillService(endpoints); + } + + private void configureLocalHost() { + int portStart = options.getLocalWindmillHostport().lastIndexOf(':'); + String endpoint = options.getLocalWindmillHostport().substring(0, portStart); + assert ("grpc:localhost".equals(endpoint)); + int port = Integer.parseInt(options.getLocalWindmillHostport().substring(portStart + 1)); + this.endpoints = ImmutableSet.of(HostAndPort.fromParts(LOCALHOST, port)); + initializeLocalHost(port); + } + + @Override + public synchronized void setWindmillServiceEndpoints(Set endpoints) + throws IOException { + Preconditions.checkNotNull(endpoints); + if (endpoints.equals(this.endpoints)) { + // The endpoints are equal don't recreate the stubs. + return; + } + LOG.info("Creating a new windmill stub, endpoints: {}", endpoints); + if (this.endpoints != null) { + LOG.info("Previous windmill stub endpoints: {}", this.endpoints); + } + initializeWindmillService(endpoints); + } + + @Override + public synchronized boolean isReady() { + return !stubList.isEmpty(); + } + + private synchronized void initializeLocalHost(int port) { + this.logEveryNStreamFailures = 1; + this.maxBackoff = Duration.millis(500); + this.unaryDeadlineSeconds = 10; // For local testing use short deadlines. + Channel channel = localhostChannel(port); + if (options.isEnableStreamingEngine()) { + this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(channel)); + this.syncStubList.add(CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(channel)); + } else { + this.syncApplianceStub = WindmillApplianceGrpc.newBlockingStub(channel); + } + } + + private synchronized void initializeWindmillService(Set endpoints) + throws IOException { + LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", endpoints); + this.stubList.clear(); + this.syncStubList.clear(); + this.endpoints = ImmutableSet.copyOf(endpoints); + for (HostAndPort endpoint : this.endpoints) { + if (LOCALHOST.equals(endpoint.getHost())) { + initializeLocalHost(endpoint.getPort()); + } else { + CallCredentials creds = + MoreCallCredentials.from(new VendoredCredentialsAdapter(options.getGcpCredential())); + this.stubList.add( + CloudWindmillServiceV1Alpha1Grpc.newStub(remoteChannel(endpoint)) + .withCallCredentials(creds)); + this.syncStubList.add( + CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(remoteChannel(endpoint)) + .withCallCredentials(creds)); + } + } + } + + private Channel remoteChannel(HostAndPort endpoint) throws IOException { + NettyChannelBuilder builder = + NettyChannelBuilder.forAddress(endpoint.getHost(), endpoint.getPort()); + int timeoutSec = options.getWindmillServiceRpcChannelAliveTimeoutSec(); + if (timeoutSec > 0) { + builder + .keepAliveTime(timeoutSec, TimeUnit.SECONDS) + .keepAliveTimeout(timeoutSec, TimeUnit.SECONDS) + .keepAliveWithoutCalls(true); + } + return builder + .flowControlWindow(10 * 1024 * 1024) + .maxInboundMessageSize(Integer.MAX_VALUE) + .maxInboundMetadataSize(1024 * 1024) + .negotiationType(NegotiationType.TLS) + // Set ciphers(null) to not use GCM, which is disabled for Dataflow + // due to it being horribly slow. + .sslContext(GrpcSslContexts.forClient().ciphers(null).build()) + .build(); + } + + private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub() { + if (stubList.isEmpty()) { + throw new RuntimeException("windmillServiceEndpoint has not been set"); + } + if (stubList.size() == 1) { + return stubList.get(0); + } + return stubList.get(rand.nextInt(stubList.size())); + } + + private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub + syncStub() { + if (syncStubList.isEmpty()) { + throw new RuntimeException("windmillServiceEndpoint has not been set"); + } + if (syncStubList.size() == 1) { + return syncStubList.get(0); + } + return syncStubList.get(rand.nextInt(syncStubList.size())); + } + + @Override + public void appendSummaryHtml(PrintWriter writer) { + writer.write("Active Streams:
"); + for (AbstractWindmillStream stream : streamRegistry) { + stream.appendSummaryHtml(writer); + writer.write("
"); + } + } + + // Configure backoff to retry calls forever, with a maximum sane retry interval. + private BackOff grpcBackoff() { + return FluentBackoff.DEFAULT + .withInitialBackoff(MIN_BACKOFF) + .withMaxBackoff(maxBackoff) + .backoff(); + } + + private ResponseT callWithBackoff(Supplier function) { + BackOff backoff = grpcBackoff(); + int rpcErrors = 0; + while (true) { + try { + return function.get(); + } catch (StatusRuntimeException e) { + try { + if (++rpcErrors % 20 == 0) { + LOG.warn( + "Many exceptions calling gRPC. Last exception: {} with status {}", + e, + e.getStatus()); + } + if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) { + throw new RpcException(e); + } + } catch (IOException | InterruptedException i) { + if (i instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + RpcException rpcException = new RpcException(e); + rpcException.addSuppressed(i); + throw rpcException; + } + } + } + } + + @Override + public GetWorkResponse getWork(GetWorkRequest request) { + if (syncApplianceStub == null) { + return callWithBackoff( + () -> + syncStub() + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .getWork( + request + .toBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .build())); + } else { + return callWithBackoff( + () -> + syncApplianceStub + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .getWork(request)); + } + } + + @Override + public GetDataResponse getData(GetDataRequest request) { + if (syncApplianceStub == null) { + return callWithBackoff( + () -> + syncStub() + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .getData( + request + .toBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .build())); + } else { + return callWithBackoff( + () -> + syncApplianceStub + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .getData(request)); + } + } + + @Override + public CommitWorkResponse commitWork(CommitWorkRequest request) { + if (syncApplianceStub == null) { + return callWithBackoff( + () -> + syncStub() + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .commitWork( + request + .toBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .build())); + } else { + return callWithBackoff( + () -> + syncApplianceStub + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .commitWork(request)); + } + } + + private StreamObserverFactory newStreamObserverFactory() { + return StreamObserverFactory.direct( + DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, options.getWindmillMessagesBetweenIsReadyChecks()); + } + + @Override + public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { + GetWorkRequest getWorkRequest = + GetWorkRequest.newBuilder(request) + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .build(); + + return GrpcGetWorkStream.create( + stub(), + getWorkRequest, + grpcBackoff(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + receiver); + } + + @Override + public GetDataStream getDataStream() { + return GrpcGetDataStream.create( + stub(), + grpcBackoff(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + makeHeader(), + nextId, + streamingRpcBatchLimit); + } + + @Override + public CommitWorkStream commitWorkStream() { + return GrpcCommitWorkStream.create( + stub(), + grpcBackoff(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + makeHeader(), + nextId, + streamingRpcBatchLimit); + } + + @Override + public GetConfigResponse getConfig(GetConfigRequest request) { + if (syncApplianceStub == null) { + throw new RpcException( + new UnsupportedOperationException("GetConfig not supported with windmill service.")); + } else { + return callWithBackoff( + () -> + syncApplianceStub + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .getConfig(request)); + } + } + + @Override + public ReportStatsResponse reportStats(ReportStatsRequest request) { + if (syncApplianceStub == null) { + throw new RpcException( + new UnsupportedOperationException("ReportStats not supported with windmill service.")); + } else { + return callWithBackoff( + () -> + syncApplianceStub + .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) + .reportStats(request)); + } + } + + @Override + public long getAndResetThrottleTime() { + return getWorkThrottleTimer.getAndResetThrottleTime() + + getDataThrottleTimer.getAndResetThrottleTime() + + commitWorkThrottleTimer.getAndResetThrottleTime(); + } + + private JobHeader makeHeader() { + return JobHeader.newBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .build(); + } + + /** + * Create a wrapper around credentials callback that delegates to the underlying vendored {@link + * com.google.auth.RequestMetadataCallback}. Note that this class should override every method + * that is not final and not static and call the delegate directly. + * + *

TODO: Replace this with an auto generated proxy which calls the underlying implementation + * delegate to reduce maintenance burden. + */ + private static class VendoredRequestMetadataCallbackAdapter + implements com.google.auth.RequestMetadataCallback { + + private final org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.RequestMetadataCallback + callback; + + private VendoredRequestMetadataCallbackAdapter( + org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.RequestMetadataCallback callback) { + this.callback = callback; + } + + @Override + public void onSuccess(Map> metadata) { + callback.onSuccess(metadata); + } + + @Override + public void onFailure(Throwable exception) { + callback.onFailure(exception); + } + } + + /** + * Create a wrapper around credentials that delegates to the underlying {@link + * com.google.auth.Credentials}. Note that this class should override every method that is not + * final and not static and call the delegate directly. + * + *

TODO: Replace this with an auto generated proxy which calls the underlying implementation + * delegate to reduce maintenance burden. + */ + private static class VendoredCredentialsAdapter + extends org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.Credentials { + + private final com.google.auth.Credentials credentials; + + private VendoredCredentialsAdapter(com.google.auth.Credentials credentials) { + this.credentials = credentials; + } + + @Override + public String getAuthenticationType() { + return credentials.getAuthenticationType(); + } + + @Override + public Map> getRequestMetadata() throws IOException { + return credentials.getRequestMetadata(); + } + + @Override + public void getRequestMetadata( + final URI uri, + Executor executor, + final org.apache.beam.vendor.grpc.v1p54p0.com.google.auth.RequestMetadataCallback + callback) { + credentials.getRequestMetadata( + uri, executor, new VendoredRequestMetadataCallbackAdapter(callback)); + } + + @Override + public Map> getRequestMetadata(URI uri) throws IOException { + return credentials.getRequestMetadata(uri); + } + + @Override + public boolean hasRequestMetadata() { + return credentials.hasRequestMetadata(); + } + + @Override + public boolean hasRequestMetadataOnly() { + return credentials.hasRequestMetadataOnly(); + } + + @Override + public void refresh() throws IOException { + credentials.refresh(); + } + } + + private static class HealthCheckTimerTask extends TimerTask { + private final StreamingDataflowWorkerOptions options; + private final Supplier>> streams; + + public HealthCheckTimerTask( + StreamingDataflowWorkerOptions options, + Supplier>> streams) { + this.options = options; + this.streams = streams; + } + + @Override + public void run() { + Instant reportThreshold = + Instant.now() + .minus(Duration.millis(options.getWindmillServiceStreamingRpcHealthCheckPeriodMs())); + for (AbstractWindmillStream stream : streams.get()) { + stream.maybeSendHealthCheck(reportThreshold); + } + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/ThrottleTimer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/ThrottleTimer.java new file mode 100644 index 000000000000..237339aff399 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/ThrottleTimer.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import org.joda.time.Instant; + +/** + * A stopwatch used to track the amount of time spent throttled due to Resource Exhausted errors. + * Throttle time is cumulative for all three rpcs types but not for all streams. So if GetWork and + * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both + * blocked for x totalTime will be x. All methods are thread safe. + */ +class ThrottleTimer { + // This is -1 if not currently being throttled or the time in + // milliseconds when throttling for this type started. + private long startTime = -1; + // This is the collected total throttle times since the last poll. Throttle times are + // reported as a delta so this is cleared whenever it gets reported. + private long totalTime = 0; + + /** + * Starts the timer if it has not been started and does nothing if it has already been started. + */ + synchronized void start() { + if (!throttled()) { // This timer is not started yet so start it now. + startTime = Instant.now().getMillis(); + } + } + + /** Stops the timer if it has been started and does nothing if it has not been started. */ + public synchronized void stop() { + if (throttled()) { // This timer has been started already so stop it now. + totalTime += Instant.now().getMillis() - startTime; + startTime = -1; + } + } + + /** Returns if the specified type is currently being throttled. */ + public synchronized boolean throttled() { + return startTime != -1; + } + + /** Returns the combined total of all throttle times and resets those times to 0. */ + public synchronized long getAndResetThrottleTime() { + if (throttled()) { + stop(); + start(); + } + long toReturn = totalTime; + totalTime = 0; + return toReturn; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 0e53210c0189..a1b57f21e193 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -164,11 +164,13 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedLong; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles; import org.hamcrest.Matcher; import org.hamcrest.Matchers; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; @@ -2852,6 +2854,75 @@ public void testActiveWorkForShardedKeys() throws Exception { Mockito.verifyNoMoreInteractions(mockExecutor); } + @Test + @Ignore // Test is flaky on Jenkins (#27555) + public void testMaxThreadMetric() throws Exception { + int maxThreads = 2; + int threadExpiration = 60; + // setting up actual implementation of executor instead of mocking to keep track of + // active thread count. + BoundedQueueExecutor executor = + new BoundedQueueExecutor( + maxThreads, + threadExpiration, + TimeUnit.SECONDS, + maxThreads, + 10000000, + new ThreadFactoryBuilder() + .setNameFormat("DataflowWorkUnits-%d") + .setDaemon(true) + .build()); + + StreamingDataflowWorker.ComputationState computationState = + new StreamingDataflowWorker.ComputationState( + "computation", + defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + executor, + ImmutableMap.of(), + null); + + ShardedKey key1Shard1 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 1); + + // overriding definition of MockWork to add sleep, which will help us keep track of how + // long each work item takes to process and therefore let us manipulate how long the time + // at which we're at max threads is. + MockWork m2 = + new MockWork(2) { + @Override + public void run() { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }; + + MockWork m3 = + new MockWork(3) { + @Override + public void run() { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }; + + assertTrue(computationState.activateWork(key1Shard1, m2)); + assertTrue(computationState.activateWork(key1Shard1, m3)); + executor.execute(m2, m2.getWorkItem().getSerializedSize()); + + executor.execute(m3, m3.getWorkItem().getSerializedSize()); + + // Will get close to 1000ms that both work items are processing (sleeping, really) + // give or take a few ms. + long i = 990L; + assertTrue(executor.allThreadsActiveTime() >= i); + executor.shutdown(); + } + static class TestExceptionInvalidatesCacheFn extends DoFn>, String> { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java similarity index 97% rename from runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java rename to runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java index c9459b7d71af..77f046e7384a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.dataflow.worker.windmill; +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -39,7 +39,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase; -import org.apache.beam.runners.dataflow.worker.windmill.GrpcWindmillServer.GetWorkTimingInfosTracker; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata; @@ -88,19 +88,22 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Unit tests for {@link GrpcWindmillServer}. */ +/** + * Unit tests for {@link + * org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcWindmillServer}. + */ @RunWith(JUnit4.class) @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) }) public class GrpcWindmillServerTest { - private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); + private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); + private static final int STREAM_CHUNK_SIZE = 2 << 20; private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); private Server server; - private GrpcWindmillServer client; - private static final int STREAM_CHUNK_SIZE = 2 << 20; + private org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcWindmillServer client; private int remainingErrors = 20; @Before @@ -114,7 +117,7 @@ public void setUp() throws Exception { .build() .start(); - this.client = GrpcWindmillServer.newTestInstance(name, true); + this.client = GrpcWindmillServer.newTestInstance(name); } @After @@ -133,50 +136,6 @@ private void maybeInjectError(Stream stream) { } } - class ResponseErrorInjector { - private Stream stream; - private Thread errorThread; - private boolean cancelled = false; - - public ResponseErrorInjector(Stream stream) { - this.stream = stream; - errorThread = new Thread(this::errorThreadBody); - errorThread.start(); - } - - private void errorThreadBody() { - int i = 0; - while (true) { - try { - Thread.sleep(ThreadLocalRandom.current().nextInt(++i * 10)); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - break; - } - synchronized (this) { - if (cancelled) { - break; - } - } - maybeInjectError(stream); - } - } - - public void cancel() { - LOG.info("Starting cancel of error injector."); - synchronized (this) { - cancelled = true; - } - errorThread.interrupt(); - try { - errorThread.join(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - LOG.info("Done cancelling."); - } - } - @Test public void testStreamingGetWork() throws Exception { // This fake server returns an infinite stream of identical WorkItems, obeying the request size @@ -187,8 +146,8 @@ public void testStreamingGetWork() throws Exception { public StreamObserver getWorkStream( StreamObserver responseObserver) { return new StreamObserver() { + final ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); boolean sawHeader = false; - ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); @Override public void onNext(StreamingGetWorkRequest request) { @@ -274,7 +233,7 @@ public void onCompleted() { (String computation, @Nullable Instant inputDataWatermark, Instant synchronizedProcessingTime, - Windmill.WorkItem workItem, + WorkItem workItem, Collection getWorkStreamLatencies) -> { latch.countDown(); assertEquals(inputDataWatermark, new Instant(18)); @@ -297,11 +256,11 @@ public void testStreamingGetData() throws Exception { public StreamObserver getDataStream( StreamObserver responseObserver) { return new StreamObserver() { - boolean sawHeader = false; - HashSet seenIds = new HashSet<>(); - ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); - StreamingGetDataResponse.Builder responseBuilder = + final HashSet seenIds = new HashSet<>(); + final ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); + final StreamingGetDataResponse.Builder responseBuilder = StreamingGetDataResponse.newBuilder(); + boolean sawHeader = false; @Override public void onNext(StreamingGetDataRequest chunk) { @@ -508,10 +467,10 @@ private StreamObserver getTestCommitStreamObserver( StreamObserver responseObserver, Map commitRequests) { return new StreamObserver() { + final ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); boolean sawHeader = false; InputStream buffer = null; long remainingBytes = 0; - ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); @Override public void onNext(StreamingCommitWorkRequest request) { @@ -1016,4 +975,49 @@ public void testGetWorkTimingInfosTracker() throws Exception { Math.min(34, (long) (elapsedTime * (130.0 / sumDurations))), latencies.get(State.GET_WORK_IN_TRANSIT_TO_USER_WORKER).getTotalDurationMillis()); } + + class ResponseErrorInjector { + + private final Stream stream; + private final Thread errorThread; + private boolean cancelled = false; + + public ResponseErrorInjector(Stream stream) { + this.stream = stream; + errorThread = new Thread(this::errorThreadBody); + errorThread.start(); + } + + private void errorThreadBody() { + int i = 0; + while (true) { + try { + Thread.sleep(ThreadLocalRandom.current().nextInt(++i * 10)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + synchronized (this) { + if (cancelled) { + break; + } + } + maybeInjectError(stream); + } + } + + public void cancel() { + LOG.info("Starting cancel of error injector."); + synchronized (this) { + cancelled = true; + } + errorThread.interrupt(); + try { + errorThread.join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + LOG.info("Done cancelling."); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 86a551f8b647..f66b2bed48c6 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -27,14 +27,14 @@ option java_outer_classname = "Windmill"; // API Data types message Message { - required int64 timestamp = 1 [default=-0x8000000000000000]; + required int64 timestamp = 1 [default = -0x8000000000000000]; required bytes data = 2; optional bytes metadata = 3; } message Timer { required bytes tag = 1; - optional int64 timestamp = 2 [default=-0x8000000000000000]; + optional int64 timestamp = 2 [default = -0x8000000000000000]; enum Type { WATERMARK = 0; REALTIME = 1; @@ -124,7 +124,7 @@ message TimerBundle { } message Value { - required int64 timestamp = 1 [default=-0x8000000000000000]; + required int64 timestamp = 1 [default = -0x8000000000000000]; required bytes data = 2; } @@ -293,7 +293,7 @@ message TagSortedListFetchResponse { repeated SortedListRange fetch_ranges = 5; // Request position copied from request. optional bytes request_position = 6; - } +} message TagSortedListUpdateRequest { optional bytes tag = 1; @@ -332,7 +332,7 @@ message SourceState { message WatermarkHold { required bytes tag = 1; - repeated int64 timestamps = 2 [packed=true]; + repeated int64 timestamps = 2 [packed = true]; optional bool reset = 3; optional string state_family = 4; } @@ -354,7 +354,7 @@ message WorkItem { optional TimerBundle timers = 4; repeated GlobalDataId global_data_id_notifications = 5; optional SourceState source_state = 6; - optional int64 output_data_watermark = 8 [default=-0x8000000000000000]; + optional int64 output_data_watermark = 8 [default = -0x8000000000000000]; // Indicates that this is a new key with no data associated. This allows // the harness to optimize data fetching. optional bool is_new_key = 10; @@ -368,9 +368,9 @@ message WorkItem { message ComputationWorkItems { required string computation_id = 1; repeated WorkItem work = 2; - optional int64 input_data_watermark = 3 [default=-0x8000000000000000]; + optional int64 input_data_watermark = 3 [default = -0x8000000000000000]; optional int64 dependent_realtime_input_watermark = 4 - [default = -0x8000000000000000]; + [default = -0x8000000000000000]; } //////////////////////////////////////////////////////////////////////////////// @@ -478,12 +478,12 @@ message Counter { // value accumulated since the worker started working on this WorkItem. // By default this is false, indicating that this metric is reported // as a delta that is not associated with any WorkItem. - optional bool cumulative = 7; + optional bool cumulative = 7; } message GlobalDataRequest { required GlobalDataId data_id = 1; - optional int64 existence_watermark_deadline = 2 [default=0x7FFFFFFFFFFFFFFF]; + optional int64 existence_watermark_deadline = 2 [default = 0x7FFFFFFFFFFFFFFF]; optional string state_family = 3; } @@ -509,8 +509,8 @@ message WorkItemCommitRequest { repeated GlobalDataRequest global_data_requests = 11; repeated GlobalData global_data_updates = 10; optional SourceState source_state_updates = 12; - optional int64 source_watermark = 13 [default=-0x8000000000000000]; - optional int64 source_backlog_bytes = 17 [default=-1]; + optional int64 source_watermark = 13 [default = -0x8000000000000000]; + optional int64 source_backlog_bytes = 17 [default = -1]; optional int64 source_bytes_processed = 22; repeated WatermarkHold watermark_holds = 14; @@ -592,7 +592,7 @@ message GetConfigResponse { optional string computation_id = 2; } repeated SystemNameToComputationIdMapEntry - system_name_to_computation_id_map = 3; + system_name_to_computation_id_map = 3; // Map of computation id to ComputationConfig. message ComputationConfigMapEntry { @@ -627,10 +627,10 @@ message ReportStatsResponse { // Streaming API message StreamingGetWorkRequest { - oneof chunk_type { + oneof chunk_type { GetWorkRequest request = 1; StreamingGetWorkRequestExtension request_extension = 2; - } + } } message StreamingGetWorkRequestExtension { @@ -661,7 +661,7 @@ message StreamingGetWorkResponseChunk { message ComputationWorkItemMetadata { optional string computation_id = 1; - optional int64 input_data_watermark = 2 [default=-0x8000000000000000]; + optional int64 input_data_watermark = 2 [default = -0x8000000000000000]; optional int64 dependent_realtime_input_watermark = 3 [default = -0x8000000000000000]; } @@ -742,6 +742,36 @@ message StreamingCommitResponse { repeated CommitStatus status = 2; } +message WorkerMetadataRequest { + optional JobHeader header = 1; +} + +message WorkerMetadataResponse { + // The metadata version increases with every modification. Within a single + // stream it will always be increasing. The version may be used across streams + // to ensure that the view of the metadata does not move backwards. + optional int64 metadata_version = 1; + + // Endpoints that should be used for requesting work with GetWorkStream. + // Additional data for returned work should be fetched from the endpoint with + // GetDataStream. The work should be committed to the endpoint with + // CommitWorkStream. Each response on this stream replaces the previous, and + // connections to endpoints that are no longer present should be closed. + message Endpoint { + optional string endpoint = 1; + } + repeated Endpoint work_endpoints = 2; + + // Maps from GlobalData tag to the endpoint that should be used for GetData + // calls to retrieve that global data. + map global_data_endpoints = 3; + + // DirectPath endpoints to be used by user workers for streaming engine jobs. + // DirectPath endpoints here are virtual IPv6 addresses of the windmill + // workers. + repeated Endpoint direct_path_endpoints = 4; +} + service WindmillAppliance { // Gets streaming Dataflow work. rpc GetWork(.windmill.GetWorkRequest) returns (.windmill.GetWorkResponse); diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto index bef819d69901..803766d1a464 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto @@ -33,6 +33,10 @@ service CloudWindmillServiceV1Alpha1 { rpc GetWorkStream(stream .windmill.StreamingGetWorkRequest) returns (stream .windmill.StreamingGetWorkResponseChunk); + // Gets worker metadata. Response is a stream. + rpc GetWorkerMetadataStream(.windmill.WorkerMetadataRequest) + returns (stream .windmill.WorkerMetadataResponse); + // Gets data from Windmill. rpc GetData(.windmill.GetDataRequest) returns(.windmill.GetDataResponse); diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/util/PipelineJsonRendererTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/util/PipelineJsonRendererTest.java index 3860fb229632..0a4f532808b1 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/util/PipelineJsonRendererTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/util/PipelineJsonRendererTest.java @@ -85,7 +85,10 @@ public void testCompositePipeline() throws IOException { Pipeline p = Pipeline.create(options); - p.apply(Create.timestamped(TimestampedValue.of(KV.of(1, 1), new Instant(1)))) + p.apply( + Create.timestamped( + TimestampedValue.of(KV.of(1, 1), new Instant(1)), + TimestampedValue.of(KV.of(2, 2), new Instant(2)))) .apply(Window.into(FixedWindows.of(Duration.millis(10)))) .apply(Sum.integersPerKey()); diff --git a/runners/spark/3/build.gradle b/runners/spark/3/build.gradle index 494d367131b4..cb34a1fd972b 100644 --- a/runners/spark/3/build.gradle +++ b/runners/spark/3/build.gradle @@ -35,7 +35,7 @@ createJavaExamplesArchetypeValidationTask(type: 'Quickstart', runner: 'Spark') // Additional supported Spark versions (used in compatibility tests) def sparkVersions = [ "330": "3.3.0", - "321": "3.2.1" + "312": "3.1.2" ] sparkVersions.each { kv -> diff --git a/runners/spark/3/job-server/build.gradle b/runners/spark/3/job-server/build.gradle index d11a1a8edb13..68bb8d9a10e1 100644 --- a/runners/spark/3/job-server/build.gradle +++ b/runners/spark/3/job-server/build.gradle @@ -28,13 +28,4 @@ project.ext { } // Load the main build script which contains all build logic. -apply from: "$basePath/spark_job_server.gradle" - - -configurations.runtimeClasspath { - resolutionStrategy { - // Downgrade the Scala version of the job-server to match the Scala version of a Spark 3.1.2 cluster to prevent - // a Scala bug (InvalidClassException when deserializing WrappedArray), see https://github.com/apache/beam/issues/21092 - force "org.scala-lang:scala-library:2.12.10" - } -} \ No newline at end of file +apply from: "$basePath/spark_job_server.gradle" \ No newline at end of file diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java index 5770d93094c9..480025b6f174 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java @@ -167,7 +167,7 @@ && eligibleForGroupByWindow(windowing, false) result = input .select(explode(col("windows")).as("window"), col("value"), col("timestamp")) - .groupBy(col("value.key"), col("window")) + .groupBy(col("value.key").as("key"), col("window")) .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) .select( inSingleWindow( diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java index e70cc7253f8d..ceafc1642baa 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java @@ -69,6 +69,10 @@ private static Expression invoke( // Spark 3.1.x return STATIC_INVOKE_CONSTRUCTOR.newInstance( cls, type, fun, seqOf(args), propagateNull, true); + case 7: + // Spark 3.2.0 + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true); case 8: // Spark 3.2.x, 3.3.x return STATIC_INVOKE_CONSTRUCTOR.newInstance( @@ -89,8 +93,14 @@ static Expression invoke( // created reflectively. This is fine as it's just needed once to create the query plan. switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) { case 6: + // Spark 3.1.x return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable); + case 7: + // Spark 3.2.0 + return INVOKE_CONSTRUCTOR.newInstance( + obj, fun, type, seqOf(args), emptyList(), false, nullable); case 8: + // Spark 3.2.x, 3.3.x return INVOKE_CONSTRUCTOR.newInstance( obj, fun, type, seqOf(args), emptyList(), false, nullable, true); default: diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index 0e3028d3843c..9ee9b48585c5 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -181,7 +181,7 @@ dependencies { runtimeOnly library.java.jackson_module_scala_2_12 // Force paranamer 2.8 to avoid issues when using Scala 2.12 runtimeOnly "com.thoughtworks.paranamer:paranamer:2.8" - provided library.java.hadoop_common + provided "org.apache.hadoop:hadoop-client-api:3.3.1" provided library.java.commons_io provided library.java.hamcrest provided "com.esotericsoftware:kryo-shaded:4.0.2" diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java index 157b3cb946e9..ef3e9589ba2e 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java @@ -83,7 +83,8 @@ public void debugBatchPipeline() { .apply(TextIO.write().to("!!PLACEHOLDER-OUTPUT-DIR!!").withNumShards(3).withSuffix(".txt")); final String expectedPipeline = - "sparkContext.()\n" + "sparkContext.()\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.FlatMapElements$3())\n" + "_.mapPartitions(" + "new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n" diff --git a/sdks/go.mod b/sdks/go.mod index 05fcf8489459..fb4e0a74d6af 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -30,11 +30,11 @@ require ( cloud.google.com/go/pubsub v1.32.0 cloud.google.com/go/spanner v1.47.0 cloud.google.com/go/storage v1.31.0 - github.com/aws/aws-sdk-go-v2 v1.18.1 - github.com/aws/aws-sdk-go-v2/config v1.18.27 - github.com/aws/aws-sdk-go-v2/credentials v1.13.26 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.71 - github.com/aws/aws-sdk-go-v2/service/s3 v1.36.0 + github.com/aws/aws-sdk-go-v2 v1.19.0 + github.com/aws/aws-sdk-go-v2/config v1.18.28 + github.com/aws/aws-sdk-go-v2/credentials v1.13.27 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.72 + github.com/aws/aws-sdk-go-v2/service/s3 v1.37.0 github.com/aws/smithy-go v1.13.5 github.com/docker/go-connections v0.4.0 github.com/dustin/go-humanize v1.0.1 @@ -57,8 +57,8 @@ require ( golang.org/x/sync v0.3.0 golang.org/x/sys v0.10.0 golang.org/x/text v0.11.0 - google.golang.org/api v0.130.0 - google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc + google.golang.org/api v0.131.0 + google.golang.org/genproto v0.0.0-20230629202037-9506855d4529 google.golang.org/grpc v1.56.2 google.golang.org/protobuf v1.31.0 gopkg.in/retry.v1 v1.0.3 @@ -85,18 +85,18 @@ require ( github.com/apache/thrift v0.16.0 // indirect github.com/aws/aws-sdk-go v1.34.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.4 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.34 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.28 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.35 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.26 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.35 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.29 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.27 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.29 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.28 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.12.12 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.12 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.19.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.30 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.29 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.12.13 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.19.3 // indirect github.com/cenkalti/backoff/v4 v4.2.0 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect @@ -119,7 +119,7 @@ require ( github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect - github.com/googleapis/gax-go/v2 v2.11.0 // indirect + github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/imdario/mergo v0.3.15 // indirect @@ -158,6 +158,6 @@ require ( golang.org/x/tools v0.9.1 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230629202037-9506855d4529 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230629202037-9506855d4529 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130 // indirect ) diff --git a/sdks/go.sum b/sdks/go.sum index 0c8704d9cb37..ed376e34a3aa 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -78,53 +78,53 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.18.1 h1:+tefE750oAb7ZQGzla6bLkOwfcQCEtC5y2RqoqCeqKo= -github.com/aws/aws-sdk-go-v2 v1.18.1/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.19.0 h1:klAT+y3pGFBU/qVf1uzwttpBbiuozJYWzNLHioyDJ+k= +github.com/aws/aws-sdk-go-v2 v1.19.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10/go.mod h1:VeTZetY5KRJLuD/7fkQXMU6Mw7H5m/KP2J5Iy9osMno= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= -github.com/aws/aws-sdk-go-v2/config v1.18.27 h1:Az9uLwmssTE6OGTpsFqOnaGpLnKDqNYOJzWuC6UAYzA= -github.com/aws/aws-sdk-go-v2/config v1.18.27/go.mod h1:0My+YgmkGxeqjXZb5BYme5pc4drjTnM+x1GJ3zv42Nw= +github.com/aws/aws-sdk-go-v2/config v1.18.28 h1:TINEaKyh1Td64tqFvn09iYpKiWjmHYrG1fa91q2gnqw= +github.com/aws/aws-sdk-go-v2/config v1.18.28/go.mod h1:nIL+4/8JdAuNHEjn/gPEXqtnS02Q3NXB/9Z7o5xE4+A= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.13.26 h1:qmU+yhKmOCyujmuPY7tf5MxR/RKyZrOPO3V4DobiTUk= -github.com/aws/aws-sdk-go-v2/credentials v1.13.26/go.mod h1:GoXt2YC8jHUBbA4jr+W3JiemnIbkXOfxSXcisUsZ3os= +github.com/aws/aws-sdk-go-v2/credentials v1.13.27 h1:dz0yr/yR1jweAnsCx+BmjerUILVPQ6FS5AwF/OyG1kA= +github.com/aws/aws-sdk-go-v2/credentials v1.13.27/go.mod h1:syOqAek45ZXZp29HlnRS/BNgMIW6uiRmeuQsz4Qh2UE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.4 h1:LxK/bitrAr4lnh9LnIS6i7zWbCOdMsfzKFBI6LUCS0I= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.4/go.mod h1:E1hLXN/BL2e6YizK1zFlYd8vsfi2GTjbjBazinMmeaM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.5 h1:kP3Me6Fy3vdi+9uHd7YLr6ewPxRL+PU6y15urfTaamU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.5/go.mod h1:Gj7tm95r+QsDoN2Fhuz/3npQvcZbkEf5mL70n3Xfluc= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.71 h1:SAB1UAVaf6nGCu3zyIrV+VWsendXrms1GqtW4zBotKA= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.71/go.mod h1:ZNo5H4PR3/fwsXYqb+Ld5YAfvHcYCbltaTTtSay4l2o= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.34 h1:A5UqQEmPaCFpedKouS4v+dHCTUo2sKqhoKO9U5kxyWo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.34/go.mod h1:wZpTEecJe0Btj3IYnDx/VlUzor9wm3fJHyvLpQF0VwY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.28 h1:srIVS45eQuewqz6fKKu6ZGXaq6FuFg5NzgQBAM6g8Y4= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.28/go.mod h1:7VRpKQQedkfIEXb4k52I7swUnZP0wohVajJMRn3vsUw= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.72 h1:m0MmP89v1B0t3b8W8rtATU76KNsodak69QtiokHyEvo= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.72/go.mod h1:ylOTxIuoTL+XjH46Omv2iPjHdeGUk3SQ4hxYho4EHMA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.35 h1:hMUCiE3Zi5AHrRNGf5j985u0WyqI6r2NULhUfo0N/No= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.35/go.mod h1:ipR5PvpSPqIqL5Mi82BxLnfMkHVbmco8kUwO2xrCi0M= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.29 h1:yOpYx+FTBdpk/g+sBU6Cb1H0U/TLEcYYp66mYqsPpcc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.29/go.mod h1:M/eUABlDbw2uVrdAn+UsI6M727qp2fxkp8K0ejcBDUY= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.35 h1:LWA+3kDM8ly001vJ1X1waCuLJdtTl48gwkPKWy9sosI= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.35/go.mod h1:0Eg1YjxE0Bhn56lx+SHJwCzhW+2JGtizsrx+lCqrfm0= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.26 h1:wscW+pnn3J1OYnanMnza5ZVYXLX4cKk5rAvUAl4Qu+c= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.26/go.mod h1:MtYiox5gvyB+OyP0Mr0Sm/yzbEAIPL9eijj/ouHAPw0= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.36 h1:8r5m1BoAWkn0TDC34lUculryf7nUF25EgIMdjvGCkgo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.36/go.mod h1:Rmw2M1hMVTwiUhjwMoIBFWFJMhvJbct06sSidxInkhY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.27 h1:cZG7psLfqpkB6H+fIrgUDWmlzM474St1LP0jcz272yI= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.27/go.mod h1:ZdjYvJpDlefgh8/hWelJhqgqJeodxu4SmbVsSdBlL7E= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 h1:y2+VQzC6Zh2ojtV2LoC0MNwHWc6qXv/j2vrQtlftkdA= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.29 h1:zZSLP3v3riMOP14H7b4XP0uyfREDQOYv2cqIrvTXDNQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.29/go.mod h1:z7EjRjVwZ6pWcWdI2H64dKttvzaP99jRIj5hphW0M5U= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.30 h1:Bje8Xkh2OWpjBdNfXLrnn8eZg569dUQmhgtydxAYyP0= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.30/go.mod h1:qQtIBl5OVMfmeQkz8HaVyh5DzFmmFXyvK27UgIgOr4c= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.28 h1:bkRyG4a929RCnpVSTvLM2j/T4ls015ZhhYApbmYs15s= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.28/go.mod h1:jj7znCIg05jXlaGBlFMGP8+7UN3VtCkRBG2spnmRQkU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.29 h1:IiDolu/eLmuB18DRZibj77n1hHQT7z12jnGO7Ze3pLc= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.29/go.mod h1:fDbkK4o7fpPXWn8YAPmTieAMuB9mk/VgvW64uaUqxd4= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.3 h1:dBL3StFxHtpBzJJ/mNEsjXVgfO+7jR0dAIEwLqMapEA= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.3/go.mod h1:f1QyiAsvIv4B49DmCqrhlXqyaR+0IxMmyX+1P+AnzOM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.4 h1:hx4WksB0NRQ9utR+2c3gEGzl6uKj3eM6PMQ6tN3lgXs= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.4/go.mod h1:JniVpqvw90sVjNqanGLufrVapWySL28fhBlYgl96Q/w= github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI= -github.com/aws/aws-sdk-go-v2/service/s3 v1.36.0 h1:lEmQ1XSD9qLk+NZXbgvLJI/IiTz7OIR2TYUTFH25EI4= -github.com/aws/aws-sdk-go-v2/service/s3 v1.36.0/go.mod h1:aVbf0sko/TsLWHx30c/uVu7c62+0EAJ3vbxaJga0xCw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.37.0 h1:PalLOEGZ/4XfQxpGZFTLaoJSmPoybnqJYotaIZEf/Rg= +github.com/aws/aws-sdk-go-v2/service/s3 v1.37.0/go.mod h1:PwyKKVL0cNkC37QwLcrhyeCrAk+5bY8O2ou7USyAS2A= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.12.12 h1:nneMBM2p79PGWBQovYO/6Xnc2ryRMw3InnDJq1FHkSY= -github.com/aws/aws-sdk-go-v2/service/sso v1.12.12/go.mod h1:HuCOxYsF21eKrerARYO6HapNeh9GBNq7fius2AcwodY= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.12 h1:2qTR7IFk7/0IN/adSFhYu9Xthr0zVFTgBrmPldILn80= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.12/go.mod h1:E4VrHCPzmVB/KFXtqBGKb3c8zpbNBgKe3fisDNLAW5w= +github.com/aws/aws-sdk-go-v2/service/sso v1.12.13 h1:sWDv7cMITPcZ21QdreULwxOOAmE05JjEsT6fCDtDA9k= +github.com/aws/aws-sdk-go-v2/service/sso v1.12.13/go.mod h1:DfX0sWuT46KpcqbMhJ9QWtxAIP1VozkDWf8VAkByjYY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.13 h1:BFubHS/xN5bjl818QaroN6mQdjneYQ+AOx44KNXlyH4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.13/go.mod h1:BzqsVVFduubEmzrVtUFQQIQdFqvUItF8XUq2EnS8Wog= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.19.2 h1:XFJ2Z6sNUUcAz9poj+245DMkrHE4h2j5I9/xD50RHfE= -github.com/aws/aws-sdk-go-v2/service/sts v1.19.2/go.mod h1:dp0yLPsLBOi++WTxzCjA/oZqi6NPIhoR+uF7GeMU9eg= +github.com/aws/aws-sdk-go-v2/service/sts v1.19.3 h1:e5mnydVdCVWxP+5rPAGi2PYxC7u2OZgH1ypC114H04U= +github.com/aws/aws-sdk-go-v2/service/sts v1.19.3/go.mod h1:yVGZA1CPkmUhBdA039jXNJJG7/6t+G+EBWmFq23xqnY= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= @@ -290,8 +290,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.2.5 h1:UR4rDjcgpgEnqpIEvki github.com/googleapis/enterprise-certificate-proxy v0.2.5/go.mod h1:RxW0N9901Cko1VOCW3SXCpWP+mlIEkk2tP7jnHy9a3w= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/googleapis/gax-go/v2 v2.11.0 h1:9V9PWXEsWnPpQhu/PeQIkS4eGzMlTLGgt80cUUI8Ki4= -github.com/googleapis/gax-go/v2 v2.11.0/go.mod h1:DxmR61SGKkGLa2xigwuZIQpkCI2S5iydzRfb3peWZJI= +github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= +github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= @@ -661,8 +661,8 @@ google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsb google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.130.0 h1:A50ujooa1h9iizvfzA4rrJr2B7uRmWexwbekQ2+5FPQ= -google.golang.org/api v0.130.0/go.mod h1:J/LCJMYSDFvAVREGCbrESb53n4++NMBDetSHGL5I5RY= +google.golang.org/api v0.131.0 h1:AcgWS2edQ4chVEt/SxgDKubVu/9/idCJy00tBGuGB4M= +google.golang.org/api v0.131.0/go.mod h1:7vtkbKv2REjJbxmHSkBTBQ5LUGvPdAqjjvt84XAfhpA= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -689,12 +689,12 @@ google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfG google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc h1:8DyZCyvI8mE1IdLy/60bS+52xfymkE72wv1asokgtao= -google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230629202037-9506855d4529 h1:DEH99RbiLZhMxrpEJCZ0A+wdTe0EOgou/poSLx9vWf4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230629202037-9506855d4529/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= +google.golang.org/genproto v0.0.0-20230629202037-9506855d4529 h1:9JucMWR7sPvCxUFd6UsOUNmA5kCcWOfORaT3tpAsKQs= +google.golang.org/genproto v0.0.0-20230629202037-9506855d4529/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= +google.golang.org/genproto/googleapis/api v0.0.0-20230629202037-9506855d4529 h1:s5YSX+ZH5b5vS9rnpGymvIyMpLRJizowqDlOuyjXnTk= +google.golang.org/genproto/googleapis/api v0.0.0-20230629202037-9506855d4529/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130 h1:2FZP5XuJY9zQyGM5N0rtovnoXjiMUEIUMvw0m9wlpLc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130/go.mod h1:8mL13HKkDa+IuJ8yruA3ci0q+0vsUz4m//+ottjwS5o= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/sdks/go/examples/snippets/04transforms.go b/sdks/go/examples/snippets/04transforms.go index bb21e9e317e4..e0ff23351135 100644 --- a/sdks/go/examples/snippets/04transforms.go +++ b/sdks/go/examples/snippets/04transforms.go @@ -27,6 +27,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange" "github.com/apache/beam/sdks/v2/go/pkg/beam/register" @@ -548,58 +549,114 @@ func contains(s []string, e string) bool { return false } -// TODO(https://github.com/apache/beam/issues/22737): Update state_and_timers to a good example to demonstrate both state and timers. -// Rename this to bag_state and update the bag state example in the programming guide at that point. // [START state_and_timers] +// stateAndTimersFn is an example stateful DoFn with state and a timer. +type stateAndTimersFn struct { + Buffer1 state.Bag[string] + Buffer2 state.Bag[int64] + Watermark timers.EventTime +} + +func (s *stateAndTimersFn) ProcessElement(sp state.Provider, tp timers.Provider, w beam.Window, key string, value int64, emit func(string, int64)) error { + // ... handle processing elements here, set a callback timer... + + // Read all the data from Buffer1 in this window. + vals, ok, err := s.Buffer1.Read(sp) + if err != nil { + return err + } + if ok && s.shouldClearBuffer(vals) { + // clear the buffer data if required conditions are met. + s.Buffer1.Clear(sp) + } + + // Add the value to Buffer2. + s.Buffer2.Add(sp, value) + + if s.allConditionsMet() { + // Clear the timer if certain condition met and you don't want to trigger + // the callback method. + s.Watermark.Clear(tp) + } + + emit(key, value) + + return nil +} + +func (s *stateAndTimersFn) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emit func(string, int64)) error { + // Window and key parameters are really useful especially for debugging issues. + switch timer.Family { + case s.Watermark.Family: + // timer expired, emit a different signal + emit(key, -1) + } + return nil +} + +func (s *stateAndTimersFn) shouldClearBuffer([]string) bool { + // some business logic + return false +} + +func (s *stateAndTimersFn) allConditionsMet() bool { + // other business logic + return true +} + +// [END state_and_timers] + +// [START bag_state] + // bagStateFn only emits words that haven't been seen type bagStateFn struct { - bag state.Bag[string] + Bag state.Bag[string] } -func (s *bagStateFn) ProcessElement(p state.Provider, book string, word string, emitWords func(string)) error { +func (s *bagStateFn) ProcessElement(p state.Provider, book, word string, emitWords func(string)) error { // Get all values we've written to this bag state in this window. - vals, ok, err := s.bag.Read(p) + vals, ok, err := s.Bag.Read(p) if err != nil { return err } if !ok || !contains(vals, word) { emitWords(word) - s.bag.Add(p, word) + s.Bag.Add(p, word) } if len(vals) > 10000 { // Example of clearing and starting again with an empty bag - s.bag.Clear(p) + s.Bag.Clear(p) } return nil } -// [END state_and_timers] +// [END bag_state] // [START value_state] // valueStateFn keeps track of the number of elements seen. type valueStateFn struct { - val state.Value[int] + Val state.Value[int] } func (s *valueStateFn) ProcessElement(p state.Provider, book string, word string, emitWords func(string)) error { // Get the value stored in our state - val, ok, err := s.val.Read(p) + val, ok, err := s.Val.Read(p) if err != nil { return err } if !ok { - s.val.Write(p, 1) + s.Val.Write(p, 1) } else { - s.val.Write(p, val+1) + s.Val.Write(p, val+1) } if val > 10000 { // Example of clearing and starting again with an empty bag - s.val.Clear(p) + s.Val.Clear(p) } return nil @@ -620,7 +677,7 @@ func (m MyCustomType) FromBytes(_ []byte) MyCustomType { // [START value_state_coder] type valueStateDoFn struct { - val state.Value[MyCustomType] + Val state.Value[MyCustomType] } func encode(m MyCustomType) []byte { @@ -644,40 +701,422 @@ type combineFn struct{} // combiningStateFn keeps track of the number of elements seen. type combiningStateFn struct { // types are the types of the accumulator, input, and output respectively - val state.Combining[int, int, int] + Val state.Combining[int, int, int] } func (s *combiningStateFn) ProcessElement(p state.Provider, book string, word string, emitWords func(string)) error { // Get the value stored in our state - val, _, err := s.val.Read(p) + val, _, err := s.Val.Read(p) if err != nil { return err } - s.val.Add(p, 1) + s.Val.Add(p, 1) if val > 10000 { // Example of clearing and starting again with an empty bag - s.val.Clear(p) + s.Val.Clear(p) } return nil } -func main() { +func combineState(s beam.Scope, input beam.PCollection) beam.PCollection { // ... // CombineFn param can be a simple fn like this or a structural CombineFn cFn := state.MakeCombiningState[int, int, int]("stateKey", func(a, b int) int { return a + b }) + combined := beam.ParDo(s, combiningStateFn{Val: cFn}, input) + // ... // [END combining_state] - fmt.Print(cFn) + return combined +} + +// [START event_time_timer] + +type eventTimerDoFn struct { + State state.Value[int64] + Timer timers.EventTime +} + +func (fn *eventTimerDoFn) ProcessElement(ts beam.EventTime, sp state.Provider, tp timers.Provider, book, word string, emitWords func(string)) { + // ... + + // Set an event-time timer to the element timestamp. + fn.Timer.Set(tp, ts.ToTime()) + + // ... +} + +func (fn *eventTimerDoFn) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emitWords func(string)) { + switch timer.Family { + case fn.Timer.Family: + // process callback for this timer + } +} + +func AddEventTimeDoFn(s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &eventTimerDoFn{ + // Timers are given family names so their callbacks can be handled independantly. + Timer: timers.InEventTime("processWatermark"), + State: state.MakeValueState[int64]("latest"), + }, in) +} + +// [END event_time_timer] + +// [START processing_time_timer] + +type processingTimerDoFn struct { + Timer timers.ProcessingTime +} + +func (fn *processingTimerDoFn) ProcessElement(sp state.Provider, tp timers.Provider, book, word string, emitWords func(string)) { + // ... + + // Set a timer to go off 30 seconds in the future. + fn.Timer.Set(tp, time.Now().Add(30*time.Second)) + + // ... +} + +func (fn *processingTimerDoFn) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emitWords func(string)) { + switch timer.Family { + case fn.Timer.Family: + // process callback for this timer + } +} + +func AddProcessingTimeDoFn(s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &processingTimerDoFn{ + // Timers are given family names so their callbacks can be handled independantly. + Timer: timers.InProcessingTime("timer"), + }, in) +} + +// [END processing_time_timer] + +// [START dynamic_timer_tags] + +type hasAction interface { + Action() string +} + +type dynamicTagsDoFn[V hasAction] struct { + Timer timers.EventTime +} + +func (fn *dynamicTagsDoFn[V]) ProcessElement(ts beam.EventTime, tp timers.Provider, key string, value V, emitWords func(string)) { + // ... + + // Set a timer to go off 30 seconds in the future. + fn.Timer.Set(tp, ts.ToTime(), timers.WithTag(value.Action())) + + // ... +} + +func (fn *dynamicTagsDoFn[V]) OnTimer(tp timers.Provider, w beam.Window, key string, timer timers.Context, emitWords func(string)) { + switch timer.Family { + case fn.Timer.Family: + tag := timer.Tag // Do something with fired tag + _ = tag + } +} + +func AddDynamicTimerTagsDoFn[V hasAction](s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &dynamicTagsDoFn[V]{ + Timer: timers.InEventTime("actionTimers"), + }, in) +} + +// [END dynamic_timer_tags] + +// [START timer_output_timestamps_bad] + +type badTimerOutputTimestampsFn[V any] struct { + ElementBag state.Bag[V] + TimerSet state.Value[bool] + OutputState timers.ProcessingTime +} + +func (fn *badTimerOutputTimestampsFn[V]) ProcessElement(sp state.Provider, tp timers.Provider, key string, value V, emit func(string)) error { + // Add the current element to the bag for this key. + if err := fn.ElementBag.Add(sp, value); err != nil { + return err + } + set, _, err := fn.TimerSet.Read(sp) + if err != nil { + return err + } + if !set { + fn.OutputState.Set(tp, time.Now().Add(1*time.Minute)) + fn.TimerSet.Write(sp, true) + } + return nil +} + +func (fn *badTimerOutputTimestampsFn[V]) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emit func(string)) error { + switch timer.Family { + case fn.OutputState.Family: + vs, _, err := fn.ElementBag.Read(sp) + if err != nil { + return err + } + for _, v := range vs { + // Output each element + emit(fmt.Sprintf("%v", v)) + } + + fn.ElementBag.Clear(sp) + // Note that the timer has now fired. + fn.TimerSet.Clear(sp) + } + return nil +} + +// [END timer_output_timestamps_bad] + +// [START timer_output_timestamps_good] + +type element[V any] struct { + Timestamp int64 + Value V +} + +type goodTimerOutputTimestampsFn[V any] struct { + ElementBag state.Bag[element[V]] // The bag of elements accumulated. + TimerTimerstamp state.Value[int64] // The timestamp of the timer set. + MinTimestampInBag state.Combining[int64, int64, int64] // The minimum timestamp stored in the bag. + OutputState timers.ProcessingTime // The timestamp of the timer. +} + +func (fn *goodTimerOutputTimestampsFn[V]) ProcessElement(et beam.EventTime, sp state.Provider, tp timers.Provider, key string, value V, emit func(beam.EventTime, string)) error { + // ... + // Add the current element to the bag for this key, and preserve the event time. + if err := fn.ElementBag.Add(sp, element[V]{Timestamp: et.Milliseconds(), Value: value}); err != nil { + return err + } + + // Keep track of the minimum element timestamp currently stored in the bag. + fn.MinTimestampInBag.Add(sp, et.Milliseconds()) + + // If the timer is already set, then reset it at the same time but with an updated output timestamp (otherwise + // we would keep resetting the timer to the future). If there is no timer set, then set one to expire in a minute. + ts, ok, _ := fn.TimerTimerstamp.Read(sp) + var tsToSet time.Time + if ok { + tsToSet = time.UnixMilli(ts) + } else { + tsToSet = time.Now().Add(1 * time.Minute) + } + + minTs, _, _ := fn.MinTimestampInBag.Read(sp) + outputTs := time.UnixMilli(minTs) + + // Setting the outputTimestamp to the minimum timestamp in the bag holds the watermark to that timestamp until the + // timer fires. This allows outputting all the elements with their timestamp. + fn.OutputState.Set(tp, tsToSet, timers.WithOutputTimestamp(outputTs)) + fn.TimerTimerstamp.Write(sp, tsToSet.UnixMilli()) + + return nil +} + +func (fn *goodTimerOutputTimestampsFn[V]) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emit func(beam.EventTime, string)) error { + switch timer.Family { + case fn.OutputState.Family: + vs, _, err := fn.ElementBag.Read(sp) + if err != nil { + return err + } + for _, v := range vs { + // Output each element with their timestamp + emit(beam.EventTime(v.Timestamp), fmt.Sprintf("%v", v.Value)) + } + + fn.ElementBag.Clear(sp) + // Note that the timer has now fired. + fn.TimerTimerstamp.Clear(sp) + } + return nil +} + +func AddTimedOutputBatching[V any](s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &goodTimerOutputTimestampsFn[V]{ + ElementBag: state.MakeBagState[element[V]]("elementBag"), + TimerTimerstamp: state.MakeValueState[int64]("timerTimestamp"), + MinTimestampInBag: state.MakeCombiningState[int64, int64, int64]("minTimestampInBag", func(a, b int64) int64 { + if a < b { + return a + } + return b + }), + OutputState: timers.InProcessingTime("outputState"), + }, in) +} + +// [END timer_output_timestamps_good] + +// updateState exists for example purposes only +func updateState(sp, state, k, v any) {} + +// [START timer_garbage_collection] + +type timerGarbageCollectionFn[V any] struct { + State state.Value[V] // The state for the key. + MaxTimestampInBag state.Combining[int64, int64, int64] // The maximum element timestamp seen so far. + GcTimer timers.EventTime // The timestamp of the timer. +} + +func (fn *timerGarbageCollectionFn[V]) ProcessElement(et beam.EventTime, sp state.Provider, tp timers.Provider, key string, value V, emit func(beam.EventTime, string)) { + updateState(sp, fn.State, key, value) + fn.MaxTimestampInBag.Add(sp, et.Milliseconds()) + + // Set the timer to be one hour after the maximum timestamp seen. This will keep overwriting the same timer, so + // as long as there is activity on this key the state will stay active. Once the key goes inactive for one hour's + // worth of event time (as measured by the watermark), then the gc timer will fire. + maxTs, _, _ := fn.MaxTimestampInBag.Read(sp) + expirationTime := time.UnixMilli(maxTs).Add(1 * time.Hour) + fn.GcTimer.Set(tp, expirationTime) } +func (fn *timerGarbageCollectionFn[V]) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emit func(beam.EventTime, string)) { + switch timer.Family { + case fn.GcTimer.Family: + // Clear all the state for the key + fn.State.Clear(sp) + fn.MaxTimestampInBag.Clear(sp) + } +} + +func AddTimerGarbageCollection[V any](s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &timerGarbageCollectionFn[V]{ + State: state.MakeValueState[V]("timerTimestamp"), + MaxTimestampInBag: state.MakeCombiningState[int64, int64, int64]("maxTimestampInBag", func(a, b int64) int64 { + if a > b { + return a + } + return b + }), + GcTimer: timers.InEventTime("gcTimer"), + }, in) +} + +// [END timer_garbage_collection] + +type Event struct{} + +func (*Event) isClick() bool { return false } + +// [START join_dofn_example] + +type JoinedEvent struct { + View, Click *Event +} + +type joinDoFn struct { + View state.Value[*Event] // Store the view event. + Click state.Value[*Event] // Store the click event. + + MaxTimestampSeen state.Combining[int64, int64, int64] // The maximum element timestamp seen so far. + GcTimer timers.EventTime // The timestamp of the timer. +} + +func (fn *joinDoFn) ProcessElement(et beam.EventTime, sp state.Provider, tp timers.Provider, key string, event *Event, emit func(JoinedEvent)) { + valueState := fn.View + if event.isClick() { + valueState = fn.Click + } + valueState.Write(sp, event) + + view, _, _ := fn.View.Read(sp) + click, _, _ := fn.Click.Read(sp) + if view != nil && click != nil { + emit(JoinedEvent{View: view, Click: click}) + fn.clearState(sp) + return + } + + fn.MaxTimestampSeen.Add(sp, et.Milliseconds()) + expTs, _, _ := fn.MaxTimestampSeen.Read(sp) + fn.GcTimer.Set(tp, time.UnixMilli(expTs).Add(1*time.Hour)) +} + +func (fn *joinDoFn) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context, emit func(beam.EventTime, string)) { + switch timer.Family { + case fn.GcTimer.Family: + fn.clearState(sp) + } +} + +func (fn *joinDoFn) clearState(sp state.Provider) { + fn.View.Clear(sp) + fn.Click.Clear(sp) + fn.MaxTimestampSeen.Clear(sp) +} + +func AddJoinDoFn(s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &joinDoFn{ + View: state.MakeValueState[*Event]("view"), + Click: state.MakeValueState[*Event]("click"), + MaxTimestampSeen: state.MakeCombiningState[int64, int64, int64]("maxTimestampSeen", func(a, b int64) int64 { + if a > b { + return a + } + return b + }), + GcTimer: timers.InEventTime("gcTimer"), + }, in) +} + +// [END join_dofn_example] + +func sendRpc(...any) {} + +// [START batching_dofn_example] + +type bufferDoFn[V any] struct { + Elements state.Bag[V] // Store the elements buffered so far. + IsTimerSet state.Value[bool] // Keep track of whether a timer is currently set or not. + + OutputElements timers.ProcessingTime // The processing-time timer user to publish the RPC. +} + +func (fn *bufferDoFn[V]) ProcessElement(et beam.EventTime, sp state.Provider, tp timers.Provider, key string, value V) { + fn.Elements.Add(sp, value) + + isSet, _, _ := fn.IsTimerSet.Read(sp) + if !isSet { + fn.OutputElements.Set(tp, time.Now().Add(10*time.Second)) + fn.IsTimerSet.Write(sp, true) + } +} + +func (fn *bufferDoFn[V]) OnTimer(sp state.Provider, tp timers.Provider, w beam.Window, key string, timer timers.Context) { + switch timer.Family { + case fn.OutputElements.Family: + elements, _, _ := fn.Elements.Read(sp) + sendRpc(elements) + fn.Elements.Clear(sp) + fn.IsTimerSet.Clear(sp) + } +} + +func AddBufferDoFn[V any](s beam.Scope, in beam.PCollection) beam.PCollection { + return beam.ParDo(s, &bufferDoFn[V]{ + Elements: state.MakeBagState[V]("elements"), + IsTimerSet: state.MakeValueState[bool]("isTimerSet"), + + OutputElements: timers.InProcessingTime("outputElements"), + }, in) +} + +// [END batching_dofn_example] + type statefulDoFn struct { - s state.Value[int] + S state.Value[int] } func statefulPipeline() beam.PCollection { @@ -686,7 +1125,9 @@ func statefulPipeline() beam.PCollection { // [START windowed_state] - items := beam.ParDo(s, statefulDoFn{}, elements) + items := beam.ParDo(s, statefulDoFn{ + S: state.MakeValueState[int]("S"), + }, elements) out := beam.WindowInto(s, window.NewFixedWindows(24*time.Hour), items) // [END windowed_state] diff --git a/sdks/go/pkg/beam/core/runtime/exec/pardo.go b/sdks/go/pkg/beam/core/runtime/exec/pardo.go index 212ff53b6dd8..b93835264507 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/pardo.go +++ b/sdks/go/pkg/beam/core/runtime/exec/pardo.go @@ -552,5 +552,5 @@ func (n *ParDo) fail(err error) error { } func (n *ParDo) String() string { - return fmt.Sprintf("ParDo[%v] Out:%v Sig: %v", path.Base(n.Fn.Name()), IDs(n.Out...), n.Fn.ProcessElementFn().Fn.Type()) + return fmt.Sprintf("ParDo[%v] Out:%v Sig: %v, SideInputs: %v", path.Base(n.Fn.Name()), IDs(n.Out...), n.Fn.ProcessElementFn().Fn.Type(), n.Side) } diff --git a/sdks/go/pkg/beam/core/runtime/exec/sideinput.go b/sdks/go/pkg/beam/core/runtime/exec/sideinput.go index 1af4e71689b1..c3ceeee5d8b8 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sideinput.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sideinput.go @@ -140,7 +140,7 @@ func (s *sideInputAdapter) NewKeyedIterable(ctx context.Context, reader StateRea } func (s *sideInputAdapter) String() string { - return fmt.Sprintf("SideInputAdapter[%v, %v]", s.sid, s.sideInputID) + return fmt.Sprintf("SideInputAdapter[%v, %v] - Coder %v", s.sid, s.sideInputID, s.c) } // proxyReStream is a simple wrapper of an open function. diff --git a/sdks/go/pkg/beam/core/runtime/harness/harness.go b/sdks/go/pkg/beam/core/runtime/harness/harness.go index 3f0e82c8265f..d97b6b7db079 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/harness.go +++ b/sdks/go/pkg/beam/core/runtime/harness/harness.go @@ -393,7 +393,8 @@ func (c *control) handleInstruction(ctx context.Context, req *fnpb.InstructionRe c.mu.Unlock() if err != nil { - return fail(ctx, instID, "Failed: %v", err) + c.failed[instID] = err + return fail(ctx, instID, "ProcessBundle failed: %v", err) } tokens := msg.GetCacheTokens() @@ -427,6 +428,7 @@ func (c *control) handleInstruction(ctx context.Context, req *fnpb.InstructionRe // If there was an error on the data channel reads, fail this bundle // since we may have had a short read. c.failed[instID] = dataError + err = dataError } else { // Non failure plans should either be moved to the finalized state // or to plans so they can be re-used. diff --git a/sdks/go/pkg/beam/io/avroio/avroio.go b/sdks/go/pkg/beam/io/avroio/avroio.go index b282c4aa3047..b00c6d2eea00 100644 --- a/sdks/go/pkg/beam/io/avroio/avroio.go +++ b/sdks/go/pkg/beam/io/avroio/avroio.go @@ -25,13 +25,16 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/linkedin/goavro/v2" ) func init() { - beam.RegisterFunction(expandFn) - beam.RegisterType(reflect.TypeOf((*avroReadFn)(nil)).Elem()) - beam.RegisterType(reflect.TypeOf((*writeAvroFn)(nil)).Elem()) + register.Function3x1(expandFn) + register.DoFn3x1[context.Context, string, func(beam.X), error]((*avroReadFn)(nil)) + register.DoFn3x1[context.Context, int, func(*string) bool, error]((*writeAvroFn)(nil)) + register.Emitter1[beam.X]() + register.Iter1[string]() } // Read reads a set of files and returns lines as a PCollection diff --git a/sdks/go/pkg/beam/io/avroio/avroio_test.go b/sdks/go/pkg/beam/io/avroio/avroio_test.go index 8e2894133bfe..403a81875557 100644 --- a/sdks/go/pkg/beam/io/avroio/avroio_test.go +++ b/sdks/go/pkg/beam/io/avroio/avroio_test.go @@ -25,12 +25,30 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" _ "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/local" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/linkedin/goavro/v2" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + +func init() { + beam.RegisterType(reflect.TypeOf((*Tweet)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*NullableFloat64)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*NullableString)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*NullableTweet)(nil)).Elem()) + register.Function2x0(toJSONString) +} + +func toJSONString(user TwitterUser, emit func(string)) { + b, _ := json.Marshal(user) + emit(string(b)) +} + type Tweet struct { Stamp int64 `json:"timestamp"` Tweet string `json:"tweet"` @@ -126,16 +144,11 @@ func TestWrite(t *testing.T) { avroFile := "./user.avro" testUsername := "user1" testInfo := "userInfo" - p, s, sequence := ptest.CreateList([]string{testUsername}) - format := beam.ParDo(s, func(username string, emit func(string)) { - newUser := TwitterUser{ - User: username, - Info: testInfo, - } - - b, _ := json.Marshal(newUser) - emit(string(b)) - }, sequence) + p, s, sequence := ptest.CreateList([]TwitterUser{{ + User: testUsername, + Info: testInfo, + }}) + format := beam.ParDo(s, toJSONString, sequence) Write(s, avroFile, userSchema, format) t.Cleanup(func() { os.Remove(avroFile) diff --git a/sdks/go/pkg/beam/io/bigqueryio/bigquery.go b/sdks/go/pkg/beam/io/bigqueryio/bigquery.go index 4ca64be87800..12beacd4a016 100644 --- a/sdks/go/pkg/beam/io/bigqueryio/bigquery.go +++ b/sdks/go/pkg/beam/io/bigqueryio/bigquery.go @@ -337,10 +337,9 @@ func (f *writeFn) ProcessElement(ctx context.Context, _ int, iter func(*beam.X) } data = nil size = writeOverheadBytes - } else { - data = append(data, reflect.ValueOf(val.(any))) - size += current } + data = append(data, reflect.ValueOf(val.(any))) + size += current } if len(data) == 0 { return nil diff --git a/sdks/go/pkg/beam/io/datastoreio/datastore_test.go b/sdks/go/pkg/beam/io/datastoreio/datastore_test.go index a18891bfd14d..345eaa2a59ef 100644 --- a/sdks/go/pkg/beam/io/datastoreio/datastore_test.go +++ b/sdks/go/pkg/beam/io/datastoreio/datastore_test.go @@ -29,6 +29,17 @@ import ( "google.golang.org/api/option" ) +func TestMain(m *testing.M) { + // TODO(https://github.com/apache/beam/issues/27549): Make tests compatible with portable runners. + // To work on this change, replace call with `ptest.Main(m)` + ptest.MainWithDefault(m, "direct") +} + +func init() { + beam.RegisterType(reflect.TypeOf((*Foo)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*Bar)(nil)).Elem()) +} + // fake client type implements datastoreio.clientType type fakeClient struct { runCounter int @@ -75,7 +86,7 @@ func Test_query(t *testing.T) { } itemType := reflect.TypeOf(tc.v) - itemKey := runtime.RegisterType(itemType) + itemKey, _ := runtime.TypeKey(itemType) p, s := beam.NewPipelineWithRoot() query(s, "project", "Item", tc.shard, itemType, itemKey, newClient) @@ -93,7 +104,12 @@ func Test_query(t *testing.T) { } } +// Baz is intentionally unregistered. +type Baz struct { +} + func Test_query_Bad(t *testing.T) { + fooKey, _ := runtime.TypeKey(reflect.TypeOf(Foo{})) testCases := []struct { v any itemType reflect.Type @@ -103,8 +119,8 @@ func Test_query_Bad(t *testing.T) { }{ // mismatch typeKey parameter { - Foo{}, - reflect.TypeOf(Foo{}), + Baz{}, + reflect.TypeOf(Baz{}), "MismatchType", "No type registered MismatchType", nil, @@ -113,7 +129,7 @@ func Test_query_Bad(t *testing.T) { { Foo{}, reflect.TypeOf(Foo{}), - runtime.RegisterType(reflect.TypeOf(Foo{})), + fooKey, "fake client error", errors.New("fake client error"), }, diff --git a/sdks/go/pkg/beam/io/spannerio/common.go b/sdks/go/pkg/beam/io/spannerio/common.go index 04cc2154a604..743a70d2fcff 100644 --- a/sdks/go/pkg/beam/io/spannerio/common.go +++ b/sdks/go/pkg/beam/io/spannerio/common.go @@ -18,9 +18,10 @@ package spannerio import ( - "cloud.google.com/go/spanner" "context" "fmt" + + "cloud.google.com/go/spanner" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" "google.golang.org/grpc" @@ -28,9 +29,9 @@ import ( ) type spannerFn struct { - Database string `json:"database"` // Database is the spanner connection string - endpoint string // Override spanner endpoint in tests - client *spanner.Client // Spanner Client + Database string `json:"database"` // Database is the spanner connection string + TestEndpoint string // Optional endpoint override for local testing. Not required for production pipelines. + client *spanner.Client // Spanner Client } func newSpannerFn(db string) spannerFn { @@ -48,9 +49,9 @@ func (f *spannerFn) Setup(ctx context.Context) error { var opts []option.ClientOption // Append emulator options assuming endpoint is local (for testing). - if f.endpoint != "" { + if f.TestEndpoint != "" { opts = []option.ClientOption{ - option.WithEndpoint(f.endpoint), + option.WithEndpoint(f.TestEndpoint), option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), option.WithoutAuthentication(), internaloption.SkipDialSettingsValidation(), diff --git a/sdks/go/pkg/beam/io/spannerio/read_test.go b/sdks/go/pkg/beam/io/spannerio/read_test.go index 1a7705b1aca2..7e1a65d0fe8a 100644 --- a/sdks/go/pkg/beam/io/spannerio/read_test.go +++ b/sdks/go/pkg/beam/io/spannerio/read_test.go @@ -27,6 +27,10 @@ import ( spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + func TestRead(t *testing.T) { ctx := context.Background() @@ -102,7 +106,7 @@ func TestRead(t *testing.T) { p, s := beam.NewPipelineWithRoot() fn := newQueryFn(testCase.database, "SELECT * from "+testCase.table, reflect.TypeOf(TestDto{}), queryOptions{}) - fn.endpoint = srv.Addr + fn.TestEndpoint = srv.Addr imp := beam.Impulse(s) rows := beam.ParDo(s, fn, imp, beam.TypeDefinition{Var: beam.XType, T: reflect.TypeOf(TestDto{})}) diff --git a/sdks/go/pkg/beam/io/spannerio/write_test.go b/sdks/go/pkg/beam/io/spannerio/write_test.go index f273315ba119..3c2c1f591519 100644 --- a/sdks/go/pkg/beam/io/spannerio/write_test.go +++ b/sdks/go/pkg/beam/io/spannerio/write_test.go @@ -17,12 +17,12 @@ package spannerio import ( "context" - spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio" "testing" "cloud.google.com/go/spanner" "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" + spannertest "github.com/apache/beam/sdks/v2/go/test/integration/io/spannerio" "google.golang.org/api/iterator" ) @@ -77,7 +77,7 @@ func TestWrite(t *testing.T) { p, s, col := ptest.CreateList(testCase.rows) fn := newWriteFn(testCase.database, testCase.table, col.Type().Type()) - fn.endpoint = srv.Addr + fn.TestEndpoint = srv.Addr beam.ParDo0(s, fn, col) diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index 89fececea108..95ad2e562d4c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -706,8 +706,15 @@ func (ss *stageState) bundleReady(em *ElementManager) (mtime.Time, bool) { } ready := true for _, side := range ss.sides { - pID := em.pcolParents[side] - parent := em.stages[pID] + pID, ok := em.pcolParents[side] + // These panics indicate pre-process/stage construction problems. + if !ok { + panic(fmt.Sprintf("stage[%v] no parent ID for side input %v", ss.ID, side)) + } + parent, ok := em.stages[pID] + if !ok { + panic(fmt.Sprintf("stage[%v] no parent for side input %v, with parent ID %v", ss.ID, side, pID)) + } ow := parent.OutputWatermark() if upstreamW > ow { ready = false diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index ea7b09c84413..cd8ab7943ce5 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -160,6 +160,6 @@ func (j *Job) Done() { // Failed indicates that the job completed unsuccessfully. func (j *Job) Failed(err error) { - j.sendState(jobpb.JobState_FAILED) j.failureErr = err + j.sendState(jobpb.JobState_FAILED) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index f65d2eb070f7..0c16b5eb34f4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -186,9 +186,20 @@ func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.Jo for { for (curMsg >= job.maxMsg || len(job.msgs) == 0) && curState > job.stateIdx { switch state { - case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_FAILED, jobpb.JobState_UPDATED: + case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_UPDATED: // Reached terminal state. return nil + case jobpb.JobState_FAILED: + // Ensure we send an error message with the cause of the job failure. + stream.Send(&jobpb.JobMessagesResponse{ + Response: &jobpb.JobMessagesResponse_MessageResponse{ + MessageResponse: &jobpb.JobMessage{ + MessageText: job.failureErr.Error(), + Importance: jobpb.JobMessage_JOB_MESSAGE_ERROR, + }, + }, + }) + return nil } job.streamCond.Wait() select { // Quit out if the external connection is done. diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index 7dbf8cf87e77..44f9c1e9d281 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -116,7 +116,11 @@ progress: progTick.Stop() break progress // exit progress loop on close. case <-progTick.C: - resp := b.Progress(wk) + resp, err := b.Progress(wk) + if err != nil { + slog.Debug("SDK Error from progress, aborting progress", "bundle", rb, "error", err.Error()) + break progress + } index, unknownIDs := j.ContributeTentativeMetrics(resp) if len(unknownIDs) > 0 { md := wk.MonitoringMetadata(unknownIDs) @@ -125,7 +129,11 @@ progress: slog.Debug("progress report", "bundle", rb, "index", index) // Progress for the bundle hasn't advanced. Try splitting. if previousIndex == index && !splitsDone { - sr := b.Split(wk, 0.5 /* fraction of remainder */, nil /* allowed splits */) + sr, err := b.Split(wk, 0.5 /* fraction of remainder */, nil /* allowed splits */) + if err != nil { + slog.Debug("SDK Error from split, aborting splits", "bundle", rb, "error", err.Error()) + break progress + } if sr.GetChannelSplits() == nil { slog.Warn("split failed", "bundle", rb) splitsDone = true @@ -164,8 +172,8 @@ progress: // Bundle has failed, fail the job. // TODO add retries & clean up this logic. Channels are closed by the "runner" transforms. if !ok && b.Error != "" { - slog.Error("job failed", "error", b.Error, "bundle", rb, "job", j) - j.Failed(fmt.Errorf("bundle failed: %v", b.Error)) + slog.Error("job failed", "bundle", rb, "job", j) + j.Failed(fmt.Errorf("%v", b.Error)) return } @@ -245,31 +253,54 @@ func buildStage(s *stage, tid string, t *pipepb.PTransform, comps *pipepb.Compon } var inputInfo engine.PColInfo var sides []string + localIdReplacements := map[string]string{} + globalIDReplacements := map[string]string{} for local, global := range t.GetInputs() { + if _, ok := sis[local]; ok { + col := comps.GetPcollections()[global] + oCID := col.GetCoderId() + nCID := lpUnknownCoders(oCID, coders, comps.GetCoders()) + + sides = append(sides, global) + if oCID != nCID { + // Add a synthetic PCollection set with the new coder. + newGlobal := global + "_prismside" + comps.GetPcollections()[newGlobal] = &pipepb.PCollection{ + DisplayData: col.GetDisplayData(), + UniqueName: col.GetUniqueName(), + CoderId: nCID, + IsBounded: col.GetIsBounded(), + WindowingStrategyId: col.WindowingStrategyId, + } + localIdReplacements[local] = newGlobal + globalIDReplacements[newGlobal] = global + } + continue + } // This id is directly used for the source, but this also copies // coders used by side inputs to the coders map for the bundle, so // needs to be run for every ID. wInCid := makeWindowedValueCoder(global, comps, coders) - _, ok := sis[local] - if ok { - sides = append(sides, global) - } else { - // this is the main input - transforms[s.inputTransformID] = sourceTransform(s.inputTransformID, portFor(wInCid, wk), global) - col := comps.GetPcollections()[global] - ed := collectionPullDecoder(col.GetCoderId(), coders, comps) - wDec, wEnc := getWindowValueCoders(comps, col, coders) - inputInfo = engine.PColInfo{ - GlobalID: global, - WDec: wDec, - WEnc: wEnc, - EDec: ed, - } + + // this is the main input + transforms[s.inputTransformID] = sourceTransform(s.inputTransformID, portFor(wInCid, wk), global) + col := comps.GetPcollections()[global] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + inputInfo = engine.PColInfo{ + GlobalID: global, + WDec: wDec, + WEnc: wEnc, + EDec: ed, } // We need to process all inputs to ensure we have all input coders, so we must continue. } + // Update side inputs to point to new PCollection with any replaced coders. + for l, g := range localIdReplacements { + t.GetInputs()[l] = g + } - prepareSides, err := handleSideInputs(t, comps, coders, wk) + prepareSides, err := handleSideInputs(t, comps, coders, wk, globalIDReplacements) if err != nil { slog.Error("buildStage: handleSideInputs", err, slog.String("transformID", tid)) panic(err) @@ -322,7 +353,7 @@ func buildStage(s *stage, tid string, t *pipepb.PTransform, comps *pipepb.Compon } // handleSideInputs ensures appropriate coders are available to the bundle, and prepares a function to stage the data. -func handleSideInputs(t *pipepb.PTransform, comps *pipepb.Components, coders map[string]*pipepb.Coder, wk *worker.W) (func(b *worker.B, tid string, watermark mtime.Time), error) { +func handleSideInputs(t *pipepb.PTransform, comps *pipepb.Components, coders map[string]*pipepb.Coder, wk *worker.W, replacements map[string]string) (func(b *worker.B, tid string, watermark mtime.Time), error) { sis, err := getSideInputs(t) if err != nil { return nil, err @@ -335,6 +366,11 @@ func handleSideInputs(t *pipepb.PTransform, comps *pipepb.Components, coders map if !ok { continue // This is the main input. } + // Use the old global ID as the identifier for the data storage + // This matches what we do in the rest of the stage layer. + if oldGlobal, ok := replacements[global]; ok { + global = oldGlobal + } // this is a side input switch si.GetAccessPattern().GetUrn() { diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 58cc813d7108..30515fa6f6e8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -16,6 +16,7 @@ package worker import ( + "fmt" "sync/atomic" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" @@ -144,18 +145,22 @@ func (b *B) Cleanup(wk *W) { wk.mu.Unlock() } -func (b *B) Progress(wk *W) *fnpb.ProcessBundleProgressResponse { - return wk.sendInstruction(&fnpb.InstructionRequest{ +func (b *B) Progress(wk *W) (*fnpb.ProcessBundleProgressResponse, error) { + resp := wk.sendInstruction(&fnpb.InstructionRequest{ Request: &fnpb.InstructionRequest_ProcessBundleProgress{ ProcessBundleProgress: &fnpb.ProcessBundleProgressRequest{ InstructionId: b.InstID, }, }, - }).GetProcessBundleProgress() + }) + if resp.GetError() != "" { + return nil, fmt.Errorf("progress[%v] error from SDK: %v", b.InstID, resp.GetError()) + } + return resp.GetProcessBundleProgress(), nil } -func (b *B) Split(wk *W, fraction float64, allowedSplits []int64) *fnpb.ProcessBundleSplitResponse { - return wk.sendInstruction(&fnpb.InstructionRequest{ +func (b *B) Split(wk *W, fraction float64, allowedSplits []int64) (*fnpb.ProcessBundleSplitResponse, error) { + resp := wk.sendInstruction(&fnpb.InstructionRequest{ Request: &fnpb.InstructionRequest_ProcessBundleSplit{ ProcessBundleSplit: &fnpb.ProcessBundleSplitRequest{ InstructionId: b.InstID, @@ -168,5 +173,9 @@ func (b *B) Split(wk *W, fraction float64, allowedSplits []int64) *fnpb.ProcessB }, }, }, - }).GetProcessBundleSplit() + }) + if resp.GetError() != "" { + return nil, fmt.Errorf("split[%v] error from SDK: %v", b.InstID, resp.GetError()) + } + return resp.GetProcessBundleSplit(), nil } diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 9767dec068fe..80bdadc51626 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -256,11 +256,7 @@ func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { // TODO: Do more than assume these are ProcessBundleResponses. wk.mu.Lock() if b, ok := wk.activeInstructions[resp.GetInstructionId()]; ok { - // TODO. Better pipeline error handling. - if resp.Error != "" { - slog.LogAttrs(ctrl.Context(), slog.LevelError, "ctrl.Recv pipeline error", - slog.String("error", resp.GetError())) - } + // Error is handled in the resonse handler. b.Respond(resp) } else { slog.Debug("ctrl.Recv: %v", resp) @@ -327,12 +323,20 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { } }() - for req := range wk.DataReqs { - if err := data.Send(req); err != nil { - slog.LogAttrs(context.TODO(), slog.LevelDebug, "data.Send error", slog.Any("error", err)) + for { + select { + case req, ok := <-wk.DataReqs: + if !ok { + return nil + } + if err := data.Send(req); err != nil { + slog.LogAttrs(context.TODO(), slog.LevelDebug, "data.Send error", slog.Any("error", err)) + } + case <-data.Context().Done(): + slog.Debug("Data context canceled") + return data.Context().Err() } } - return nil } // State relays elements and timer bytes to SDKs and back again, coordinated via diff --git a/sdks/go/pkg/beam/runners/universal/runnerlib/execute.go b/sdks/go/pkg/beam/runners/universal/runnerlib/execute.go index 68db9b0ee76a..5b49d2f94739 100644 --- a/sdks/go/pkg/beam/runners/universal/runnerlib/execute.go +++ b/sdks/go/pkg/beam/runners/universal/runnerlib/execute.go @@ -41,7 +41,7 @@ func Execute(ctx context.Context, p *pipepb.Pipeline, endpoint string, opt *JobO presult := &universalPipelineResult{} bin := opt.Worker - if bin == "" { + if bin == "" && !opt.Loopback { if self, ok := IsWorkerCompatibleBinary(); ok { bin = self log.Infof(ctx, "Using running binary as worker binary: '%v'", bin) @@ -56,6 +56,11 @@ func Execute(ctx context.Context, p *pipepb.Pipeline, endpoint string, opt *JobO bin = worker } + } else if opt.Loopback { + // TODO(https://github.com/apache/beam/issues/27569: determine the canonical location for Beam temp files. + // In loopback mode, the binary is unused, so we can avoid an unnecessary compile step. + f, _ := os.CreateTemp(os.TempDir(), "beamloopbackworker-*") + bin = f.Name() } else { log.Infof(ctx, "Using specified worker binary: '%v'", bin) } diff --git a/sdks/go/pkg/beam/runners/universal/runnerlib/job.go b/sdks/go/pkg/beam/runners/universal/runnerlib/job.go index daa6896da406..5752b33892bb 100644 --- a/sdks/go/pkg/beam/runners/universal/runnerlib/job.go +++ b/sdks/go/pkg/beam/runners/universal/runnerlib/job.go @@ -39,15 +39,17 @@ type JobOptions struct { // Experiments are additional experiments. Experiments []string - // TODO(herohde) 3/17/2018: add further parametrization as needed - // Worker is the worker binary override. Worker string - // RetainDocker is an option to pass to the runner. + // RetainDocker is an option to pass to the runner indicating the docker containers should be cached. RetainDocker bool + // Indicates a limit on parallelism the runner should impose. Parallelism int + + // Loopback indicates this job is running in loopback mode and will reconnect to the local process. + Loopback bool } // Prepare prepares a job to the given job service. It returns the preparation id @@ -101,10 +103,17 @@ func WaitForCompletion(ctx context.Context, client jobpb.JobServiceClient, jobID return errors.Wrap(err, "failed to get job stream") } + mostRecentError := errors.New("") + var errReceived, jobFailed bool + for { msg, err := stream.Recv() if err != nil { if err == io.EOF { + if jobFailed { + // Connection finished with a failed status, so produce what we have. + return errors.Errorf("job %v failed:\n%w", jobID, mostRecentError) + } return nil } return err @@ -120,7 +129,11 @@ func WaitForCompletion(ctx context.Context, client jobpb.JobServiceClient, jobID case jobpb.JobState_DONE, jobpb.JobState_CANCELLED: return nil case jobpb.JobState_FAILED: - return errors.Errorf("job %v failed", jobID) + jobFailed = true + if errReceived { + return errors.Errorf("job %v failed:\n%w", jobID, mostRecentError) + } + // Otherwise, wait for at least one error log from the runner, or the connection to close. } case msg.GetMessageResponse() != nil: @@ -129,6 +142,15 @@ func WaitForCompletion(ctx context.Context, client jobpb.JobServiceClient, jobID text := fmt.Sprintf("%v (%v): %v", resp.GetTime(), resp.GetMessageId(), resp.GetMessageText()) log.Output(ctx, messageSeverity(resp.GetImportance()), 1, text) + if resp.GetImportance() >= jobpb.JobMessage_JOB_MESSAGE_ERROR { + errReceived = true + mostRecentError = errors.New(resp.GetMessageText()) + + if jobFailed { + return errors.Errorf("job %v failed:\n%w", jobID, mostRecentError) + } + } + default: return errors.Errorf("unexpected job update: %v", proto.MarshalTextString(msg)) } diff --git a/sdks/go/pkg/beam/runners/universal/universal.go b/sdks/go/pkg/beam/runners/universal/universal.go index 299a64acdd69..8af9e91e1e15 100644 --- a/sdks/go/pkg/beam/runners/universal/universal.go +++ b/sdks/go/pkg/beam/runners/universal/universal.go @@ -101,6 +101,7 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) Worker: *jobopts.WorkerBinary, RetainDocker: *jobopts.RetainDockerContainers, Parallelism: *jobopts.Parallelism, + Loopback: jobopts.IsLoopback(), } return runnerlib.Execute(ctx, pipeline, endpoint, opt, *jobopts.Async) } diff --git a/sdks/go/pkg/beam/testing/passert/equals_test.go b/sdks/go/pkg/beam/testing/passert/equals_test.go index b0ddeae8d6f7..a8a5c835f8ff 100644 --- a/sdks/go/pkg/beam/testing/passert/equals_test.go +++ b/sdks/go/pkg/beam/testing/passert/equals_test.go @@ -182,10 +182,12 @@ func ExampleEqualsList_mismatch() { EqualsList(s, col, list) err := ptest.Run(p) err = unwrapError(err) - fmt.Println(err) + + // Process error for cleaner example output, demonstrating the diff. + processedErr := strings.SplitAfter(err.Error(), "/passert.failIfBadEntries] failed:") + fmt.Println(processedErr[1]) // Output: - // DoFn[UID:1, PID:passert.failIfBadEntries, Name: github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert.failIfBadEntries] failed: // actual PCollection does not match expected values // ========= // 2 correct entries (present in both) diff --git a/sdks/go/pkg/beam/testing/passert/floats.go b/sdks/go/pkg/beam/testing/passert/floats.go index 727c313820b7..f71e55090838 100644 --- a/sdks/go/pkg/beam/testing/passert/floats.go +++ b/sdks/go/pkg/beam/testing/passert/floats.go @@ -24,8 +24,16 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" ) +func init() { + register.DoFn2x1[[]byte, func(*beam.T) bool, error]((*boundsFn)(nil)) + register.DoFn3x1[[]byte, func(*beam.T) bool, func(*beam.T) bool, error]((*thresholdFn)(nil)) + register.Emitter1[beam.T]() + register.Iter1[beam.T]() +} + // EqualsFloat calls into TryEqualsFloat, checkong that two PCollections of non-complex // numeric types are equal, with each element being within a provided threshold of an // expected value. Panics if TryEqualsFloat returns an error. @@ -110,11 +118,11 @@ func AllWithinBounds(s beam.Scope, col beam.PCollection, lo, hi float64) { lo, hi = hi, lo } s = s.Scope(fmt.Sprintf("passert.AllWithinBounds([%v, %v])", lo, hi)) - beam.ParDo0(s, &boundsFn{lo: lo, hi: hi}, beam.Impulse(s), beam.SideInput{Input: col}) + beam.ParDo0(s, &boundsFn{Lo: lo, Hi: hi}, beam.Impulse(s), beam.SideInput{Input: col}) } type boundsFn struct { - lo, hi float64 + Lo, Hi float64 } func (f *boundsFn) ProcessElement(_ []byte, col func(*beam.T) bool) error { @@ -122,9 +130,9 @@ func (f *boundsFn) ProcessElement(_ []byte, col func(*beam.T) bool) error { var input beam.T for col(&input) { val := toFloat(input) - if val < f.lo { + if val < f.Lo { tooLow = append(tooLow, val) - } else if val > f.hi { + } else if val > f.Hi { tooHigh = append(tooHigh, val) } } @@ -134,11 +142,11 @@ func (f *boundsFn) ProcessElement(_ []byte, col func(*beam.T) bool) error { errorStrings := []string{} if len(tooLow) != 0 { sort.Float64s(tooLow) - errorStrings = append(errorStrings, fmt.Sprintf("values below minimum value %v: %v", f.lo, tooLow)) + errorStrings = append(errorStrings, fmt.Sprintf("values below minimum value %v: %v", f.Lo, tooLow)) } if len(tooHigh) != 0 { sort.Float64s(tooHigh) - errorStrings = append(errorStrings, fmt.Sprintf("values above maximum value %v: %v", f.hi, tooHigh)) + errorStrings = append(errorStrings, fmt.Sprintf("values above maximum value %v: %v", f.Hi, tooHigh)) } return errors.New(strings.Join(errorStrings, "\n")) } diff --git a/sdks/go/pkg/beam/testing/passert/passert.go b/sdks/go/pkg/beam/testing/passert/passert.go index 990d3c8c4d47..c4b0f490dafd 100644 --- a/sdks/go/pkg/beam/testing/passert/passert.go +++ b/sdks/go/pkg/beam/testing/passert/passert.go @@ -39,9 +39,13 @@ import ( func Diff(s beam.Scope, a, b beam.PCollection) (left, both, right beam.PCollection) { imp := beam.Impulse(s) - t := beam.ValidateNonCompositeType(a) - beam.ValidateNonCompositeType(b) - return beam.ParDo3(s, &diffFn{Type: beam.EncodedType{T: t.Type()}}, imp, beam.SideInput{Input: a}, beam.SideInput{Input: b}) + ta := beam.ValidateNonCompositeType(a) + tb := beam.ValidateNonCompositeType(b) + + if !typex.IsEqual(ta, tb) { + panic(fmt.Sprintf("passert.Diff input PColections don't have matching types: %v != %v", ta, tb)) + } + return beam.ParDo3(s, &diffFn{Type: beam.EncodedType{T: ta.Type()}}, imp, beam.SideInput{Input: a}, beam.SideInput{Input: b}) } // diffFn computes the symmetrical multi-set difference of 2 collections, under diff --git a/sdks/go/pkg/beam/testing/passert/passert_test.go b/sdks/go/pkg/beam/testing/passert/passert_test.go index 9524bc868ebb..d472f6883939 100644 --- a/sdks/go/pkg/beam/testing/passert/passert_test.go +++ b/sdks/go/pkg/beam/testing/passert/passert_test.go @@ -20,15 +20,30 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + +func isA(input string) bool { return input == "a" } +func isB(input string) bool { return input == "b" } +func lessThan13(input int) bool { return input < 13 } +func greaterThan13(input int) bool { return input > 13 } + +func init() { + register.Function1x1(isA) + register.Function1x1(isB) + register.Function1x1(lessThan13) + register.Function1x1(greaterThan13) +} + func TestTrue_string(t *testing.T) { p, s := beam.NewPipelineWithRoot() col := beam.Create(s, "a", "a", "a") - True(s, col, func(input string) bool { - return input == "a" - }) + True(s, col, isA) if err := ptest.Run(p); err != nil { t.Errorf("Pipeline failed: %v", err) } @@ -37,9 +52,7 @@ func TestTrue_string(t *testing.T) { func TestTrue_numeric(t *testing.T) { p, s := beam.NewPipelineWithRoot() col := beam.Create(s, 3, 3, 6) - True(s, col, func(input int) bool { - return input < 13 - }) + True(s, col, lessThan13) if err := ptest.Run(p); err != nil { t.Errorf("Pipeline failed: %v", err) } @@ -48,9 +61,7 @@ func TestTrue_numeric(t *testing.T) { func TestTrue_bad(t *testing.T) { p, s := beam.NewPipelineWithRoot() col := beam.Create(s, "a", "a", "b") - True(s, col, func(input string) bool { - return input == "a" - }) + True(s, col, isA) err := ptest.Run(p) if err == nil { t.Fatalf("Pipeline succeeded when it should haved failed, got %v", err) @@ -63,9 +74,7 @@ func TestTrue_bad(t *testing.T) { func TestFalse_string(t *testing.T) { p, s := beam.NewPipelineWithRoot() col := beam.Create(s, "a", "a", "a") - False(s, col, func(input string) bool { - return input == "b" - }) + False(s, col, isB) if err := ptest.Run(p); err != nil { t.Errorf("Pipeline failed: %v", err) } @@ -74,9 +83,7 @@ func TestFalse_string(t *testing.T) { func TestFalse_numeric(t *testing.T) { p, s := beam.NewPipelineWithRoot() col := beam.Create(s, 3, 3, 6) - False(s, col, func(input int) bool { - return input > 13 - }) + False(s, col, greaterThan13) if err := ptest.Run(p); err != nil { t.Errorf("Pipeline failed: %v", err) } @@ -85,9 +92,7 @@ func TestFalse_numeric(t *testing.T) { func TestFalse_bad(t *testing.T) { p, s := beam.NewPipelineWithRoot() col := beam.Create(s, "a", "a", "b") - False(s, col, func(input string) bool { - return input == "b" - }) + False(s, col, isB) err := ptest.Run(p) if err == nil { t.Fatalf("Pipeline succeeded when it should haved failed, got %v", err) diff --git a/sdks/go/pkg/beam/transforms/filter/filter.go b/sdks/go/pkg/beam/transforms/filter/filter.go index 913e7355c30d..997eec5eb4ef 100644 --- a/sdks/go/pkg/beam/transforms/filter/filter.go +++ b/sdks/go/pkg/beam/transforms/filter/filter.go @@ -21,11 +21,15 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" ) -//go:generate go install github.com/apache/beam/sdks/v2/go/cmd/starcgen -//go:generate starcgen --package=filter --identifiers=filterFn,mapFn,mergeFn -//go:generate go fmt +func init() { + register.DoFn2x0[beam.T, func(beam.T)]((*filterFn)(nil)) + register.Function1x2(mapFn) + register.Function2x1(mergeFn) + register.Emitter1[beam.T]() +} var ( sig = funcx.MakePredicate(beam.TType) // T -> bool diff --git a/sdks/go/pkg/beam/transforms/filter/filter.shims.go b/sdks/go/pkg/beam/transforms/filter/filter.shims.go deleted file mode 100644 index b0d18233ab18..000000000000 --- a/sdks/go/pkg/beam/transforms/filter/filter.shims.go +++ /dev/null @@ -1,201 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by starcgen. DO NOT EDIT. -// File: filter.shims.go - -package filter - -import ( - "context" - "reflect" - - // Library imports - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/schema" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" -) - -func init() { - runtime.RegisterFunction(mapFn) - runtime.RegisterFunction(mergeFn) - runtime.RegisterType(reflect.TypeOf((*filterFn)(nil)).Elem()) - schema.RegisterType(reflect.TypeOf((*filterFn)(nil)).Elem()) - reflectx.RegisterStructWrapper(reflect.TypeOf((*filterFn)(nil)).Elem(), wrapMakerFilterFn) - reflectx.RegisterFunc(reflect.TypeOf((*func(int, int) int)(nil)).Elem(), funcMakerIntIntГInt) - reflectx.RegisterFunc(reflect.TypeOf((*func(typex.T, func(typex.T)))(nil)).Elem(), funcMakerTypex۰TEmitTypex۰TГ) - reflectx.RegisterFunc(reflect.TypeOf((*func(typex.T) (typex.T, int))(nil)).Elem(), funcMakerTypex۰TГTypex۰TInt) - reflectx.RegisterFunc(reflect.TypeOf((*func())(nil)).Elem(), funcMakerГ) - exec.RegisterEmitter(reflect.TypeOf((*func(typex.T))(nil)).Elem(), emitMakerTypex۰T) -} - -func wrapMakerFilterFn(fn any) map[string]reflectx.Func { - dfn := fn.(*filterFn) - return map[string]reflectx.Func{ - "ProcessElement": reflectx.MakeFunc(func(a0 typex.T, a1 func(typex.T)) { dfn.ProcessElement(a0, a1) }), - "Setup": reflectx.MakeFunc(func() { dfn.Setup() }), - } -} - -type callerIntIntГInt struct { - fn func(int, int) int -} - -func funcMakerIntIntГInt(fn any) reflectx.Func { - f := fn.(func(int, int) int) - return &callerIntIntГInt{fn: f} -} - -func (c *callerIntIntГInt) Name() string { - return reflectx.FunctionName(c.fn) -} - -func (c *callerIntIntГInt) Type() reflect.Type { - return reflect.TypeOf(c.fn) -} - -func (c *callerIntIntГInt) Call(args []any) []any { - out0 := c.fn(args[0].(int), args[1].(int)) - return []any{out0} -} - -func (c *callerIntIntГInt) Call2x1(arg0, arg1 any) any { - return c.fn(arg0.(int), arg1.(int)) -} - -type callerTypex۰TEmitTypex۰TГ struct { - fn func(typex.T, func(typex.T)) -} - -func funcMakerTypex۰TEmitTypex۰TГ(fn any) reflectx.Func { - f := fn.(func(typex.T, func(typex.T))) - return &callerTypex۰TEmitTypex۰TГ{fn: f} -} - -func (c *callerTypex۰TEmitTypex۰TГ) Name() string { - return reflectx.FunctionName(c.fn) -} - -func (c *callerTypex۰TEmitTypex۰TГ) Type() reflect.Type { - return reflect.TypeOf(c.fn) -} - -func (c *callerTypex۰TEmitTypex۰TГ) Call(args []any) []any { - c.fn(args[0].(typex.T), args[1].(func(typex.T))) - return []any{} -} - -func (c *callerTypex۰TEmitTypex۰TГ) Call2x0(arg0, arg1 any) { - c.fn(arg0.(typex.T), arg1.(func(typex.T))) -} - -type callerTypex۰TГTypex۰TInt struct { - fn func(typex.T) (typex.T, int) -} - -func funcMakerTypex۰TГTypex۰TInt(fn any) reflectx.Func { - f := fn.(func(typex.T) (typex.T, int)) - return &callerTypex۰TГTypex۰TInt{fn: f} -} - -func (c *callerTypex۰TГTypex۰TInt) Name() string { - return reflectx.FunctionName(c.fn) -} - -func (c *callerTypex۰TГTypex۰TInt) Type() reflect.Type { - return reflect.TypeOf(c.fn) -} - -func (c *callerTypex۰TГTypex۰TInt) Call(args []any) []any { - out0, out1 := c.fn(args[0].(typex.T)) - return []any{out0, out1} -} - -func (c *callerTypex۰TГTypex۰TInt) Call1x2(arg0 any) (any, any) { - return c.fn(arg0.(typex.T)) -} - -type callerГ struct { - fn func() -} - -func funcMakerГ(fn any) reflectx.Func { - f := fn.(func()) - return &callerГ{fn: f} -} - -func (c *callerГ) Name() string { - return reflectx.FunctionName(c.fn) -} - -func (c *callerГ) Type() reflect.Type { - return reflect.TypeOf(c.fn) -} - -func (c *callerГ) Call(args []any) []any { - c.fn() - return []any{} -} - -func (c *callerГ) Call0x0() { - c.fn() -} - -type emitNative struct { - n exec.ElementProcessor - fn any - est *sdf.WatermarkEstimator - - ctx context.Context - ws []typex.Window - et typex.EventTime - value exec.FullValue -} - -func (e *emitNative) Init(ctx context.Context, ws []typex.Window, et typex.EventTime) error { - e.ctx = ctx - e.ws = ws - e.et = et - return nil -} - -func (e *emitNative) Value() any { - return e.fn -} - -func (e *emitNative) AttachEstimator(est *sdf.WatermarkEstimator) { - e.est = est -} - -func emitMakerTypex۰T(n exec.ElementProcessor) exec.ReusableEmitter { - ret := &emitNative{n: n} - ret.fn = ret.invokeTypex۰T - return ret -} - -func (e *emitNative) invokeTypex۰T(val typex.T) { - e.value = exec.FullValue{Windows: e.ws, Timestamp: e.et, Elm: val} - if e.est != nil { - (*e.est).(sdf.TimestampObservingEstimator).ObserveTimestamp(e.et.ToTime()) - } - if err := e.n.ProcessElement(e.ctx, &e.value); err != nil { - panic(err) - } -} - -// DO NOT MODIFY: GENERATED CODE diff --git a/sdks/go/pkg/beam/transforms/filter/filter_test.go b/sdks/go/pkg/beam/transforms/filter/filter_test.go index 9cc5a526af9c..96b4cbe12d79 100644 --- a/sdks/go/pkg/beam/transforms/filter/filter_test.go +++ b/sdks/go/pkg/beam/transforms/filter/filter_test.go @@ -18,11 +18,28 @@ package filter_test import ( "testing" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + +func init() { + register.Function1x1(alwaysTrue) + register.Function1x1(alwaysFalse) + register.Function1x1(isOne) + register.Function1x1(greaterThanOne) +} + +func alwaysTrue(a int) bool { return true } +func alwaysFalse(a int) bool { return false } +func isOne(a int) bool { return a == 1 } +func greaterThanOne(a int) bool { return a > 1 } + func TestInclude(t *testing.T) { tests := []struct { in []int diff --git a/sdks/go/pkg/beam/transforms/stats/count_test.go b/sdks/go/pkg/beam/transforms/stats/count_test.go index 23627a92f799..be6ce950e20a 100644 --- a/sdks/go/pkg/beam/transforms/stats/count_test.go +++ b/sdks/go/pkg/beam/transforms/stats/count_test.go @@ -20,10 +20,19 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" ) +func TestMain(m *testing.M) { + ptest.Main(m) +} + +func init() { + register.Function2x1(kvToCount) +} + type count struct { Elm int Count int diff --git a/sdks/go/pkg/beam/transforms/stats/max_test.go b/sdks/go/pkg/beam/transforms/stats/max_test.go index af817527dc91..531792e70f58 100644 --- a/sdks/go/pkg/beam/transforms/stats/max_test.go +++ b/sdks/go/pkg/beam/transforms/stats/max_test.go @@ -19,10 +19,16 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" ) +func init() { + register.Function2x1(kvToStudent) + register.Function1x2(studentToKV) +} + type student struct { Name string Grade float64 diff --git a/sdks/go/pkg/beam/transforms/stats/quantiles.go b/sdks/go/pkg/beam/transforms/stats/quantiles.go index 79a66b58e1f0..6d2baa8b5e99 100644 --- a/sdks/go/pkg/beam/transforms/stats/quantiles.go +++ b/sdks/go/pkg/beam/transforms/stats/quantiles.go @@ -31,6 +31,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" ) func init() { @@ -44,6 +45,9 @@ func init() { beam.RegisterType(reflect.TypeOf((*shardElementsFn)(nil)).Elem()) beam.RegisterCoder(compactorsType, encodeCompactors, decodeCompactors) beam.RegisterCoder(weightedElementType, encodeWeightedElement, decodeWeightedElement) + + register.Function1x2(fixedKey) + register.Function2x1(makeWeightedElement) } // Opts contains settings used to configure how approximate quantiles are computed. @@ -663,12 +667,14 @@ func makeWeightedElement(weight int, element beam.T) weightedElement { return weightedElement{weight: weight, element: element} } +func fixedKey(e beam.T) (int, beam.T) { return 1, e } + // ApproximateQuantiles computes approximate quantiles for the input PCollection. // // The output PCollection contains a single element: a list of numQuantiles - 1 elements approximately splitting up the input collection into numQuantiles separate quantiles. // For example, if numQuantiles = 2, the returned list would contain a single element such that approximately half of the input would be less than that element and half would be greater. func ApproximateQuantiles(s beam.Scope, pc beam.PCollection, less any, opts Opts) beam.PCollection { - return ApproximateWeightedQuantiles(s, beam.ParDo(s, func(e beam.T) (int, beam.T) { return 1, e }, pc), less, opts) + return ApproximateWeightedQuantiles(s, beam.ParDo(s, fixedKey, pc), less, opts) } // reduce takes a PCollection and returns a PCollection<*compactors>. The output PCollection may have at most shardSizes[len(shardSizes) - 1] compactors. diff --git a/sdks/go/pkg/beam/transforms/stats/quantiles_test.go b/sdks/go/pkg/beam/transforms/stats/quantiles_test.go index c03620d0b9b7..1e389eed128b 100644 --- a/sdks/go/pkg/beam/transforms/stats/quantiles_test.go +++ b/sdks/go/pkg/beam/transforms/stats/quantiles_test.go @@ -16,46 +16,19 @@ package stats import ( - "reflect" "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/google/go-cmp/cmp" ) func init() { - beam.RegisterFunction(weightedElementToKv) - - // In practice, this runs faster than plain reflection. - // TODO(https://github.com/apache/beam/issues/20271): Remove once collisions don't occur for starcgen over test code and an equivalent is generated for us. - reflectx.RegisterFunc(reflect.ValueOf(less).Type(), func(_ any) reflectx.Func { - return newIntLess() - }) -} - -type intLess struct { - name string - t reflect.Type -} - -func newIntLess() *intLess { - return &intLess{ - name: reflectx.FunctionName(reflect.ValueOf(less).Interface()), - t: reflect.ValueOf(less).Type(), - } -} - -func (i *intLess) Name() string { - return i.name -} -func (i *intLess) Type() reflect.Type { - return i.t -} -func (i *intLess) Call(args []any) []any { - return []any{args[0].(int) < args[1].(int)} + register.Function1x2(weightedElementToKv) + register.Function2x1(less) } func less(a, b int) bool { @@ -68,7 +41,7 @@ func TestLargeQuantiles(t *testing.T) { for i := 0; i < numElements; i++ { inputSlice = append(inputSlice, i) } - p, s, input, expected := ptest.CreateList2(inputSlice, [][]int{[]int{10006, 19973}}) + p, s, input, expected := ptest.CreateList2(inputSlice, [][]int{{10006, 19973}}) quantiles := ApproximateQuantiles(s, input, less, Opts{ K: 200, NumQuantiles: 3, @@ -85,7 +58,7 @@ func TestLargeQuantilesReversed(t *testing.T) { for i := numElements - 1; i >= 0; i-- { inputSlice = append(inputSlice, i) } - p, s, input, expected := ptest.CreateList2(inputSlice, [][]int{[]int{9985, 19959}}) + p, s, input, expected := ptest.CreateList2(inputSlice, [][]int{{9985, 19959}}) quantiles := ApproximateQuantiles(s, input, less, Opts{ K: 200, NumQuantiles: 3, @@ -103,8 +76,8 @@ func TestBasicQuantiles(t *testing.T) { Expected [][]int }{ {[]int{}, [][]int{}}, - {[]int{1}, [][]int{[]int{1}}}, - {[]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, [][]int{[]int{6, 13}}}, + {[]int{1}, [][]int{{1}}}, + {[]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, [][]int{{6, 13}}}, } for _, test := range tests { @@ -180,7 +153,7 @@ func TestMerging(t *testing.T) { K: 3, NumberOfCompactions: 1, Compactors: []compactor{{ - sorted: [][]beam.T{[]beam.T{1}, []beam.T{2}, []beam.T{3}}, + sorted: [][]beam.T{{1}, {2}, {3}}, unsorted: []beam.T{6, 5, 4}, capacity: 4, }}, @@ -191,7 +164,7 @@ func TestMerging(t *testing.T) { NumberOfCompactions: 1, Compactors: []compactor{ { - sorted: [][]beam.T{[]beam.T{7}, []beam.T{8}, []beam.T{9}}, + sorted: [][]beam.T{{7}, {8}, {9}}, unsorted: []beam.T{12, 11, 10}, capacity: 4}, }, @@ -205,7 +178,7 @@ func TestMerging(t *testing.T) { Compactors: []compactor{ {capacity: 4}, { - sorted: [][]beam.T{[]beam.T{1, 3, 5, 7, 9, 11}}, + sorted: [][]beam.T{{1, 3, 5, 7, 9, 11}}, capacity: 4, }, }, @@ -222,12 +195,12 @@ func TestCompactorsEncoding(t *testing.T) { Compactors: []compactor{ { capacity: 4, - sorted: [][]beam.T{[]beam.T{1, 2}}, + sorted: [][]beam.T{{1, 2}}, unsorted: []beam.T{3, 4}, }, { capacity: 4, - sorted: [][]beam.T{[]beam.T{5, 6}}, + sorted: [][]beam.T{{5, 6}}, unsorted: []beam.T{7, 8}, }, }, diff --git a/sdks/go/test/regression/lperror.go b/sdks/go/test/regression/lperror.go index 088f81d7a7cb..db327e588a58 100644 --- a/sdks/go/test/regression/lperror.go +++ b/sdks/go/test/regression/lperror.go @@ -22,8 +22,15 @@ import ( "sort" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" ) +func init() { + register.Function2x2(toFoo) + register.Iter1[*fruit]() + register.Function3x1(toID) +} + // REPRO found by https://github.com/zelliott type fruit struct { diff --git a/sdks/go/test/regression/pardo.go b/sdks/go/test/regression/pardo.go index 4b8fba7f9dd6..7dc28bff2db0 100644 --- a/sdks/go/test/regression/pardo.go +++ b/sdks/go/test/regression/pardo.go @@ -18,10 +18,22 @@ package regression import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" ) +func init() { + register.Function1x1(directFn) + register.Function2x0(emitFn) + register.Function3x0(emit2Fn) + register.Function2x1(mixedFn) + register.Function2x2(directCountFn) + register.Function3x1(emitCountFn) + register.Emitter1[int]() + register.Iter1[int]() +} + func directFn(elm int) int { return elm + 1 } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java index a112c8030d4d..e4efe9f30459 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; @@ -122,7 +123,7 @@ public class Create { * Otherwise, use {@link Create.Values#withCoder} to set the coder explicitly. */ public static Values of(Iterable elems) { - return new Values<>(elems, Optional.absent(), Optional.absent()); + return new Values<>(elems, Optional.absent(), Optional.absent(), false); } /** @@ -154,7 +155,7 @@ public static Values of(@Nullable T elem, @Nullable T... elems) { */ public static Values empty(Schema schema) { return new Values( - new ArrayList<>(), Optional.of(SchemaCoder.of(schema)), Optional.absent()); + new ArrayList<>(), Optional.of(SchemaCoder.of(schema)), Optional.absent(), false); } /** @@ -167,7 +168,7 @@ public static Values empty(Schema schema) { * the {@code Coder} is provided via the {@code coder} argument. */ public static Values empty(Coder coder) { - return new Values<>(new ArrayList<>(), Optional.of(coder), Optional.absent()); + return new Values<>(new ArrayList<>(), Optional.of(coder), Optional.absent(), false); } /** @@ -181,7 +182,7 @@ public static Values empty(Coder coder) { * must be registered for the class described in the {@code TypeDescriptor}. */ public static Values empty(TypeDescriptor type) { - return new Values<>(new ArrayList<>(), Optional.absent(), Optional.of(type)); + return new Values<>(new ArrayList<>(), Optional.absent(), Optional.of(type), false); } /** @@ -284,7 +285,7 @@ public static class Values extends PTransform> { *

Note that for {@link Create.Values} with no elements, the {@link VoidCoder} is used. */ public Values withCoder(Coder coder) { - return new Values<>(elems, Optional.of(coder), typeDescriptor); + return new Values<>(elems, Optional.of(coder), typeDescriptor, alwaysUseRead); } /** @@ -321,7 +322,11 @@ public Values withRowSchema(Schema schema) { *

Note that for {@link Create.Values} with no elements, the {@link VoidCoder} is used. */ public Values withType(TypeDescriptor type) { - return new Values<>(elems, coder, Optional.of(type)); + return new Values<>(elems, coder, Optional.of(type), alwaysUseRead); + } + + public Values alwaysUseRead() { + return new AlwaysUseRead<>(elems, coder, typeDescriptor); } public Iterable getElements() { @@ -362,6 +367,41 @@ public PCollection expand(PBegin input) { e); } try { + if (!alwaysUseRead) { + int numElements = Iterables.size(elems); + if (numElements == 0) { + return input + .apply(Impulse.create()) + .apply( + FlatMapElements.via( + new SimpleFunction>() { + @Override + public Iterable apply(byte[] input) { + return Collections.emptyList(); + } + })) + .setCoder(coder); + } else if (numElements == 1) { + final byte[] encodedElement = + CoderUtils.encodeToByteArray(coder, Iterables.getOnlyElement(elems)); + final Coder capturedCoder = coder; + return input + .apply(Impulse.create()) + .apply( + MapElements.via( + new SimpleFunction() { + @Override + public T apply(byte[] input) { + try { + return CoderUtils.decodeFromByteArray(capturedCoder, encodedElement); + } catch (CoderException exn) { + throw new RuntimeException(exn); + } + } + })) + .setCoder(coder); + } + } CreateSource source = CreateSource.fromIterable(elems, coder); return input.getPipeline().apply(Read.from(source)); } catch (IOException e) { @@ -381,6 +421,9 @@ public PCollection expand(PBegin input) { /** The value type. */ private final transient Optional> typeDescriptor; + /** Whether to unconditionally implement this via reading a CreateSource. */ + private final transient boolean alwaysUseRead; + /** * Constructs a {@code Create.Values} transform that produces a {@link PCollection} containing * the specified elements. @@ -388,10 +431,14 @@ public PCollection expand(PBegin input) { *

The arguments should not be modified after this is called. */ private Values( - Iterable elems, Optional> coder, Optional> typeDescriptor) { + Iterable elems, + Optional> coder, + Optional> typeDescriptor, + boolean alwaysUseRead) { this.elems = elems; this.coder = coder; this.typeDescriptor = typeDescriptor; + this.alwaysUseRead = alwaysUseRead; } @VisibleForTesting @@ -514,6 +561,14 @@ protected boolean advanceImpl() throws IOException { return true; } } + + /** A subclass to avoid getting re-matched. */ + private static class AlwaysUseRead extends Values { + private AlwaysUseRead( + Iterable elems, Optional> coder, Optional> typeDescriptor) { + super(elems, coder, typeDescriptor, true); + } + } } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NoopLock.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NoopLock.java index 0fc822987a6e..36454a125d67 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NoopLock.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NoopLock.java @@ -21,21 +21,19 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; -import javax.annotation.Nonnull; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A lock which can always be acquired. It should not be used when a proper lock is required, but it * is useful as a performance optimization when locking is not necessary but the code paths have to * be shared between the locking and the non-locking variant. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class NoopLock implements Lock, Serializable { - private static NoopLock instance; + private static @MonotonicNonNull NoopLock instance; - public static NoopLock get() { + public static @NonNull NoopLock get() { if (instance == null) { instance = new NoopLock(); } @@ -56,14 +54,13 @@ public boolean tryLock() { } @Override - public boolean tryLock(long time, @Nonnull TimeUnit unit) { + public boolean tryLock(long time, TimeUnit unit) { return true; } @Override public void unlock() {} - @Nonnull @Override public Condition newCondition() { throw new UnsupportedOperationException("Not implemented"); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java index aae85133cb9e..a71bd9168780 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java @@ -397,7 +397,7 @@ public void testPAssertEqualsSingletonFalseDefaultReasonString() throws Exceptio String message = thrown.getMessage(); - assertThat(message, containsString("Create.Values/Read(CreateSource)")); + assertThat(message, containsString("Create.Values/")); assertThat(message, containsString("Expected: <44>")); assertThat(message, containsString("but: was <42>")); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java index 130a81f7aa28..42eb4d756100 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java @@ -421,12 +421,17 @@ public TableDestination getTable(DestinationT destination) { } } - /** Returns the table schema for the destination. */ + /** + * Returns the table schema for the destination. If possible, will return the existing table + * schema. + */ @Override public @Nullable TableSchema getSchema(DestinationT destination) { TableDestination wrappedDestination = super.getTable(destination); @Nullable Table existingTable = getBigQueryTable(wrappedDestination.getTableReference()); - if (existingTable == null || existingTable.getSchema() == null) { + if (existingTable == null + || existingTable.getSchema() == null + || existingTable.getSchema().isEmpty()) { return super.getSchema(destination); } else { return existingTable.getSchema(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java index b8dbb9703c2e..d231d84aea28 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java @@ -48,7 +48,6 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.format.DateTimeParseException; -import java.time.temporal.ChronoUnit; import java.util.AbstractMap; import java.util.Collections; import java.util.List; @@ -825,23 +824,23 @@ private static void fieldDescriptorFromTableField( try { // '2011-12-03T10:15:30Z', '2011-12-03 10:15:30+05:00' // '2011-12-03 10:15:30 UTC', '2011-12-03T10:15:30 America/New_York' - return ChronoUnit.MICROS.between( - Instant.EPOCH, Instant.from(TIMESTAMP_FORMATTER.parse((String) value))); + Instant timestamp = Instant.from(TIMESTAMP_FORMATTER.parse((String) value)); + return toEpochMicros(timestamp); } catch (DateTimeException e) { try { // for backwards compatibility, default time zone is UTC for values with no time-zone // '2011-12-03T10:15:30' - return ChronoUnit.MICROS.between( - Instant.EPOCH, - Instant.from(TIMESTAMP_FORMATTER.withZone(ZoneOffset.UTC).parse((String) value))); + Instant timestamp = + Instant.from(TIMESTAMP_FORMATTER.withZone(ZoneOffset.UTC).parse((String) value)); + return toEpochMicros(timestamp); } catch (DateTimeParseException err) { // "12345667" - return ChronoUnit.MICROS.between( - Instant.EPOCH, Instant.ofEpochMilli(Long.parseLong((String) value))); + Instant timestamp = Instant.ofEpochMilli(Long.parseLong((String) value)); + return toEpochMicros(timestamp); } } } else if (value instanceof Instant) { - return ChronoUnit.MICROS.between(Instant.EPOCH, (Instant) value); + return toEpochMicros((Instant) value); } else if (value instanceof org.joda.time.Instant) { // joda instant precision is millisecond return ((org.joda.time.Instant) value).getMillis() * 1000L; @@ -972,6 +971,11 @@ private static void fieldDescriptorFromTableField( + schemaInformation.getType()); } + private static long toEpochMicros(Instant timestamp) { + // i.e 1970-01-01T00:01:01.000040Z: 61 * 1000_000L + 40000/1000 = 61000040 + return timestamp.getEpochSecond() * 1000_000L + timestamp.getNano() / 1000; + } + @VisibleForTesting public static TableRow tableRowFromMessage(Message message, boolean includeCdcColumns) { // TODO: Would be more correct to generate TableRows using setF. diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/UpdateSchemaDestination.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/UpdateSchemaDestination.java index 4d5717388313..d85268030aa9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/UpdateSchemaDestination.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/UpdateSchemaDestination.java @@ -271,7 +271,7 @@ private BigQueryHelpers.PendingJob startZeroLoadJob( try { destinationTable = datasetService.getTable(tableReference); if (destinationTable == null) { - return null; // no need to update schema ahead if table does not exists + return null; // no need to update schema ahead if table does not exist } } catch (IOException | InterruptedException e) { LOG.warn("Failed to get table {} with {}", tableReference, e.toString()); @@ -281,6 +281,7 @@ private BigQueryHelpers.PendingJob startZeroLoadJob( // or when destination schema is null (the write will set the schema) // or when provided schema is null (e.g. when using CREATE_NEVER disposition) if (destinationTable.getSchema() == null + || destinationTable.getSchema().isEmpty() || destinationTable.getSchema().equals(schema) || schema == null) { return null; @@ -322,7 +323,7 @@ private BigQueryHelpers.PendingJob startZeroLoadJob( jobService.startLoadJob( jobRef, loadConfig, new ByteArrayContent("text/plain", new byte[0])); } catch (IOException | InterruptedException e) { - LOG.warn("Load job {} failed with {}", jobRef, e.toString()); + LOG.warn("Schema update load job {} failed with {}", jobRef, e.toString()); throw new RuntimeException(e); } return null; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java index f663f60f09bb..befb22ca6dc2 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.io.gcp.pubsub; +import com.google.api.client.util.Clock; import com.google.auto.value.AutoValue; import javax.annotation.Nullable; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubTestClient.PubsubTestClientFactory; import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; /** * Configuration for reading from Pub/Sub. @@ -33,137 +35,57 @@ @DefaultSchema(AutoValueSchema.class) @AutoValue public abstract class PubsubReadSchemaTransformConfiguration { + @SchemaFieldDescription( + "The name of the topic to consume data from. If a topic is specified, " + + " will create a new subscription for that topic and start consuming from that point. " + + "Either a topic or a subscription must be provided. " + + "Format: projects/${PROJECT}/topics/${TOPIC}") + public abstract @Nullable String getTopic(); + + @SchemaFieldDescription( + "The name of the subscription to consume data. " + + "Either a topic or subscription must be provided. " + + "Format: projects/${PROJECT}/subscriptions/${SUBSCRIPTION}") + public abstract @Nullable String getSubscription(); + + @SchemaFieldDescription( + "The encoding format for the data stored in Pubsub. Valid options are: " + + PubsubReadSchemaTransformProvider.VALID_FORMATS_STR) + public abstract String getFormat(); // AVRO, JSON + + @SchemaFieldDescription( + "The schema in which the data is encoded in the Pubsub topic. " + + "For AVRO data, this is a schema defined with AVRO schema syntax " + + "(https://avro.apache.org/docs/1.10.2/spec.html#schemas). " + + "For JSON data, this is a schema defined with JSON-schema syntax (https://json-schema.org/).") + public abstract String getSchema(); + + // Used for testing only. + public abstract @Nullable PubsubTestClientFactory getClientFactory(); + + // Used for testing only. + public abstract @Nullable Clock getClock(); - /** Instantiates a {@link PubsubReadSchemaTransformConfiguration.Builder}. */ public static Builder builder() { return new AutoValue_PubsubReadSchemaTransformConfiguration.Builder(); } - /** The expected schema of the Pub/Sub message. */ - public abstract Schema getDataSchema(); - - /** - * The Pub/Sub topic path to write failures. - * - *

See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the dead - * letter queue topic string. - */ - @Nullable - public abstract String getDeadLetterQueue(); - - /** - * The expected format of the Pub/Sub message. - * - *

Used to retrieve the {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer} from - * {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers}. - */ - @Nullable - public abstract String getFormat(); - - /** Used by the ProtoPayloadSerializerProvider when serializing from a Pub/Sub message. */ - @Nullable - public abstract String getProtoClass(); - - /** - * The subscription from which to read Pub/Sub messages. - * - *

See {@link PubsubIO.PubsubSubscription#fromPath(String)} for more details on the format of - * the subscription string. - */ - @Nullable - public abstract String getSubscription(); - - /** Used by the ThriftPayloadSerializerProvider when serializing from a Pub/Sub message. */ - @Nullable - public abstract String getThriftClass(); - - /** Used by the ThriftPayloadSerializerProvider when serializing from a Pub/Sub message. */ - @Nullable - public abstract String getThriftProtocolFactoryClass(); - - /** - * When reading from Cloud Pub/Sub where record timestamps are provided as Pub/Sub message - * attributes, specifies the name of the attribute that contains the timestamp. - */ - @Nullable - public abstract String getTimestampAttribute(); - - /** - * When reading from Cloud Pub/Sub where unique record identifiers are provided as Pub/Sub message - * attributes, specifies the name of the attribute containing the unique identifier. - */ - @Nullable - public abstract String getIdAttribute(); - - /** - * The topic from which to read Pub/Sub messages. - * - *

See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the - * topic string. - */ - @Nullable - public abstract String getTopic(); - @AutoValue.Builder public abstract static class Builder { + public abstract Builder setTopic(@Nullable String topic); - /** The expected schema of the Pub/Sub message. */ - public abstract Builder setDataSchema(Schema value); - - /** - * The Pub/Sub topic path to write failures. - * - *

See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the - * dead letter queue topic string. - */ - public abstract Builder setDeadLetterQueue(String value); - - /** - * The expected format of the Pub/Sub message. - * - *

Used to retrieve the {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer} - * from {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers}. - */ - public abstract Builder setFormat(String value); - - /** Used by the ProtoPayloadSerializerProvider when serializing from a Pub/Sub message. */ - public abstract Builder setProtoClass(String value); - - /** - * The subscription from which to read Pub/Sub messages. - * - *

See {@link PubsubIO.PubsubSubscription#fromPath(String)} for more details on the format of - * the subscription string. - */ - public abstract Builder setSubscription(String value); - - /** Used by the ThriftPayloadSerializerProvider when serializing from a Pub/Sub message. */ - public abstract Builder setThriftClass(String value); + public abstract Builder setSubscription(@Nullable String subscription); - /** Used by the ThriftPayloadSerializerProvider when serializing from a Pub/Sub message. */ - public abstract Builder setThriftProtocolFactoryClass(String value); + public abstract Builder setFormat(String format); - /** - * When reading from Cloud Pub/Sub where record timestamps are provided as Pub/Sub message - * attributes, specifies the name of the attribute that contains the timestamp. - */ - public abstract Builder setTimestampAttribute(String value); + public abstract Builder setSchema(String schema); - /** - * When reading from Cloud Pub/Sub where unique record identifiers are provided as Pub/Sub - * message attributes, specifies the name of the attribute containing the unique identifier. - */ - public abstract Builder setIdAttribute(String value); + // Used for testing only. + public abstract Builder setClientFactory(@Nullable PubsubTestClientFactory clientFactory); - /** - * The topic from which to read Pub/Sub messages. - * - *

See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the - * topic string. - */ - public abstract Builder setTopic(String value); + // Used for testing only. + public abstract Builder setClock(@Nullable Clock clock); - /** Builds a {@link PubsubReadSchemaTransformConfiguration} instance. */ public abstract PubsubReadSchemaTransformConfiguration build(); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java index cec07dafef4f..c0e8880d0287 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java @@ -17,23 +17,39 @@ */ package org.apache.beam.sdk.io.gcp.pubsub; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageToRow.DLQ_TAG; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageToRow.MAIN_TAG; - import com.google.api.client.util.Clock; import com.google.auto.service.AutoService; +import java.io.Serializable; +import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers; +import java.util.Objects; +import java.util.Set; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubTestClient.PubsubTestClientFactory; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.utils.JsonUtils; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.FinishBundle; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -43,196 +59,191 @@ * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam * repository. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Internal @AutoService(SchemaTransformProvider.class) public class PubsubReadSchemaTransformProvider extends TypedSchemaTransformProvider { - static final String OUTPUT_TAG = "OUTPUT"; - /** Returns the expected class of the configuration. */ - @Override - protected Class configurationClass() { - return PubsubReadSchemaTransformConfiguration.class; - } + public static final String VALID_FORMATS_STR = "AVRO,JSON"; + public static final Set VALID_DATA_FORMATS = + Sets.newHashSet(VALID_FORMATS_STR.split(",")); - /** Returns the expected {@link SchemaTransform} of the configuration. */ - @Override - protected SchemaTransform from(PubsubReadSchemaTransformConfiguration configuration) { - PubsubMessageToRow toRowTransform = - PubsubSchemaTransformMessageToRowFactory.from(configuration).buildMessageToRow(); - return new PubsubReadSchemaTransform(configuration, toRowTransform); - } + public static final TupleTag OUTPUT_TAG = new TupleTag() {}; + public static final TupleTag ERROR_TAG = new TupleTag() {}; + public static final Schema ERROR_SCHEMA = + Schema.builder().addStringField("error").addNullableByteArrayField("row").build(); - /** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */ @Override - public String identifier() { - return "beam:schematransform:org.apache.beam:pubsub_read:v1"; - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since - * no input is expected, this returns an empty list. - */ - @Override - public List inputCollectionNames() { - return Collections.emptyList(); + public Class configurationClass() { + return PubsubReadSchemaTransformConfiguration.class; } - /** - * Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. Since - * a single output is expected, this returns a list with a single name. - */ @Override - public List outputCollectionNames() { - return Collections.singletonList(OUTPUT_TAG); - } - - /** - * An implementation of {@link SchemaTransform} for Pub/Sub reads configured using {@link - * PubsubReadSchemaTransformConfiguration}. - */ - static class PubsubReadSchemaTransform extends SchemaTransform { + public SchemaTransform from(PubsubReadSchemaTransformConfiguration configuration) { + if (configuration.getSubscription() == null && configuration.getTopic() == null) { + throw new IllegalArgumentException( + "To read from Pubsub, a subscription name or a topic name must be provided"); + } - private final PubsubReadSchemaTransformConfiguration configuration; - private final PubsubMessageToRow pubsubMessageToRow; + if (configuration.getSubscription() != null && configuration.getTopic() != null) { + throw new IllegalArgumentException( + "To read from Pubsub, a subscription name or a topic name must be provided. Not both."); + } - private PubsubClient.PubsubClientFactory clientFactory; + if ((Strings.isNullOrEmpty(configuration.getSchema()) + && !Strings.isNullOrEmpty(configuration.getFormat())) + || (!Strings.isNullOrEmpty(configuration.getSchema()) + && Strings.isNullOrEmpty(configuration.getFormat()))) { + throw new IllegalArgumentException( + "A schema was provided without a data format (or viceversa). Please provide " + + "both of these parameters to read from Pubsub, or if you would like to use the Pubsub schema service," + + " please leave both of these blank."); + } - private Clock clock; + Schema beamSchema; + SerializableFunction valueMapper; - private PubsubReadSchemaTransform( - PubsubReadSchemaTransformConfiguration configuration, - PubsubMessageToRow pubsubMessageToRow) { - this.configuration = configuration; - this.pubsubMessageToRow = pubsubMessageToRow; + if (!VALID_DATA_FORMATS.contains(configuration.getFormat())) { + throw new IllegalArgumentException( + String.format( + "Format %s not supported. Only supported formats are %s", + configuration.getFormat(), VALID_FORMATS_STR)); } - - /** - * Sets the {@link PubsubClient.PubsubClientFactory}. - * - *

Used for testing. - */ - void setClientFactory(PubsubClient.PubsubClientFactory value) { - this.clientFactory = value; + beamSchema = + Objects.equals(configuration.getFormat(), "JSON") + ? JsonUtils.beamSchemaFromJsonSchema(configuration.getSchema()) + : AvroUtils.toBeamSchema( + new org.apache.avro.Schema.Parser().parse(configuration.getSchema())); + valueMapper = + Objects.equals(configuration.getFormat(), "JSON") + ? JsonUtils.getJsonBytesToRowFunction(beamSchema) + : AvroUtils.getAvroBytesToRowFunction(beamSchema); + + PubsubReadSchemaTransform transform = + new PubsubReadSchemaTransform( + configuration.getTopic(), configuration.getSubscription(), beamSchema, valueMapper); + + if (configuration.getClientFactory() != null) { + transform.setClientFactory(configuration.getClientFactory()); + } + if (configuration.getClock() != null) { + transform.setClock(configuration.getClock()); } - /** - * Sets the {@link Clock}. - * - *

Used for testing. - */ - void setClock(Clock clock) { - this.clock = clock; + return transform; + } + + private static class PubsubReadSchemaTransform extends SchemaTransform implements Serializable { + final Schema beamSchema; + final SerializableFunction valueMapper; + final @Nullable String topic; + final @Nullable String subscription; + @Nullable PubsubTestClientFactory clientFactory; + @Nullable Clock clock; + + PubsubReadSchemaTransform( + @Nullable String topic, + @Nullable String subscription, + Schema beamSchema, + SerializableFunction valueMapper) { + this.topic = topic; + this.subscription = subscription; + this.beamSchema = beamSchema; + this.valueMapper = valueMapper; } - /** Validates the {@link PubsubReadSchemaTransformConfiguration}. */ - @Override - public void validate(@Nullable PipelineOptions options) { - if (configuration.getSubscription() == null && configuration.getTopic() == null) { - throw new IllegalArgumentException( - String.format( - "%s needs to set either the topic or the subscription", - PubsubReadSchemaTransformConfiguration.class)); - } + private static class ErrorCounterFn extends DoFn { + private Counter pubsubErrorCounter; + private Long errorsInBundle = 0L; + private SerializableFunction valueMapper; - if (configuration.getSubscription() != null && configuration.getTopic() != null) { - throw new IllegalArgumentException( - String.format( - "%s should not set both the topic or the subscription", - PubsubReadSchemaTransformConfiguration.class)); + ErrorCounterFn(String name, SerializableFunction valueMapper) { + this.pubsubErrorCounter = Metrics.counter(PubsubReadSchemaTransformProvider.class, name); + this.valueMapper = valueMapper; } - try { - PayloadSerializers.getSerializer( - configuration.getFormat(), configuration.getDataSchema(), new HashMap<>()); - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException( - String.format( - "Invalid %s, no serializer provider exists for format `%s`", - PubsubReadSchemaTransformConfiguration.class, configuration.getFormat())); + @ProcessElement + public void process(@DoFn.Element PubsubMessage message, MultiOutputReceiver receiver) { + + try { + receiver.get(OUTPUT_TAG).output(valueMapper.apply(message.getPayload())); + } catch (Exception e) { + errorsInBundle += 1; + receiver + .get(ERROR_TAG) + .output( + Row.withSchema(ERROR_SCHEMA) + .addValues(e.toString(), message.getPayload()) + .build()); + } } - } - /** Reads from Pub/Sub according to {@link PubsubReadSchemaTransformConfiguration}. */ - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - if (!input.getAll().isEmpty()) { - throw new IllegalArgumentException( - String.format( - "%s %s input is expected to be empty", - input.getClass().getSimpleName(), getClass().getSimpleName())); + @FinishBundle + public void finish(FinishBundleContext c) { + pubsubErrorCounter.inc(errorsInBundle); + errorsInBundle = 0L; } - - PCollectionTuple rowsWithDlq = - input - .getPipeline() - .apply("ReadFromPubsub", buildPubsubRead()) - .apply("PubsubMessageToRow", pubsubMessageToRow); - - writeToDeadLetterQueue(rowsWithDlq); - - return PCollectionRowTuple.of(OUTPUT_TAG, rowsWithDlq.get(MAIN_TAG)); } - private void writeToDeadLetterQueue(PCollectionTuple rowsWithDlq) { - PubsubIO.Write deadLetterQueue = buildDeadLetterQueueWrite(); - if (deadLetterQueue == null) { - return; - } - rowsWithDlq.get(DLQ_TAG).apply("WriteToDeadLetterQueue", deadLetterQueue); + void setClientFactory(@Nullable PubsubTestClientFactory factory) { + this.clientFactory = factory; } - /** - * Builds {@link PubsubIO.Write} dead letter queue from {@link - * PubsubReadSchemaTransformConfiguration}. - */ - PubsubIO.Write buildDeadLetterQueueWrite() { - if (configuration.getDeadLetterQueue() == null) { - return null; - } - - PubsubIO.Write writeDlq = - PubsubIO.writeMessages().to(configuration.getDeadLetterQueue()); - - if (configuration.getTimestampAttribute() != null) { - writeDlq = writeDlq.withTimestampAttribute(configuration.getTimestampAttribute()); - } - - return writeDlq; + void setClock(@Nullable Clock clock) { + this.clock = clock; } - /** Builds {@link PubsubIO.Read} from a {@link PubsubReadSchemaTransformConfiguration}. */ + @SuppressWarnings("nullness") PubsubIO.Read buildPubsubRead() { - PubsubIO.Read read = PubsubIO.readMessagesWithAttributes(); - - if (configuration.getSubscription() != null) { - read = read.fromSubscription(configuration.getSubscription()); + PubsubIO.Read pubsubRead = PubsubIO.readMessages(); + if (!Strings.isNullOrEmpty(topic)) { + pubsubRead = pubsubRead.fromTopic(topic); + } else { + pubsubRead = pubsubRead.fromSubscription(subscription); } - - if (configuration.getTopic() != null) { - read = read.fromTopic(configuration.getTopic()); + if (clientFactory != null && clock != null) { + pubsubRead = pubsubRead.withClientFactory(clientFactory); + pubsubRead = clientFactory.setClock(pubsubRead, clock); + } else if (clientFactory != null || clock != null) { + throw new IllegalArgumentException( + "Both PubsubTestClientFactory and Clock need to be specified for testing, but only one is provided"); } + return pubsubRead; + } - if (configuration.getTimestampAttribute() != null) { - read = read.withTimestampAttribute(configuration.getTimestampAttribute()); - } + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PubsubIO.Read pubsubRead = buildPubsubRead(); - if (configuration.getIdAttribute() != null) { - read = read.withIdAttribute(configuration.getIdAttribute()); - } + PCollectionTuple outputTuple = + input + .getPipeline() + .apply(pubsubRead) + .apply( + ParDo.of(new ErrorCounterFn("PubSub-read-error-counter", valueMapper)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + return PCollectionRowTuple.of( + "output", + outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema), + "errors", + outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA)); + } + } - if (clientFactory != null) { - read = read.withClientFactory(clientFactory); - } + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:pubsub_read:v1"; + } - if (clock != null) { - read = read.withClock(clock); - } + @Override + public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> + inputCollectionNames() { + return Collections.emptyList(); + } - return read; - } + @Override + public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> + outputCollectionNames() { + return Arrays.asList("output", "errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubSchemaTransformMessageToRowFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubSchemaTransformMessageToRowFactory.java deleted file mode 100644 index 988c593e32fa..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubSchemaTransformMessageToRowFactory.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.pubsub; - -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageToRow.ATTRIBUTES_FIELD; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageToRow.PAYLOAD_FIELD; -import static org.apache.beam.sdk.schemas.Schema.TypeName.ROW; - -import java.util.HashMap; -import java.util.Map; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers; - -/** - * Builds a {@link PubsubMessageToRow} from a {@link PubsubReadSchemaTransformConfiguration}. - * - *

Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Internal -class PubsubSchemaTransformMessageToRowFactory { - private static final String DEFAULT_FORMAT = "json"; - - private static final Schema.FieldType ATTRIBUTE_MAP_FIELD_TYPE = - Schema.FieldType.map(Schema.FieldType.STRING.withNullable(false), Schema.FieldType.STRING); - private static final Schema ATTRIBUTE_ARRAY_ENTRY_SCHEMA = - Schema.builder().addStringField("key").addStringField("value").build(); - private static final Schema.FieldType ATTRIBUTE_ARRAY_FIELD_TYPE = - Schema.FieldType.array(Schema.FieldType.row(ATTRIBUTE_ARRAY_ENTRY_SCHEMA)); - - private static final String THRIFT_CLASS_KEY = "thriftClass"; - private static final String THRIFT_PROTOCOL_FACTORY_CLASS_KEY = "thriftProtocolFactoryClass"; - private static final String PROTO_CLASS_KEY = "protoClass"; - - /** - * Instantiate a {@link PubsubSchemaTransformMessageToRowFactory} from a {@link - * PubsubReadSchemaTransformConfiguration}. - */ - static PubsubSchemaTransformMessageToRowFactory from( - PubsubReadSchemaTransformConfiguration configuration) { - return new PubsubSchemaTransformMessageToRowFactory(configuration); - } - - /** Build the {@link PubsubMessageToRow}. */ - PubsubMessageToRow buildMessageToRow() { - PubsubMessageToRow.Builder builder = - PubsubMessageToRow.builder() - .messageSchema(configuration.getDataSchema()) - .useDlq( - configuration.getDeadLetterQueue() != null - && !configuration.getDeadLetterQueue().isEmpty()) - .useFlatSchema(!shouldUseNestedSchema()); - - if (needsSerializer()) { - builder = builder.serializerProvider(serializer()); - } - - return builder.build(); - } - - private final PubsubReadSchemaTransformConfiguration configuration; - - private PubsubSchemaTransformMessageToRowFactory( - PubsubReadSchemaTransformConfiguration configuration) { - this.configuration = configuration; - } - - private PayloadSerializer payloadSerializer() { - Schema schema = configuration.getDataSchema(); - String format = DEFAULT_FORMAT; - - if (configuration.getFormat() != null && !configuration.getFormat().isEmpty()) { - format = configuration.getFormat(); - } - - Map params = new HashMap<>(); - - if (configuration.getThriftClass() != null && !configuration.getThriftClass().isEmpty()) { - params.put(THRIFT_CLASS_KEY, configuration.getThriftClass()); - } - - if (configuration.getThriftProtocolFactoryClass() != null - && !configuration.getThriftProtocolFactoryClass().isEmpty()) { - params.put(THRIFT_PROTOCOL_FACTORY_CLASS_KEY, configuration.getThriftProtocolFactoryClass()); - } - - if (configuration.getProtoClass() != null && !configuration.getProtoClass().isEmpty()) { - params.put(PROTO_CLASS_KEY, configuration.getProtoClass()); - } - - return PayloadSerializers.getSerializer(format, schema, params); - } - - PubsubMessageToRow.SerializerProvider serializer() { - return input -> payloadSerializer(); - } - - /** - * Determines whether the {@link PubsubMessageToRow} needs a {@link - * PubsubMessageToRow.SerializerProvider}. - * - *

The determination is based on {@link #shouldUseNestedSchema()} is false or if the {@link - * PubsubMessageToRow#PAYLOAD_FIELD} is not present. - */ - boolean needsSerializer() { - return !shouldUseNestedSchema() || !fieldPresent(PAYLOAD_FIELD, Schema.FieldType.BYTES); - } - - /** - * Determines whether a nested schema should be used for {@link - * PubsubReadSchemaTransformConfiguration#getDataSchema()}. - * - *

The determination is based on {@link #schemaHasValidPayloadField()} and {@link - * #schemaHasValidAttributesField()}} - */ - boolean shouldUseNestedSchema() { - return schemaHasValidPayloadField() && schemaHasValidAttributesField(); - } - - /** - * Determines whether {@link PubsubReadSchemaTransformConfiguration#getDataSchema()} has a valid - * {@link PubsubMessageToRow#PAYLOAD_FIELD}. - */ - boolean schemaHasValidPayloadField() { - Schema schema = configuration.getDataSchema(); - if (!schema.hasField(PAYLOAD_FIELD)) { - return false; - } - if (fieldPresent(PAYLOAD_FIELD, Schema.FieldType.BYTES)) { - return true; - } - return schema.getField(PAYLOAD_FIELD).getType().getTypeName().equals(ROW); - } - - /** - * Determines whether {@link PubsubReadSchemaTransformConfiguration#getDataSchema()} has a valid - * {@link PubsubMessageToRow#ATTRIBUTES_FIELD} field. - * - *

The determination is based on whether {@link #fieldPresent(String, Schema.FieldType)} for - * {@link PubsubMessageToRow#ATTRIBUTES_FIELD} is true for either {@link - * #ATTRIBUTE_MAP_FIELD_TYPE} or {@link #ATTRIBUTE_ARRAY_FIELD_TYPE} {@link Schema.FieldType}s. - */ - boolean schemaHasValidAttributesField() { - return fieldPresent(ATTRIBUTES_FIELD, ATTRIBUTE_MAP_FIELD_TYPE) - || fieldPresent(ATTRIBUTES_FIELD, ATTRIBUTE_ARRAY_FIELD_TYPE); - } - - /** - * Determines whether {@link PubsubReadSchemaTransformConfiguration#getDataSchema()} contains the - * field and whether that field is an expectedType {@link Schema.FieldType}. - */ - boolean fieldPresent(String field, Schema.FieldType expectedType) { - Schema schema = configuration.getDataSchema(); - return schema.hasField(field) - && expectedType.equivalent( - schema.getField(field).getType(), Schema.EquivalenceNullablePolicy.IGNORE); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformConfiguration.java index acaf04cdfc69..57620c968c5f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformConfiguration.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformConfiguration.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.io.gcp.pubsub; import com.google.auto.value.AutoValue; -import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; /** * Configuration for writing to Pub/Sub. @@ -32,179 +32,25 @@ @DefaultSchema(AutoValueSchema.class) @AutoValue public abstract class PubsubWriteSchemaTransformConfiguration { + @SchemaFieldDescription( + "The encoding format for the data stored in Pubsub. Valid options are: " + + PubsubWriteSchemaTransformProvider.VALID_FORMATS_STR) + public abstract String getFormat(); - public static final String DEFAULT_TIMESTAMP_ATTRIBUTE = "event_timestamp"; + @SchemaFieldDescription( + "The name of the topic to write data to. " + "Format: projects/${PROJECT}/topics/${TOPIC}") + public abstract String getTopic(); public static Builder builder() { return new AutoValue_PubsubWriteSchemaTransformConfiguration.Builder(); } - public static TargetConfiguration.Builder targetConfigurationBuilder() { - return new AutoValue_PubsubWriteSchemaTransformConfiguration_TargetConfiguration.Builder() - .setTimestampAttributeKey(DEFAULT_TIMESTAMP_ATTRIBUTE); - } - - public static SourceConfiguration.Builder sourceConfigurationBuilder() { - return new AutoValue_PubsubWriteSchemaTransformConfiguration_SourceConfiguration.Builder(); - } - - /** - * Configuration details of the source {@link org.apache.beam.sdk.values.Row} {@link - * org.apache.beam.sdk.schemas.Schema}. - */ - @Nullable - public abstract SourceConfiguration getSource(); - - /** Configuration details of the target {@link PubsubMessage}. */ - public abstract TargetConfiguration getTarget(); - - /** - * The topic to which to write Pub/Sub messages. - * - *

See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the - * topic string. - */ - public abstract String getTopic(); - - /** - * The expected format of the Pub/Sub message. - * - *

Used to retrieve the {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer} from - * {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers}. See list of supported - * values by invoking {@link org.apache.beam.sdk.schemas.io.Providers#loadProviders(Class)}. - * - *

{@code Providers.loadProviders(PayloadSerializer.class).keySet()}
- */ - @Nullable - public abstract String getFormat(); - - /** - * When writing to Cloud Pub/Sub where unique record identifiers are provided as Pub/Sub message - * attributes, specifies the name of the attribute containing the unique identifier. - */ - @Nullable - public abstract String getIdAttribute(); - - /** Builder for {@link PubsubWriteSchemaTransformConfiguration}. */ @AutoValue.Builder public abstract static class Builder { + public abstract Builder setFormat(String format); - /** - * Configuration details of the source {@link org.apache.beam.sdk.values.Row} {@link - * org.apache.beam.sdk.schemas.Schema}. - */ - public abstract Builder setSource(SourceConfiguration value); - - /** Configuration details of the target {@link PubsubMessage}. */ - public abstract Builder setTarget(TargetConfiguration value); - - /** - * The topic to which to write Pub/Sub messages. - * - *

See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the - * topic string. - */ - public abstract Builder setTopic(String value); - - /** - * The expected format of the Pub/Sub message. - * - *

Used to retrieve the {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer} - * from {@link org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers}. See list of - * supported values by invoking {@link - * org.apache.beam.sdk.schemas.io.Providers#loadProviders(Class)}. - * - *

{@code Providers.loadProviders(PayloadSerializer.class).keySet()}
- */ - public abstract Builder setFormat(String value); - - /** - * When reading from Cloud Pub/Sub where unique record identifiers are provided as Pub/Sub - * message attributes, specifies the name of the attribute containing the unique identifier. - */ - public abstract Builder setIdAttribute(String value); + public abstract Builder setTopic(String topic); public abstract PubsubWriteSchemaTransformConfiguration build(); } - - @DefaultSchema(AutoValueSchema.class) - @AutoValue - public abstract static class SourceConfiguration { - /** - * The attributes field name of the source {@link org.apache.beam.sdk.values.Row}. {@link - * org.apache.beam.sdk.schemas.Schema.FieldType} must be a Map<String, String> - * - */ - @Nullable - public abstract String getAttributesFieldName(); - - /** - * The timestamp field name of the source {@link org.apache.beam.sdk.values.Row}. {@link - * org.apache.beam.sdk.schemas.Schema.FieldType} must be a {@link - * org.apache.beam.sdk.schemas.Schema.FieldType#DATETIME}. - */ - @Nullable - public abstract String getTimestampFieldName(); - - /** - * The payload field name of the source {@link org.apache.beam.sdk.values.Row}. {@link - * org.apache.beam.sdk.schemas.Schema.FieldType} must be either {@link - * org.apache.beam.sdk.schemas.Schema.FieldType#BYTES} or a {@link - * org.apache.beam.sdk.values.Row}. If null, payload serialized from user fields other than - * attributes. Not compatible with other payload intended fields. - */ - @Nullable - public abstract String getPayloadFieldName(); - - @AutoValue.Builder - public abstract static class Builder { - /** - * The attributes field name of the source {@link org.apache.beam.sdk.values.Row}. {@link - * org.apache.beam.sdk.schemas.Schema.FieldType} must be a Map<String, String> - * - */ - public abstract Builder setAttributesFieldName(String value); - - /** - * The timestamp field name of the source {@link org.apache.beam.sdk.values.Row}. {@link - * org.apache.beam.sdk.schemas.Schema.FieldType} must be a {@link - * org.apache.beam.sdk.schemas.Schema.FieldType#DATETIME}. - */ - public abstract Builder setTimestampFieldName(String value); - - /** - * The payload field name of the source {@link org.apache.beam.sdk.values.Row}. {@link - * org.apache.beam.sdk.schemas.Schema.FieldType} must be either {@link - * org.apache.beam.sdk.schemas.Schema.FieldType#BYTES} or a {@link - * org.apache.beam.sdk.values.Row}. If null, payload serialized from user fields other than - * attributes. Not compatible with other payload intended fields. - */ - public abstract Builder setPayloadFieldName(String value); - - public abstract SourceConfiguration build(); - } - } - - @DefaultSchema(AutoValueSchema.class) - @AutoValue - public abstract static class TargetConfiguration { - - /** - * The attribute key to assign the {@link PubsubMessage} stringified timestamp value. {@link - * #builder()} method defaults value to {@link #DEFAULT_TIMESTAMP_ATTRIBUTE}. - */ - public abstract String getTimestampAttributeKey(); - - @AutoValue.Builder - public abstract static class Builder { - - /** - * The attribute key to assign the {@link PubsubMessage} stringified timestamp value. Defaults - * to {@link #DEFAULT_TIMESTAMP_ATTRIBUTE}. - */ - public abstract Builder setTimestampAttributeKey(String value); - - public abstract TargetConfiguration build(); - } - } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java index 7f3f6f2c7020..8e8b804801b3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java @@ -17,56 +17,29 @@ */ package org.apache.beam.sdk.io.gcp.pubsub; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.ATTRIBUTES_FIELD_TYPE; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.DEFAULT_ATTRIBUTES_KEY_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.DEFAULT_EVENT_TIMESTAMP_KEY_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.DEFAULT_PAYLOAD_KEY_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.ERROR; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.EVENT_TIMESTAMP_FIELD_TYPE; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.OUTPUT; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.PAYLOAD_BYTES_TYPE_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.PAYLOAD_ROW_TYPE_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.removeFields; -import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; - -import com.google.api.client.util.Clock; import com.google.auto.service.AutoService; -import java.io.IOException; -import java.util.ArrayList; +import java.io.Serializable; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Objects; import java.util.Set; -import java.util.stream.Stream; -import javax.annotation.Nullable; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.SchemaPath; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.FieldMatcher; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.SchemaReflection; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubWriteSchemaTransformConfiguration.SourceConfiguration; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.TypeName; -import org.apache.beam.sdk.schemas.io.Providers; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializerProvider; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.utils.JsonUtils; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; -import org.joda.time.Instant; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -76,360 +49,102 @@ * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam * repository. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Internal @AutoService(SchemaTransformProvider.class) public class PubsubWriteSchemaTransformProvider extends TypedSchemaTransformProvider { - private static final String IDENTIFIER = "beam:schematransform:org.apache.beam:pubsub_write:v1"; - static final String INPUT_TAG = "input"; - static final String ERROR_TAG = "error"; - /** Returns the expected class of the configuration. */ - @Override - protected Class configurationClass() { - return PubsubWriteSchemaTransformConfiguration.class; - } + public static final TupleTag OUTPUT_TAG = new TupleTag() {}; + public static final TupleTag ERROR_TAG = new TupleTag() {}; - /** Returns the expected {@link SchemaTransform} of the configuration. */ - @Override - public SchemaTransform from(PubsubWriteSchemaTransformConfiguration configuration) { - return new PubsubWriteSchemaTransform(configuration); - } - - /** Implementation of the {@link SchemaTransformProvider} identifier method. */ - @Override - public String identifier() { - return IDENTIFIER; - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since a - * single input is expected, this returns a list with a single name. - */ - @Override - public List inputCollectionNames() { - return Collections.singletonList(INPUT_TAG); - } + public static final String VALID_FORMATS_STR = "AVRO,JSON"; + public static final Set VALID_DATA_FORMATS = + Sets.newHashSet(VALID_FORMATS_STR.split(",")); - /** - * Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. The - * only expected output is the {@link #ERROR_TAG}. - */ @Override - public List outputCollectionNames() { - return Collections.singletonList(ERROR_TAG); + public Class configurationClass() { + return PubsubWriteSchemaTransformConfiguration.class; } - /** - * An implementation of {@link SchemaTransform} for Pub/Sub writes configured using {@link - * PubsubWriteSchemaTransformConfiguration}. - */ - static class PubsubWriteSchemaTransform extends SchemaTransform { - - private final PubsubWriteSchemaTransformConfiguration configuration; + public static class ErrorFn extends DoFn { + private SerializableFunction valueMapper; + private Schema errorSchema; - private PubsubClient.PubsubClientFactory pubsubClientFactory; - - PubsubWriteSchemaTransform(PubsubWriteSchemaTransformConfiguration configuration) { - this.configuration = configuration; + ErrorFn(SerializableFunction valueMapper, Schema errorSchema) { + this.valueMapper = valueMapper; + this.errorSchema = errorSchema; } - PubsubWriteSchemaTransform withPubsubClientFactory(PubsubClient.PubsubClientFactory factory) { - this.pubsubClientFactory = factory; - return this; - } - - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - if (input.getAll().size() != 1 || !input.has(INPUT_TAG)) { - throw new IllegalArgumentException( - String.format( - "%s %s input is expected to contain a single %s tagged PCollection", - input.getClass().getSimpleName(), getClass().getSimpleName(), INPUT_TAG)); - } - - PCollection rows = input.get(INPUT_TAG); - if (rows.getSchema().getFieldCount() == 0) { - throw new IllegalArgumentException(String.format("empty Schema for %s", INPUT_TAG)); - } - - Schema targetSchema = buildTargetSchema(rows.getSchema()); - - rows = - rows.apply( - ConvertForRowToMessage.class.getSimpleName(), - convertForRowToMessage(targetSchema)) - .setRowSchema(targetSchema); - - Schema schema = rows.getSchema(); - - Schema serializableSchema = - removeFields(schema, DEFAULT_ATTRIBUTES_KEY_NAME, DEFAULT_EVENT_TIMESTAMP_KEY_NAME); - FieldMatcher payloadRowMatcher = FieldMatcher.of(DEFAULT_PAYLOAD_KEY_NAME, TypeName.ROW); - if (payloadRowMatcher.match(serializableSchema)) { - serializableSchema = - serializableSchema.getField(DEFAULT_PAYLOAD_KEY_NAME).getType().getRowSchema(); - } - - validateTargetSchemaAgainstPubsubSchema(serializableSchema, input.getPipeline().getOptions()); - - PCollectionTuple pct = - rows.apply( - PubsubRowToMessage.class.getSimpleName(), - buildPubsubRowToMessage(serializableSchema)); - - PCollection messages = pct.get(OUTPUT); - messages.apply(PubsubIO.Write.class.getSimpleName(), buildPubsubWrite()); - return PCollectionRowTuple.of(ERROR_TAG, pct.get(ERROR)); - } - - PayloadSerializer getPayloadSerializer(Schema schema) { - if (configuration.getFormat() == null) { - return null; - } - String format = configuration.getFormat(); - Set availableFormats = - Providers.loadProviders(PayloadSerializerProvider.class).keySet(); - if (!availableFormats.contains(format)) { - String availableFormatsString = String.join(",", availableFormats); - throw new IllegalArgumentException( - String.format( - "%s is not among the valid formats: [%s]", format, availableFormatsString)); - } - return PayloadSerializers.getSerializer(configuration.getFormat(), schema, ImmutableMap.of()); - } - - PubsubRowToMessage buildPubsubRowToMessage(Schema schema) { - PubsubRowToMessage.Builder builder = - PubsubRowToMessage.builder().setPayloadSerializer(getPayloadSerializer(schema)); - - if (configuration.getTarget() != null) { - builder = - builder.setTargetTimestampAttributeName( - configuration.getTarget().getTimestampAttributeKey()); - } - - return builder.build(); - } - - PubsubIO.Write buildPubsubWrite() { - PubsubIO.Write write = PubsubIO.writeMessages().to(configuration.getTopic()); - - if (configuration.getIdAttribute() != null) { - write = write.withIdAttribute(configuration.getIdAttribute()); - } - - if (pubsubClientFactory != null) { - write = write.withClientFactory(pubsubClientFactory); + @ProcessElement + public void processElement(@Element Row row, MultiOutputReceiver receiver) { + try { + receiver.get(OUTPUT_TAG).output(new PubsubMessage(valueMapper.apply(row), null)); + } catch (Exception e) { + receiver + .get(ERROR_TAG) + .output(Row.withSchema(errorSchema).addValues(e.toString(), row).build()); } - - return write; } + } - void validateSourceSchemaAgainstConfiguration(Schema sourceSchema) { - if (sourceSchema.getFieldCount() == 0) { - throw new IllegalArgumentException(String.format("empty Schema for %s", INPUT_TAG)); - } - - if (configuration.getSource() == null) { - return; - } - - SourceConfiguration source = configuration.getSource(); - - if (source.getAttributesFieldName() != null) { - String fieldName = source.getAttributesFieldName(); - FieldType fieldType = ATTRIBUTES_FIELD_TYPE; - FieldMatcher fieldMatcher = FieldMatcher.of(fieldName, fieldType); - checkArgument( - fieldMatcher.match(sourceSchema), - String.format("schema missing field: %s for type %s: ", fieldName, fieldType)); - } - - if (source.getTimestampFieldName() != null) { - String fieldName = source.getTimestampFieldName(); - FieldType fieldType = EVENT_TIMESTAMP_FIELD_TYPE; - FieldMatcher fieldMatcher = FieldMatcher.of(fieldName, fieldType); - checkArgument( - fieldMatcher.match(sourceSchema), - String.format("schema missing field: %s for type: %s", fieldName, fieldType)); - } - - if (source.getPayloadFieldName() == null) { - return; - } - - String fieldName = source.getPayloadFieldName(); - FieldMatcher bytesFieldMatcher = FieldMatcher.of(fieldName, PAYLOAD_BYTES_TYPE_NAME); - FieldMatcher rowFieldMatcher = FieldMatcher.of(fieldName, PAYLOAD_ROW_TYPE_NAME); - SchemaReflection schemaReflection = SchemaReflection.of(sourceSchema); - checkArgument( - schemaReflection.matchesAny(bytesFieldMatcher, rowFieldMatcher), + @Override + public SchemaTransform from(PubsubWriteSchemaTransformConfiguration configuration) { + if (!VALID_DATA_FORMATS.contains(configuration.getFormat())) { + throw new IllegalArgumentException( String.format( - "schema missing field: %s for types %s or %s", - fieldName, PAYLOAD_BYTES_TYPE_NAME, PAYLOAD_ROW_TYPE_NAME)); - - String[] fieldsToExclude = - Stream.of( - source.getAttributesFieldName(), - source.getTimestampFieldName(), - source.getPayloadFieldName()) - .filter(Objects::nonNull) - .toArray(String[]::new); - - Schema userFieldsSchema = removeFields(sourceSchema, fieldsToExclude); - - if (userFieldsSchema.getFieldCount() > 0) { - throw new IllegalArgumentException( - String.format("user fields incompatible with %s field", source.getPayloadFieldName())); - } - } - - void validateTargetSchemaAgainstPubsubSchema(Schema targetSchema, PipelineOptions options) { - checkArgument(options != null); - - try (PubsubClient pubsubClient = getPubsubClient(options.as(PubsubOptions.class))) { - PubsubClient.TopicPath topicPath = PubsubClient.topicPathFromPath(configuration.getTopic()); - PubsubClient.SchemaPath schemaPath = pubsubClient.getSchemaPath(topicPath); - if (schemaPath == null || schemaPath.equals(SchemaPath.DELETED_SCHEMA)) { - return; - } - Schema expectedSchema = pubsubClient.getSchema(schemaPath); - checkState( - targetSchema.equals(expectedSchema), - String.format( - "input schema mismatch with expected schema at path: %s\ninput schema: %s\nPub/Sub schema: %s", - schemaPath, targetSchema, expectedSchema)); - } catch (IOException e) { - throw new IllegalStateException(e.getMessage()); - } - } - - Schema buildTargetSchema(Schema sourceSchema) { - validateSourceSchemaAgainstConfiguration(sourceSchema); - FieldType payloadFieldType = null; - - List fieldsToRemove = new ArrayList<>(); - - if (configuration.getSource() != null) { - SourceConfiguration source = configuration.getSource(); - - if (source.getAttributesFieldName() != null) { - fieldsToRemove.add(source.getAttributesFieldName()); - } - - if (source.getTimestampFieldName() != null) { - fieldsToRemove.add(source.getTimestampFieldName()); - } - - if (source.getPayloadFieldName() != null) { - String fieldName = source.getPayloadFieldName(); - Field field = sourceSchema.getField(fieldName); - payloadFieldType = field.getType(); - fieldsToRemove.add(fieldName); - } - } - - Schema targetSchema = - PubsubRowToMessage.builder() - .build() - .inputSchemaFactory(payloadFieldType) - .buildSchema(sourceSchema.getFields().toArray(new Field[0])); - - return removeFields(targetSchema, fieldsToRemove.toArray(new String[0])); + "Format %s not supported. Only supported formats are %s", + configuration.getFormat(), VALID_FORMATS_STR)); } + return new PubsubWriteSchemaTransform(configuration.getTopic(), configuration.getFormat()); + } - private PubsubClient.PubsubClientFactory getPubsubClientFactory() { - if (pubsubClientFactory != null) { - return pubsubClientFactory; - } - return PubsubGrpcClient.FACTORY; - } + private static class PubsubWriteSchemaTransform extends SchemaTransform implements Serializable { + final String topic; + final String format; - private PubsubClient getPubsubClient(PubsubOptions options) throws IOException { - return getPubsubClientFactory() - .newClient( - configuration.getTarget().getTimestampAttributeKey(), - configuration.getIdAttribute(), - options); + PubsubWriteSchemaTransform(String topic, String format) { + this.topic = topic; + this.format = format; } - ParDo.SingleOutput convertForRowToMessage(Schema targetSchema) { - return convertForRowToMessage(targetSchema, null); - } - - ParDo.SingleOutput convertForRowToMessage( - Schema targetSchema, @Nullable Clock clock) { - String attributesName = null; - String timestampName = null; - String payloadName = null; - SourceConfiguration source = configuration.getSource(); - if (source != null) { - attributesName = source.getAttributesFieldName(); - timestampName = source.getTimestampFieldName(); - payloadName = source.getPayloadFieldName(); - } - return ParDo.of( - new ConvertForRowToMessage( - targetSchema, clock, attributesName, timestampName, payloadName)); + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + final Schema errorSchema = + Schema.builder() + .addStringField("error") + .addNullableRowField("row", input.get("input").getSchema()) + .build(); + SerializableFunction fn = + format.equals("AVRO") + ? AvroUtils.getRowToAvroBytesFunction(input.get("input").getSchema()) + : JsonUtils.getRowToJsonBytesFunction(input.get("input").getSchema()); + + PCollectionTuple outputTuple = + input + .get("input") + .apply( + ParDo.of(new ErrorFn(fn, errorSchema)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + outputTuple.get(OUTPUT_TAG).apply(PubsubIO.writeMessages().to(topic)); + + return PCollectionRowTuple.of("errors", outputTuple.get(ERROR_TAG).setRowSchema(errorSchema)); } } - private static class ConvertForRowToMessage extends DoFn { - private final Schema targetSchema; - @Nullable private final Clock clock; - @Nullable private final String attributesFieldName; - @Nullable private final String timestampFieldName; - @Nullable private final String payloadFieldName; - - ConvertForRowToMessage( - Schema targetSchema, - @Nullable Clock clock, - @Nullable String attributesFieldName, - @Nullable String timestampFieldName, - @Nullable String payloadFieldName) { - this.targetSchema = targetSchema; - this.clock = clock; - this.attributesFieldName = attributesFieldName; - this.timestampFieldName = timestampFieldName; - this.payloadFieldName = payloadFieldName; - } - - @ProcessElement - public void process(@Element Row row, OutputReceiver receiver) { - Instant now = Instant.now(); - if (clock != null) { - now = Instant.ofEpochMilli(clock.currentTimeMillis()); - } - Map values = new HashMap<>(); - - // Default attributes value - checkState(targetSchema.hasField(DEFAULT_ATTRIBUTES_KEY_NAME)); - values.put(DEFAULT_ATTRIBUTES_KEY_NAME, ImmutableMap.of()); - - // Default timestamp value - checkState(targetSchema.hasField(DEFAULT_EVENT_TIMESTAMP_KEY_NAME)); - values.put(DEFAULT_EVENT_TIMESTAMP_KEY_NAME, now); + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:pubsub_write:v1"; + } - for (String fieldName : row.getSchema().getFieldNames()) { - if (targetSchema.hasField(fieldName)) { - values.put(fieldName, row.getValue(fieldName)); - } + @Override + public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> + inputCollectionNames() { + return Collections.singletonList("input"); + } - if (attributesFieldName != null) { - values.put(DEFAULT_ATTRIBUTES_KEY_NAME, row.getValue(attributesFieldName)); - } - if (timestampFieldName != null) { - values.put(DEFAULT_EVENT_TIMESTAMP_KEY_NAME, row.getValue(timestampFieldName)); - } - if (payloadFieldName != null) { - values.put(DEFAULT_PAYLOAD_KEY_NAME, row.getValue(payloadFieldName)); - } - } - receiver.output(Row.withSchema(targetSchema).withFieldValues(values).build()); - } + @Override + public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> + outputCollectionNames() { + return Collections.singletonList("errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java index db4a1008ecdd..27dec31c8c43 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java @@ -128,15 +128,15 @@ private void runBigQueryIOStorageWritePipeline( } @Test - public void testBigQueryStorageWrite30MProto() { + public void testBigQueryStorageWrite3MProto() { setUpTestEnvironment(WriteMode.EXACT_ONCE); - runBigQueryIOStorageWritePipeline(3000000, WriteMode.EXACT_ONCE, false); + runBigQueryIOStorageWritePipeline(3_000_000, WriteMode.EXACT_ONCE, false); } @Test - public void testBigQueryStorageWrite30MProtoALO() { + public void testBigQueryStorageWrite3MProtoALO() { setUpTestEnvironment(WriteMode.AT_LEAST_ONCE); - runBigQueryIOStorageWritePipeline(3000000, WriteMode.AT_LEAST_ONCE, false); + runBigQueryIOStorageWritePipeline(3_000_000, WriteMode.AT_LEAST_ONCE, false); } @Test diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index 970580f3ef67..534c1b0c360b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -118,6 +118,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SerializableFunctions; @@ -133,6 +134,7 @@ import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.Row; @@ -302,6 +304,7 @@ public void testWriteDynamicDestinationsBatchWithSchemas() throws Exception { @Test public void testWriteDynamicDestinationsStreamingWithAutoSharding() throws Exception { assumeTrue(useStreaming); + assumeTrue(!useStorageApiApproximate); // STORAGE_API_AT_LEAST_ONCE ignores auto-sharding writeDynamicDestinations(true, true); } @@ -599,8 +602,8 @@ public void testClusteringTableFunction() throws Exception { assertEquals(clustering, table.getClustering()); } - @Test - public void testTriggeredFileLoads() throws Exception { + public void runStreamingFileLoads(String tableRef, boolean useTempTables, boolean useTempDataset) + throws Exception { assumeTrue(!useStorageApi); assumeTrue(useStreaming); List elements = Lists.newArrayList(); @@ -620,65 +623,29 @@ public void testTriggeredFileLoads() throws Exception { elements.get(20), Iterables.toArray(elements.subList(21, 30), TableRow.class)) .advanceWatermarkToInfinity(); - BigQueryIO.Write.Method method = Method.FILE_LOADS; - p.apply(testStream) - .apply( - BigQueryIO.writeTableRows() - .to("project-id:dataset-id.table-id") - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("number").setType("INTEGER")))) - .withTestServices(fakeBqServices) - .withTriggeringFrequency(Duration.standardSeconds(30)) - .withNumFileShards(2) - .withMethod(method) - .withoutValidation()); - p.run(); - - assertThat( - fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), - containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); - } + BigQueryIO.Write writeTransform = + BigQueryIO.writeTableRows() + .to(tableRef) + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withWriteDisposition(Write.WriteDisposition.WRITE_APPEND) + .withTriggeringFrequency(Duration.standardSeconds(30)) + .withNumFileShards(2) + .withMethod(Method.FILE_LOADS) + .withoutValidation(); - @Test - public void testTriggeredFileLoadsWithTempTablesAndDataset() throws Exception { - String tableRef = "bigquery-project-id:dataset-id.table-id"; - List elements = Lists.newArrayList(); - for (int i = 0; i < 30; ++i) { - elements.add(new TableRow().set("number", i)); + if (useTempTables) { + writeTransform = writeTransform.withMaxBytesPerPartition(1).withMaxFilesPerPartition(1); + } + if (useTempDataset) { + writeTransform = writeTransform.withWriteTempDataset("temp-dataset-id"); } - TestStream testStream = - TestStream.create(TableRowJsonCoder.of()) - .addElements( - elements.get(0), Iterables.toArray(elements.subList(1, 10), TableRow.class)) - .advanceProcessingTime(Duration.standardMinutes(1)) - .addElements( - elements.get(10), Iterables.toArray(elements.subList(11, 20), TableRow.class)) - .advanceProcessingTime(Duration.standardMinutes(1)) - .addElements( - elements.get(20), Iterables.toArray(elements.subList(21, 30), TableRow.class)) - .advanceWatermarkToInfinity(); - BigQueryIO.Write.Method method = Method.FILE_LOADS; - p.apply(testStream) - .apply( - BigQueryIO.writeTableRows() - .to(tableRef) - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("number").setType("INTEGER")))) - .withTestServices(fakeBqServices) - .withTriggeringFrequency(Duration.standardSeconds(30)) - .withNumFileShards(2) - .withMaxBytesPerPartition(1) - .withMaxFilesPerPartition(1) - .withMethod(method) - .withoutValidation() - .withWriteTempDataset("temp-dataset-id")); + p.apply(testStream).apply(writeTransform); p.run(); final int projectIdSplitter = tableRef.indexOf(':'); @@ -690,7 +657,49 @@ public void testTriggeredFileLoadsWithTempTablesAndDataset() throws Exception { containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); } - public void testTriggeredFileLoadsWithTempTables(String tableRef) throws Exception { + public void runStreamingFileLoads(String tableRef) throws Exception { + runStreamingFileLoads(tableRef, true, false); + } + + @Test + public void testStreamingFileLoads() throws Exception { + runStreamingFileLoads("project-id:dataset-id.table-id", false, false); + } + + @Test + public void testStreamingFileLoadsWithTempTables() throws Exception { + runStreamingFileLoads("project-id:dataset-id.table-id"); + } + + @Test + public void testStreamingFileLoadsWithTempTablesDefaultProject() throws Exception { + runStreamingFileLoads("dataset-id.table-id"); + } + + @Test + @ProjectOverride + public void testStreamingFileLoadsWithTempTablesBigQueryProject() throws Exception { + runStreamingFileLoads("bigquery-project-id:dataset-id.table-id"); + } + + @Test + public void testStreamingFileLoadsWithTempTablesAndDataset() throws Exception { + runStreamingFileLoads("bigquery-project-id:dataset-id.table-id", true, true); + } + + @Test + public void testStreamingFileLoadsWithTempTablesToExistingNullSchemaTable() throws Exception { + TableReference ref = + new TableReference() + .setProjectId("project-id") + .setDatasetId("dataset-id") + .setTableId("table-id"); + fakeDatasetService.createTable(new Table().setTableReference(ref).setSchema(null)); + runStreamingFileLoads("project-id:dataset-id.table-id"); + } + + @Test + public void testStreamingFileLoadsWithAutoSharding() throws Exception { assumeTrue(!useStorageApi); assumeTrue(useStreaming); List elements = Lists.newArrayList(); @@ -698,72 +707,102 @@ public void testTriggeredFileLoadsWithTempTables(String tableRef) throws Excepti elements.add(new TableRow().set("number", i)); } + Instant startInstant = new Instant(0L); TestStream testStream = TestStream.create(TableRowJsonCoder.of()) + // Initialize watermark for timer to be triggered correctly. + .advanceWatermarkTo(startInstant) .addElements( elements.get(0), Iterables.toArray(elements.subList(1, 10), TableRow.class)) .advanceProcessingTime(Duration.standardMinutes(1)) + .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(10))) .addElements( elements.get(10), Iterables.toArray(elements.subList(11, 20), TableRow.class)) .advanceProcessingTime(Duration.standardMinutes(1)) + .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(30))) .addElements( elements.get(20), Iterables.toArray(elements.subList(21, 30), TableRow.class)) + .advanceProcessingTime(Duration.standardMinutes(2)) .advanceWatermarkToInfinity(); - BigQueryIO.Write.Method method = Method.FILE_LOADS; + int numTables = 3; p.apply(testStream) .apply( BigQueryIO.writeTableRows() - .to(tableRef) + .to( + (ValueInSingleWindow vsw) -> { + String tableSpec = + "project-id:dataset-id.table-" + + ((int) vsw.getValue().get("number") % numTables); + return new TableDestination(tableSpec, null); + }) .withSchema( new TableSchema() .setFields( ImmutableList.of( new TableFieldSchema().setName("number").setType("INTEGER")))) .withTestServices(fakeBqServices) - .withTriggeringFrequency(Duration.standardSeconds(30)) - .withNumFileShards(2) - .withMaxBytesPerPartition(1) - .withMaxFilesPerPartition(1) - .withMethod(method) + // Set a triggering frequency without needing to also specify numFileShards when + // using autoSharding. + .withTriggeringFrequency(Duration.standardSeconds(100)) + .withAutoSharding() + .withMaxBytesPerPartition(1000) + .withMaxFilesPerPartition(10) + .withMethod(BigQueryIO.Write.Method.FILE_LOADS) .withoutValidation()); p.run(); - final int projectIdSplitter = tableRef.indexOf(':'); - final String projectId = - projectIdSplitter == -1 ? "project-id" : tableRef.substring(0, projectIdSplitter); - - assertThat( - fakeDatasetService.getAllRows(projectId, "dataset-id", "table-id"), - containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); + Map> elementsByTableIdx = new HashMap<>(); + for (int i = 0; i < elements.size(); i++) { + elementsByTableIdx + .computeIfAbsent(i % numTables, k -> new ArrayList<>()) + .add(elements.get(i)); + } + for (Map.Entry> entry : elementsByTableIdx.entrySet()) { + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-" + entry.getKey()), + containsInAnyOrder(Iterables.toArray(entry.getValue(), TableRow.class))); + } + // For each table destination, it's expected to create two load jobs based on the triggering + // frequency and processing time intervals. + assertEquals(2 * numTables, fakeDatasetService.getInsertCount()); } @Test - @ProjectOverride - public void testTriggeredFileLoadsWithTempTablesBigQueryProject() throws Exception { - testTriggeredFileLoadsWithTempTables("bigquery-project-id:dataset-id.table-id"); - } + public void testBatchFileLoads() throws Exception { + assumeTrue(!useStreaming); + assumeTrue(!useStorageApi); + List elements = Lists.newArrayList(); + for (int i = 0; i < 30; ++i) { + elements.add(new TableRow().set("number", i)); + } - @Test - public void testTriggeredFileLoadsWithTempTables() throws Exception { - testTriggeredFileLoadsWithTempTables("project-id:dataset-id.table-id"); - } + WriteResult result = + p.apply(Create.of(elements).withCoder(TableRowJsonCoder.of())) + .apply( + BigQueryIO.writeTableRows() + .to("dataset-id.table-id") + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withoutValidation()); - @Test - public void testTriggeredFileLoadsWithTempTablesToExistingNullSchemaTable() throws Exception { - Table fakeTable = new Table(); - TableReference ref = - new TableReference() - .setProjectId("project-id") - .setDatasetId("dataset-id") - .setTableId("table-id"); - fakeTable.setTableReference(ref); - fakeDatasetService.createTable(fakeTable); - testTriggeredFileLoadsWithTempTables("project-id:dataset-id.table-id"); + PAssert.that(result.getSuccessfulTableLoads()) + .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); + p.run(); + + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), + containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); } @Test - public void testUntriggeredFileLoadsWithTempTables() throws Exception { + public void testBatchFileLoadsWithTempTables() throws Exception { // Test only non-streaming inserts. assumeTrue(!useStorageApi); assumeTrue(!useStreaming); @@ -771,19 +810,23 @@ public void testUntriggeredFileLoadsWithTempTables() throws Exception { for (int i = 0; i < 30; ++i) { elements.add(new TableRow().set("number", i)); } - p.apply(Create.of(elements)) - .apply( - BigQueryIO.writeTableRows() - .to("project-id:dataset-id.table-id") - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("number").setType("INTEGER")))) - .withTestServices(fakeBqServices) - .withMaxBytesPerPartition(1) - .withMaxFilesPerPartition(1) - .withoutValidation()); + WriteResult result = + p.apply(Create.of(elements)) + .apply( + BigQueryIO.writeTableRows() + .to("project-id:dataset-id.table-id") + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withMaxBytesPerPartition(1) + .withMaxFilesPerPartition(1) + .withoutValidation()); + + PAssert.that(result.getSuccessfulTableLoads()) + .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); p.run(); assertThat( @@ -792,12 +835,7 @@ public void testUntriggeredFileLoadsWithTempTables() throws Exception { } @Test - public void testTriggeredFileLoadsWithTempTablesDefaultProject() throws Exception { - testTriggeredFileLoadsWithTempTables("dataset-id.table-id"); - } - - @Test - public void testTriggeredFileLoadsWithTempTablesCreateNever() throws Exception { + public void testBatchFileLoadsWithTempTablesCreateNever() throws Exception { assumeTrue(!useStorageApi); assumeTrue(!useStreaming); @@ -841,77 +879,7 @@ public void testTriggeredFileLoadsWithTempTablesCreateNever() throws Exception { } @Test - public void testTriggeredFileLoadsWithAutoSharding() throws Exception { - assumeTrue(!useStorageApi); - assumeTrue(useStreaming); - List elements = Lists.newArrayList(); - for (int i = 0; i < 30; ++i) { - elements.add(new TableRow().set("number", i)); - } - - Instant startInstant = new Instant(0L); - TestStream testStream = - TestStream.create(TableRowJsonCoder.of()) - // Initialize watermark for timer to be triggered correctly. - .advanceWatermarkTo(startInstant) - .addElements( - elements.get(0), Iterables.toArray(elements.subList(1, 10), TableRow.class)) - .advanceProcessingTime(Duration.standardMinutes(1)) - .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(10))) - .addElements( - elements.get(10), Iterables.toArray(elements.subList(11, 20), TableRow.class)) - .advanceProcessingTime(Duration.standardMinutes(1)) - .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(30))) - .addElements( - elements.get(20), Iterables.toArray(elements.subList(21, 30), TableRow.class)) - .advanceProcessingTime(Duration.standardMinutes(2)) - .advanceWatermarkToInfinity(); - - int numTables = 3; - p.apply(testStream) - .apply( - BigQueryIO.writeTableRows() - .to( - (ValueInSingleWindow vsw) -> { - String tableSpec = - "project-id:dataset-id.table-" - + ((int) vsw.getValue().get("number") % numTables); - return new TableDestination(tableSpec, null); - }) - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("number").setType("INTEGER")))) - .withTestServices(fakeBqServices) - // Set a triggering frequency without needing to also specify numFileShards when - // using autoSharding. - .withTriggeringFrequency(Duration.standardSeconds(100)) - .withAutoSharding() - .withMaxBytesPerPartition(1000) - .withMaxFilesPerPartition(10) - .withMethod(BigQueryIO.Write.Method.FILE_LOADS) - .withoutValidation()); - p.run(); - - Map> elementsByTableIdx = new HashMap<>(); - for (int i = 0; i < elements.size(); i++) { - elementsByTableIdx - .computeIfAbsent(i % numTables, k -> new ArrayList<>()) - .add(elements.get(i)); - } - for (Map.Entry> entry : elementsByTableIdx.entrySet()) { - assertThat( - fakeDatasetService.getAllRows("project-id", "dataset-id", "table-" + entry.getKey()), - containsInAnyOrder(Iterables.toArray(entry.getValue(), TableRow.class))); - } - // For each table destination, it's expected to create two load jobs based on the triggering - // frequency and processing time intervals. - assertEquals(2 * numTables, fakeDatasetService.getInsertCount()); - } - - @Test - public void testFailuresNoRetryPolicy() throws Exception { + public void testStreamingInsertsFailuresNoRetryPolicy() throws Exception { assumeTrue(!useStorageApi); assumeTrue(useStreaming); TableRow row1 = new TableRow().set("name", "a").set("number", "1"); @@ -949,7 +917,7 @@ public void testFailuresNoRetryPolicy() throws Exception { } @Test - public void testRetryPolicy() throws Exception { + public void testStreamingInsertsRetryPolicy() throws Exception { assumeTrue(!useStorageApi); assumeTrue(useStreaming); TableRow row1 = new TableRow().set("name", "a").set("number", "1"); @@ -1023,70 +991,6 @@ public void testWrite() throws Exception { p.run(); } - @Test - public void testWriteWithSuccessfulBatchInserts() throws Exception { - assumeTrue(!useStreaming); - assumeTrue(!useStorageApi); - - WriteResult result = - p.apply( - Create.of( - new TableRow().set("name", "a").set("number", 1), - new TableRow().set("name", "b").set("number", 2), - new TableRow().set("name", "c").set("number", 3)) - .withCoder(TableRowJsonCoder.of())) - .apply( - BigQueryIO.writeTableRows() - .to("dataset-id.table-id") - .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("name").setType("STRING"), - new TableFieldSchema().setName("number").setType("INTEGER")))) - .withTestServices(fakeBqServices) - .withoutValidation()); - - PAssert.that(result.getSuccessfulTableLoads()) - .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); - - p.run(); - } - - @Test - public void testWriteWithSuccessfulBatchInsertsAndWriteRename() throws Exception { - assumeTrue(!useStreaming); - assumeTrue(!useStorageApi); - - WriteResult result = - p.apply( - Create.of( - new TableRow().set("name", "a").set("number", 1), - new TableRow().set("name", "b").set("number", 2), - new TableRow().set("name", "c").set("number", 3)) - .withCoder(TableRowJsonCoder.of())) - .apply( - BigQueryIO.writeTableRows() - .to("dataset-id.table-id") - .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("name").setType("STRING"), - new TableFieldSchema().setName("number").setType("INTEGER")))) - .withMaxFileSize(1) - .withMaxFilesPerPartition(1) - .withTestServices(fakeBqServices) - .withoutValidation()); - - PAssert.that(result.getSuccessfulTableLoads()) - .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); - - p.run(); - } - @Test public void testWriteWithoutInsertId() throws Exception { assumeTrue(!useStorageApi); @@ -1298,9 +1202,13 @@ public void testStreamingWriteWithAutoSharding() throws Exception { } private void streamingWrite(boolean autoSharding) throws Exception { - if (!useStreaming) { - return; - } + assumeTrue(useStreaming); + List elements = + ImmutableList.of( + new TableRow().set("name", "a").set("number", "1"), + new TableRow().set("name", "b").set("number", "2"), + new TableRow().set("name", "c").set("number", "3"), + new TableRow().set("name", "d").set("number", "4")); BigQueryIO.Write write = BigQueryIO.writeTableRows() .to("project-id:dataset-id.table-id") @@ -1316,33 +1224,39 @@ private void streamingWrite(boolean autoSharding) throws Exception { if (autoSharding) { write = write.withAutoSharding(); } - p.apply( - Create.of( - new TableRow().set("name", "a").set("number", "1"), - new TableRow().set("name", "b").set("number", "2"), - new TableRow().set("name", "c").set("number", "3"), - new TableRow().set("name", "d").set("number", "4")) - .withCoder(TableRowJsonCoder.of())) + p.apply(Create.of(elements).withCoder(TableRowJsonCoder.of())) .setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED) .apply("WriteToBQ", write); p.run(); assertThat( fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), - containsInAnyOrder( - new TableRow().set("name", "a").set("number", "1"), - new TableRow().set("name", "b").set("number", "2"), - new TableRow().set("name", "c").set("number", "3"), - new TableRow().set("name", "d").set("number", "4"))); - } - - @Test - public void testStorageApiWriteWithAutoSharding() throws Exception { - storageWrite(true); + containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); } private void storageWrite(boolean autoSharding) throws Exception { assumeTrue(useStorageApi); + if (autoSharding) { + assumeTrue(!useStorageApiApproximate); + assumeTrue(useStreaming); + } + List elements = Lists.newArrayList(); + for (int i = 0; i < 30; ++i) { + elements.add(new TableRow().set("number", String.valueOf(i))); + } + + TestStream testStream = + TestStream.create(TableRowJsonCoder.of()) + .addElements( + elements.get(0), Iterables.toArray(elements.subList(1, 10), TableRow.class)) + .advanceProcessingTime(Duration.standardMinutes(1)) + .addElements( + elements.get(10), Iterables.toArray(elements.subList(11, 20), TableRow.class)) + .advanceProcessingTime(Duration.standardMinutes(1)) + .addElements( + elements.get(20), Iterables.toArray(elements.subList(21, 30), TableRow.class)) + .advanceWatermarkToInfinity(); + BigQueryIO.Write write = BigQueryIO.writeTableRows() .to("project-id:dataset-id.table-id") @@ -1351,35 +1265,50 @@ private void storageWrite(boolean autoSharding) throws Exception { new TableSchema() .setFields( ImmutableList.of( - new TableFieldSchema().setName("name").setType("STRING"), new TableFieldSchema().setName("number").setType("INTEGER")))) .withTestServices(fakeBqServices) .withoutValidation(); - if (autoSharding) { - write = - write - .withAutoSharding() - .withTriggeringFrequency(Duration.standardSeconds(5)) - .withMethod(Method.STORAGE_WRITE_API); + + if (useStreaming) { + if (!useStorageApiApproximate) { + write = + write + .withTriggeringFrequency(Duration.standardSeconds(30)) + .withNumStorageWriteApiStreams(2); + } + if (autoSharding) { + write = write.withAutoSharding(); + } } - p.apply( - Create.of( - new TableRow().set("name", "a").set("number", "1"), - new TableRow().set("name", "b").set("number", "2"), - new TableRow().set("name", "c").set("number", "3"), - new TableRow().set("name", "d").set("number", "4")) - .withCoder(TableRowJsonCoder.of())) - .setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED) - .apply("WriteToBQ", write); - p.run(); + + PTransform> source = + useStreaming ? testStream : Create.of(elements).withCoder(TableRowJsonCoder.of()); + + p.apply(source).apply("WriteToBQ", write); + p.run().waitUntilFinish(); assertThat( fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), - containsInAnyOrder( - new TableRow().set("name", "a").set("number", "1"), - new TableRow().set("name", "b").set("number", "2"), - new TableRow().set("name", "c").set("number", "3"), - new TableRow().set("name", "d").set("number", "4"))); + containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); + } + + @Test + public void testBatchStorageApiWrite() throws Exception { + assumeTrue(!useStreaming); + storageWrite(false); + } + + @Test + public void testStreamingStorageApiWrite() throws Exception { + assumeTrue(useStreaming); + storageWrite(false); + } + + @Test + public void testStreamingStorageApiWriteWithAutoSharding() throws Exception { + assumeTrue(useStreaming); + assumeTrue(!useStorageApiApproximate); + storageWrite(true); } @DefaultSchema(JavaFieldSchema.class) @@ -1398,7 +1327,7 @@ static class SchemaPojo { public void testSchemaWriteLoads() throws Exception { assumeTrue(!useStreaming); // withMethod overrides the pipeline option, so we need to explicitly request - // STORAGE_API_WRITES. + // STORAGE_WRITE_API. BigQueryIO.Write.Method method = useStorageApi ? (useStorageApiApproximate @@ -2228,6 +2157,7 @@ public void testWriteValidateFailsWithBatchAutoSharding() { @Test public void testMaxRetryJobs() { + assumeTrue(!useStorageApi); BigQueryIO.Write write = BigQueryIO.writeTableRows() .to("dataset.table") @@ -2692,7 +2622,7 @@ public void testWriteToTableDecorator() throws Exception { } @Test - public void testExtendedErrorRetrieval() throws Exception { + public void testStreamingInsertsExtendedErrorRetrieval() throws Exception { assumeTrue(!useStorageApi); TableRow row1 = new TableRow().set("name", "a").set("number", "1"); TableRow row2 = new TableRow().set("name", "b").set("number", "2"); @@ -2850,7 +2780,7 @@ public void testStorageApiErrors() throws Exception { } @Test - public void testWrongErrorConfigs() { + public void testStreamingInsertsWrongErrorConfigs() { assumeTrue(!useStorageApi); p.enableAutoRunIfMissing(true); TableRow row1 = new TableRow().set("name", "a").set("number", "1"); @@ -2914,6 +2844,7 @@ public void testWrongErrorConfigs() { void schemaUpdateOptionsTest( BigQueryIO.Write.Method insertMethod, Set schemaUpdateOptions) throws Exception { + assumeTrue(!useStorageApi); TableRow row = new TableRow().set("date", "2019-01-01").set("number", "1"); TableSchema schema = @@ -3003,32 +2934,6 @@ public void testWriteWithStorageApiWithDefaultProject() throws Exception { containsInAnyOrder(new TableRow().set("name", "a"), new TableRow().set("name", "b"))); } - @Test - public void testWriteWithStorageApiWithoutSettingShardsEnableAutoSharding() throws Exception { - assumeTrue(useStorageApi); - assumeTrue(p.getOptions().as(BigQueryOptions.class).getNumStorageWriteApiStreams() == 0); - BigQueryIO.Write write = - BigQueryIO.writeTableRows() - .to("dataset-id.table-id") - .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) - .withSchema( - new TableSchema() - .setFields( - ImmutableList.of(new TableFieldSchema().setName("name").setType("STRING")))) - .withMethod(Method.STORAGE_WRITE_API) - .withoutValidation() - .withTestServices(fakeBqServices); - - p.apply( - Create.of(new TableRow().set("name", "a"), new TableRow().set("name", "b")) - .withCoder(TableRowJsonCoder.of())) - .apply("WriteToBQ", write); - p.run(); - assertThat( - fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), - containsInAnyOrder(new TableRow().set("name", "a"), new TableRow().set("name", "b"))); - } - @Test public void testBatchStorageWriteWithMultipleAppendsPerStream() throws Exception { assumeTrue(useStorageApi); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java index 58f181700d7e..90be99fce84c 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java @@ -114,6 +114,7 @@ public class TableRowToStorageApiProtoTest { .setType("TIMESTAMP") .setName("timestampValueSpaceTrailingZero")) .add(new TableFieldSchema().setType("DATETIME").setName("datetimeValueSpace")) + .add(new TableFieldSchema().setType("TIMESTAMP").setName("timestampValueMaximum")) .build()); private static final TableSchema BASE_TABLE_SCHEMA_NO_F = @@ -163,6 +164,7 @@ public class TableRowToStorageApiProtoTest { .setType("TIMESTAMP") .setName("timestampValueSpaceTrailingZero")) .add(new TableFieldSchema().setType("DATETIME").setName("datetimeValueSpace")) + .add(new TableFieldSchema().setType("TIMESTAMP").setName("timestampValueMaximum")) .build()); private static final DescriptorProto BASE_TABLE_SCHEMA_PROTO = @@ -356,6 +358,13 @@ public class TableRowToStorageApiProtoTest { .setType(Type.TYPE_INT64) .setLabel(Label.LABEL_OPTIONAL) .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("timestampvaluemaximum") + .setNumber(28) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) .build(); private static final DescriptorProto BASE_TABLE_SCHEMA_NO_F_PROTO = @@ -542,6 +551,13 @@ public class TableRowToStorageApiProtoTest { .setType(Type.TYPE_INT64) .setLabel(Label.LABEL_OPTIONAL) .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("timestampvaluemaximum") + .setNumber(27) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) .build(); private static final TableSchema NESTED_TABLE_SCHEMA = new TableSchema() @@ -689,7 +705,8 @@ public void testNestedFromTableSchema() { new TableCell().setV("1970-01-01 00:00:00.123456 America/New_York"), new TableCell().setV("1970-01-01 00:00:00.123"), new TableCell().setV("1970-01-01 00:00:00.1230"), - new TableCell().setV("2019-08-16 00:52:07.123456"))); + new TableCell().setV("2019-08-16 00:52:07.123456"), + new TableCell().setV("9999-12-31 23:59:59.999999Z"))); private static final TableRow BASE_TABLE_ROW_NO_F = new TableRow() @@ -721,7 +738,8 @@ public void testNestedFromTableSchema() { .set("timestampValueZoneRegion", "1970-01-01 00:00:00.123456 America/New_York") .set("timestampValueSpaceMilli", "1970-01-01 00:00:00.123") .set("timestampValueSpaceTrailingZero", "1970-01-01 00:00:00.1230") - .set("datetimeValueSpace", "2019-08-16 00:52:07.123456"); + .set("datetimeValueSpace", "2019-08-16 00:52:07.123456") + .set("timestampValueMaximum", "9999-12-31 23:59:59.999999Z"); private static final Map BASE_ROW_EXPECTED_PROTO_VALUES = ImmutableMap.builder() @@ -761,6 +779,7 @@ public void testNestedFromTableSchema() { .put("timestampvaluespacemilli", 123000L) .put("timestampvaluespacetrailingzero", 123000L) .put("datetimevaluespace", 142111881387172416L) + .put("timestampvaluemaximum", 253402300799999999L) .build(); private static final Map BASE_ROW_NO_F_EXPECTED_PROTO_VALUES = @@ -800,6 +819,7 @@ public void testNestedFromTableSchema() { .put("timestampvaluespacemilli", 123000L) .put("timestampvaluespacetrailingzero", 123000L) .put("datetimevaluespace", 142111881387172416L) + .put("timestampvaluemaximum", 253402300799999999L) .build(); private void assertBaseRecord(DynamicMessage msg, boolean withF) { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/healthcare/FhirIOReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/healthcare/FhirIOReadIT.java index 142b26dd2cdf..a9db2fc240eb 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/healthcare/FhirIOReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/healthcare/FhirIOReadIT.java @@ -133,7 +133,7 @@ public void testFhirIORead() throws Exception { "waitForAnyMessage", signal.signalSuccessWhen(resources.getCoder(), anyResources -> true)); // wait for any resource - Supplier start = signal.waitForStart(Duration.standardMinutes(5)); + Supplier start = signal.waitForStart(Duration.standardMinutes(8)); pipeline.apply(signal.signalStart()); PipelineResult job = pipeline.run(); start.get(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadIT.java index 7b370ebf7e2c..193ba5e19c40 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadIT.java @@ -51,7 +51,7 @@ public void testReadPublicData() throws Exception { messages.apply( "waitForAnyMessage", signal.signalSuccessWhen(messages.getCoder(), anyMessages -> true)); - Supplier start = signal.waitForStart(Duration.standardMinutes(5)); + Supplier start = signal.waitForStart(Duration.standardMinutes(8)); pipeline.apply(signal.signalStart()); PipelineResult job = pipeline.run(); start.get(); @@ -79,7 +79,7 @@ public void testReadPubsubMessageId() throws Exception { "isMessageIdNonNull", signal.signalSuccessWhen(messages.getCoder(), new NonEmptyMessageIdCheck())); - Supplier start = signal.waitForStart(Duration.standardMinutes(5)); + Supplier start = signal.waitForStart(Duration.standardMinutes(8)); pipeline.apply(signal.signalStart()); PipelineResult job = pipeline.run(); start.get(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java index 848549f19298..0de998f11127 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java @@ -18,300 +18,237 @@ package org.apache.beam.sdk.io.gcp.pubsub; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import com.google.api.client.util.Clock; -import com.google.gson.Gson; -import com.google.gson.JsonObject; -import com.google.gson.JsonPrimitive; import com.google.protobuf.ByteString; import com.google.protobuf.Timestamp; import java.io.IOException; import java.io.Serializable; -import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Objects; import java.util.UUID; import java.util.stream.Collectors; +import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.extensions.avro.schemas.io.payloads.AvroPayloadSerializerProvider; -import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubTestClient.PubsubTestClientFactory; +import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricQueryResults; +import org.apache.beam.sdk.metrics.MetricResult; +import org.apache.beam.sdk.metrics.MetricResults; +import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Tests for {@link PubsubReadSchemaTransformProvider}. */ +/** Tests for {@link org.apache.beam.sdk.io.gcp.pubsub.PubsubReadSchemaTransformProvider}. */ @RunWith(JUnit4.class) public class PubsubReadSchemaTransformProviderTest { - private static final Schema SCHEMA = + private static final Schema BEAM_SCHEMA = Schema.of( Schema.Field.of("name", Schema.FieldType.STRING), Schema.Field.of("number", Schema.FieldType.INT64)); - + private static final Schema BEAM_SCHEMA_WITH_ERROR = + Schema.of(Schema.Field.of("error", Schema.FieldType.STRING)); + private static final String SCHEMA = AvroUtils.toAvroSchema(BEAM_SCHEMA).toString(); private static final String SUBSCRIPTION = "projects/project/subscriptions/subscription"; private static final String TOPIC = "projects/project/topics/topic"; - private static final List cases = - Arrays.asList( - testCase( - "no configured topic or subscription", - PubsubReadSchemaTransformConfiguration.builder().setDataSchema(SCHEMA).build()) - .expectInvalidConfiguration(), - testCase( - "both topic and subscription configured", - PubsubReadSchemaTransformConfiguration.builder() - .setSubscription(SUBSCRIPTION) - .setSubscription(TOPIC) - .setDataSchema(SCHEMA) - .build()) - .expectInvalidConfiguration(), - testCase( - "invalid format configured", - PubsubReadSchemaTransformConfiguration.builder() - .setSubscription(SUBSCRIPTION) - .setDataSchema(SCHEMA) - .setFormat("invalidformat") - .build()) - .expectInvalidConfiguration(), - testCase( - "configuration with subscription", - PubsubReadSchemaTransformConfiguration.builder() - .setSubscription(SUBSCRIPTION) - .setDataSchema(SCHEMA) - .build()) - .withExpectedPubsubRead(PubsubIO.readMessages().fromSubscription(SUBSCRIPTION)), - testCase( - "configuration with topic", - PubsubReadSchemaTransformConfiguration.builder() - .setTopic(TOPIC) - .setDataSchema(SCHEMA) - .build()) - .withExpectedPubsubRead(PubsubIO.readMessages().fromTopic(TOPIC)), - testCase( - "configuration with subscription, timestamp and id attributes", - PubsubReadSchemaTransformConfiguration.builder() - .setSubscription(SUBSCRIPTION) - .setTimestampAttribute("timestampAttribute") - .setIdAttribute("idAttribute") - .setDataSchema(SCHEMA) - .build()) - .withExpectedPubsubRead( - PubsubIO.readMessages() - .fromSubscription(SUBSCRIPTION) - .withTimestampAttribute("timestampAttribute") - .withIdAttribute("idAttribute")), - testCase( - "configuration with subscription and dead letter queue", - PubsubReadSchemaTransformConfiguration.builder() - .setSubscription(SUBSCRIPTION) - .setDataSchema(SCHEMA) - .setDeadLetterQueue(TOPIC) - .build()) - .withExpectedPubsubRead(PubsubIO.readMessages().fromSubscription(SUBSCRIPTION)) - .withExpectedDeadLetterQueue(PubsubIO.writeMessages().to(TOPIC)), - testCase( - "configuration with subscription, timestamp attribute, and dead letter queue", - PubsubReadSchemaTransformConfiguration.builder() - .setSubscription(SUBSCRIPTION) - .setTimestampAttribute("timestampAttribute") - .setDataSchema(SCHEMA) - .setDeadLetterQueue(TOPIC) - .build()) - .withExpectedPubsubRead( - PubsubIO.readMessages() - .fromSubscription(SUBSCRIPTION) - .withTimestampAttribute("timestampAttribute")) - .withExpectedDeadLetterQueue( - PubsubIO.writeMessages().to(TOPIC).withTimestampAttribute("timestampAttribute"))); - - private static final AutoValueSchema AUTO_VALUE_SCHEMA = new AutoValueSchema(); - private static final TypeDescriptor TYPE_DESCRIPTOR = - TypeDescriptor.of(PubsubReadSchemaTransformConfiguration.class); - private static final SerializableFunction - ROW_SERIALIZABLE_FUNCTION = AUTO_VALUE_SCHEMA.toRowFunction(TYPE_DESCRIPTOR); - private static final List ROWS = Arrays.asList( - Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 100L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 200L).build(), - Row.withSchema(SCHEMA) + Row.withSchema(BEAM_SCHEMA) + .withFieldValue("name", "a") + .withFieldValue("number", 100L) + .build(), + Row.withSchema(BEAM_SCHEMA) + .withFieldValue("name", "b") + .withFieldValue("number", 200L) + .build(), + Row.withSchema(BEAM_SCHEMA) .withFieldValue("name", "c") .withFieldValue("number", 300L) .build()); - private static final Clock CLOCK = (Clock & Serializable) () -> 1656788475425L; + private static final List ROWSWITHERROR = + Arrays.asList( + Row.withSchema(BEAM_SCHEMA_WITH_ERROR).withFieldValue("error", "a").build(), + Row.withSchema(BEAM_SCHEMA_WITH_ERROR).withFieldValue("error", "b").build(), + Row.withSchema(BEAM_SCHEMA_WITH_ERROR).withFieldValue("error", "c").build()); + + private static final Clock CLOCK = (Clock & Serializable) () -> 1678988970000L; private static final AvroPayloadSerializerProvider AVRO_PAYLOAD_SERIALIZER_PROVIDER = new AvroPayloadSerializerProvider(); private static final PayloadSerializer AVRO_PAYLOAD_SERIALIZER = - AVRO_PAYLOAD_SERIALIZER_PROVIDER.getSerializer(SCHEMA, new HashMap<>()); + AVRO_PAYLOAD_SERIALIZER_PROVIDER.getSerializer(BEAM_SCHEMA, new HashMap<>()); + private static final PayloadSerializer AVRO_PAYLOAD_SERIALIZER_WITH_ERROR = + AVRO_PAYLOAD_SERIALIZER_PROVIDER.getSerializer(BEAM_SCHEMA_WITH_ERROR, new HashMap<>()); - @Rule public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); + @Rule public transient TestPipeline p = TestPipeline.create(); @Test - public void testBuildDeadLetterQueueWrite() { - for (TestCase testCase : cases) { - PubsubIO.Write dlq = - testCase.pubsubReadSchemaTransform().buildDeadLetterQueueWrite(); - - if (testCase.expectedDeadLetterQueue == null) { - assertNull(testCase.name, dlq); - return; - } - - Map actual = DisplayData.from(dlq).asMap(); - Map expected = testCase.expectedDeadLetterQueue; - - assertEquals(testCase.name, expected, actual); - } + public void testInvalidConfigNoTopicOrSubscription() { + assertThrows( + IllegalArgumentException.class, + () -> + new PubsubReadSchemaTransformProvider() + .from( + PubsubReadSchemaTransformConfiguration.builder() + .setSchema(SCHEMA) + .setFormat("AVRO") + .build())); } @Test - public void testReadAvro() throws IOException { + public void testInvalidConfigBothTopicAndSubscription() { PCollectionRowTuple begin = PCollectionRowTuple.empty(p); - PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform transform = - schemaTransformWithClock("avro"); - PubsubTestClient.PubsubTestClientFactory clientFactory = - clientFactory(incomingAvroMessagesOf(CLOCK.currentTimeMillis())); - transform.setClientFactory(clientFactory); - PCollectionRowTuple reads = begin.apply(transform); - - PAssert.that(reads.get(PubsubReadSchemaTransformProvider.OUTPUT_TAG)).containsInAnyOrder(ROWS); - + assertThrows( + IllegalArgumentException.class, + () -> + begin.apply( + new PubsubReadSchemaTransformProvider() + .from( + PubsubReadSchemaTransformConfiguration.builder() + .setSchema(SCHEMA) + .setFormat("AVRO") + .setTopic(TOPIC) + .setSubscription(SUBSCRIPTION) + .build()))); p.run().waitUntilFinish(); - clientFactory.close(); } @Test - public void testReadJson() throws IOException { + public void testInvalidConfigInvalidFormat() { PCollectionRowTuple begin = PCollectionRowTuple.empty(p); - PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform transform = - schemaTransformWithClock("json"); - PubsubTestClient.PubsubTestClientFactory clientFactory = - clientFactory(incomingJsonMessagesOf(CLOCK.currentTimeMillis())); - transform.setClientFactory(clientFactory); - PCollectionRowTuple reads = begin.apply(transform); - - PAssert.that(reads.get(PubsubReadSchemaTransformProvider.OUTPUT_TAG)).containsInAnyOrder(ROWS); - + assertThrows( + IllegalArgumentException.class, + () -> + begin.apply( + new PubsubReadSchemaTransformProvider() + .from( + PubsubReadSchemaTransformConfiguration.builder() + .setSchema(SCHEMA) + .setFormat("BadFormat") + .setSubscription(SUBSCRIPTION) + .build()))); p.run().waitUntilFinish(); - - clientFactory.close(); - } - - @Test - public void testBuildPubSubRead() { - for (TestCase testCase : cases) { - if (testCase.invalidConfigurationExpected) { - continue; - } - Map actual = - DisplayData.from(testCase.pubsubReadSchemaTransform().buildPubsubRead()).asMap(); - - Map expected = testCase.expectedPubsubRead; - - assertEquals(testCase.name, expected, actual); - } - } - - @Test - public void testInvalidConfiguration() { - for (TestCase testCase : cases) { - PCollectionRowTuple begin = PCollectionRowTuple.empty(p); - if (testCase.invalidConfigurationExpected) { - assertThrows( - testCase.name, - RuntimeException.class, - () -> begin.apply(testCase.pubsubReadSchemaTransform())); - } - } } @Test - public void testInvalidInput() { - PCollectionRowTuple begin = PCollectionRowTuple.of("BadInput", p.apply(Create.of(ROWS))); + public void testNoSchema() { + PCollectionRowTuple begin = PCollectionRowTuple.empty(p); assertThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> begin.apply( new PubsubReadSchemaTransformProvider() .from( PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(SCHEMA) + .setSubscription(SUBSCRIPTION) + .setFormat("AVRO") .build()))); + p.run().waitUntilFinish(); } - private PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform schemaTransformWithClock( - String format) { - PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform transform = - (PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform) - new PubsubReadSchemaTransformProvider() - .from( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(SCHEMA) - .setSubscription(SUBSCRIPTION) - .setFormat(format) - .build()); - - transform.setClock(CLOCK); + @Test + public void testReadAvro() throws IOException { + PCollectionRowTuple begin = PCollectionRowTuple.empty(p); - return transform; + try (PubsubTestClientFactory clientFactory = clientFactory(beamRowToMessage())) { + PubsubReadSchemaTransformConfiguration config = + PubsubReadSchemaTransformConfiguration.builder() + .setFormat("AVRO") + .setSchema(SCHEMA) + .setSubscription(SUBSCRIPTION) + .setClientFactory(clientFactory) + .setClock(CLOCK) + .build(); + SchemaTransform transform = new PubsubReadSchemaTransformProvider().from(config); + PCollectionRowTuple reads = begin.apply(transform); + + PAssert.that(reads.get("output")).containsInAnyOrder(ROWS); + + p.run().waitUntilFinish(); + } catch (Exception e) { + throw e; + } } - private static PubsubTestClient.PubsubTestClientFactory clientFactory( - List messages) { - return PubsubTestClient.createFactoryForPull( - CLOCK, PubsubClient.subscriptionPathFromPath(SUBSCRIPTION), 60, messages); - } + @Test + public void testReadAvroWithError() throws IOException { + PCollectionRowTuple begin = PCollectionRowTuple.empty(p); - private static List incomingAvroMessagesOf(long millisSinceEpoch) { - return ROWS.stream() - .map(row -> incomingAvroMessageOf(row, millisSinceEpoch)) - .collect(Collectors.toList()); - } + try (PubsubTestClientFactory clientFactory = clientFactory(beamRowToMessageWithError())) { + PubsubReadSchemaTransformConfiguration config = + PubsubReadSchemaTransformConfiguration.builder() + .setFormat("AVRO") + .setSchema(SCHEMA) + .setSubscription(SUBSCRIPTION) + .setClientFactory(clientFactory) + .setClock(CLOCK) + .build(); + SchemaTransform transform = new PubsubReadSchemaTransformProvider().from(config); + PCollectionRowTuple reads = begin.apply(transform); + + PAssert.that(reads.get("output")).empty(); + + PipelineResult result = p.run(); + result.waitUntilFinish(); + + MetricResults metrics = result.metrics(); + MetricQueryResults metricResults = + metrics.queryMetrics( + MetricsFilter.builder() + .addNameFilter( + MetricNameFilter.named( + PubsubReadSchemaTransformProvider.class, "PubSub-read-error-counter")) + .build()); + + Iterable> counters = metricResults.getCounters(); + if (!counters.iterator().hasNext()) { + throw new RuntimeException("no counters available "); + } - private static PubsubClient.IncomingMessage incomingAvroMessageOf( - Row row, long millisSinceEpoch) { - byte[] bytes = AVRO_PAYLOAD_SERIALIZER.serialize(row); - return incomingMessageOf(bytes, millisSinceEpoch); + Long expectedCount = 3L; + for (MetricResult count : counters) { + assertEquals(expectedCount, count.getAttempted()); + } + } catch (Exception e) { + throw e; + } } - private static List incomingJsonMessagesOf(long millisSinceEpoch) { - return PubsubReadSchemaTransformProviderTest.ROWS.stream() - .map(row -> incomingJsonMessageOf(row, millisSinceEpoch)) + private static List beamRowToMessage() { + long timestamp = CLOCK.currentTimeMillis(); + return ROWS.stream() + .map( + row -> { + byte[] bytes = AVRO_PAYLOAD_SERIALIZER.serialize(row); + return incomingMessageOf(bytes, timestamp); + }) .collect(Collectors.toList()); } - private static PubsubClient.IncomingMessage incomingJsonMessageOf( - Row row, long millisSinceEpoch) { - String name = Objects.requireNonNull(row.getString("name")); - long number = Objects.requireNonNull(row.getInt64("number")); - return incomingJsonMessageOf(name, number, millisSinceEpoch); - } - - private static PubsubClient.IncomingMessage incomingJsonMessageOf( - String name, long number, long millisSinceEpoch) { - Gson gson = new Gson(); - JsonObject obj = new JsonObject(); - obj.add("name", new JsonPrimitive(name)); - obj.add("number", new JsonPrimitive(number)); - byte[] bytes = gson.toJson(obj).getBytes(StandardCharsets.UTF_8); - return incomingMessageOf(bytes, millisSinceEpoch); + private static List beamRowToMessageWithError() { + long timestamp = CLOCK.currentTimeMillis(); + return ROWSWITHERROR.stream() + .map( + row -> { + byte[] bytes = AVRO_PAYLOAD_SERIALIZER_WITH_ERROR.serialize(row); + return incomingMessageOf(bytes, timestamp); + }) + .collect(Collectors.toList()); } private static PubsubClient.IncomingMessage incomingMessageOf( @@ -329,51 +266,9 @@ private static PubsubClient.IncomingMessage incomingMessageOf( UUID.randomUUID().toString()); } - static TestCase testCase(String name, PubsubReadSchemaTransformConfiguration configuration) { - return new TestCase(name, configuration); - } - - private static class TestCase { - - private final String name; - private final PubsubReadSchemaTransformConfiguration configuration; - - private Map expectedDeadLetterQueue; - - private Map expectedPubsubRead = - DisplayData.from(PubsubIO.readMessages()).asMap(); - - private boolean invalidConfigurationExpected = false; - - TestCase(String name, PubsubReadSchemaTransformConfiguration configuration) { - this.name = name; - this.configuration = configuration; - } - - PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform pubsubReadSchemaTransform() { - PubsubReadSchemaTransformProvider provider = new PubsubReadSchemaTransformProvider(); - Row configurationRow = toBeamRow(); - return (PubsubReadSchemaTransformProvider.PubsubReadSchemaTransform) - provider.from(configurationRow); - } - - private Row toBeamRow() { - return ROW_SERIALIZABLE_FUNCTION.apply(configuration); - } - - TestCase withExpectedDeadLetterQueue(PubsubIO.Write value) { - this.expectedDeadLetterQueue = DisplayData.from(value).asMap(); - return this; - } - - TestCase withExpectedPubsubRead(PubsubIO.Read value) { - this.expectedPubsubRead = DisplayData.from(value).asMap(); - return this; - } - - TestCase expectInvalidConfiguration() { - this.invalidConfigurationExpected = true; - return this; - } + private static PubsubTestClient.PubsubTestClientFactory clientFactory( + List messages) { + return PubsubTestClient.createFactoryForPull( + CLOCK, PubsubClient.subscriptionPathFromPath(SUBSCRIPTION), 60, messages); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubSchemaTransformMessageToRowFactoryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubSchemaTransformMessageToRowFactoryTest.java deleted file mode 100644 index 709fc35e02ae..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubSchemaTransformMessageToRowFactoryTest.java +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.pubsub; - -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageToRow.ATTRIBUTES_FIELD; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageToRow.PAYLOAD_FIELD; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import org.apache.beam.sdk.extensions.avro.schemas.io.payloads.AvroPayloadSerializerProvider; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.io.payloads.JsonPayloadSerializerProvider; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializerProvider; -import org.apache.beam.sdk.values.Row; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Test for {@link PubsubSchemaTransformMessageToRowFactory}. */ -@RunWith(JUnit4.class) -public class PubsubSchemaTransformMessageToRowFactoryTest { - - List cases = - Arrays.asList( - testCase(PubsubReadSchemaTransformConfiguration.builder().setDataSchema(SCHEMA)) - .expectPayloadSerializerProvider(JSON_PAYLOAD_SERIALIZER_PROVIDER) - .withSerializerInput(), - testCase(PubsubReadSchemaTransformConfiguration.builder().setDataSchema(SCHEMA)) - .expectPubsubToRow( - PubsubMessageToRow.builder() - .messageSchema(SCHEMA) - .useFlatSchema(true) - .useDlq(false)), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(SCHEMA) - .setDeadLetterQueue("projects/project/topics/topic")) - .expectPubsubToRow( - PubsubMessageToRow.builder() - .messageSchema(SCHEMA) - .useFlatSchema(true) - .useDlq(true)), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(SCHEMA) - .setFormat("avro")) - .expectPayloadSerializerProvider(AVRO_PAYLOAD_SERIALIZER_PROVIDER) - .withSerializerInput(), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(ATTRIBUTES_FIELD_ARRAY))) - .schemaShouldHaveValidAttributesField() - .fieldShouldBePresent( - ATTRIBUTES_FIELD_ARRAY.getName(), ATTRIBUTES_FIELD_ARRAY.getType()), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(ATTRIBUTES_FIELD_MAP))) - .schemaShouldHaveValidAttributesField() - .fieldShouldBePresent(ATTRIBUTES_FIELD_MAP.getName(), ATTRIBUTES_FIELD_MAP.getType()), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(ATTRIBUTES_FIELD_SHOULD_NOT_MATCH))), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(PAYLOAD_FIELD_SHOULD_NOT_MATCH))), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(PAYLOAD_FIELD_BYTES))) - .schemaShouldHaveValidPayloadField() - .fieldShouldBePresent(PAYLOAD_FIELD_BYTES.getName(), PAYLOAD_FIELD_BYTES.getType()), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(PAYLOAD_FIELD_ROW))) - .schemaShouldHaveValidPayloadField() - .fieldShouldBePresent(PAYLOAD_FIELD_ROW.getName(), PAYLOAD_FIELD_ROW.getType()), - testCase( - PubsubReadSchemaTransformConfiguration.builder() - .setDataSchema(Schema.of(ATTRIBUTES_FIELD_ARRAY, PAYLOAD_FIELD_BYTES))) - .schemaShouldHaveValidAttributesField() - .schemaShouldHaveValidPayloadField() - .shouldUseNestedSchema() - .shouldNotNeedSerializer() - .expectPubsubToRow( - PubsubMessageToRow.builder() - .messageSchema(Schema.of(ATTRIBUTES_FIELD_ARRAY, PAYLOAD_FIELD_BYTES)) - .useFlatSchema(false) - .useDlq(false))); - - static final Schema.FieldType ATTRIBUTE_MAP_FIELD_TYPE = - Schema.FieldType.map(Schema.FieldType.STRING.withNullable(false), Schema.FieldType.STRING); - static final Schema ATTRIBUTE_ARRAY_ENTRY_SCHEMA = - Schema.builder().addStringField("key").addStringField("value").build(); - - static final Schema.FieldType ATTRIBUTE_ARRAY_FIELD_TYPE = - Schema.FieldType.array(Schema.FieldType.row(ATTRIBUTE_ARRAY_ENTRY_SCHEMA)); - - private static final Schema.Field ATTRIBUTES_FIELD_SHOULD_NOT_MATCH = - Schema.Field.of(ATTRIBUTES_FIELD, Schema.FieldType.STRING); - - private static final Schema.Field ATTRIBUTES_FIELD_MAP = - Schema.Field.of(ATTRIBUTES_FIELD, ATTRIBUTE_MAP_FIELD_TYPE); - - private static final Schema.Field ATTRIBUTES_FIELD_ARRAY = - Schema.Field.of(ATTRIBUTES_FIELD, ATTRIBUTE_ARRAY_FIELD_TYPE); - - private static final Schema.Field PAYLOAD_FIELD_SHOULD_NOT_MATCH = - Schema.Field.of(PAYLOAD_FIELD, Schema.FieldType.STRING); - - private static final Schema.Field PAYLOAD_FIELD_BYTES = - Schema.Field.of(PAYLOAD_FIELD, Schema.FieldType.BYTES); - - private static final Schema.Field PAYLOAD_FIELD_ROW = - Schema.Field.of(PAYLOAD_FIELD, Schema.FieldType.row(Schema.of())); - - private static final PayloadSerializerProvider JSON_PAYLOAD_SERIALIZER_PROVIDER = - new JsonPayloadSerializerProvider(); - - private static final AvroPayloadSerializerProvider AVRO_PAYLOAD_SERIALIZER_PROVIDER = - new AvroPayloadSerializerProvider(); - - private static final Schema SCHEMA = - Schema.of( - Schema.Field.of("name", Schema.FieldType.STRING), - Schema.Field.of("number", Schema.FieldType.INT64)); - - private static final Row ROW = - Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(); - - @Test - public void testBuildMessageToRow() { - for (TestCase testCase : cases) { - if (testCase.expectPubsubToRow == null) { - continue; - } - - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - - PubsubMessageToRow expected = testCase.expectPubsubToRow; - PubsubMessageToRow actual = factory.buildMessageToRow(); - - assertEquals("messageSchema", expected.messageSchema(), actual.messageSchema()); - assertEquals("useFlatSchema", expected.useFlatSchema(), actual.useFlatSchema()); - assertEquals("useDlq", expected.useDlq(), actual.useDlq()); - } - } - - @Test - public void serializer() { - for (TestCase testCase : cases) { - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - - if (testCase.expectPayloadSerializerProvider == null) { - continue; - } - - Row serializerInput = testCase.serializerInput; - - byte[] expectedBytes = - testCase - .expectSerializerProvider() - .apply(testCase.dataSchema()) - .serialize(serializerInput); - - byte[] actualBytes = - factory.serializer().apply(testCase.dataSchema()).serialize(serializerInput); - - String expected = new String(expectedBytes, StandardCharsets.UTF_8); - String actual = new String(actualBytes, StandardCharsets.UTF_8); - - assertEquals(expected, actual); - } - } - - @Test - public void needsSerializer() { - for (TestCase testCase : cases) { - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - - boolean expected = testCase.shouldNeedSerializer; - boolean actual = factory.needsSerializer(); - - assertEquals(expected, actual); - } - } - - @Test - public void shouldUseNestedSchema() { - for (TestCase testCase : cases) { - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - - boolean expected = testCase.shouldUseNestedSchema; - boolean actual = factory.shouldUseNestedSchema(); - - assertEquals(expected, actual); - } - } - - @Test - public void schemaHasValidPayloadField() { - for (TestCase testCase : cases) { - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - - boolean expected = testCase.shouldSchemaHaveValidPayloadField; - boolean actual = factory.schemaHasValidPayloadField(); - - assertEquals(expected, actual); - } - } - - @Test - public void schemaHasValidAttributesField() { - for (TestCase testCase : cases) { - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - - boolean expected = testCase.shouldSchemaHaveValidAttributesField; - boolean actual = factory.schemaHasValidAttributesField(); - - assertEquals(expected, actual); - } - } - - @Test - public void fieldPresent() { - for (TestCase testCase : cases) { - PubsubSchemaTransformMessageToRowFactory factory = testCase.factory(); - for (Entry entry : testCase.shouldFieldPresent.entrySet()) { - - boolean actual = factory.fieldPresent(entry.getKey(), entry.getValue()); - - assertTrue(actual); - } - } - } - - static TestCase testCase(PubsubReadSchemaTransformConfiguration.Builder configurationBuilder) { - return new TestCase(configurationBuilder); - } - - private static class TestCase { - private final PubsubReadSchemaTransformConfiguration configuration; - - private PubsubMessageToRow expectPubsubToRow; - - private PayloadSerializerProvider expectPayloadSerializerProvider; - - private boolean shouldUseNestedSchema = false; - private boolean shouldNeedSerializer = true; - private boolean shouldSchemaHaveValidPayloadField = false; - private boolean shouldSchemaHaveValidAttributesField = false; - private final Map shouldFieldPresent = new HashMap<>(); - - private Row serializerInput; - - TestCase(PubsubReadSchemaTransformConfiguration.Builder configurationBuilder) { - this.configuration = configurationBuilder.build(); - } - - PubsubSchemaTransformMessageToRowFactory factory() { - return PubsubSchemaTransformMessageToRowFactory.from(configuration); - } - - Schema dataSchema() { - return configuration.getDataSchema(); - } - - TestCase expectPubsubToRow(PubsubMessageToRow.Builder pubsubMessageToRowBuilder) { - this.expectPubsubToRow = pubsubMessageToRowBuilder.build(); - return this; - } - - TestCase withSerializerInput() { - this.serializerInput = PubsubSchemaTransformMessageToRowFactoryTest.ROW; - return this; - } - - TestCase expectPayloadSerializerProvider(PayloadSerializerProvider value) { - this.expectPayloadSerializerProvider = value; - return this; - } - - PubsubMessageToRow.SerializerProvider expectSerializerProvider() { - Map params = new HashMap<>(); - PayloadSerializer payloadSerializer = - expectPayloadSerializerProvider.getSerializer(configuration.getDataSchema(), params); - - return (input -> payloadSerializer); - } - - TestCase shouldUseNestedSchema() { - this.shouldUseNestedSchema = true; - return this; - } - - TestCase shouldNotNeedSerializer() { - this.shouldNeedSerializer = false; - return this; - } - - TestCase schemaShouldHaveValidPayloadField() { - this.shouldSchemaHaveValidPayloadField = true; - return this; - } - - TestCase schemaShouldHaveValidAttributesField() { - this.shouldSchemaHaveValidAttributesField = true; - return this; - } - - TestCase fieldShouldBePresent(String name, Schema.FieldType expectedType) { - this.shouldFieldPresent.put(name, expectedType); - return this; - } - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProviderIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProviderIT.java deleted file mode 100644 index cb0e6ec03ccd..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProviderIT.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.pubsub; - -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubWriteSchemaTransformConfiguration.DEFAULT_TIMESTAMP_ATTRIBUTE; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubWriteSchemaTransformProvider.INPUT_TAG; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.IncomingMessage; -import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; -import org.apache.commons.lang3.tuple.Pair; -import org.joda.time.Instant; -import org.joda.time.format.ISODateTimeFormat; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; - -/** Integration tests for {@link PubsubWriteSchemaTransformProvider}. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class PubsubWriteSchemaTransformProviderIT { - - @Rule public transient TestPipeline pipeline = TestPipeline.create(); - - private static final TestPubsubOptions TEST_PUBSUB_OPTIONS = - TestPipeline.testingPipelineOptions().as(TestPubsubOptions.class); - - static { - TEST_PUBSUB_OPTIONS.setBlockOnRun(false); - } - - private static final String HAS_NO_SCHEMA = "has-no-schema"; - - private static PubsubClient pubsubClient; - - private static PubsubClient.TopicPath hasNoSchemaTopic; - - private static PubsubClient.SubscriptionPath hasNoSchemaSubscription; - - private static final Instant TIMESTAMP = Instant.now(); - - private static final String RESOURCE_NAME_POSTFIX = "-" + TIMESTAMP.getMillis(); - - private static final int ACK_DEADLINE_SECONDS = 60; - - private static final int AWAIT_TERMINATED_SECONDS = 30; - - private static final AutoValueSchema AUTO_VALUE_SCHEMA = new AutoValueSchema(); - - private static final TypeDescriptor - CONFIGURATION_TYPE_DESCRIPTOR = - TypeDescriptor.of(PubsubWriteSchemaTransformConfiguration.class); - - private static final SerializableFunction - TO_ROW_FN = AUTO_VALUE_SCHEMA.toRowFunction(CONFIGURATION_TYPE_DESCRIPTOR); - - private final Field timestampField = Field.of("timestamp", FieldType.DATETIME); - - private final Field payloadBytesField = Field.of("payload", FieldType.BYTES); - - @BeforeClass - public static void setUp() throws IOException { - String project = TEST_PUBSUB_OPTIONS.as(PubsubOptions.class).getProject(); - pubsubClient = PubsubGrpcClient.FACTORY.newClient(null, null, TEST_PUBSUB_OPTIONS); - hasNoSchemaTopic = - PubsubClient.topicPathFromName(project, HAS_NO_SCHEMA + RESOURCE_NAME_POSTFIX); - hasNoSchemaSubscription = - PubsubClient.subscriptionPathFromName(project, HAS_NO_SCHEMA + RESOURCE_NAME_POSTFIX); - - pubsubClient.createTopic(hasNoSchemaTopic); - pubsubClient.createSubscription( - hasNoSchemaTopic, hasNoSchemaSubscription, ACK_DEADLINE_SECONDS); - } - - @AfterClass - public static void tearDown() throws IOException { - pubsubClient.deleteSubscription(hasNoSchemaSubscription); - pubsubClient.deleteTopic(hasNoSchemaTopic); - - pubsubClient.close(); - } - - @Test - public void testWritePayloadBytes() throws IOException { - Instant timestamp = Instant.ofEpochMilli(100000L); - Schema schema = Schema.of(payloadBytesField, timestampField); - List input = - Collections.singletonList( - Row.withSchema(schema).attachValues("aaa".getBytes(StandardCharsets.UTF_8), timestamp)); - Row configuration = - TO_ROW_FN.apply( - PubsubWriteSchemaTransformConfiguration.builder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setPayloadFieldName(payloadBytesField.getName()) - .setTimestampFieldName(timestampField.getName()) - .build()) - .setTopic(hasNoSchemaTopic.getPath()) - .setTarget( - PubsubWriteSchemaTransformConfiguration.targetConfigurationBuilder().build()) - .build()); - - PCollectionRowTuple.of(INPUT_TAG, pipeline.apply(Create.of(input).withRowSchema(schema))) - .apply(new PubsubWriteSchemaTransformProvider().from(configuration)); - - PipelineResult job = pipeline.run(TEST_PUBSUB_OPTIONS); - Instant now = Instant.now(); - Instant stop = Instant.ofEpochMilli(now.getMillis() + AWAIT_TERMINATED_SECONDS * 1000); - List>> actualList = new ArrayList<>(); - while (now.isBefore(stop)) { - List received = pubsubClient.pull(0, hasNoSchemaSubscription, 1, true); - for (IncomingMessage incoming : received) { - actualList.add( - Pair.of( - incoming.message().getData().toStringUtf8(), - ImmutableMap.of( - DEFAULT_TIMESTAMP_ATTRIBUTE, - incoming - .message() - .getAttributesMap() - .getOrDefault(DEFAULT_TIMESTAMP_ATTRIBUTE, "")))); - } - if (actualList.size() == input.size()) { - break; - } - now = Instant.now(); - } - job.cancel(); - assertFalse( - String.format( - "messages pulled from %s should not be empty", hasNoSchemaSubscription.getPath()), - actualList.isEmpty()); - Pair> actual = actualList.get(0); - Row expected = input.get(0); - String payload = - new String( - Objects.requireNonNull(expected.getBytes(payloadBytesField.getName())), - StandardCharsets.UTF_8); - assertEquals(payload, actual.getLeft()); - assertEquals( - ISODateTimeFormat.dateTime().print(timestamp), - actual.getRight().get(DEFAULT_TIMESTAMP_ATTRIBUTE)); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProviderTest.java deleted file mode 100644 index 98939f7ddc68..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProviderTest.java +++ /dev/null @@ -1,786 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.pubsub; - -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.ATTRIBUTES_FIELD_TYPE; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.DEFAULT_ATTRIBUTES_KEY_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.DEFAULT_EVENT_TIMESTAMP_KEY_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.DEFAULT_PAYLOAD_KEY_NAME; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessage.EVENT_TIMESTAMP_FIELD_TYPE; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessageTest.ALL_DATA_TYPES_SCHEMA; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessageTest.NON_USER_WITH_BYTES_PAYLOAD; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubRowToMessageTest.rowWithAllDataTypes; -import static org.apache.beam.sdk.io.gcp.pubsub.PubsubWriteSchemaTransformProvider.INPUT_TAG; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThrows; - -import com.google.api.client.util.Clock; -import java.io.IOException; -import java.io.Serializable; -import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; -import java.util.Map; -import org.apache.avro.SchemaParseException; -import org.apache.beam.sdk.extensions.avro.schemas.io.payloads.AvroPayloadSerializerProvider; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.SchemaPath; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.TopicPath; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubTestClient.PubsubTestClientFactory; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubWriteSchemaTransformProvider.PubsubWriteSchemaTransform; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.io.payloads.JsonPayloadSerializerProvider; -import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.util.RowJson.UnsupportedRowJsonException; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; -import org.joda.time.Instant; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link PubsubWriteSchemaTransformProvider}. */ -@RunWith(JUnit4.class) -public class PubsubWriteSchemaTransformProviderTest { - - private static final String ID_ATTRIBUTE = "id_attribute"; - private static final String TOPIC = "projects/project/topics/topic"; - private static final MockClock CLOCK = new MockClock(Instant.now()); - private static final AutoValueSchema AUTO_VALUE_SCHEMA = new AutoValueSchema(); - private static final TypeDescriptor TYPE_DESCRIPTOR = - TypeDescriptor.of(PubsubWriteSchemaTransformConfiguration.class); - private static final SerializableFunction TO_ROW = - AUTO_VALUE_SCHEMA.toRowFunction(TYPE_DESCRIPTOR); - - private static final PipelineOptions OPTIONS = PipelineOptionsFactory.create(); - - static { - OPTIONS.setStableUniqueNames(PipelineOptions.CheckEnabled.OFF); - } - - @Rule public transient TestPipeline pipeline = TestPipeline.create(); - - @Test - public void testBuildPubsubWrite() { - assertEquals( - "default configuration should yield a topic Pub/Sub write", - pubsubWrite(), - transform(configurationBuilder()).buildPubsubWrite()); - - assertEquals( - "idAttribute in configuration should yield a idAttribute set Pub/Sub write", - pubsubWrite().withIdAttribute(ID_ATTRIBUTE), - transform(configurationBuilder().setIdAttribute(ID_ATTRIBUTE)).buildPubsubWrite()); - } - - @Test - public void testBuildPubsubRowToMessage() { - assertEquals( - "override timestamp attribute on configuration should yield a PubsubRowToMessage with target timestamp", - rowToMessageBuilder().setTargetTimestampAttributeName("custom_timestamp_attribute").build(), - transform( - configurationBuilder() - .setTarget( - PubsubWriteSchemaTransformConfiguration.targetConfigurationBuilder() - .setTimestampAttributeKey("custom_timestamp_attribute") - .build())) - .buildPubsubRowToMessage(NON_USER_WITH_BYTES_PAYLOAD)); - - assertNull( - "failing to set format should yield a null payload serializer", - transform(configurationBuilder()) - .buildPubsubRowToMessage(ALL_DATA_TYPES_SCHEMA) - .getPayloadSerializer()); - - assertThrows( - "setting 'json' format for a unsupported field containing Schema should throw an Exception", - UnsupportedRowJsonException.class, - () -> - transform(configurationBuilder().setFormat("json")) - .buildPubsubRowToMessage( - Schema.of(Field.of(DEFAULT_ATTRIBUTES_KEY_NAME, ATTRIBUTES_FIELD_TYPE)))); - - assertThrows( - "setting 'avro' format for a unsupported field containing Schema should throw an Exception", - SchemaParseException.class, - () -> - transform(configurationBuilder().setFormat("avro")) - .buildPubsubRowToMessage( - Schema.of(Field.of(DEFAULT_ATTRIBUTES_KEY_NAME, ATTRIBUTES_FIELD_TYPE)))); - - assertNotNull( - "setting 'json' format for valid schema should yield PayloadSerializer", - transform(configurationBuilder().setFormat("json")) - .buildPubsubRowToMessage(ALL_DATA_TYPES_SCHEMA) - .getPayloadSerializer()); - - assertNotNull( - "setting 'avro' format for valid schema should yield PayloadSerializer", - transform(configurationBuilder().setFormat("avro")) - .buildPubsubRowToMessage(ALL_DATA_TYPES_SCHEMA) - .getPayloadSerializer()); - } - - @Test - public void testInvalidTaggedInput() { - Row withAllDataTypes = - rowWithAllDataTypes( - true, - (byte) 0, - Instant.now().toDateTime(), - BigDecimal.valueOf(1L), - 3.12345, - 4.1f, - (short) 5, - 2, - 7L, - "asdfjkl;"); - - PCollection rows = - pipeline.apply(Create.of(withAllDataTypes)).setRowSchema(ALL_DATA_TYPES_SCHEMA); - - assertThrows( - "empty input should not be allowed", - IllegalArgumentException.class, - () -> transform(configurationBuilder()).expand(PCollectionRowTuple.empty(pipeline))); - - assertThrows( - "input with >1 tagged rows should not be allowed", - IllegalArgumentException.class, - () -> - transform(configurationBuilder()) - .expand(PCollectionRowTuple.of(INPUT_TAG, rows).and("somethingelse", rows))); - - assertThrows( - "input missing INPUT tag should not be allowed", - IllegalArgumentException.class, - () -> - transform(configurationBuilder()) - .expand(PCollectionRowTuple.of("somethingelse", rows))); - - pipeline.run(OPTIONS); - } - - @Test - public void testValidateSourceSchemaAgainstConfiguration() { - // Only containing user fields and no configuration details should be valid - transform(configurationBuilder()) - .validateSourceSchemaAgainstConfiguration(ALL_DATA_TYPES_SCHEMA); - - // Matching attributes, timestamp, and payload (bytes) fields configured with expected types - // should be valid - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .setTimestampFieldName("timestamp") - .setPayloadFieldName("payload") - .build())) - .validateSourceSchemaAgainstConfiguration( - Schema.of( - Field.of("attributes", ATTRIBUTES_FIELD_TYPE), - Field.of("timestamp", EVENT_TIMESTAMP_FIELD_TYPE), - Field.of("payload", Schema.FieldType.BYTES))); - - // Matching attributes, timestamp, and payload (ROW) fields configured with expected types - // should be valid - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .setTimestampFieldName("timestamp") - .setPayloadFieldName("payload") - .build())) - .validateSourceSchemaAgainstConfiguration( - Schema.of( - Field.of("attributes", ATTRIBUTES_FIELD_TYPE), - Field.of("timestamp", EVENT_TIMESTAMP_FIELD_TYPE), - Field.of("payload", Schema.FieldType.row(ALL_DATA_TYPES_SCHEMA)))); - - assertThrows( - "empty Schema should be invalid", - IllegalArgumentException.class, - () -> - transform(configurationBuilder()) - .validateSourceSchemaAgainstConfiguration(Schema.of())); - - assertThrows( - "attributes field in configuration but not in schema should be invalid", - IllegalArgumentException.class, - () -> - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .build())) - .validateSourceSchemaAgainstConfiguration(ALL_DATA_TYPES_SCHEMA)); - - assertThrows( - "timestamp field in configuration but not in schema should be invalid", - IllegalArgumentException.class, - () -> - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setTimestampFieldName("timestamp") - .build())) - .validateSourceSchemaAgainstConfiguration(ALL_DATA_TYPES_SCHEMA)); - - assertThrows( - "payload field in configuration but not in schema should be invalid", - IllegalArgumentException.class, - () -> - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setPayloadFieldName("payload") - .build())) - .validateSourceSchemaAgainstConfiguration(ALL_DATA_TYPES_SCHEMA)); - - assertThrows( - "attributes field in configuration but mismatching attributes type should be invalid", - IllegalArgumentException.class, - () -> - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .build())) - .validateSourceSchemaAgainstConfiguration( - // should be FieldType.map(FieldType.STRING, FieldType.STRING) - Schema.of( - Field.of("attributes", FieldType.map(FieldType.BYTES, FieldType.STRING))))); - - assertThrows( - "timestamp field in configuration but mismatching timestamp type should be invalid", - IllegalArgumentException.class, - () -> - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("timestamp") - .build())) - .validateSourceSchemaAgainstConfiguration( - // should be FieldType.DATETIME - Schema.of(Field.of("timestamp", FieldType.STRING)))); - - assertThrows( - "payload field in configuration but mismatching payload type should be invalid", - IllegalArgumentException.class, - () -> - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("payload") - .build())) - .validateSourceSchemaAgainstConfiguration( - // should be FieldType.BYTES or FieldType.row(...) - Schema.of(Field.of("payload", FieldType.STRING)))); - } - - @Test - public void testValidateTargetSchemaAgainstPubsubSchema() throws IOException { - TopicPath topicPath = PubsubClient.topicPathFromPath(TOPIC); - PubsubTestClientFactory noSchemaFactory = - PubsubTestClient.createFactoryForGetSchema(topicPath, null, null); - - PubsubTestClientFactory schemaDeletedFactory = - PubsubTestClient.createFactoryForGetSchema(topicPath, SchemaPath.DELETED_SCHEMA, null); - - PubsubTestClientFactory mismatchingSchemaFactory = - PubsubTestClient.createFactoryForGetSchema( - topicPath, - PubsubClient.schemaPathFromId("testProject", "misMatch"), - Schema.of(Field.of("StringField", FieldType.STRING))); - - PubsubTestClientFactory matchingSchemaFactory = - PubsubTestClient.createFactoryForGetSchema( - topicPath, - PubsubClient.schemaPathFromId("testProject", "match"), - ALL_DATA_TYPES_SCHEMA); - - // Should pass validation exceptions if Pub/Sub topic lacks schema - transform(configurationBuilder()) - .withPubsubClientFactory(noSchemaFactory) - .validateTargetSchemaAgainstPubsubSchema(ALL_DATA_TYPES_SCHEMA, OPTIONS); - noSchemaFactory.close(); - - // Should pass validation if Pub/Sub topic schema deleted - transform(configurationBuilder()) - .withPubsubClientFactory(schemaDeletedFactory) - .validateTargetSchemaAgainstPubsubSchema(ALL_DATA_TYPES_SCHEMA, OPTIONS); - schemaDeletedFactory.close(); - - assertThrows( - "mismatched schema should be detected from Pub/Sub topic", - IllegalStateException.class, - () -> - transform(configurationBuilder()) - .withPubsubClientFactory(mismatchingSchemaFactory) - .validateTargetSchemaAgainstPubsubSchema(ALL_DATA_TYPES_SCHEMA, OPTIONS)); - mismatchingSchemaFactory.close(); - - // Should pass validation if Pub/Sub topic schema matches - transform(configurationBuilder()) - .withPubsubClientFactory(matchingSchemaFactory) - .validateTargetSchemaAgainstPubsubSchema(ALL_DATA_TYPES_SCHEMA, OPTIONS); - matchingSchemaFactory.close(); - } - - @Test - public void testBuildTargetSchema() { - - Field sourceAttributesField = Field.of("attributes", ATTRIBUTES_FIELD_TYPE); - Field sourceTimestampField = Field.of("timestamp", EVENT_TIMESTAMP_FIELD_TYPE); - Field sourcePayloadBytesField = Field.of("payload", FieldType.BYTES); - Field sourcePayloadRowField = Field.of("payload", FieldType.row(ALL_DATA_TYPES_SCHEMA)); - - Field targetAttributesField = Field.of(DEFAULT_ATTRIBUTES_KEY_NAME, ATTRIBUTES_FIELD_TYPE); - Field targetTimestampField = - Field.of(DEFAULT_EVENT_TIMESTAMP_KEY_NAME, EVENT_TIMESTAMP_FIELD_TYPE); - Field targetPayloadBytesField = Field.of(DEFAULT_PAYLOAD_KEY_NAME, FieldType.BYTES); - Field targetPayloadRowField = - Field.of(DEFAULT_PAYLOAD_KEY_NAME, FieldType.row(ALL_DATA_TYPES_SCHEMA)); - - assertEquals( - "attributes and timestamp field should append to user fields", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .build())) - .buildTargetSchema(ALL_DATA_TYPES_SCHEMA)); - - assertEquals( - "timestamp field should append to user fields; attributes field name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .build())) - .buildTargetSchema( - Schema.builder() - .addField(sourceAttributesField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build())); - - assertEquals( - "attributes field should append to user fields; timestamp field name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setTimestampFieldName("timestamp") - .build())) - .buildTargetSchema( - Schema.builder() - .addField(sourceTimestampField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build())); - - assertEquals( - "attributes and timestamp field appended to user payload bytes field; payload field name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addField(targetPayloadBytesField) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setPayloadFieldName("payload") - .build())) - .buildTargetSchema(Schema.builder().addField(sourcePayloadBytesField).build())); - - assertEquals( - "attributes and timestamp field appended to user payload row field; payload field name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addField(targetPayloadRowField) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setPayloadFieldName("payload") - .build())) - .buildTargetSchema(Schema.builder().addField(sourcePayloadRowField).build())); - - assertEquals( - "attributes and timestamp fields name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .setTimestampFieldName("timestamp") - .build())) - .buildTargetSchema( - Schema.builder() - .addField(sourceAttributesField) - .addField(sourceTimestampField) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build())); - - assertEquals( - "attributes, timestamp, payload bytes fields name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addFields(targetPayloadBytesField) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .setTimestampFieldName("timestamp") - .setPayloadFieldName("payload") - .build())) - .buildTargetSchema( - Schema.builder() - .addField(sourceAttributesField) - .addField(sourceTimestampField) - .addField(sourcePayloadBytesField) - .build())); - - assertEquals( - "attributes, timestamp, payload row fields name changed", - Schema.builder() - .addField(targetAttributesField) - .addField(targetTimestampField) - .addFields(targetPayloadRowField) - .build(), - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration.sourceConfigurationBuilder() - .setAttributesFieldName("attributes") - .setTimestampFieldName("timestamp") - .setPayloadFieldName("payload") - .build())) - .buildTargetSchema( - Schema.builder() - .addField(sourceAttributesField) - .addField(sourceTimestampField) - .addField(sourcePayloadRowField) - .build())); - } - - @Test - public void testConvertForRowToMessageTransform() { - Row userRow = - rowWithAllDataTypes( - false, - (byte) 0, - Instant.ofEpochMilli(CLOCK.currentTimeMillis()).toDateTime(), - BigDecimal.valueOf(1L), - 1.12345, - 1.1f, - (short) 1, - 1, - 1L, - "吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮"); - - Field sourceAttributes = Field.of("attributes", ATTRIBUTES_FIELD_TYPE); - Field targetAttributes = Field.of(DEFAULT_ATTRIBUTES_KEY_NAME, ATTRIBUTES_FIELD_TYPE); - - Field sourceTimestamp = Field.of("timestamp", EVENT_TIMESTAMP_FIELD_TYPE); - Field targetTimestamp = Field.of(DEFAULT_EVENT_TIMESTAMP_KEY_NAME, EVENT_TIMESTAMP_FIELD_TYPE); - - Field sourcePayloadBytes = Field.of("payload", FieldType.BYTES); - Field targetPayloadBytes = Field.of(DEFAULT_PAYLOAD_KEY_NAME, FieldType.BYTES); - - Field sourcePayloadRow = Field.of("payload", FieldType.row(ALL_DATA_TYPES_SCHEMA)); - Field targetPayloadRow = - Field.of(DEFAULT_PAYLOAD_KEY_NAME, FieldType.row(ALL_DATA_TYPES_SCHEMA)); - - Map attributes = ImmutableMap.of("a", "1"); - Instant generatedTimestamp = Instant.ofEpochMilli(CLOCK.currentTimeMillis()); - Instant timestampFromSource = Instant.ofEpochMilli(CLOCK.currentTimeMillis() + 10000L); - byte[] payloadBytes = "吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮".getBytes(StandardCharsets.UTF_8); - - PAssert.that( - "attributes only source yields attributes + timestamp target", - pipeline - .apply( - Create.of(Row.withSchema(Schema.of(sourceAttributes)).attachValues(attributes))) - .setRowSchema(Schema.of(sourceAttributes)) - .apply( - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration - .sourceConfigurationBuilder() - .setAttributesFieldName(sourceAttributes.getName()) - .build())) - .convertForRowToMessage( - Schema.of(targetAttributes, targetTimestamp), CLOCK)) - .setRowSchema(Schema.of(targetAttributes, targetTimestamp))) - .containsInAnyOrder( - Row.withSchema(Schema.of(targetAttributes, targetTimestamp)) - .attachValues(attributes, generatedTimestamp)); - - PAssert.that( - "timestamp only source yields attributes + timestamp target", - pipeline - .apply( - Create.of( - Row.withSchema(Schema.of(sourceTimestamp)) - .attachValues(timestampFromSource))) - .setRowSchema(Schema.of(sourceTimestamp)) - .apply( - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration - .sourceConfigurationBuilder() - .setTimestampFieldName(sourceTimestamp.getName()) - .build())) - .convertForRowToMessage( - Schema.of(targetAttributes, targetTimestamp), CLOCK)) - .setRowSchema(Schema.of(targetAttributes, targetTimestamp))) - .containsInAnyOrder( - Row.withSchema(Schema.of(targetAttributes, targetTimestamp)) - .attachValues(ImmutableMap.of(), timestampFromSource)); - - PAssert.that( - "timestamp and attributes source yields renamed fields in target", - pipeline - .apply( - Create.of( - Row.withSchema(Schema.of(sourceAttributes, sourceTimestamp)) - .attachValues(attributes, timestampFromSource))) - .setRowSchema(Schema.of(sourceAttributes, sourceTimestamp)) - .apply( - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration - .sourceConfigurationBuilder() - .setAttributesFieldName(sourceAttributes.getName()) - .setTimestampFieldName(sourceTimestamp.getName()) - .build())) - .convertForRowToMessage( - Schema.of(targetAttributes, targetTimestamp), CLOCK)) - .setRowSchema(Schema.of(targetAttributes, targetTimestamp))) - .containsInAnyOrder( - Row.withSchema(Schema.of(targetAttributes, targetTimestamp)) - .attachValues(attributes, timestampFromSource)); - - PAssert.that( - "bytes payload only source yields attributes + timestamp + renamed bytes payload target", - pipeline - .apply( - Create.of( - Row.withSchema(Schema.of(sourcePayloadBytes)) - .withFieldValue(sourcePayloadBytes.getName(), payloadBytes) - .build())) - .setRowSchema(Schema.of(sourcePayloadBytes)) - .apply( - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration - .sourceConfigurationBuilder() - .setPayloadFieldName(sourcePayloadBytes.getName()) - .build())) - .convertForRowToMessage( - Schema.of(targetAttributes, targetTimestamp, targetPayloadBytes), - CLOCK)) - .setRowSchema(Schema.of(targetAttributes, targetTimestamp, targetPayloadBytes))) - .containsInAnyOrder( - Row.withSchema(Schema.of(targetAttributes, targetTimestamp, targetPayloadBytes)) - .attachValues(ImmutableMap.of(), generatedTimestamp, payloadBytes)); - - PAssert.that( - "row payload only source yields attributes + timestamp + renamed row payload target", - pipeline - .apply(Create.of(Row.withSchema(Schema.of(sourcePayloadRow)).attachValues(userRow))) - .setRowSchema(Schema.of(sourcePayloadRow)) - .apply( - transform( - configurationBuilder() - .setSource( - PubsubWriteSchemaTransformConfiguration - .sourceConfigurationBuilder() - .setPayloadFieldName(sourcePayloadRow.getName()) - .build())) - .convertForRowToMessage( - Schema.of(targetAttributes, targetTimestamp, targetPayloadRow), CLOCK)) - .setRowSchema(Schema.of(targetAttributes, targetTimestamp, targetPayloadRow))) - .containsInAnyOrder( - Row.withSchema(Schema.of(targetAttributes, targetTimestamp, targetPayloadRow)) - .attachValues(ImmutableMap.of(), generatedTimestamp, userRow)); - - PAssert.that( - "user only fields source yields attributes + timestamp + user fields target", - pipeline - .apply(Create.of(userRow)) - .setRowSchema(ALL_DATA_TYPES_SCHEMA) - .apply( - transform(configurationBuilder()) - .convertForRowToMessage( - Schema.builder() - .addField(targetAttributes) - .addField(targetTimestamp) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build(), - CLOCK)) - .setRowSchema( - Schema.builder() - .addField(targetAttributes) - .addField(targetTimestamp) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build())) - .containsInAnyOrder( - Row.withSchema( - Schema.builder() - .addField(targetAttributes) - .addField(targetTimestamp) - .addFields(ALL_DATA_TYPES_SCHEMA.getFields()) - .build()) - .addValue(ImmutableMap.of()) - .addValue(generatedTimestamp) - .addValues(userRow.getValues()) - .build()); - - pipeline.run(OPTIONS); - } - - @Test - public void testGetPayloadSerializer() { - Row withAllDataTypes = - rowWithAllDataTypes( - false, - (byte) 0, - Instant.now().toDateTime(), - BigDecimal.valueOf(-1L), - -3.12345, - -4.1f, - (short) -5, - -2, - -7L, - "吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮"); - - PayloadSerializer jsonPayloadSerializer = - new JsonPayloadSerializerProvider().getSerializer(ALL_DATA_TYPES_SCHEMA, ImmutableMap.of()); - byte[] expectedJson = jsonPayloadSerializer.serialize(withAllDataTypes); - byte[] actualJson = - transform(configurationBuilder().setFormat("json")) - .getPayloadSerializer(ALL_DATA_TYPES_SCHEMA) - .serialize(withAllDataTypes); - - PayloadSerializer avroPayloadSerializer = - new AvroPayloadSerializerProvider().getSerializer(ALL_DATA_TYPES_SCHEMA, ImmutableMap.of()); - byte[] expectedAvro = avroPayloadSerializer.serialize(withAllDataTypes); - byte[] actualAvro = - transform(configurationBuilder().setFormat("avro")) - .getPayloadSerializer(ALL_DATA_TYPES_SCHEMA) - .serialize(withAllDataTypes); - - assertArrayEquals( - "configuration with json format should yield JSON PayloadSerializer", - expectedJson, - actualJson); - - assertArrayEquals( - "configuration with avro format should yield Avro PayloadSerializer", - expectedAvro, - actualAvro); - } - - private static PubsubWriteSchemaTransformConfiguration.Builder configurationBuilder() { - return PubsubWriteSchemaTransformConfiguration.builder() - .setTopic(TOPIC) - .setTarget(PubsubWriteSchemaTransformConfiguration.targetConfigurationBuilder().build()); - } - - private static PubsubRowToMessage.Builder rowToMessageBuilder() { - return PubsubRowToMessage.builder(); - } - - private static PubsubIO.Write pubsubWrite() { - return PubsubIO.writeMessages().to(TOPIC); - } - - private static PubsubWriteSchemaTransformProvider.PubsubWriteSchemaTransform transform( - PubsubWriteSchemaTransformConfiguration.Builder configurationBuilder) { - Row configurationRow = TO_ROW.apply(configurationBuilder.build()); - PubsubWriteSchemaTransformProvider provider = new PubsubWriteSchemaTransformProvider(); - return (PubsubWriteSchemaTransform) provider.from(configurationRow); - } - - private static class MockClock implements Clock, Serializable { - private final Long millis; - - private MockClock(Instant timestamp) { - this.millis = timestamp.getMillis(); - } - - @Override - public long currentTimeMillis() { - return millis; - } - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java index 89a70a642f50..fd01cf9d0068 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java @@ -318,7 +318,7 @@ public void testPubsubLiteWriteReadWithSchemaTransform() throws Exception { return Objects.requireNonNull(row.getInt64("numberInInt")).intValue(); })); ids.apply("PubsubSignalTest", signal.signalSuccessWhen(BigEndianIntegerCoder.of(), testIds())); - Supplier start = signal.waitForStart(Duration.standardMinutes(5)); + Supplier start = signal.waitForStart(Duration.standardMinutes(8)); pipeline.apply("start signal", signal.signalStart()); PipelineResult job = pipeline.run(); start.get(); @@ -363,7 +363,7 @@ public void testReadWrite() throws Exception { PCollection messages = readMessages(subscription, pipeline); PCollection ids = messages.apply(MapElements.via(extractIds())); ids.apply("PubsubSignalTest", signal.signalSuccessWhen(BigEndianIntegerCoder.of(), testIds())); - Supplier start = signal.waitForStart(Duration.standardMinutes(5)); + Supplier start = signal.waitForStart(Duration.standardMinutes(8)); pipeline.apply(signal.signalStart()); PipelineResult job = pipeline.run(); start.get(); diff --git a/sdks/java/io/kafka/kafka-integration-test.gradle b/sdks/java/io/kafka/kafka-integration-test.gradle index bfb8c7f5fd02..57a04eabfd1b 100644 --- a/sdks/java/io/kafka/kafka-integration-test.gradle +++ b/sdks/java/io/kafka/kafka-integration-test.gradle @@ -19,6 +19,7 @@ import org.apache.beam.gradle.kafka.KafkaTestUtilities apply plugin: 'org.apache.beam.module' applyJavaNature( + publish: false, automaticModuleName: 'org.apache.beam.sdk.io.kafka', mavenRepositories: [ [id: 'io.confluent', url: 'https://packages.confluent.io/maven/'] diff --git a/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/options.py b/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/options.py index b32200ed7331..1966100430e9 100644 --- a/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/options.py +++ b/sdks/python/apache_beam/examples/inference/runinference_metrics/pipeline/options.py @@ -61,7 +61,6 @@ def get_pipeline_options( flags = [ "--experiment=worker_accelerator=type:nvidia-tesla-p4;count:1;"\ "install-nvidia-driver", - "--experiment=use_runner_v2", ] dataflow_options.update({ "sdk_container_image": cfg.DOCKER_IMG, diff --git a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py index 65e943f5c697..2166d0db366e 100644 --- a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py +++ b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py @@ -61,49 +61,111 @@ def parse_args(): return parser.parse_known_args() -def run(args): - data = [ - dict(x=["Let's", "go", "to", "the", "park"]), - dict(x=["I", "enjoy", "going", "to", "the", "park"]), - dict(x=["I", "enjoy", "reading", "books"]), - dict(x=["Beam", "can", "be", "fun"]), - dict(x=["The", "weather", "is", "really", "nice", "today"]), - dict(x=["I", "love", "to", "go", "to", "the", "park"]), - dict(x=["I", "love", "to", "read", "books"]), - dict(x=["I", "love", "to", "program"]), - ] +def preprocess_data_for_ml_training(train_data, artifact_mode, args): + """ + Preprocess the data for ML training. This method runs a pipeline to + preprocess the data needed for ML training. It produces artifacts that + can be used for ML inference later. + """ with beam.Pipeline() as p: - input_data = p | beam.Create(data) - - # arfifacts produce mode. - input_data |= ( - 'MLTransform' >> MLTransform( + train_data_pcoll = (p | "CreateData" >> beam.Create(train_data)) + + # When 'artifact_mode' is set to 'produce', the ComputeAndApplyVocabulary + # function generates a vocabulary file. This file, stored in + # 'artifact_location', contains the vocabulary of the entire dataset. + # This is considered as an artifact of ComputeAndApplyVocabulary transform. + # The indices of the vocabulary in this file are returned as + # the output of MLTransform. + transformed_data_pcoll = ( + train_data_pcoll + | 'MLTransform' >> MLTransform( artifact_location=args.artifact_location, - artifact_mode=ArtifactMode.PRODUCE, + artifact_mode=artifact_mode, ).with_transform(ComputeAndApplyVocabulary( columns=['x'])).with_transform(TFIDF(columns=['x']))) - # _ = input_data | beam.Map(logging.info) + _ = transformed_data_pcoll | beam.Map(logging.info) + # output for the element dict(x=["Let's", "go", "to", "the", "park"]) + # will be: + # Row(x=array([21, 5, 0, 2, 3]), + # x_tfidf_weight=array([0.28109303, 0.36218604, 0.36218604, 0.41972247, + # 0.5008155 ], dtype=float32), x_vocab_index=array([ 0, 2, 3, 5, 21])) + +def preprocess_data_for_ml_inference(test_data, artifact_mode, args): + """ + Preprocess the data for ML inference. This method runs a pipeline to + preprocess the data needed for ML inference. It consumes the artifacts + produced during the preprocessing stage for ML training. + """ with beam.Pipeline() as p: - input_data = [ - dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam']) - ] - input_data = p | beam.Create(input_data) - - # artifacts consume mode. - input_data |= ( - MLTransform( + + test_data_pcoll = (p | beam.Create(test_data)) + # Here, the previously saved vocabulary from an MLTransform run is used by + # ComputeAndApplyVocabulary to access and apply the stored artifacts to the + # test data. + transformed_data_pcoll = ( + test_data_pcoll + | "MLTransformOnTestData" >> MLTransform( artifact_location=args.artifact_location, - artifact_mode=ArtifactMode.CONSUME, - # you don't need to specify transforms as they are already saved in + artifact_mode=artifact_mode, + # ww don't need to specify transforms as they are already saved in # in the artifacts. )) + _ = transformed_data_pcoll | beam.Map(logging.info) + # output for dict(x=['I', 'love', 'books']) will be: + # Row(x=array([1, 4, 7]), + # x_tfidf_weight=array([0.4684884 , 0.6036434 , 0.69953746], dtype=float32) + # , x_vocab_index=array([1, 4, 7])) - _ = input_data | beam.Map(logging.info) - # To fetch the artifacts after the pipeline is run +def run(args): + """ + This example demonstrates how to use MLTransform in ML workflow. + 1. Preprocess the data for ML training. + 2. Do some ML model training. + 3. Preprocess the data for ML inference. + + training and inference on ML modes are not shown in this example. + This example only shows how to use MLTransform for preparing data for ML + training and inference. + """ + + train_data = [ + dict(x=["Let's", "go", "to", "the", "park"]), + dict(x=["I", "enjoy", "going", "to", "the", "park"]), + dict(x=["I", "enjoy", "reading", "books"]), + dict(x=["Beam", "can", "be", "fun"]), + dict(x=["The", "weather", "is", "really", "nice", "today"]), + dict(x=["I", "love", "to", "go", "to", "the", "park"]), + dict(x=["I", "love", "to", "read", "books"]), + dict(x=["I", "love", "to", "program"]), + ] + + test_data = [ + dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam']) + ] + + # Preprocess the data for ML training. + # For the data going into the ML model training, we want to produce the + # artifacts. So, we set artifact_mode to ArtifactMode.PRODUCE. + preprocess_data_for_ml_training( + train_data, artifact_mode=ArtifactMode.PRODUCE, args=args) + + # Do some ML model training here. + + # Preprocess the data for ML inference. + # For the data going into the ML model inference, we want to consume the + # artifacts produced during the stage where we preprocessed the data for ML + # training. So, we set artifact_mode to ArtifactMode.CONSUME. + preprocess_data_for_ml_inference( + test_data, artifact_mode=ArtifactMode.CONSUME, args=args) + + # To fetch the artifacts produced in MLTransform, you can use + # ArtifactsFetcher for fetching vocab related artifacts. For + # others such as TFIDF weight, they can be accessed directly + # from the output of MLTransform. artifacts_fetcher = ArtifactsFetcher(artifact_location=args.artifact_location) vocab_list = artifacts_fetcher.get_vocab_list() assert vocab_list[22] == 'Beam' diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py index 248e0849cdd0..0672cd5ea168 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py @@ -792,10 +792,7 @@ def create_bq_schema(cls, with_extra=False): @skip(['PortableRunner', 'FlinkRunner']) @pytest.mark.it_postcommit def test_read_queries(self): - # TODO(https://github.com/apache/beam/issues/20610): Remove experiment when - # tests run on r_v2. - args = self.args + ["--experiments=use_runner_v2"] - with beam.Pipeline(argv=args) as p: + with beam.Pipeline(argv=self.args) as p: result = ( p | beam.Create([ diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 541bbfd94113..2fdbce73170a 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -824,7 +824,6 @@ def finish(self): self._upload_thread.join() # Check for exception since the last put() call. if self._upload_thread.last_error is not None: + e = self._upload_thread.last_error raise type(self._upload_thread.last_error)( - "Error while uploading file %s: %s", - self._path, - self._upload_thread.last_error.message) # pylint: disable=raising-bad-type + "Error while uploading file %s" % self._path) from e # pylint: disable=raising-bad-type diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 8cee8acfebb0..af58006d6e76 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -39,9 +39,9 @@ from typing import Tuple from apache_beam import coders +from apache_beam.io import iobase from apache_beam.io.iobase import Read from apache_beam.io.iobase import Write -from apache_beam.runners.dataflow.native_io import iobase as dataflow_io from apache_beam.transforms import Flatten from apache_beam.transforms import Map from apache_beam.transforms import PTransform @@ -261,6 +261,7 @@ def __init__( timestamp_attribute=timestamp_attribute) def expand(self, pvalue): + # TODO(BEAM-27443): Apply a proper transform rather than Read. pcoll = pvalue.pipeline | Read(self._source) pcoll.element_type = bytes if self.with_attributes: @@ -423,7 +424,8 @@ def parse_subscription(full_subscription): return project, subscription_name -class _PubSubSource(dataflow_io.NativeSource): +# TODO(BEAM-27443): Remove (or repurpose as a proper PTransform). +class _PubSubSource(iobase.SourceBase): """Source for a Cloud Pub/Sub topic or subscription. This ``NativeSource`` is overridden by a native Pubsub implementation. @@ -460,11 +462,6 @@ def __init__( if subscription: self.project, self.subscription_name = parse_subscription(subscription) - @property - def format(self): - """Source format name required for remote execution.""" - return 'pubsub' - def display_data(self): return { 'id_label': DisplayDataItem(self.id_label, @@ -480,14 +477,15 @@ def display_data(self): label='Timestamp Attribute').drop_if_none(), } - def reader(self): - raise NotImplementedError + def default_output_coder(self): + return self.coder def is_bounded(self): return False -class _PubSubSink(dataflow_io.NativeSink): +# TODO(BEAM-27443): Remove in favor of a proper WriteToPubSub transform. +class _PubSubSink(object): """Sink for a Cloud Pub/Sub topic. This ``NativeSource`` is overridden by a native Pubsub implementation. @@ -505,14 +503,6 @@ def __init__( self.project, self.topic_name = parse_topic(topic) - @property - def format(self): - """Sink format name required for remote execution.""" - return 'pubsub' - - def writer(self): - raise NotImplementedError - class PubSubSourceDescriptor(NamedTuple): """A PubSub source descriptor for ``MultipleReadFromPubSub``` diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 6d75d520af55..e15205ead4d4 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -101,6 +101,9 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn): """ urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE) + def default_output_coder(self): + raise NotImplementedError + def is_bounded(self): # type: () -> bool raise NotImplementedError @@ -923,11 +926,8 @@ def get_windowing(self, unused_inputs): def _infer_output_coder(self, input_type=None, input_coder=None): # type: (...) -> Optional[coders.Coder] - from apache_beam.runners.dataflow.native_io import iobase as dataflow_io - if isinstance(self.source, BoundedSource): + if isinstance(self.source, SourceBase): return self.source.default_output_coder() - elif isinstance(self.source, dataflow_io.NativeSource): - return self.source.coder else: return None @@ -941,18 +941,17 @@ def to_runner_api_parameter( self, context: PipelineContext, ) -> Tuple[str, Any]: - from apache_beam.runners.dataflow.native_io import iobase as dataflow_io - if isinstance(self.source, (BoundedSource, dataflow_io.NativeSource)): - from apache_beam.io.gcp.pubsub import _PubSubSource - if isinstance(self.source, _PubSubSource): - return ( - common_urns.composites.PUBSUB_READ.urn, - beam_runner_api_pb2.PubSubReadPayload( - topic=self.source.full_topic, - subscription=self.source.full_subscription, - timestamp_attribute=self.source.timestamp_attribute, - with_attributes=self.source.with_attributes, - id_attribute=self.source.id_label)) + from apache_beam.io.gcp.pubsub import _PubSubSource + if isinstance(self.source, _PubSubSource): + return ( + common_urns.composites.PUBSUB_READ.urn, + beam_runner_api_pb2.PubSubReadPayload( + topic=self.source.full_topic, + subscription=self.source.full_subscription, + timestamp_attribute=self.source.timestamp_attribute, + with_attributes=self.source.with_attributes, + id_attribute=self.source.id_label)) + if isinstance(self.source, BoundedSource): return ( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload( @@ -976,6 +975,7 @@ def from_runner_api_parameter( if transform.spec.urn == common_urns.composites.PUBSUB_READ.urn: assert isinstance(payload, beam_runner_api_pb2.PubSubReadPayload) # Importing locally to prevent circular dependencies. + # TODO(BEAM-27443): Remove the need for this. from apache_beam.io.gcp.pubsub import _PubSubSource source = _PubSubSource( topic=payload.topic or None, @@ -1015,6 +1015,7 @@ def _from_runner_api_parameter_pubsub_read( Read._from_runner_api_parameter_read, ) +# TODO(BEAM-27443): Remove. ptransform.PTransform.register_urn( common_urns.composites.PUBSUB_READ.urn, beam_runner_api_pb2.PubSubReadPayload, @@ -1065,10 +1066,11 @@ def display_data(self): return {'sink': self.sink.__class__, 'sink_dd': self.sink} def expand(self, pcoll): - from apache_beam.runners.dataflow.native_io import iobase as dataflow_io - if isinstance(self.sink, dataflow_io.NativeSink): - # A native sink - return pcoll | 'NativeWrite' >> dataflow_io._NativeWrite(self.sink) + # Importing locally to prevent circular dependencies. + from apache_beam.io.gcp.pubsub import _PubSubSink + if isinstance(self.sink, _PubSubSink): + # TODO(BEAM-27443): Remove the need for special casing here. + return pvalue.PDone(pcoll.pipeline) elif isinstance(self.sink, Sink): # A custom sink return pcoll | WriteImpl(self.sink) @@ -1084,6 +1086,7 @@ def to_runner_api_parameter( self, context: PipelineContext, ) -> Tuple[str, Any]: + # TODO(BEAM-27443): Remove the need for special casing here. # Importing locally to prevent circular dependencies. from apache_beam.io.gcp.pubsub import _PubSubSink if isinstance(self.sink, _PubSubSink): diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 34fd6d27448d..a0e0d9d3f8f0 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -66,6 +66,7 @@ def __init__( project: str, location: str, experiment: Optional[str] = None, + network: Optional[str] = None, **kwargs): """Implementation of the ModelHandler interface for Vertex AI. **NOTE:** This API and its implementation are under development and @@ -73,19 +74,30 @@ def __init__( Unlike other ModelHandler implementations, this does not load the model being used onto the worker and instead makes remote queries to a Vertex AI endpoint. In that way it functions more like a mid-pipeline - IO. At present this implementation only supports public endpoints with - a maximum request size of 1.5 MB. + IO. Public Vertex AI endpoints have a maximum request size of 1.5 MB. + If you wish to make larger requests and use a private endpoint, provide + the Compute Engine network you wish to use. + Args: endpoint_id: the numerical ID of the Vertex AI endpoint to query project: the GCP project name where the endpoint is deployed location: the GCP location where the endpoint is deployed - experiment (Optional): experiment label to apply to the queries + experiment: optional. experiment label to apply to the + queries + network: optional. the full name of the Compute Engine + network the endpoint is deployed on; used for private + endpoints only. + Ex: "projects/12345/global/networks/myVPC" """ self._env_vars = kwargs.get('env_vars', {}) # TODO: support the full list of options for aiplatform.init() # See https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform#google_cloud_aiplatform_init - aiplatform.init(project=project, location=location, experiment=experiment) + aiplatform.init( + project=project, + location=location, + experiment=experiment, + network=network) # Check for liveness here but don't try to actually store the endpoint # in the class yet diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 3ac59ff98a7c..09f4ddfa53f3 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -72,13 +72,13 @@ def test_ml_transform_appends_transforms_to_process_handler_correctly(self): self.assertEqual( ml_transform._process_handler.transforms[1].name, 'fake_fn_2') - def test_ml_transform_on_unbatched_dict(self): + def test_ml_transform_on_dict(self): transforms = [tft.ScaleTo01(columns=['x'])] - unbatched_data = [{'x': 1}, {'x': 2}] + data = [{'x': 1}, {'x': 2}] with beam.Pipeline() as p: result = ( p - | beam.Create(unbatched_data) + | beam.Create(data) | base.MLTransform( artifact_location=self.artifact_location, transforms=transforms)) expected_output = [ @@ -89,20 +89,20 @@ def test_ml_transform_on_unbatched_dict(self): assert_that( actual_output, equal_to(expected_output, equals_fn=np.array_equal)) - def test_ml_transform_on_batched_dict(self): + def test_ml_transform_on_list_dict(self): transforms = [tft.ScaleTo01(columns=['x'])] - batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] with beam.Pipeline() as p: - batched_result = ( + result = ( p - | beam.Create(batched_data) + | beam.Create(data) | base.MLTransform( transforms=transforms, artifact_location=self.artifact_location)) expected_output = [ np.array([0, 0.2, 0.4], dtype=np.float32), np.array([0.6, 0.8, 1], dtype=np.float32), ] - actual_output = batched_result | beam.Map(lambda x: x.x) + actual_output = result | beam.Map(lambda x: x.x) assert_that( actual_output, equal_to(expected_output, equals_fn=np.array_equal)) @@ -193,19 +193,19 @@ def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): def test_ml_transform_on_multiple_columns_single_transform(self): transforms = [tft.ScaleTo01(columns=['x', 'y'])] - batched_data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}] + data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}] with beam.Pipeline() as p: - batched_result = ( + result = ( p - | beam.Create(batched_data) + | beam.Create(data) | base.MLTransform( transforms=transforms, artifact_location=self.artifact_location)) expected_output_x = [ np.array([0, 0.5, 1], dtype=np.float32), ] expected_output_y = [np.array([0, 0.47368422, 1], dtype=np.float32)] - actual_output_x = batched_result | beam.Map(lambda x: x.x) - actual_output_y = batched_result | beam.Map(lambda x: x.y) + actual_output_x = result | beam.Map(lambda x: x.x) + actual_output_y = result | beam.Map(lambda x: x.y) assert_that( actual_output_x, equal_to(expected_output_x, equals_fn=np.array_equal)) @@ -219,19 +219,19 @@ def test_ml_transforms_on_multiple_columns_multiple_transforms(self): tft.ScaleTo01(columns=['x']), tft.ComputeAndApplyVocabulary(columns=['y']) ] - batched_data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}] + data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}] with beam.Pipeline() as p: - batched_result = ( + result = ( p - | beam.Create(batched_data) + | beam.Create(data) | base.MLTransform( transforms=transforms, artifact_location=self.artifact_location)) expected_output_x = [ np.array([0, 0.5, 1], dtype=np.float32), ] expected_output_y = [np.array([2, 1, 0])] - actual_output_x = batched_result | beam.Map(lambda x: x.x) - actual_output_y = batched_result | beam.Map(lambda x: x.y) + actual_output_x = result | beam.Map(lambda x: x.x) + actual_output_y = result | beam.Map(lambda x: x.y) assert_that( actual_output_x, diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 4abcfee0a6e9..3889efc77b7c 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -66,15 +66,15 @@ def get_artifacts(self, data, col_name): return {'artifact': tf.convert_to_tensor([1])} -class UnBatchedIntType(NamedTuple): +class IntType(NamedTuple): x: int -class BatchedIntType(NamedTuple): +class ListIntType(NamedTuple): x: List[int] -class BatchedNumpyType(NamedTuple): +class NumpyType(NamedTuple): x: np.int64 @@ -116,13 +116,12 @@ def test_preprocessing_fn_with_artifacts(self): expected_result = {'x': [1, 2, 3], 'artifact': tf.convert_to_tensor([1])} self.assertDictEqual(actual_result, expected_result) - def test_input_type_from_schema_named_tuple_pcoll_unbatched(self): - non_batched_data = [{'x': 1}] + def test_input_type_from_schema_named_tuple_pcoll(self): + data = [{'x': 1}] with beam.Pipeline() as p: data = ( - p | beam.Create(non_batched_data) - | beam.Map(lambda x: UnBatchedIntType(**x)).with_output_types( - UnBatchedIntType)) + p | beam.Create(data) + | beam.Map(lambda x: IntType(**x)).with_output_types(IntType)) element_type = data.element_type process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) @@ -132,13 +131,12 @@ def test_input_type_from_schema_named_tuple_pcoll_unbatched(self): self.assertEqual(inferred_input_type, expected_input_type) - def test_input_type_from_schema_named_tuple_pcoll_batched(self): - batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + def test_input_type_from_schema_named_tuple_pcoll_list(self): + data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] with beam.Pipeline() as p: data = ( - p | beam.Create(batched_data) - | beam.Map(lambda x: BatchedIntType(**x)).with_output_types( - BatchedIntType)) + p | beam.Create(data) + | beam.Map(lambda x: ListIntType(**x)).with_output_types(ListIntType)) element_type = data.element_type process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) @@ -147,11 +145,11 @@ def test_input_type_from_schema_named_tuple_pcoll_batched(self): expected_input_type = dict(x=List[int]) self.assertEqual(inferred_input_type, expected_input_type) - def test_input_type_from_row_type_pcoll_unbatched(self): - non_batched_data = [{'x': 1}] + def test_input_type_from_row_type_pcoll(self): + data = [{'x': 1}] with beam.Pipeline() as p: data = ( - p | beam.Create(non_batched_data) + p | beam.Create(data) | beam.Map(lambda ele: beam.Row(x=int(ele['x'])))) element_type = data.element_type process_handler = handlers.TFTProcessHandler( @@ -161,11 +159,11 @@ def test_input_type_from_row_type_pcoll_unbatched(self): expected_input_type = dict(x=List[int]) self.assertEqual(inferred_input_type, expected_input_type) - def test_input_type_from_row_type_pcoll_batched(self): - batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + def test_input_type_from_row_type_pcoll_list(self): + data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] with beam.Pipeline() as p: data = ( - p | beam.Create(batched_data) + p | beam.Create(data) | beam.Map(lambda ele: beam.Row(x=list(ele['x']))).with_output_types( beam.row_type.RowTypeConstraint.from_fields([('x', List[int])]))) @@ -177,17 +175,16 @@ def test_input_type_from_row_type_pcoll_batched(self): expected_input_type = dict(x=List[int]) self.assertEqual(inferred_input_type, expected_input_type) - def test_input_type_from_named_tuple_pcoll_batched_numpy(self): - batched = [{ + def test_input_type_from_named_tuple_pcoll_numpy(self): + np_data = [{ 'x': np.array([1, 2, 3], dtype=np.int64) }, { 'x': np.array([4, 5, 6], dtype=np.int64) }] with beam.Pipeline() as p: data = ( - p | beam.Create(batched) - | beam.Map(lambda x: BatchedNumpyType(**x)).with_output_types( - BatchedNumpyType)) + p | beam.Create(np_data) + | beam.Map(lambda x: NumpyType(**x)).with_output_types(NumpyType)) element_type = data.element_type process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index c96290d0440a..b24c3cff216a 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -39,6 +39,7 @@ from typing import Iterable from typing import List from typing import Optional +from typing import Tuple from typing import Union import tensorflow as tf @@ -56,6 +57,8 @@ 'Bucketize', 'TFIDF', 'TFTOperation', + 'ScaleByMinMax', + 'NGrams', ] # Register the expected input types for each operation @@ -98,12 +101,39 @@ def get_artifacts(self, data: common_types.TensorType, """ return {} + def _split_string_with_delimiter(self, data, delimiter): + """ + only applicable to string columns. + """ + data = tf.sparse.to_dense(data) + # this method acts differently compared to tf.strings.split + # this will split the string based on multiple delimiters while + # the latter will split the string based on a single delimiter. + fn = lambda data: tf.compat.v1.string_split( + data, delimiter, result_type='RaggedTensor') + # tf.compat.v1.string_split works on a single string. Use tf.map_fn + # to apply the function on each element of the input data. + data = tf.map_fn( + fn, + data, + fn_output_signature=tf.RaggedTensorSpec( + tf.TensorShape([None, None]), tf.string)) + data = data.values.to_sparse() + # the columns of the sparse tensor are suffixed with $indices, $values + # related to sparse tensor. Create a new sparse tensor by extracting + # the indices, values and dense_shape from the original sparse tensor + # to preserve the original column name. + data = tf.sparse.SparseTensor( + indices=data.indices, values=data.values, dense_shape=data.dense_shape) + return data + @register_input_dtype(str) class ComputeAndApplyVocabulary(TFTOperation): def __init__( self, columns: List[str], + split_string_by_delimiter: Optional[str] = None, *, default_value: Any = -1, top_k: Optional[int] = None, @@ -118,6 +148,8 @@ def __init__( Args: columns: List of column names to apply the transformation. + split_string_by_delimiter: (Optional) A string that specifies the + delimiter to split strings. default_value: (Optional) The value to use for out-of-vocabulary values. top_k: (Optional) The number of most frequent tokens to keep. frequency_threshold: (Optional) Limit the generated vocabulary only to @@ -140,10 +172,14 @@ def __init__( self._vocab_filename = vocab_filename if vocab_filename else ( 'compute_and_apply_vocab') self._name = name + self.split_string_by_delimiter = split_string_by_delimiter def apply_transform( self, data: common_types.TensorType, output_column_name: str) -> Dict[str, common_types.TensorType]: + if self.split_string_by_delimiter: + data = self._split_string_with_delimiter( + data, self.split_string_by_delimiter) return { output_column_name: tft.compute_and_apply_vocabulary( x=data, @@ -434,3 +470,78 @@ def apply_transform( output_column_name + '_tfidf_weight': tfidf_weight } return output + + +@register_input_dtype(float) +class ScaleByMinMax(TFTOperation): + def __init__( + self, + columns: List[str], + min_value: float = 0.0, + max_value: float = 1.0, + name: Optional[str] = None): + """ + This function applies a scaling transformation on the given columns + of incoming data. The transformation scales the input values to the + range [min_value, max_value]. + + Args: + columns: A list of column names to apply the transformation on. + min_value: The minimum value of the output range. + max_value: The maximum value of the output range. + name: A name for the operation (optional). + """ + super().__init__(columns) + self.min_value = min_value + self.max_value = max_value + self.name = name + + if self.max_value <= self.min_value: + raise ValueError('max_value must be greater than min_value') + + def apply_transform( + self, data: tf.Tensor, output_column_name: str) -> tf.Tensor: + + output = tft.scale_by_min_max( + x=data, output_min=self.min_value, output_max=self.max_value) + return {output_column_name: output} + + +@register_input_dtype(str) +class NGrams(TFTOperation): + def __init__( + self, + columns: List[str], + split_string_by_delimiter: Optional[str] = None, + *, + ngram_range: Tuple[int, int], + ngrams_separator: str, + name: Optional[str] = None): + """ + An n-gram is a contiguous sequence of n items from a given sample of text + or speech. This operation applies an n-gram transformation to + specified columns of incoming data, splitting the input data into a + set of consecutive n-grams. + + Args: + columns: A list of column names to apply the transformation on. + split_string_by_delimiter: (Optional) A string that specifies the + delimiter to split the input strings before computing ngrams. + ngram_range: A tuple of integers(inclusive) specifying the range of + n-gram sizes. + ngrams_separator: A string that will be inserted between each ngram. + name: A name for the operation (optional). + """ + super().__init__(columns) + self.ngram_range = ngram_range + self.ngrams_separator = ngrams_separator + self.name = name + self.split_string_by_delimiter = split_string_by_delimiter + + def apply_transform(self, data: tf.SparseTensor, + output_column_name: str) -> Dict[str, tf.SparseTensor]: + if self.split_string_by_delimiter: + data = self._split_string_with_delimiter( + data, self.split_string_by_delimiter) + output = tft.ngrams(data, self.ngram_range, self.ngrams_separator) + return {output_column_name: output} diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index 66578c7366dc..2cb24defa59c 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + # pytype: skip-file import shutil @@ -70,40 +71,28 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) - def test_z_score_unbatched(self): - unbatched_data = [{ - 'x': 1 - }, { - 'x': 2 - }, { - 'x': 3 - }, { - 'x': 4 - }, { - 'x': 5 - }, { - 'x': 6 - }] + def test_z_score(self): + data = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}, {'x': 6}] with beam.Pipeline() as p: - unbatched_result = ( + result = ( p - | "unbatchedCreate" >> beam.Create(unbatched_data) - | "unbatchedMLTransform" >> base.MLTransform( + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( artifact_location=self.artifact_location).with_transform( tft.ScaleToZScore(columns=['x']))) - _ = (unbatched_result | beam.Map(assert_z_score_artifacts)) + _ = (result | beam.Map(assert_z_score_artifacts)) - def test_z_score_batched(self): - batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + def test_z_score_list_data(self): + list_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] with beam.Pipeline() as p: - batched_result = ( + list_result = ( p - | "batchedCreate" >> beam.Create(batched_data) - | "batchedMLTransform" >> base.MLTransform( + | "listCreate" >> beam.Create(list_data) + | "listMLTransform" >> base.MLTransform( artifact_location=self.artifact_location).with_transform( tft.ScaleToZScore(columns=['x']))) - _ = (batched_result | beam.Map(assert_z_score_artifacts)) + _ = (list_result | beam.Map(assert_z_score_artifacts)) class ScaleTo01Test(unittest.TestCase): @@ -113,48 +102,36 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) - def test_ScaleTo01_batched(self): - batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + def test_ScaleTo01_list(self): + list_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] with beam.Pipeline() as p: - batched_result = ( + list_result = ( p - | "batchedCreate" >> beam.Create(batched_data) - | "batchedMLTransform" >> base.MLTransform( + | "listCreate" >> beam.Create(list_data) + | "MLTransform" >> base.MLTransform( artifact_location=self.artifact_location).with_transform( tft.ScaleTo01(columns=['x']))) - _ = (batched_result | beam.Map(assert_ScaleTo01_artifacts)) + _ = (list_result | beam.Map(assert_ScaleTo01_artifacts)) expected_output = [ np.array([0, 0.2, 0.4], dtype=np.float32), np.array([0.6, 0.8, 1], dtype=np.float32) ] - actual_output = (batched_result | beam.Map(lambda x: x.x)) + actual_output = (list_result | beam.Map(lambda x: x.x)) assert_that( actual_output, equal_to(expected_output, equals_fn=np.array_equal)) - def test_ScaleTo01_unbatched(self): - unbatched_data = [{ - 'x': 1 - }, { - 'x': 2 - }, { - 'x': 3 - }, { - 'x': 4 - }, { - 'x': 5 - }, { - 'x': 6 - }] + def test_ScaleTo01(self): + data = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}, {'x': 6}] with beam.Pipeline() as p: - unbatched_result = ( + result = ( p - | "unbatchedCreate" >> beam.Create(unbatched_data) - | "unbatchedMLTransform" >> base.MLTransform( + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( artifact_location=self.artifact_location).with_transform( tft.ScaleTo01(columns=['x']))) - _ = (unbatched_result | beam.Map(assert_ScaleTo01_artifacts)) + _ = (result | beam.Map(assert_ScaleTo01_artifacts)) expected_output = ( np.array([0], dtype=np.float32), np.array([0.2], dtype=np.float32), @@ -162,7 +139,7 @@ def test_ScaleTo01_unbatched(self): np.array([0.6], dtype=np.float32), np.array([0.8], dtype=np.float32), np.array([1], dtype=np.float32)) - actual_output = (unbatched_result | beam.Map(lambda x: x.x)) + actual_output = (result | beam.Map(lambda x: x.x)) assert_that( actual_output, equal_to(expected_output, equals_fn=np.array_equal)) @@ -174,18 +151,18 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) - def test_bucketize_unbatched(self): - unbatched = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}, {'x': 6}] + def test_bucketize(self): + data = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}, {'x': 6}] with beam.Pipeline() as p: - unbatched_result = ( + result = ( p - | "unbatchedCreate" >> beam.Create(unbatched) - | "unbatchedMLTransform" >> base.MLTransform( + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( artifact_location=self.artifact_location).with_transform( tft.Bucketize(columns=['x'], num_buckets=3))) - _ = (unbatched_result | beam.Map(assert_bucketize_artifacts)) + _ = (result | beam.Map(assert_bucketize_artifacts)) - transformed_data = (unbatched_result | beam.Map(lambda x: x.x)) + transformed_data = (result | beam.Map(lambda x: x.x)) expected_data = [ np.array([0]), np.array([0]), @@ -197,19 +174,19 @@ def test_bucketize_unbatched(self): assert_that( transformed_data, equal_to(expected_data, equals_fn=np.array_equal)) - def test_bucketize_batched(self): - batched = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + def test_bucketize_list(self): + list_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] with beam.Pipeline() as p: - batched_result = ( + list_result = ( p - | "batchedCreate" >> beam.Create(batched) - | "batchedMLTransform" >> base.MLTransform( + | "Create" >> beam.Create(list_data) + | "MLTransform" >> base.MLTransform( artifact_location=self.artifact_location).with_transform( tft.Bucketize(columns=['x'], num_buckets=3))) - _ = (batched_result | beam.Map(assert_bucketize_artifacts)) + _ = (list_result | beam.Map(assert_bucketize_artifacts)) transformed_data = ( - batched_result + list_result | "TransformedColumnX" >> beam.Map(lambda ele: ele.x)) expected_data = [ np.array([0, 0, 1], dtype=np.int64), @@ -290,9 +267,9 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) - def test_compute_and_apply_vocabulary_unbatched_inputs(self): - batch_size = 100 - num_instances = batch_size + 1 + def test_compute_and_apply_vocabulary_inputs(self): + num_elements = 100 + num_instances = num_elements + 1 input_data = [{ 'x': '%.10i' % i, # Front-padded to facilitate lexicographic sorting. } for i in range(num_instances)] @@ -312,9 +289,9 @@ def test_compute_and_apply_vocabulary_unbatched_inputs(self): assert_that(actual_data, equal_to(expected_data)) - def test_compute_and_apply_vocabulary_batched(self): - batch_size = 100 - num_instances = batch_size + 1 + def test_compute_and_apply_vocabulary(self): + num_elements = 100 + num_instances = num_elements + 1 input_data = [ { 'x': ['%.10i' % i, '%.10i' % (i + 1), '%.10i' % (i + 2)], @@ -322,7 +299,7 @@ def test_compute_and_apply_vocabulary_batched(self): } for i in range(0, num_instances, 3) ] - # since we have 3 elements in a single batch, multiply with 3 for + # since we have 3 elements in a single list, multiply with 3 for # each iteration i on the expected output. excepted_data = [ np.array([(len(input_data) * 3 - 1) - i, @@ -344,6 +321,48 @@ def test_compute_and_apply_vocabulary_batched(self): assert_that( actual_output, equal_to(excepted_data, equals_fn=np.array_equal)) + def test_string_split_with_single_delimiter(self): + data = [{ + 'x': ['I like pie', 'yum yum pie'], + }, { + 'x': 'yum yum pie' + }] + + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ComputeAndApplyVocabulary( + columns=['x'], split_string_by_delimiter=' '))) + result = result | beam.Map(lambda x: x.x) + expected_result = [ + np.array([3, 2, 1]), np.array([0, 0, 1]), np.array([0, 0, 1]) + ] + assert_that(result, equal_to(expected_result, equals_fn=np.array_equal)) + + def test_string_split_with_multiple_delimiters(self): + data = [{ + 'x': ['I like pie', 'yum;yum;pie'], + }, { + 'x': 'yum yum pie' + }] + + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ComputeAndApplyVocabulary( + columns=['x'], split_string_by_delimiter=' ;'))) + result = result | beam.Map(lambda x: x.x) + expected_result = [ + np.array([3, 2, 1]), np.array([0, 0, 1]), np.array([0, 0, 1]) + ] + assert_that(result, equal_to(expected_result, equals_fn=np.array_equal)) + class TFIDIFTest(unittest.TestCase): def setUp(self) -> None: @@ -352,7 +371,7 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) - def test_tfidf_batched_compute_vocab_size_during_runtime(self): + def test_tfidf_compute_vocab_size_during_runtime(self): raw_data = [ dict(x=["I", "like", "pie", "pie", "pie"]), dict(x=["yum", "yum", "pie"]) @@ -391,5 +410,148 @@ def equals_fn(a, b): assert_that(actual_output, equal_to(expected_output, equals_fn=equals_fn)) +class ScaleToMinMaxTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_scale_to_min_max(self): + data = [{ + 'x': 4, + }, { + 'x': 1, + }, { + 'x': 5, + }, { + 'x': 2, + }] + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ScaleByMinMax( + columns=['x'], + min_value=-1, + max_value=1, + ), + )) + result = result | beam.Map(lambda x: x.as_dict()) + expected_data = [{ + 'x': np.array([0.5], dtype=np.float32) + }, { + 'x': np.array([-1.0], dtype=np.float32) + }, { + 'x': np.array([1.0], dtype=np.float32) + }, { + 'x': np.array([-0.5], dtype=np.float32) + }] + assert_that(result, equal_to(expected_data)) + + def test_fail_max_value_less_than_min(self): + with self.assertRaises(ValueError): + tft.ScaleByMinMax(columns=['x'], min_value=10, max_value=0) + + +class NGramsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_ngrams_on_list_separated_words(self): + data = [{ + 'x': ['I', 'like', 'pie'], + }, { + 'x': ['yum', 'yum', 'pie'] + }] + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location, + transforms=[ + tft.NGrams( + columns=['x'], ngram_range=(1, 3), ngrams_separator=' ') + ])) + result = result | beam.Map(lambda x: x.x) + expected_data = [ + np.array( + [b'I', b'I like', b'I like pie', b'like', b'like pie', b'pie'], + dtype=object), + np.array( + [b'yum', b'yum yum', b'yum yum pie', b'yum', b'yum pie', b'pie'], + dtype=object) + ] + assert_that(result, equal_to(expected_data, equals_fn=np.array_equal)) + + def test_with_string_split_delimiter(self): + data = [{ + 'x': 'I like pie', + }, { + 'x': 'yum yum pie' + }] + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location, + transforms=[ + tft.NGrams( + columns=['x'], + split_string_by_delimiter=' ', + ngram_range=(1, 3), + ngrams_separator=' ') + ])) + result = result | beam.Map(lambda x: x.x) + + expected_data = [ + np.array( + [b'I', b'I like', b'I like pie', b'like', b'like pie', b'pie'], + dtype=object), + np.array( + [b'yum', b'yum yum', b'yum yum pie', b'yum', b'yum pie', b'pie'], + dtype=object) + ] + assert_that(result, equal_to(expected_data, equals_fn=np.array_equal)) + + def test_with_multiple_string_delimiters(self): + data = [{ + 'x': 'I?like?pie', + }, { + 'x': 'yum yum pie' + }] + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location, + transforms=[ + tft.NGrams( + columns=['x'], + split_string_by_delimiter=' ?', + ngram_range=(1, 3), + ngrams_separator=' ') + ])) + result = result | beam.Map(lambda x: x.x) + + expected_data = [ + np.array( + [b'I', b'I like', b'I like pie', b'like', b'like pie', b'pie'], + dtype=object), + np.array( + [b'yum', b'yum yum', b'yum yum pie', b'yum', b'yum pie', b'pie'], + dtype=object) + ] + assert_that(result, equal_to(expected_data, equals_fn=np.array_equal)) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index d56b464e71c6..d36b335ae294 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1151,8 +1151,7 @@ def _add_argparse_args(cls, parser): help=( 'Number of threads per worker to use on the runner. If left ' 'unspecified, the runner will compute an appropriate number of ' - 'threads to use. Currently only enabled for DataflowRunner when ' - 'experiment \'use_runner_v2\' is enabled.')) + 'threads to use.')) def add_experiment(self, experiment): # pylint: disable=access-member-before-definition diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 18ab0f091aaf..c9ac4ce4c13d 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -30,7 +30,7 @@ from apache_beam import typehints from apache_beam.coders import BytesCoder from apache_beam.io import Read -from apache_beam.metrics import Metrics +from apache_beam.io.iobase import SourceBase from apache_beam.options.pipeline_options import PortableOptions from apache_beam.pipeline import Pipeline from apache_beam.pipeline import PipelineOptions @@ -40,7 +40,6 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSingleton from apache_beam.pvalue import TaggedOutput -from apache_beam.runners.dataflow.native_io.iobase import NativeSource from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -61,39 +60,9 @@ from apache_beam.utils import windowed_value from apache_beam.utils.timestamp import MIN_TIMESTAMP -# TODO(BEAM-1555): Test is failing on the service, with FakeSource. - -class FakeSource(NativeSource): - """Fake source returning a fixed list of values.""" - class _Reader(object): - def __init__(self, vals): - self._vals = vals - self._output_counter = Metrics.counter('main', 'outputs') - - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - pass - - def __iter__(self): - for v in self._vals: - self._output_counter.inc() - yield v - - def __init__(self, vals): - self._vals = vals - - def reader(self): - return FakeSource._Reader(self._vals) - - -class FakeUnboundedSource(NativeSource): +class FakeUnboundedSource(SourceBase): """Fake unbounded source. Does not work at runtime""" - def reader(self): - return None - def is_bounded(self): return False @@ -259,24 +228,6 @@ def test_create_singleton_pcollection(self): pcoll = pipeline | 'label' >> Create([[1, 2, 3]]) assert_that(pcoll, equal_to([[1, 2, 3]])) - # TODO(BEAM-1555): Test is failing on the service, with FakeSource. - # @pytest.mark.it_validatesrunner - def test_metrics_in_fake_source(self): - pipeline = TestPipeline() - pcoll = pipeline | Read(FakeSource([1, 2, 3, 4, 5, 6])) - assert_that(pcoll, equal_to([1, 2, 3, 4, 5, 6])) - res = pipeline.run() - metric_results = res.metrics().query() - outputs_counter = metric_results['counters'][0] - self.assertEqual(outputs_counter.key.step, 'Read') - self.assertEqual(outputs_counter.key.metric.name, 'outputs') - self.assertEqual(outputs_counter.committed, 6) - - def test_fake_read(self): - with TestPipeline() as pipeline: - pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3])) - assert_that(pcoll, equal_to([1, 2, 3])) - def test_visit_entire_graph(self): pipeline = Pipeline() pcoll1 = pipeline | 'pcoll' >> beam.Impulse() diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py index 06e1585ffc42..86e71f9c1ed2 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics_test.py @@ -503,7 +503,6 @@ def test_translate_portable_job_step_name(self): self.ONLY_COUNTERS_LIST) pipeline_options = PipelineOptions([ - '--experiments=use_runner_v2', '--experiments=use_portable_job_submission', '--temp_location=gs://any-location/temp', '--project=dummy_project', diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index ff1beeab510d..5f1d3c0c329a 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -1160,7 +1160,7 @@ def _get_required_container_version(): current version of the SDK. """ if 'dev' in beam_version.__version__: - return names.BEAM_FNAPI_CONTAINER_VERSION + return names.BEAM_DEV_SDK_CONTAINER_TAG else: return _get_container_image_tag() diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 22e779a8c274..d639ad21c31c 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -635,7 +635,7 @@ def test_pinned_worker_harness_image_tag_used_in_dev_sdk(self): '/beam_python%d.%d_sdk:%s' % ( sys.version_info[0], sys.version_info[1], - names.BEAM_FNAPI_CONTAINER_VERSION))) + names.BEAM_DEV_SDK_CONTAINER_TAG))) pipeline_options = PipelineOptions( ['--temp_location', 'gs://any-location/temp']) @@ -651,7 +651,7 @@ def test_pinned_worker_harness_image_tag_used_in_dev_sdk(self): '/beam_python%d.%d_sdk:%s' % ( sys.version_info[0], sys.version_info[1], - names.BEAM_FNAPI_CONTAINER_VERSION))) + names.BEAM_DEV_SDK_CONTAINER_TAG))) @mock.patch( 'apache_beam.runners.dataflow.internal.apiclient.' diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py index f86306eb276e..2075c8eee3f1 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/names.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py @@ -30,15 +30,10 @@ SOURCE_TYPE = 'CustomSourcesType' SERIALIZED_SOURCE_KEY = 'serialized_source' -# In a released SDK, container tags are selected based on the SDK version. -# Unreleased versions use container versions based on values of -# BEAM_CONTAINER_VERSION and BEAM_FNAPI_CONTAINER_VERSION (see below). - -# Update this version to the next version whenever there is a change that will -# require changes to legacy Dataflow worker execution environment. -BEAM_CONTAINER_VERSION = 'beam-master-20230629' -# Update this version to the next version whenever there is a change that +# In a released SDK, Python sdk container image is tagged with the SDK version. +# Unreleased sdks use container image tag specified below. +# Update this tag whenever there is a change that # requires changes to SDK harness container or SDK harness launcher. -BEAM_FNAPI_CONTAINER_VERSION = 'beam-master-20230705' +BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20230717' DATAFLOW_CONTAINER_IMAGE_REPOSITORY = 'gcr.io/cloud-dataflow/v1beta3' diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/__init__.py b/sdks/python/apache_beam/runners/dataflow/native_io/__init__.py deleted file mode 100644 index cce3acad34a4..000000000000 --- a/sdks/python/apache_beam/runners/dataflow/native_io/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py deleted file mode 100644 index 3d1afe546901..000000000000 --- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py +++ /dev/null @@ -1,342 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Dataflow native sources and sinks. - -For internal use only; no backwards-compatibility guarantees. -""" - -# pytype: skip-file - -import logging -from typing import TYPE_CHECKING -from typing import Optional - -from apache_beam import pvalue -from apache_beam.io import iobase -from apache_beam.transforms import ptransform -from apache_beam.transforms.display import HasDisplayData - -if TYPE_CHECKING: - from apache_beam import coders - -_LOGGER = logging.getLogger(__name__) - - -def _dict_printable_fields(dict_object, skip_fields): - """Returns a list of strings for the interesting fields of a dict.""" - return [ - '%s=%r' % (name, value) for name, - value in dict_object.items() - # want to output value 0 but not None nor [] - if (value or value == 0) and name not in skip_fields - ] - - -_minor_fields = [ - 'coder', - 'key_coder', - 'value_coder', - 'config_bytes', - 'elements', - 'append_trailing_newlines', - 'strip_trailing_newlines', - 'compression_type' -] - - -class NativeSource(iobase.SourceBase): - """A source implemented by Dataflow service. - - This class is to be only inherited by sources natively implemented by Cloud - Dataflow service, hence should not be sub-classed by users. - - This class is deprecated and should not be used to define new sources. - """ - coder = None # type: Optional[coders.Coder] - - def reader(self): - """Returns a NativeSourceReader instance associated with this source.""" - raise NotImplementedError - - def is_bounded(self): - return True - - def __repr__(self): - return '<{name} {vals}>'.format( - name=self.__class__.__name__, - vals=', '.join(_dict_printable_fields(self.__dict__, _minor_fields))) - - -class NativeSourceReader(object): - """A reader for a source implemented by Dataflow service.""" - def __enter__(self): - """Opens everything necessary for a reader to function properly.""" - raise NotImplementedError - - def __exit__(self, exception_type, exception_value, traceback): - """Cleans up after a reader executed.""" - raise NotImplementedError - - def __iter__(self): - """Returns an iterator over all the records of the source.""" - raise NotImplementedError - - @property - def returns_windowed_values(self): - """Returns whether this reader returns windowed values.""" - return False - - def get_progress(self): - """Returns a representation of how far the reader has read. - - Returns: - A SourceReaderProgress object that gives the current progress of the - reader. - """ - - def request_dynamic_split(self, dynamic_split_request): - """Attempts to split the input in two parts. - - The two parts are named the "primary" part and the "residual" part. The - current 'NativeSourceReader' keeps processing the primary part, while the - residual part will be processed elsewhere (e.g. perhaps on a different - worker). - - The primary and residual parts, if concatenated, must represent the - same input as the current input of this 'NativeSourceReader' before this - call. - - The boundary between the primary part and the residual part is - specified in a framework-specific way using 'DynamicSplitRequest' e.g., - if the framework supports the notion of positions, it might be a - position at which the input is asked to split itself (which is not - necessarily the same position at which it *will* split itself); it - might be an approximate fraction of input, or something else. - - This function returns a 'DynamicSplitResult', which encodes, in a - framework-specific way, the information sufficient to construct a - description of the resulting primary and residual inputs. For example, it - might, again, be a position demarcating these parts, or it might be a pair - of fully-specified input descriptions, or something else. - - After a successful call to 'request_dynamic_split()', subsequent calls - should be interpreted relative to the new primary. - - Args: - dynamic_split_request: A 'DynamicSplitRequest' describing the split - request. - - Returns: - 'None' if the 'DynamicSplitRequest' cannot be honored (in that - case the input represented by this 'NativeSourceReader' stays the same), - or a 'DynamicSplitResult' describing how the input was split into a - primary and residual part. - """ - _LOGGER.debug( - 'SourceReader %r does not support dynamic splitting. Ignoring dynamic ' - 'split request: %r', - self, - dynamic_split_request) - - -class ReaderProgress(object): - """A representation of how far a NativeSourceReader has read.""" - def __init__( - self, - position=None, - percent_complete=None, - remaining_time=None, - consumed_split_points=None, - remaining_split_points=None): - - self._position = position - - if percent_complete is not None: - percent_complete = float(percent_complete) - if percent_complete < 0 or percent_complete > 1: - raise ValueError( - 'The percent_complete argument was %f. Must be in range [0, 1].' % - percent_complete) - self._percent_complete = percent_complete - - self._remaining_time = remaining_time - self._consumed_split_points = consumed_split_points - self._remaining_split_points = remaining_split_points - - @property - def position(self): - """Returns progress, represented as a ReaderPosition object.""" - return self._position - - @property - def percent_complete(self): - """Returns progress, represented as a percentage of total work. - - Progress range from 0.0 (beginning, nothing complete) to 1.0 (end of the - work range, entire WorkItem complete). - - Returns: - Progress represented as a percentage of total work. - """ - return self._percent_complete - - @property - def remaining_time(self): - """Returns progress, represented as an estimated time remaining.""" - return self._remaining_time - - @property - def consumed_split_points(self): - return self._consumed_split_points - - @property - def remaining_split_points(self): - return self._remaining_split_points - - -class ReaderPosition(object): - """A representation of position in an iteration of a 'NativeSourceReader'.""" - def __init__( - self, - end=None, - key=None, - byte_offset=None, - record_index=None, - shuffle_position=None, - concat_position=None): - """Initializes ReaderPosition. - - A ReaderPosition may get instantiated for one of these position types. Only - one of these should be specified. - - Args: - end: position is past all other positions. For example, this may be used - to represent the end position of an unbounded range. - key: position is a string key. - byte_offset: position is a byte offset. - record_index: position is a record index - shuffle_position: position is a base64 encoded shuffle position. - concat_position: position is a 'ConcatPosition'. - """ - - self.end = end - self.key = key - self.byte_offset = byte_offset - self.record_index = record_index - self.shuffle_position = shuffle_position - - if concat_position is not None: - assert isinstance(concat_position, ConcatPosition) - self.concat_position = concat_position - - -class ConcatPosition(object): - """A position that encapsulate an inner position and an index. - - This is used to represent the position of a source that encapsulate several - other sources. - """ - def __init__(self, index, position): - """Initializes ConcatPosition. - - Args: - index: index of the source currently being read. - position: inner position within the source currently being read. - """ - - if position is not None: - assert isinstance(position, ReaderPosition) - self.index = index - self.position = position - - -class DynamicSplitRequest(object): - """Specifies how 'NativeSourceReader.request_dynamic_split' should split. - """ - def __init__(self, progress): - assert isinstance(progress, ReaderProgress) - self.progress = progress - - -class DynamicSplitResult(object): - pass - - -class DynamicSplitResultWithPosition(DynamicSplitResult): - def __init__(self, stop_position): - assert isinstance(stop_position, ReaderPosition) - self.stop_position = stop_position - - -class NativeSink(HasDisplayData): - """A sink implemented by Dataflow service. - - This class is to be only inherited by sinks natively implemented by Cloud - Dataflow service, hence should not be sub-classed by users. - """ - def writer(self): - """Returns a SinkWriter for this source.""" - raise NotImplementedError - - def __repr__(self): - return '<{name} {vals}>'.format( - name=self.__class__.__name__, - vals=_dict_printable_fields(self.__dict__, _minor_fields)) - - -class NativeSinkWriter(object): - """A writer for a sink implemented by Dataflow service.""" - def __enter__(self): - """Opens everything necessary for a writer to function properly.""" - raise NotImplementedError - - def __exit__(self, exception_type, exception_value, traceback): - """Cleans up after a writer executed.""" - raise NotImplementedError - - @property - def takes_windowed_values(self): - """Returns whether this writer takes windowed values.""" - return False - - def Write(self, o): # pylint: disable=invalid-name - """Writes a record to the sink associated with this writer.""" - raise NotImplementedError - - -class _NativeWrite(ptransform.PTransform): - """A PTransform for writing to a Dataflow native sink. - - These are sinks that are implemented natively by the Dataflow service - and hence should not be updated by users. These sinks are processed - using a Dataflow native write transform. - - Applying this transform results in a ``pvalue.PDone``. - """ - def __init__(self, sink): - """Initializes a Write transform. - - Args: - sink: Sink to use for the write - """ - super().__init__() - self.sink = sink - - def expand(self, pcoll): - self._check_pcollection(pcoll) - return pvalue.PDone(pcoll.pipeline) diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py deleted file mode 100644 index 5e72ca555b69..000000000000 --- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py +++ /dev/null @@ -1,203 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests corresponding to Dataflow's iobase module.""" - -# pytype: skip-file - -import unittest - -from apache_beam import Create -from apache_beam import error -from apache_beam import pvalue -from apache_beam.runners.dataflow.native_io.iobase import ConcatPosition -from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitRequest -from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitResultWithPosition -from apache_beam.runners.dataflow.native_io.iobase import NativeSink -from apache_beam.runners.dataflow.native_io.iobase import NativeSinkWriter -from apache_beam.runners.dataflow.native_io.iobase import NativeSource -from apache_beam.runners.dataflow.native_io.iobase import ReaderPosition -from apache_beam.runners.dataflow.native_io.iobase import ReaderProgress -from apache_beam.runners.dataflow.native_io.iobase import _dict_printable_fields -from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite -from apache_beam.testing.test_pipeline import TestPipeline - - -class TestHelperFunctions(unittest.TestCase): - def test_dict_printable_fields(self): - dict_object = { - 'key_alpha': '1', - 'key_beta': None, - 'key_charlie': [], - 'key_delta': 2.0, - 'key_echo': 'skip_me', - 'key_fox': 0 - } - skip_fields = [ - 'key_echo', - ] - self.assertEqual( - sorted(_dict_printable_fields(dict_object, skip_fields)), - ["key_alpha='1'", 'key_delta=2.0', 'key_fox=0']) - - -class TestNativeSource(unittest.TestCase): - def test_reader_method(self): - native_source = NativeSource() - self.assertRaises(NotImplementedError, native_source.reader) - - def test_repr_method(self): - class FakeSource(NativeSource): - """A fake source modeled after BigQuerySource, which inherits from - NativeSource.""" - def __init__( - self, - table=None, - dataset=None, - project=None, - query=None, - validate=False, - coder=None, - use_std_sql=False, - flatten_results=True): - self.validate = validate - - fake_source = FakeSource() - self.assertEqual(fake_source.__repr__(), '') - - -class TestReaderProgress(unittest.TestCase): - def test_out_of_bounds_percent_complete(self): - with self.assertRaises(ValueError): - ReaderProgress(percent_complete=-0.1) - with self.assertRaises(ValueError): - ReaderProgress(percent_complete=1.1) - - def test_position_property(self): - reader_progress = ReaderProgress(position=ReaderPosition()) - self.assertEqual(type(reader_progress.position), ReaderPosition) - - def test_percent_complete_property(self): - reader_progress = ReaderProgress(percent_complete=0.5) - self.assertEqual(reader_progress.percent_complete, 0.5) - - -class TestReaderPosition(unittest.TestCase): - def test_invalid_concat_position_type(self): - with self.assertRaises(AssertionError): - ReaderPosition(concat_position=1) - - def test_valid_concat_position_type(self): - ReaderPosition(concat_position=ConcatPosition(None, None)) - - -class TestConcatPosition(unittest.TestCase): - def test_invalid_position_type(self): - with self.assertRaises(AssertionError): - ConcatPosition(None, position=1) - - def test_valid_position_type(self): - ConcatPosition(None, position=ReaderPosition()) - - -class TestDynamicSplitRequest(unittest.TestCase): - def test_invalid_progress_type(self): - with self.assertRaises(AssertionError): - DynamicSplitRequest(progress=1) - - def test_valid_progress_type(self): - DynamicSplitRequest(progress=ReaderProgress()) - - -class TestDynamicSplitResultWithPosition(unittest.TestCase): - def test_invalid_stop_position_type(self): - with self.assertRaises(AssertionError): - DynamicSplitResultWithPosition(stop_position=1) - - def test_valid_stop_position_type(self): - DynamicSplitResultWithPosition(stop_position=ReaderPosition()) - - -class TestNativeSink(unittest.TestCase): - def test_writer_method(self): - native_sink = NativeSink() - self.assertRaises(NotImplementedError, native_sink.writer) - - def test_repr_method(self): - class FakeSink(NativeSink): - """A fake sink modeled after BigQuerySink, which inherits from - NativeSink.""" - def __init__( - self, - validate=False, - dataset=None, - project=None, - schema=None, - create_disposition='create', - write_disposition=None, - coder=None): - self.validate = validate - - fake_sink = FakeSink() - self.assertEqual(fake_sink.__repr__(), "") - - def test_on_direct_runner(self): - class FakeSink(NativeSink): - """A fake sink outputing a number of elements.""" - def __init__(self): - self.written_values = [] - self.writer_instance = FakeSinkWriter(self.written_values) - - def writer(self): - return self.writer_instance - - class FakeSinkWriter(NativeSinkWriter): - """A fake sink writer for testing.""" - def __init__(self, written_values): - self.written_values = written_values - - def __enter__(self): - return self - - def __exit__(self, *unused_args): - pass - - def Write(self, value): - self.written_values.append(value) - - with TestPipeline() as p: - sink = FakeSink() - p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned - - self.assertEqual(['a', 'b', 'c'], sorted(sink.written_values)) - - -class Test_NativeWrite(unittest.TestCase): - def setUp(self): - self.native_sink = NativeSink() - self.native_write = _NativeWrite(self.native_sink) - - def test_expand_method_pcollection_errors(self): - with self.assertRaises(error.TransformError): - self.native_write.expand(None) - with self.assertRaises(error.TransformError): - pcoll = pvalue.PCollection(pipeline=None) - self.native_write.expand(pcoll) - - -if __name__ == '__main__': - unittest.main() diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 6466bc7752b5..db53e4122bbc 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -72,9 +72,9 @@ def is_fnapi_compatible(self): def run_pipeline(self, pipeline, options): from apache_beam.pipeline import PipelineVisitor - from apache_beam.runners.dataflow.native_io.iobase import NativeSource - from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite from apache_beam.testing.test_stream import TestStream + from apache_beam.io.gcp.pubsub import ReadFromPubSub + from apache_beam.io.gcp.pubsub import WriteToPubSub class _FnApiRunnerSupportVisitor(PipelineVisitor): """Visitor determining if a Pipeline can be run on the FnApiRunner.""" @@ -83,18 +83,17 @@ def accept(self, pipeline): pipeline.visit(self) return self.supported_by_fnapi_runner + def enter_composite_transform(self, applied_ptransform): + # The FnApiRunner does not support streaming execution. + if isinstance(applied_ptransform.transform, + (ReadFromPubSub, WriteToPubSub)): + self.supported_by_fnapi_runner = False + def visit_transform(self, applied_ptransform): transform = applied_ptransform.transform # The FnApiRunner does not support streaming execution. if isinstance(transform, TestStream): self.supported_by_fnapi_runner = False - # The FnApiRunner does not support reads from NativeSources. - if (isinstance(transform, beam.io.Read) and - isinstance(transform.source, NativeSource)): - self.supported_by_fnapi_runner = False - # The FnApiRunner does not support the use of _NativeWrites. - if isinstance(transform, _NativeWrite): - self.supported_by_fnapi_runner = False if isinstance(transform, beam.ParDo): dofn = transform.dofn # The FnApiRunner does not support execution of CombineFns with diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index bfb27c4adc00..37004c7258a7 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -39,7 +39,6 @@ from apache_beam.runners import common from apache_beam.runners.common import DoFnRunner from apache_beam.runners.common import DoFnState -from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access from apache_beam.runners.direct.direct_runner import _DirectReadFromPubSub from apache_beam.runners.direct.direct_runner import _GroupByKeyOnly from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow @@ -106,7 +105,6 @@ def __init__(self, evaluation_context): _GroupByKeyOnly: _GroupByKeyOnlyEvaluator, _StreamingGroupByKeyOnly: _StreamingGroupByKeyOnlyEvaluator, _StreamingGroupAlsoByWindow: _StreamingGroupAlsoByWindowEvaluator, - _NativeWrite: _NativeWriteEvaluator, _TestStream: _TestStreamEvaluator, ProcessElements: _ProcessElementsEvaluator, _WatermarkController: _WatermarkControllerEvaluator, @@ -172,11 +170,10 @@ def should_execute_serially(self, applied_ptransform): Returns: True if executor should execute applied_ptransform serially. """ - if isinstance(applied_ptransform.transform, - (_GroupByKeyOnly, - _StreamingGroupByKeyOnly, - _StreamingGroupAlsoByWindow, - _NativeWrite)): + if isinstance( + applied_ptransform.transform, + (_GroupByKeyOnly, _StreamingGroupByKeyOnly, + _StreamingGroupAlsoByWindow)): return True elif (isinstance(applied_ptransform.transform, core.ParDo) and is_stateful_dofn(applied_ptransform.transform.dofn)): @@ -1125,77 +1122,6 @@ def finish_bundle(self): return TransformResult(self, bundles, [], None, self.keyed_holds) -class _NativeWriteEvaluator(_TransformEvaluator): - """TransformEvaluator for _NativeWrite transform.""" - - ELEMENTS_TAG = _ListStateTag('elements') - - def __init__( - self, - evaluation_context, - applied_ptransform, - input_committed_bundle, - side_inputs): - assert not side_inputs - super().__init__( - evaluation_context, - applied_ptransform, - input_committed_bundle, - side_inputs) - - assert applied_ptransform.transform.sink - self._sink = applied_ptransform.transform.sink - - @property - def _is_final_bundle(self): - return ( - self._execution_context.watermarks.input_watermark == - WatermarkManager.WATERMARK_POS_INF) - - @property - def _has_already_produced_output(self): - return ( - self._execution_context.watermarks.output_watermark == - WatermarkManager.WATERMARK_POS_INF) - - def start_bundle(self): - self.global_state = self._step_context.get_keyed_state(None) - - def process_timer(self, timer_firing): - # We do not need to emit a KeyedWorkItem to process_element(). - pass - - def process_element(self, element): - self.global_state.add_state( - None, _NativeWriteEvaluator.ELEMENTS_TAG, element) - - def finish_bundle(self): - # finish_bundle will append incoming bundles in memory until all the bundles - # carrying data is processed. This is done to produce only a single output - # shard (some tests depends on this behavior). It is possible to have - # incoming empty bundles after the output is produced, these bundles will be - # ignored and would not generate additional output files. - # TODO(altay): Do not wait until the last bundle to write in a single shard. - if self._is_final_bundle: - elements = self.global_state.get_state( - None, _NativeWriteEvaluator.ELEMENTS_TAG) - if self._has_already_produced_output: - # Ignore empty bundles that arrive after the output is produced. - assert elements == [] - else: - self._sink.pipeline_options = self._evaluation_context.pipeline_options - with self._sink.writer() as writer: - for v in elements: - writer.Write(v.value) - hold = WatermarkManager.WATERMARK_POS_INF - else: - hold = WatermarkManager.WATERMARK_NEG_INF - self.global_state.set_timer( - None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) - - return TransformResult(self, [], [], None, {None: hold}) - - class _ProcessElementsEvaluator(_TransformEvaluator): """An evaluator for sdf_direct_runner.ProcessElements transform.""" diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index be7f99dc61f4..8d957068d08b 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -825,11 +825,10 @@ def _execute_bundle(self, buffers_to_clean = set() known_consumers = set() - for transform_id, buffer_id in ( - bundle_context_manager.stage_data_outputs.items()): - for (consuming_stage_name, consuming_transform - ) in runner_execution_context.buffer_id_to_consumer_pairs.get( - buffer_id, []): + for _, buffer_id in bundle_context_manager.stage_data_outputs.items(): + for (consuming_stage_name, consuming_transform) in \ + runner_execution_context.buffer_id_to_consumer_pairs.get(buffer_id, + []): buffer = runner_execution_context.pcoll_buffers.get(buffer_id, None) if (buffer_id in runner_execution_context.pcoll_buffers and @@ -841,11 +840,6 @@ def _execute_bundle(self, # so we create a copy of the buffer for every new stage. runner_execution_context.pcoll_buffers[buffer_id] = buffer.copy() buffer = runner_execution_context.pcoll_buffers[buffer_id] - # When the buffer is not in the pcoll_buffers, it means that the - # it could be an empty PCollection. In this case, get the buffer using - # the buffer id and transform id - if buffer is None: - buffer = bundle_context_manager.get_buffer(buffer_id, transform_id) # If the buffer has already been added to be consumed by # (stage, transform), then we don't need to add it again. This case @@ -860,7 +854,7 @@ def _execute_bundle(self, # MAX_TIMESTAMP for the downstream stage. runner_execution_context.queues.watermark_pending_inputs.enque( ((consuming_stage_name, timestamp.MAX_TIMESTAMP), - DataInput({consuming_transform: buffer}, {}))) + DataInput({consuming_transform: buffer}, {}))) # type: ignore for bid in buffers_to_clean: if bid in runner_execution_context.pcoll_buffers: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index b55c7162aea7..ed09bb8f2236 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -1831,15 +1831,6 @@ def create_pipeline(self, is_drain=False): p._options.view_as(DebugOptions).experiments.remove('beam_fn_api') return p - def test_group_by_key_with_empty_pcoll_elements(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create([('test_key', 'test_value')]) - | beam.Filter(lambda x: False) - | beam.GroupByKey()) - assert_that(res, equal_to([])) - def test_metrics(self): raise unittest.SkipTest("This test is for a single worker only.") diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py index 8b81c9f17ac6..673c1cea111a 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py @@ -19,8 +19,6 @@ # pytype: skip-file import unittest -from typing import Dict -from typing import List import apache_beam as beam from apache_beam.coders.coders import FastPrimitivesCoder diff --git a/sdks/python/apache_beam/runners/worker/data_sampler_test.py b/sdks/python/apache_beam/runners/worker/data_sampler_test.py index b67936121830..8d063fdb49d6 100644 --- a/sdks/python/apache_beam/runners/worker/data_sampler_test.py +++ b/sdks/python/apache_beam/runners/worker/data_sampler_test.py @@ -22,7 +22,6 @@ import traceback import unittest from typing import Any -from typing import Dict from typing import List from typing import Optional diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index abbcfb72382b..87cf06e862ab 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -52,7 +52,7 @@ def _import_beam_plugins(plugins): for plugin in plugins: try: importlib.import_module(plugin) - _LOGGER.info('Imported beam-plugin %s', plugin) + _LOGGER.debug('Imported beam-plugin %s', plugin) except ImportError: try: _LOGGER.debug(( @@ -61,7 +61,7 @@ def _import_beam_plugins(plugins): plugin) module, _ = plugin.rsplit('.', 1) importlib.import_module(module) - _LOGGER.info('Imported %s for beam-plugin %s', module, plugin) + _LOGGER.debug('Imported %s for beam-plugin %s', module, plugin) except ImportError as exc: _LOGGER.warning('Failed to import beam-plugin %s', plugin, exc_info=exc) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index ccfc27d0101c..8570c5a7722c 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -23,7 +23,6 @@ import logging import unittest from collections import namedtuple -from typing import Any import grpc import hamcrest as hc diff --git a/sdks/python/apache_beam/testing/analyzers/constants.py b/sdks/python/apache_beam/testing/analyzers/constants.py index 2a52fab563db..8f8bdf13300c 100644 --- a/sdks/python/apache_beam/testing/analyzers/constants.py +++ b/sdks/python/apache_beam/testing/analyzers/constants.py @@ -70,3 +70,5 @@ }, { 'name': _ISSUE_URL, 'field_type': 'STRING', 'mode': 'REQUIRED' }] + +_ANOMALY_MARKER = ' <---- Anomaly' diff --git a/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py b/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py index f6e2939161ef..e1f20baa50a6 100644 --- a/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py +++ b/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py @@ -24,6 +24,8 @@ import pandas as pd import requests +from apache_beam.testing.analyzers import constants + try: _GITHUB_TOKEN: Optional[str] = os.environ['GITHUB_TOKEN'] except KeyError as e: @@ -162,28 +164,32 @@ def get_issue_description( """ # TODO: Add mean and median before and after the changepoint index. - max_timestamp_index = min( - change_point_index + max_results_to_display, len(metric_values) - 1) - min_timestamp_index = max(0, change_point_index - max_results_to_display) - description = _ISSUE_DESCRIPTION_TEMPLATE.format( - test_name, metric_name) + 2 * '\n' + description = [] + + description.append(_ISSUE_DESCRIPTION_TEMPLATE.format(test_name, metric_name)) - description += ( - "`Test description:` " + f'{test_description}' + - 2 * '\n') if test_description else '' + description.append(("`Test description:` " + + f'{test_description}') if test_description else '') - description += '```' + '\n' - runs_to_display = [ - _METRIC_INFO_TEMPLATE.format( - timestamps[i].ctime(), format(metric_values[i], '.2f')) - for i in reversed(range(min_timestamp_index, max_timestamp_index + 1)) - ] + description.append('```') + + runs_to_display = [] + max_timestamp_index = min( + change_point_index + max_results_to_display, len(metric_values) - 1) + min_timestamp_index = max(0, change_point_index - max_results_to_display) - runs_to_display[change_point_index - min_timestamp_index] += " <---- Anomaly" - description += '\n'.join(runs_to_display) + '\n' - description += '```' + '\n' - return description + # run in reverse to display the most recent runs first. + for i in reversed(range(min_timestamp_index, max_timestamp_index + 1)): + row_template = _METRIC_INFO_TEMPLATE.format( + timestamps[i].ctime(), format(metric_values[i], '.2f')) + if i == change_point_index: + row_template += constants._ANOMALY_MARKER + runs_to_display.append(row_template) + + description.append(os.linesep.join(runs_to_display)) + description.append('```') + return (2 * os.linesep).join(description) def report_change_point_on_issues( diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py index fabf185a41d8..c18b1bb9506b 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py @@ -16,8 +16,10 @@ # # pytype: skip-file +import datetime import logging import os +import re import unittest import mock @@ -28,10 +30,13 @@ try: import apache_beam.testing.analyzers.perf_analysis as analysis from apache_beam.testing.analyzers import constants + from apache_beam.testing.analyzers import github_issues_utils from apache_beam.testing.analyzers.perf_analysis_utils import is_change_point_in_valid_window from apache_beam.testing.analyzers.perf_analysis_utils import is_perf_alert from apache_beam.testing.analyzers.perf_analysis_utils import e_divisive + from apache_beam.testing.analyzers.perf_analysis_utils import find_latest_change_point_index from apache_beam.testing.analyzers.perf_analysis_utils import validate_config + except ImportError as e: analysis = None # type: ignore @@ -45,17 +50,18 @@ def get_fake_data_with_no_change_point(**kwargs): def get_fake_data_with_change_point(**kwargs): + # change point will be at index 13. num_samples = 20 - metric_values = [0] * (num_samples // 2) + [1] * (num_samples // 2) + metric_values = [0] * 12 + [3] + [4] * 7 timestamps = [i for i in range(num_samples)] return metric_values, timestamps def get_existing_issue_data(**kwargs): - # change point found at index 10. So passing 10 in the + # change point found at index 13. So passing 13 in the # existing issue data in mock method. return pd.DataFrame([{ - constants._CHANGE_POINT_TIMESTAMP_LABEL: 10, + constants._CHANGE_POINT_TIMESTAMP_LABEL: 13, constants._ISSUE_NUMBER: np.array([0]) }]) @@ -193,6 +199,29 @@ def test_alert_on_data_with_reported_change_point(self, *args): big_query_metrics_fetcher=None) self.assertFalse(is_alert) + def test_change_point_has_anomaly_marker_in_gh_description(self): + metric_values, timestamps = get_fake_data_with_change_point() + timestamps = [datetime.datetime.fromtimestamp(ts) for ts in timestamps] + change_point_index = find_latest_change_point_index(metric_values) + + description = github_issues_utils.get_issue_description( + test_name=self.test_id, + test_description=self.params['test_description'], + metric_name=self.params['metric_name'], + metric_values=metric_values, + timestamps=timestamps, + change_point_index=change_point_index, + max_results_to_display=( + constants._NUM_RESULTS_TO_DISPLAY_ON_ISSUE_DESCRIPTION)) + + runs_info = next(( + line for line in description.split(2 * os.linesep) + if re.match(r'timestamp: .*, metric_value: .*', line.strip())), + '') + pattern = (r'timestamp: .+ (\d{4}), metric_value: (\d+.\d+) <---- Anomaly') + match = re.search(pattern, runs_info) + self.assertTrue(match) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 6834c1d6174b..62dbbc5fb77c 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -20,12 +20,10 @@ # pytype: skip-file import unittest -from functools import wraps import pytest from parameterized import parameterized_class -from apache_beam.options.pipeline_options import DebugOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.runners.direct import direct_runner @@ -36,39 +34,17 @@ from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo -def skip_unless_v2(fn): - @wraps(fn) - def wrapped(*args, **kwargs): - self = args[0] - options = self.pipeline.get_pipeline_options() - standard_options = options.view_as(StandardOptions) - experiments = options.view_as(DebugOptions).experiments or [] - - if 'DataflowRunner' in standard_options.runner and \ - 'use_runner_v2' not in experiments: - self.skipTest( - 'CombineFn.setup and CombineFn.teardown are not supported. ' - 'Please use Dataflow Runner V2.') - else: - return fn(*args, **kwargs) - - return wrapped - - @pytest.mark.it_validatesrunner class CombineFnLifecycleTest(unittest.TestCase): def setUp(self): self.pipeline = TestPipeline(is_integration_test=True) - @skip_unless_v2 def test_combine(self): run_combine(self.pipeline) - @skip_unless_v2 def test_non_liftable_combine(self): run_combine(self.pipeline, lift_combiners=False) - @skip_unless_v2 def test_combining_value_state(self): if ('DataflowRunner' in self.pipeline.get_pipeline_options().view_as( StandardOptions).runner): diff --git a/sdks/python/build-requirements.txt b/sdks/python/build-requirements.txt index e8152fbc3ba1..32ccb05dce08 100644 --- a/sdks/python/build-requirements.txt +++ b/sdks/python/build-requirements.txt @@ -22,7 +22,7 @@ wheel>=0.36.0 grpcio-tools==1.53.0 mypy-protobuf==3.4.0 # Avoid https://github.com/pypa/virtualenv/issues/2006 -distlib==0.3.6 +distlib==0.3.7 # Numpy headers numpy>=1.14.3,<1.26 diff --git a/sdks/python/build.gradle b/sdks/python/build.gradle index 0c2ed72dfba1..88d0abae83bd 100644 --- a/sdks/python/build.gradle +++ b/sdks/python/build.gradle @@ -96,7 +96,7 @@ platform_identifiers_map.each { platform, idsuffix -> exec { environment CIBW_BUILD: "cp${pyversion}-${idsuffix}" environment CIBW_ENVIRONMENT: "SETUPTOOLS_USE_DISTUTILS=stdlib" - environment CIBW_BEFORE_BUILD: "pip install cython numpy && pip install --upgrade setuptools" + environment CIBW_BEFORE_BUILD: "pip install cython==0.29.36 numpy && pip install --upgrade setuptools" // note: sync cibuildwheel version with GitHub Action // .github/workflow/build_wheel.yml:build_wheels "Install cibuildwheel" step executable 'sh' diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index 1e70e0db1513..da2f3cc28f54 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -164,11 +164,11 @@ func launchSDKProcess() error { if err != nil { return errors.New( "failed to create a virtual environment. If running on Ubuntu systems, " + - "you might need to install `python3-venv` package. " + - "To run the SDK process in default environment instead, " + - "set the environment variable `RUN_PYTHON_SDK_IN_DEFAULT_ENVIRONMENT=1`. " + - "In custom Docker images, you can do that with an `ENV` statement. " + - fmt.Sprintf("Encountered error: %v", err)) + "you might need to install `python3-venv` package. " + + "To run the SDK process in default environment instead, " + + "set the environment variable `RUN_PYTHON_SDK_IN_DEFAULT_ENVIRONMENT=1`. " + + "In custom Docker images, you can do that with an `ENV` statement. " + + fmt.Sprintf("Encountered error: %v", err)) } cleanupFunc := func() { os.RemoveAll(venvDir) diff --git a/sdks/python/container/piputil.go b/sdks/python/container/piputil.go index a00e017445e3..c9e396b0e6a4 100644 --- a/sdks/python/container/piputil.go +++ b/sdks/python/container/piputil.go @@ -37,14 +37,15 @@ func pipInstallRequirements(files []string, dir, name string) error { // as possible PyPI downloads. In the first round the --find-links // option will make sure that only things staged in the worker will be // used without following their dependencies. - args := []string{"-m", "pip", "install", "-r", filepath.Join(dir, name), "--no-cache-dir", "--disable-pip-version-check", "--no-index", "--no-deps", "--find-links", dir} + args := []string{"-m", "pip", "install", "-r", filepath.Join(dir, name), "--no-cache-dir", "--disable-pip-version-check", "--no-index", "--no-deps", "--find-links", + "-q", dir} if err := execx.Execute("python", args...); err != nil { fmt.Println("Some packages could not be installed solely from the requirements cache. Installing packages from PyPI.") } // The second install round opens up the search for packages on PyPI and // also installs dependencies. The key is that if all the packages have // been installed in the first round then this command will be a no-op. - args = []string{"-m", "pip", "install", "-r", filepath.Join(dir, name), "--no-cache-dir", "--disable-pip-version-check", "--find-links", dir} + args = []string{"-m", "pip", "install", "-r", filepath.Join(dir, name), "--no-cache-dir", "--disable-pip-version-check", "--find-links", "-q", dir} return execx.Execute("python", args...) } } @@ -76,18 +77,18 @@ func pipInstallPackage(files []string, dir, name string, force, optional bool, e // installed version will match the package specified, the package itself // will not be reinstalled, but its dependencies will now be resolved and // installed if necessary. This achieves our goal outlined above. - args := []string{"-m", "pip", "install", "--no-cache-dir", "--disable-pip-version-check", "--upgrade", "--force-reinstall", "--no-deps", + args := []string{"-m", "pip", "install", "--no-cache-dir", "--disable-pip-version-check", "--upgrade", "--force-reinstall", "--no-deps", "-q", filepath.Join(dir, packageSpec)} err := execx.Execute("python", args...) if err != nil { return err } - args = []string{"-m", "pip", "install", "--no-cache-dir", "--disable-pip-version-check", filepath.Join(dir, packageSpec)} + args = []string{"-m", "pip", "install", "--no-cache-dir", "--disable-pip-version-check", "-q", filepath.Join(dir, packageSpec)} return execx.Execute("python", args...) } // Case when we do not perform a forced reinstall. - args := []string{"-m", "pip", "install", "--no-cache-dir", "--disable-pip-version-check", filepath.Join(dir, packageSpec)} + args := []string{"-m", "pip", "install", "--no-cache-dir", "--disable-pip-version-check", "-q", filepath.Join(dir, packageSpec)} return execx.Execute("python", args...) } } diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index e50cd2841156..5396fc615977 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -29,48 +29,48 @@ bs4==0.0.1 cachetools==5.3.1 certifi==2023.5.7 cffi==1.15.1 -charset-normalizer==3.1.0 -click==8.1.3 +charset-normalizer==3.2.0 +click==8.1.5 cloudpickle==2.2.1 crcmod==1.7 -cryptography==41.0.1 -Cython==0.29.35 +cryptography==41.0.2 +Cython==0.29.36 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.3.0 docker==6.1.3 docopt==0.6.2 -exceptiongroup==1.1.1 -execnet==1.9.0 -fastavro==1.7.4 +exceptiongroup==1.1.2 +execnet==2.0.2 +fastavro==1.8.0 fasteners==0.18 flatbuffers==23.5.26 freezegun==1.2.2 future==0.18.3 gast==0.4.0 google-api-core==2.11.1 -google-api-python-client==2.90.0 +google-api-python-client==2.93.0 google-apitools==0.5.31 -google-auth==2.21.0 +google-auth==2.22.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 -google-cloud-aiplatform==1.26.1 -google-cloud-bigquery==3.11.2 -google-cloud-bigquery-storage==2.20.0 +google-cloud-aiplatform==1.28.0 +google-cloud-bigquery==3.11.3 +google-cloud-bigquery-storage==2.22.0 google-cloud-bigtable==2.19.0 -google-cloud-core==2.3.2 -google-cloud-datastore==2.16.0 -google-cloud-dlp==3.12.1 -google-cloud-language==2.10.0 +google-cloud-core==2.3.3 +google-cloud-datastore==2.16.1 +google-cloud-dlp==3.12.2 +google-cloud-language==2.10.1 google-cloud-profiler==4.0.0 -google-cloud-pubsub==2.17.1 -google-cloud-pubsublite==1.8.2 -google-cloud-recommendations-ai==0.10.3 -google-cloud-resource-manager==1.10.1 +google-cloud-pubsub==2.18.0 +google-cloud-pubsublite==1.8.3 +google-cloud-recommendations-ai==0.10.4 +google-cloud-resource-manager==1.10.2 google-cloud-spanner==3.36.0 google-cloud-storage==2.10.0 -google-cloud-videointelligence==2.11.2 -google-cloud-vision==3.4.3 +google-cloud-videointelligence==2.11.3 +google-cloud-vision==3.4.4 google-crc32c==1.5.0 google-pasta==0.2.0 google-resumable-media==2.5.0 @@ -83,33 +83,31 @@ guppy3==3.1.3 h5py==3.9.0 hdfs==2.7.0 httplib2==0.22.0 -hypothesis==6.79.3 +hypothesis==6.81.1 idna==3.4 iniconfig==2.0.0 -jax==0.4.13 -joblib==1.2.0 -keras==2.12.0 +joblib==1.3.1 +keras==2.13.1 libclang==16.0.0 Markdown==3.4.3 MarkupSafe==2.1.3 -ml-dtypes==0.2.0 mmh3==4.0.0 -mock==5.0.2 +mock==5.1.0 nltk==3.8.1 nose==1.3.7 -numpy==1.23.5 +numpy==1.24.3 oauth2client==4.1.3 oauthlib==3.2.2 objsize==0.6.1 opt-einsum==3.3.0 -orjson==3.9.1 +orjson==3.9.2 overrides==6.5.0 packaging==23.1 pandas==1.5.3 parameterized==0.9.0 pluggy==1.2.0 proto-plus==1.22.3 -protobuf==4.23.3 +protobuf==4.23.4 psycopg2-binary==2.9.6 pyarrow==11.0.0 pyasn1==0.5.0 @@ -117,7 +115,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 -pymongo==4.4.0 +pymongo==4.4.1 PyMySQL==1.1.0 pyparsing==3.1.0 pytest==7.4.0 @@ -132,29 +130,29 @@ requests==2.31.0 requests-mock==1.11.0 requests-oauthlib==1.3.1 rsa==4.9 -scikit-learn==1.2.2 -scipy==1.11.0 +scikit-learn==1.3.0 +scipy==1.11.1 Shapely==1.8.5.post1 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.4.1 -SQLAlchemy==1.4.48 +SQLAlchemy==1.4.49 sqlparse==0.4.4 tenacity==8.2.2 -tensorboard==2.12.3 +tensorboard==2.13.0 tensorboard-data-server==0.7.1 -tensorflow==2.12.0 -tensorflow-estimator==2.12.0 +tensorflow==2.13.0 +tensorflow-estimator==2.13.0 tensorflow-io-gcs-filesystem==0.32.0 termcolor==2.3.0 testcontainers==3.7.1 -threadpoolctl==3.1.0 +threadpoolctl==3.2.0 tomli==2.0.1 tqdm==4.65.0 -typing_extensions==4.6.3 +typing_extensions==4.5.0 uritemplate==4.1.1 urllib3==1.26.16 websocket-client==1.6.1 Werkzeug==2.3.6 -wrapt==1.14.1 +wrapt==1.15.0 zstandard==0.21.0 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 8afa7e2a3456..20fcba6e2117 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -29,19 +29,19 @@ bs4==0.0.1 cachetools==5.3.1 certifi==2023.5.7 cffi==1.15.1 -charset-normalizer==3.1.0 -click==8.1.3 +charset-normalizer==3.2.0 +click==8.1.5 cloudpickle==2.2.1 crcmod==1.7 -cryptography==41.0.1 -Cython==0.29.35 +cryptography==41.0.2 +Cython==0.29.36 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.3.0 docker==6.1.3 docopt==0.6.2 -execnet==1.9.0 -fastavro==1.7.4 +execnet==2.0.2 +fastavro==1.8.0 fasteners==0.18 flatbuffers==23.5.26 freezegun==1.2.2 @@ -49,25 +49,25 @@ future==0.18.3 gast==0.4.0 google-api-core==2.11.1 google-apitools==0.5.31 -google-auth==2.21.0 +google-auth==2.22.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 -google-cloud-aiplatform==1.26.1 -google-cloud-bigquery==3.11.2 -google-cloud-bigquery-storage==2.20.0 +google-cloud-aiplatform==1.28.0 +google-cloud-bigquery==3.11.3 +google-cloud-bigquery-storage==2.22.0 google-cloud-bigtable==2.19.0 -google-cloud-core==2.3.2 -google-cloud-datastore==2.16.0 -google-cloud-dlp==3.12.1 -google-cloud-language==2.10.0 -google-cloud-pubsub==2.17.1 -google-cloud-pubsublite==1.8.2 -google-cloud-recommendations-ai==0.10.3 -google-cloud-resource-manager==1.10.1 +google-cloud-core==2.3.3 +google-cloud-datastore==2.16.1 +google-cloud-dlp==3.12.2 +google-cloud-language==2.10.1 +google-cloud-pubsub==2.18.0 +google-cloud-pubsublite==1.8.3 +google-cloud-recommendations-ai==0.10.4 +google-cloud-resource-manager==1.10.2 google-cloud-spanner==3.36.0 google-cloud-storage==2.10.0 -google-cloud-videointelligence==2.11.2 -google-cloud-vision==3.4.3 +google-cloud-videointelligence==2.11.3 +google-cloud-vision==3.4.4 google-crc32c==1.5.0 google-pasta==0.2.0 google-resumable-media==2.5.0 @@ -80,33 +80,31 @@ guppy3==3.1.3 h5py==3.9.0 hdfs==2.7.0 httplib2==0.22.0 -hypothesis==6.79.3 +hypothesis==6.81.1 idna==3.4 iniconfig==2.0.0 -jax==0.4.13 -joblib==1.2.0 -keras==2.12.0 +joblib==1.3.1 +keras==2.13.1 libclang==16.0.0 Markdown==3.4.3 MarkupSafe==2.1.3 -ml-dtypes==0.2.0 mmh3==4.0.0 -mock==5.0.2 +mock==5.1.0 nltk==3.8.1 nose==1.3.7 -numpy==1.23.5 +numpy==1.24.3 oauth2client==4.1.3 oauthlib==3.2.2 objsize==0.6.1 opt-einsum==3.3.0 -orjson==3.9.1 +orjson==3.9.2 overrides==6.5.0 packaging==23.1 pandas==1.5.3 parameterized==0.9.0 pluggy==1.2.0 proto-plus==1.22.3 -protobuf==4.23.3 +protobuf==4.23.4 psycopg2-binary==2.9.6 pyarrow==11.0.0 pyasn1==0.5.0 @@ -114,7 +112,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 -pymongo==4.4.0 +pymongo==4.4.1 PyMySQL==1.1.0 pyparsing==3.1.0 pytest==7.4.0 @@ -128,27 +126,27 @@ requests==2.31.0 requests-mock==1.11.0 requests-oauthlib==1.3.1 rsa==4.9 -scikit-learn==1.2.2 -scipy==1.11.0 +scikit-learn==1.3.0 +scipy==1.11.1 Shapely==1.8.5.post1 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.4.1 -SQLAlchemy==1.4.48 +SQLAlchemy==1.4.49 sqlparse==0.4.4 tenacity==8.2.2 -tensorboard==2.12.3 +tensorboard==2.13.0 tensorboard-data-server==0.7.1 -tensorflow==2.12.0 -tensorflow-estimator==2.12.0 +tensorflow==2.13.0 +tensorflow-estimator==2.13.0 tensorflow-io-gcs-filesystem==0.32.0 termcolor==2.3.0 testcontainers==3.7.1 -threadpoolctl==3.1.0 +threadpoolctl==3.2.0 tqdm==4.65.0 -typing_extensions==4.6.3 +typing_extensions==4.5.0 urllib3==1.26.16 websocket-client==1.6.1 Werkzeug==2.3.6 -wrapt==1.14.1 +wrapt==1.15.0 zstandard==0.21.0 diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt index e03782dd7a74..127bc835c928 100644 --- a/sdks/python/container/py38/base_image_requirements.txt +++ b/sdks/python/container/py38/base_image_requirements.txt @@ -29,48 +29,48 @@ bs4==0.0.1 cachetools==5.3.1 certifi==2023.5.7 cffi==1.15.1 -charset-normalizer==3.1.0 -click==8.1.3 +charset-normalizer==3.2.0 +click==8.1.5 cloudpickle==2.2.1 crcmod==1.7 -cryptography==41.0.1 -Cython==0.29.35 +cryptography==41.0.2 +Cython==0.29.36 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.3.0 docker==6.1.3 docopt==0.6.2 -exceptiongroup==1.1.1 -execnet==1.9.0 -fastavro==1.7.4 +exceptiongroup==1.1.2 +execnet==2.0.2 +fastavro==1.8.0 fasteners==0.18 flatbuffers==23.5.26 freezegun==1.2.2 future==0.18.3 gast==0.4.0 google-api-core==2.11.1 -google-api-python-client==2.90.0 +google-api-python-client==2.93.0 google-apitools==0.5.31 -google-auth==2.21.0 +google-auth==2.22.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 -google-cloud-aiplatform==1.26.1 -google-cloud-bigquery==3.11.2 -google-cloud-bigquery-storage==2.20.0 +google-cloud-aiplatform==1.28.0 +google-cloud-bigquery==3.11.3 +google-cloud-bigquery-storage==2.22.0 google-cloud-bigtable==2.19.0 -google-cloud-core==2.3.2 -google-cloud-datastore==2.16.0 -google-cloud-dlp==3.12.1 -google-cloud-language==2.10.0 +google-cloud-core==2.3.3 +google-cloud-datastore==2.16.1 +google-cloud-dlp==3.12.2 +google-cloud-language==2.10.1 google-cloud-profiler==4.0.0 -google-cloud-pubsub==2.17.1 -google-cloud-pubsublite==1.8.2 -google-cloud-recommendations-ai==0.10.3 -google-cloud-resource-manager==1.10.1 +google-cloud-pubsub==2.18.0 +google-cloud-pubsublite==1.8.3 +google-cloud-recommendations-ai==0.10.4 +google-cloud-resource-manager==1.10.2 google-cloud-spanner==3.36.0 google-cloud-storage==2.10.0 -google-cloud-videointelligence==2.11.2 -google-cloud-vision==3.4.3 +google-cloud-videointelligence==2.11.3 +google-cloud-vision==3.4.4 google-crc32c==1.5.0 google-pasta==0.2.0 google-resumable-media==2.5.0 @@ -83,34 +83,32 @@ guppy3==3.1.3 h5py==3.9.0 hdfs==2.7.0 httplib2==0.22.0 -hypothesis==6.79.3 +hypothesis==6.81.1 idna==3.4 -importlib-metadata==6.7.0 +importlib-metadata==6.8.0 iniconfig==2.0.0 -jax==0.4.13 -joblib==1.2.0 -keras==2.12.0 +joblib==1.3.1 +keras==2.13.1 libclang==16.0.0 Markdown==3.4.3 MarkupSafe==2.1.3 -ml-dtypes==0.2.0 mmh3==4.0.0 -mock==5.0.2 +mock==5.1.0 nltk==3.8.1 nose==1.3.7 -numpy==1.23.5 +numpy==1.24.3 oauth2client==4.1.3 oauthlib==3.2.2 objsize==0.6.1 opt-einsum==3.3.0 -orjson==3.9.1 +orjson==3.9.2 overrides==6.5.0 packaging==23.1 pandas==1.5.3 parameterized==0.9.0 pluggy==1.2.0 proto-plus==1.22.3 -protobuf==4.23.3 +protobuf==4.23.4 psycopg2-binary==2.9.6 pyarrow==11.0.0 pyasn1==0.5.0 @@ -118,7 +116,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 -pymongo==4.4.0 +pymongo==4.4.1 PyMySQL==1.1.0 pyparsing==3.1.0 pytest==7.4.0 @@ -133,30 +131,30 @@ requests==2.31.0 requests-mock==1.11.0 requests-oauthlib==1.3.1 rsa==4.9 -scikit-learn==1.2.2 +scikit-learn==1.3.0 scipy==1.10.1 Shapely==1.8.5.post1 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.4.1 -SQLAlchemy==1.4.48 +SQLAlchemy==1.4.49 sqlparse==0.4.4 tenacity==8.2.2 -tensorboard==2.12.3 +tensorboard==2.13.0 tensorboard-data-server==0.7.1 -tensorflow==2.12.0 -tensorflow-estimator==2.12.0 +tensorflow==2.13.0 +tensorflow-estimator==2.13.0 tensorflow-io-gcs-filesystem==0.32.0 termcolor==2.3.0 testcontainers==3.7.1 -threadpoolctl==3.1.0 +threadpoolctl==3.2.0 tomli==2.0.1 tqdm==4.65.0 -typing_extensions==4.6.3 +typing_extensions==4.5.0 uritemplate==4.1.1 urllib3==1.26.16 websocket-client==1.6.1 Werkzeug==2.3.6 -wrapt==1.14.1 -zipp==3.15.0 +wrapt==1.15.0 +zipp==3.16.1 zstandard==0.21.0 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 1a89004a32f3..680303078a2b 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -29,48 +29,48 @@ bs4==0.0.1 cachetools==5.3.1 certifi==2023.5.7 cffi==1.15.1 -charset-normalizer==3.1.0 -click==8.1.3 +charset-normalizer==3.2.0 +click==8.1.5 cloudpickle==2.2.1 crcmod==1.7 -cryptography==41.0.1 -Cython==0.29.35 +cryptography==41.0.2 +Cython==0.29.36 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.3.0 docker==6.1.3 docopt==0.6.2 -exceptiongroup==1.1.1 -execnet==1.9.0 -fastavro==1.7.4 +exceptiongroup==1.1.2 +execnet==2.0.2 +fastavro==1.8.0 fasteners==0.18 flatbuffers==23.5.26 freezegun==1.2.2 future==0.18.3 gast==0.4.0 google-api-core==2.11.1 -google-api-python-client==2.90.0 +google-api-python-client==2.93.0 google-apitools==0.5.31 -google-auth==2.21.0 +google-auth==2.22.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 -google-cloud-aiplatform==1.26.1 -google-cloud-bigquery==3.11.2 -google-cloud-bigquery-storage==2.20.0 +google-cloud-aiplatform==1.28.0 +google-cloud-bigquery==3.11.3 +google-cloud-bigquery-storage==2.22.0 google-cloud-bigtable==2.19.0 -google-cloud-core==2.3.2 -google-cloud-datastore==2.16.0 -google-cloud-dlp==3.12.1 -google-cloud-language==2.10.0 +google-cloud-core==2.3.3 +google-cloud-datastore==2.16.1 +google-cloud-dlp==3.12.2 +google-cloud-language==2.10.1 google-cloud-profiler==4.0.0 -google-cloud-pubsub==2.17.1 -google-cloud-pubsublite==1.8.2 -google-cloud-recommendations-ai==0.10.3 -google-cloud-resource-manager==1.10.1 +google-cloud-pubsub==2.18.0 +google-cloud-pubsublite==1.8.3 +google-cloud-recommendations-ai==0.10.4 +google-cloud-resource-manager==1.10.2 google-cloud-spanner==3.36.0 google-cloud-storage==2.10.0 -google-cloud-videointelligence==2.11.2 -google-cloud-vision==3.4.3 +google-cloud-videointelligence==2.11.3 +google-cloud-vision==3.4.4 google-crc32c==1.5.0 google-pasta==0.2.0 google-resumable-media==2.5.0 @@ -83,34 +83,32 @@ guppy3==3.1.3 h5py==3.9.0 hdfs==2.7.0 httplib2==0.22.0 -hypothesis==6.79.3 +hypothesis==6.81.1 idna==3.4 -importlib-metadata==6.7.0 +importlib-metadata==6.8.0 iniconfig==2.0.0 -jax==0.4.13 -joblib==1.2.0 -keras==2.12.0 +joblib==1.3.1 +keras==2.13.1 libclang==16.0.0 Markdown==3.4.3 MarkupSafe==2.1.3 -ml-dtypes==0.2.0 mmh3==4.0.0 -mock==5.0.2 +mock==5.1.0 nltk==3.8.1 nose==1.3.7 -numpy==1.23.5 +numpy==1.24.3 oauth2client==4.1.3 oauthlib==3.2.2 objsize==0.6.1 opt-einsum==3.3.0 -orjson==3.9.1 +orjson==3.9.2 overrides==6.5.0 packaging==23.1 pandas==1.5.3 parameterized==0.9.0 pluggy==1.2.0 proto-plus==1.22.3 -protobuf==4.23.3 +protobuf==4.23.4 psycopg2-binary==2.9.6 pyarrow==11.0.0 pyasn1==0.5.0 @@ -118,7 +116,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 -pymongo==4.4.0 +pymongo==4.4.1 PyMySQL==1.1.0 pyparsing==3.1.0 pytest==7.4.0 @@ -133,30 +131,30 @@ requests==2.31.0 requests-mock==1.11.0 requests-oauthlib==1.3.1 rsa==4.9 -scikit-learn==1.2.2 -scipy==1.11.0 +scikit-learn==1.3.0 +scipy==1.11.1 Shapely==1.8.5.post1 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.4.1 -SQLAlchemy==1.4.48 +SQLAlchemy==1.4.49 sqlparse==0.4.4 tenacity==8.2.2 -tensorboard==2.12.3 +tensorboard==2.13.0 tensorboard-data-server==0.7.1 -tensorflow==2.12.0 -tensorflow-estimator==2.12.0 +tensorflow==2.13.0 +tensorflow-estimator==2.13.0 tensorflow-io-gcs-filesystem==0.32.0 termcolor==2.3.0 testcontainers==3.7.1 -threadpoolctl==3.1.0 +threadpoolctl==3.2.0 tomli==2.0.1 tqdm==4.65.0 -typing_extensions==4.6.3 +typing_extensions==4.5.0 uritemplate==4.1.1 urllib3==1.26.16 websocket-client==1.6.1 Werkzeug==2.3.6 -wrapt==1.14.1 -zipp==3.15.0 +wrapt==1.15.0 +zipp==3.16.1 zstandard==0.21.0 diff --git a/sdks/python/scripts/run_integration_test.sh b/sdks/python/scripts/run_integration_test.sh index 508d9f50421e..4f29ed5a4ad9 100755 --- a/sdks/python/scripts/run_integration_test.sh +++ b/sdks/python/scripts/run_integration_test.sh @@ -133,16 +133,6 @@ case $key in shift # past argument shift # past value ;; - --runner_v2) - RUNNER_V2="$2" - shift # past argument - shift # past value - ;; - --disable_runner_v2) - DISABLE_RUNNER_V2="$2" - shift # past argument - shift # past value - ;; --kms_key_name) KMS_KEY_NAME="$2" shift # past argument @@ -244,23 +234,6 @@ if [[ -z $PIPELINE_OPTS ]]; then opts+=("--streaming") fi - # Add --runner_v2 if provided - if [[ "$RUNNER_V2" = true ]]; then - opts+=("--experiments=use_runner_v2") - if [[ "$STREAMING" = true ]]; then - # Dataflow Runner V2 only supports streaming engine. - opts+=("--enable_streaming_engine") - else - opts+=("--experiments=beam_fn_api") - fi - - fi - - # Add --disable_runner_v2 if provided - if [[ "$DISABLE_RUNNER_V2" = true ]]; then - opts+=("--experiments=disable_runner_v2") - fi - if [[ ! -z "$KMS_KEY_NAME" ]]; then opts+=( "--kms_key_name=$KMS_KEY_NAME" diff --git a/sdks/python/test-suites/dataflow/build.gradle b/sdks/python/test-suites/dataflow/build.gradle index 50d35774ffc3..08f03e207f32 100644 --- a/sdks/python/test-suites/dataflow/build.gradle +++ b/sdks/python/test-suites/dataflow/build.gradle @@ -42,14 +42,14 @@ task chicagoTaxiExample { } } -task validatesRunnerBatchTestsV2 { - getVersionsAsList('dataflow_validates_runner_batch_tests_V2').each { +task validatesRunnerBatchTests { + getVersionsAsList('dataflow_validates_runner_batch_tests').each { dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:validatesRunnerBatchTests") } } -task validatesRunnerStreamingTestsV2 { - getVersionsAsList('dataflow_validates_runner_streaming_tests_V2').each { +task validatesRunnerStreamingTests { + getVersionsAsList('dataflow_validates_runner_streaming_tests').each { dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:validatesRunnerStreamingTests") } } diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 44257b09c01a..12c440131855 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -181,7 +181,6 @@ task examples { def argMap = [ "test_opts": testOpts + ["--numprocesses=8", "--dist=loadfile"], "sdk_location": project.ext.sdkLocation, - "runner_v2": "true", "suite": "postCommitIT-df${pythonVersionSuffix}-xdist", "collect": "examples_postcommit and not no_xdist and not sickbay_dataflow" ] @@ -197,7 +196,6 @@ task examples { def argMap = [ "test_opts": testOpts, "sdk_location": project.ext.sdkLocation, - "runner_v2": "true", "suite": "postCommitIT-df${pythonVersionSuffix}-no-xdist", "collect": "examples_postcommit and no_xdist and not sickbay_dataflow" ] @@ -220,13 +218,6 @@ task validatesRunnerBatchTests { "collect": "it_validatesrunner and not no_sickbay_batch" ] - if (project.hasProperty('useRunnerV2')) { - argMap.put("runner_v2", "true") - } - - if (project.hasProperty('disableRunnerV2')) { - argMap.put("disable_runner_v2", "true") - } def cmdArgs = mapToArgString(argMap) exec { executable 'sh' @@ -247,7 +238,6 @@ task validatesRunnerStreamingTests { "sdk_location": project.ext.sdkLocation, "suite": "validatesRunnerStreamingTests-df${pythonVersionSuffix}-xdist", "collect": "it_validatesrunner and not no_sickbay_streaming and not no_xdist", - "runner_v2": "true", ] def cmdArgs = mapToArgString(argMap) @@ -265,7 +255,6 @@ task validatesRunnerStreamingTests { "sdk_location": project.ext.sdkLocation, "suite": "validatesRunnerStreamingTests-df${pythonVersionSuffix}-noxdist", "collect": "it_validatesrunner and not no_sickbay_streaming and no_xdist", - "runner_v2": "true", ] def cmdArgs = mapToArgString(argMap) diff --git a/sdks/python/test-suites/gradle.properties b/sdks/python/test-suites/gradle.properties index af3b16d2e30e..72fc651733db 100644 --- a/sdks/python/test-suites/gradle.properties +++ b/sdks/python/test-suites/gradle.properties @@ -28,8 +28,8 @@ dataflow_mongodbio_it_task_py_versions=3.8 dataflow_chicago_taxi_example_task_py_versions=3.8 # TODO: Enable following tests after making sure we have enough capacity. -dataflow_validates_runner_batch_tests_V2=3.8,3.11 -dataflow_validates_runner_streaming_tests_V2=3.8,3.11 +dataflow_validates_runner_batch_tests=3.8,3.11 +dataflow_validates_runner_streaming_tests=3.8,3.11 dataflow_examples_postcommit_py_versions=3.11 # TFX_BSL is not yet supported on Python 3.10. dataflow_cloudml_benchmark_tests_py_versions=3.9 diff --git a/sdks/typescript/package-lock.json b/sdks/typescript/package-lock.json index 12c01ae96a4c..22cb6c1c5b15 100644 --- a/sdks/typescript/package-lock.json +++ b/sdks/typescript/package-lock.json @@ -4063,9 +4063,9 @@ } }, "node_modules/word-wrap": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", - "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.4.tgz", + "integrity": "sha512-2V81OA4ugVo5pRo46hAoD2ivUJx8jXmWXfUkY4KFNw0hEptvN0QfH3K4nHiwzGeKl5rFKedV48QVoqYavy4YpA==", "dev": true, "engines": { "node": ">=0.10.0" @@ -7137,9 +7137,9 @@ } }, "word-wrap": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", - "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.4.tgz", + "integrity": "sha512-2V81OA4ugVo5pRo46hAoD2ivUJx8jXmWXfUkY4KFNw0hEptvN0QfH3K4nHiwzGeKl5rFKedV48QVoqYavy4YpA==", "dev": true }, "wordwrap": { diff --git a/sdks/typescript/package.json b/sdks/typescript/package.json index 5bd7bfbcbbe7..18be237f816f 100644 --- a/sdks/typescript/package.json +++ b/sdks/typescript/package.json @@ -56,8 +56,9 @@ "main": "./dist/src/apache_beam/index.js", "exports": { ".": "./dist/src/apache_beam/index.js", - "./transforms": "./dist/src/apache_beam/transforms/index.js", + "./io": "./dist/src/apache_beam/io/index.js", "./runners": "./dist/src/apache_beam/index.js", + "./transforms": "./dist/src/apache_beam/transforms/index.js", "./*": "./dist/src/apache_beam/*.js" } } diff --git a/website/www/site/config.toml b/website/www/site/config.toml index 406011422bd2..e57bdbd280e8 100644 --- a/website/www/site/config.toml +++ b/website/www/site/config.toml @@ -104,7 +104,7 @@ github_project_repo = "https://github.com/apache/beam" [params] description = "Apache Beam is an open source, unified model and set of language-specific SDKs for defining and executing data processing workflows, and also data ingestion and integration flows, supporting Enterprise Integration Patterns (EIPs) and Domain Specific Languages (DSLs). Dataflow pipelines simplify the mechanics of large-scale batch and streaming data processing and can run on a number of runtimes like Apache Flink, Apache Spark, and Google Cloud Dataflow (a cloud service). Beam also brings DSL in different languages, allowing users to easily implement their data integration processes." -release_latest = "2.48.0" +release_latest = "2.49.0" # The repository and branch where the files live in Github or Colab. This is used # to serve and stage from your local branch, but publish to the master branch. # e.g. https://github.com/{{< param branch_repo >}}/path/to/notebook.ipynb diff --git a/website/www/site/content/en/blog/beam-2.33.0.md b/website/www/site/content/en/blog/beam-2.33.0.md index 8e7237591d1b..d2afd7a79bcd 100644 --- a/website/www/site/content/en/blog/beam-2.33.0.md +++ b/website/www/site/content/en/blog/beam-2.33.0.md @@ -79,6 +79,9 @@ notes](https://issues.apache.org/jira/secure/ReleaseNote.jspa?projectId=12319527 * Spark 2.x users will need to update Spark's Jackson runtime dependencies (`spark.jackson.version`) to at least version 2.9.2, due to Beam updating its dependencies. * See a full list of open [issues that affect](https://issues.apache.org/jira/issues/?jql=project%20%3D%20BEAM%20AND%20affectedVersion%20%3D%202.33.0%20ORDER%20BY%20priority%20DESC%2C%20updated%20DESC) this version. * Go SDK jobs may produce "Failed to deduce Step from MonitoringInfo" messages following successful job execution. The messages are benign and don't indicate job failure. These are due to not yet handling PCollection metrics. +* Large Java BigQueryIO writes with the FILE_LOADS method will fail in batch mode (specifically, when copy jobs are used). + This results in the error message: `IllegalArgumentException: Attempting to access unknown side input`. + Please upgrade to a newer version (> 2.34.0) or use another write method (e.g. `STORAGE_WRITE_API`). ## List of Contributors diff --git a/website/www/site/content/en/blog/beam-2.34.0.md b/website/www/site/content/en/blog/beam-2.34.0.md index 2f335a86fd83..288496aaa809 100644 --- a/website/www/site/content/en/blog/beam-2.34.0.md +++ b/website/www/site/content/en/blog/beam-2.34.0.md @@ -61,6 +61,12 @@ notes](https://issues.apache.org/jira/secure/ReleaseNote.jspa?projectId=12319527 * Fixed error when importing the DataFrame API with pandas 1.0.x installed ([BEAM-12945](https://issues.apache.org/jira/browse/BEAM-12945)). * Fixed top.SmallestPerKey implementation in the Go SDK ([BEAM-12946](https://issues.apache.org/jira/browse/BEAM-12946)). +### Known Issues + +* Large Java BigQueryIO writes with the FILE_LOADS method will fail in batch mode (specifically, when copy jobs are used). + This results in the error message: `IllegalArgumentException: Attempting to access unknown side input`. + Please upgrade to a newer version (> 2.34.0) or use another write method (e.g. `STORAGE_WRITE_API`). + ## List of Contributors According to git shortlog, the following people contributed to the 2.34.0 release. Thank you to all contributors! diff --git a/website/www/site/content/en/blog/beam-2.49.0.md b/website/www/site/content/en/blog/beam-2.49.0.md new file mode 100644 index 000000000000..621637d655f6 --- /dev/null +++ b/website/www/site/content/en/blog/beam-2.49.0.md @@ -0,0 +1,221 @@ +--- +title: "Apache Beam 2.49.0" +date: 2023-07-17 09:00:00 -0400 +categories: + - blog + - release +authors: + - yhu +--- + + +We are happy to present the new 2.49.0 release of Beam. +This release includes both improvements and new functionality. +See the [download page](/get-started/downloads/#2490-2023-07-17) for this release. + + + +For more information on changes in 2.49.0, check out the [detailed release notes](https://github.com/apache/beam/milestone/13). + +## I/Os + +* Support for Bigtable Change Streams added in Java `BigtableIO.ReadChangeStream` ([#27183](https://github.com/apache/beam/issues/27183)). +* Added Bigtable Read and Write cross-language transforms to Python SDK (([#26593](https://github.com/apache/beam/issues/26593)), ([#27146](https://github.com/apache/beam/issues/27146))). + +## New Features / Improvements + +* Allow prebuilding large images when using `--prebuild_sdk_container_engine=cloud_build`, like images depending on `tensorflow` or `torch` ([#27023](https://github.com/apache/beam/pull/27023)). +* Disabled `pip` cache when installing packages on the workers. This reduces the size of prebuilt Python container images ([#27035](https://github.com/apache/beam/pull/27035)). +* Select dedicated avro datum reader and writer (Java) ([#18874](https://github.com/apache/beam/issues/18874)). +* Timer API for the Go SDK (Go) ([#22737](https://github.com/apache/beam/issues/22737)). + + +## Deprecations + +* Remove Python 3.7 support. ([#26447](https://github.com/apache/beam/issues/26447)) + +## Bugfixes + +* Fixed KinesisIO `NullPointerException` when a progress check is made before the reader is started (IO) ([#23868](https://github.com/apache/beam/issues/23868)) + +### Known Issues + + +## List of Contributors + +According to git shortlog, the following people contributed to the 2.49.0 release. Thank you to all contributors! + +Abzal Tuganbay + +AdalbertMemSQL + +Ahmed Abualsaud + +Ahmet Altay + +Alan Zhang + +Alexey Romanenko + +Anand Inguva + +Andrei Gurau + +Arwin Tio + +Bartosz Zablocki + +Bruno Volpato + +Burke Davison + +Byron Ellis + +Chamikara Jayalath + +Charles Rothrock + +Chris Gavin + +Claire McGinty + +Clay Johnson + +Damon + +Daniel Dopierała + +Danny McCormick + +Darkhan Nausharipov + +David Cavazos + +Dip Patel + +Dmitry Repin + +Gavin McDonald + +Jack Dingilian + +Jack McCluskey + +James Fricker + +Jan Lukavský + +Jasper Van den Bossche + +John Casey + +John Gill + +Joseph Crowley + +Kanishk Karanawat + +Katie Liu + +Kenneth Knowles + +Kyle Galloway + +Liam Miller-Cushon + +MakarkinSAkvelon + +Masato Nakamura + +Mattie Fu + +Michel Davit + +Naireen Hussain + +Nathaniel Young + +Nelson Osacky + +Nick Li + +Oleh Borysevych + +Pablo Estrada + +Reeba Qureshi + +Reuven Lax + +Ritesh Ghorse + +Robert Bradshaw + +Robert Burke + +Rouslan + +Saadat Su + +Sam Rohde + +Sam Whittle + +Sanil Jain + +Shunping Huang + +Smeet nagda + +Svetak Sundhar + +Timur Sultanov + +Udi Meiri + +Valentyn Tymofieiev + +Vlado Djerek + +WuA + +XQ Hu + +Xianhua Liu + +Xinyu Liu + +Yi Hu + +Zachary Houfek + +alexeyinkin + +bigduu + +bullet03 + +bzablocki + +jonathan-lemos + +jubebo + +magicgoody + +ruslan-ikhsan + +sultanalieva-s + +vitaly.terentyev + diff --git a/website/www/site/content/en/contribute/release-guide.md b/website/www/site/content/en/contribute/release-guide.md index 7bca90daefd3..3b0a0cb2c9a4 100644 --- a/website/www/site/content/en/contribute/release-guide.md +++ b/website/www/site/content/en/contribute/release-guide.md @@ -318,7 +318,7 @@ There are 2 ways to perform this verification, either running automation script( ``` cd beam/release/src/main/scripts && ./verify_release_build.sh ``` - 1. Trigger `beam_Release_Gradle_Build` and all Jenkins PostCommit jobs from the PR created by the previous step. + 1. Trigger all Jenkins PostCommit jobs from the PR created by the previous step. You can run [mass_comment.py](https://github.com/apache/beam/blob/master/release/src/main/scripts/mass_comment.py) to do that. Or manually add one trigger phrase per PR comment. See [jenkins_jobs.txt](https://github.com/apache/beam/blob/master/release/src/main/scripts/jenkins_jobs.txt) @@ -328,9 +328,6 @@ There are 2 ways to perform this verification, either running automation script( 1. Installs ```hub``` with your agreement and setup local git repo; 1. Create a test PR against release branch; -The [`beam_Release_Gradle_Build`](https://ci-beam.apache.org/job/beam_Release_Gradle_Build/) Jenkins job runs `./gradlew build -PisRelease`. -This only verifies that everything builds with unit tests passing. - #### Verify the build succeeds * Tasks you need to do manually to __verify the build succeed__: diff --git a/website/www/site/content/en/documentation/io/connectors.md b/website/www/site/content/en/documentation/io/connectors.md index 83c53f165b51..4524dbc0c975 100644 --- a/website/www/site/content/en/documentation/io/connectors.md +++ b/website/www/site/content/en/documentation/io/connectors.md @@ -79,7 +79,7 @@ This table provides a consolidated, at-a-glance overview of the available built- ✔ - TextIO + TextIO (metrics) ✔ ✔ @@ -200,7 +200,7 @@ This table provides a consolidated, at-a-glance overview of the available built- ✘ - GcsFileSystem + GcsFileSystem (metrics) ✔ ✔ @@ -522,7 +522,7 @@ This table provides a consolidated, at-a-glance overview of the available built- ✔ - BigQueryIO (guide) + BigQueryIO (guide) (metrics) ✔ ✔ @@ -545,7 +545,7 @@ This table provides a consolidated, at-a-glance overview of the available built- ✔ - BigTableIO + BigTableIO (metrics) ✔ ✔ @@ -554,7 +554,10 @@ This table provides a consolidated, at-a-glance overview of the available built- ✔ - native + native (sink) +
+ ✔ + via X-language Not available Not available diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index 30bdb7247d49..0427e50e0b19 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -2568,9 +2568,7 @@ Timers and States are explained in more detail in the {{< paragraph class="language-go">}} **Timer and State:** -User defined State parameters can be used in a stateful DoFn. Timers aren't implemented in the Go SDK yet; -see more at [Issue 22737](https://github.com/apache/beam/issues/22737). Once implemented, user defined Timer -parameters can be used in a stateful DoFn. +User defined State and Timer parameters can be used in a stateful DoFn. Timers and States are explained in more detail in the [Timely (and Stateful) Processing with Apache Beam](/blog/2017/08/28/timely-processing.html) blog post. {{< /paragraph >}} @@ -6147,7 +6145,7 @@ _ = (p | 'Read per user' >> ReadPerUser() {{< /highlight >}} {{< highlight go >}} -{{< code_sample "sdks/go/examples/snippets/04transforms.go" state_and_timers >}} +{{< code_sample "sdks/go/examples/snippets/04transforms.go" bag_state >}} {{< /highlight >}} ### 11.2. Deferred state reads {#deferred-state-reads} @@ -6270,7 +6268,7 @@ _ = (p | 'Read per user' >> ReadPerUser() {{< /highlight >}} {{< highlight go >}} -This is not supported yet, see https://github.com/apache/beam/issues/22737. +{{< code_sample "sdks/go/examples/snippets/04transforms.go" event_time_timer >}} {{< /highlight >}} #### 11.3.2. Processing-time timers {#processing-time-timers} @@ -6322,7 +6320,7 @@ _ = (p | 'Read per user' >> ReadPerUser() {{< /highlight >}} {{< highlight go >}} -This is not supported yet, see https://github.com/apache/beam/issues/22737. +{{< code_sample "sdks/go/examples/snippets/04transforms.go" processing_time_timer >}} {{< /highlight >}} #### 11.3.3. Dynamic timer tags {#dynamic-timer-tags} @@ -6383,7 +6381,7 @@ _ = (p | 'Read per user' >> ReadPerUser() {{< /highlight >}} {{< highlight go >}} -This is not supported yet, see https://github.com/apache/beam/issues/22737. +{{< code_sample "sdks/go/examples/snippets/04transforms.go" dynamic_timer_tags >}} {{< /highlight >}} #### 11.3.4. Timer output timestamps {#timer-output-timestamps} @@ -6435,6 +6433,10 @@ perUser.apply(ParDo.of(new DoFn, OutputT>() { })); {{< /highlight >}} +{{< highlight go >}} +{{< code_sample "sdks/go/examples/snippets/04transforms.go" timer_output_timestamps_bad >}} +{{< /highlight >}} + The problem with this code is that the ParDo is buffering elements, however nothing is preventing the watermark from advancing past the timestamp of those elements, so all those elements might be dropped as late data. In order to prevent this from happening, an output timestamp needs to be set on the timer to prevent the watermark from advancing @@ -6471,7 +6473,7 @@ perUser.apply(ParDo.of(new DoFn, OutputT>() { ? Instant.now().plus(Duration.standardMinutes(1)) : new Instant(timerTimestampMs); // Setting the outputTimestamp to the minimum timestamp in the bag holds the watermark to that timestamp until the // timer fires. This allows outputting all the elements with their timestamp. - timer.withOutputTimestamp(minTimestamp.read()).set(timerToSet). + timer.withOutputTimestamp(minTimestamp.read()).s et(timerToSet). timerTimestamp.write(timerToSet.getMillis()); } @@ -6494,7 +6496,7 @@ Timer output timestamps is not yet supported in Python SDK. See https://github.c {{< /highlight >}} {{< highlight go >}} -This is not supported yet, see https://github.com/apache/beam/issues/22737. +{{< code_sample "sdks/go/examples/snippets/04transforms.go" timer_output_timestamps_good >}} {{< /highlight >}} ### 11.4. Garbage collecting state {#garbage-collecting-state} @@ -6624,7 +6626,7 @@ _ = (p | 'Read per user' >> ReadPerUser() {{< /highlight >}} {{< highlight go >}} -This is not supported yet, see https://github.com/apache/beam/issues/22737. +{{< code_sample "sdks/go/examples/snippets/04transforms.go" timer_garbage_collection >}} {{< /highlight >}} ### 11.5. State and timers examples {#state-timers-examples} @@ -6764,6 +6766,11 @@ _ = (p | 'EventsPerLinkId' >> ReadPerLinkEvents() | 'Join DoFn' >> beam.ParDo(JoinDoFn())) {{< /highlight >}} + +{{< highlight go >}} +{{< code_sample "sdks/go/examples/snippets/04transforms.go" join_dofn_example >}} +{{< /highlight >}} + #### 11.5.2. Batching RPCs {#batching-rpcs} In this example, input elements are being forwarded to an external RPC service. The RPC accepts batch requests - @@ -6830,6 +6837,12 @@ class BufferDoFn(DoFn): {{< /highlight >}} + +{{< highlight go >}} +{{< code_sample "sdks/go/examples/snippets/04transforms.go" batching_dofn_example >}} +{{< /highlight >}} + + ## 12. Splittable `DoFns` {#splittable-dofns} A Splittable `DoFn` (SDF) enables users to create modular components containing I/Os (and some advanced diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 1549857068fa..64a6e9bade9b 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -325,235 +325,84 @@ To find out which version of Flink is compatible with Beam please see the table - + - + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + - - - - - - - - - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + - +
Beam Version Flink Version Artifact IdSupported Beam Versions
≥ 2.47.0 1.16.x beam-runners-flink-1.16≥ 2.47.0
1.15.x beam-runners-flink-1.15≥ 2.40.0
1.14.x beam-runners-flink-1.14≥ 2.38.0
1.13.x beam-runners-flink-1.13≥ 2.31.0
1.12.x beam-runners-flink-1.12
2.40.0 - 2.46.01.15.xbeam-runners-flink-1.15
1.14.xbeam-runners-flink-1.14
1.13.xbeam-runners-flink-1.13
1.12.xbeam-runners-flink-1.12
2.39.01.14.xbeam-runners-flink-1.14
1.13.xbeam-runners-flink-1.13
1.12.xbeam-runners-flink-1.12
2.38.01.14.xbeam-runners-flink-1.14
1.13.xbeam-runners-flink-1.13
1.12.xbeam-runners-flink-1.12
1.11.xbeam-runners-flink-1.11
2.31.0 - 2.37.01.13.xbeam-runners-flink-1.13
1.12.xbeam-runners-flink-1.12
1.11.xbeam-runners-flink-1.11
2.30.01.12.xbeam-runners-flink-1.12
1.11.xbeam-runners-flink-1.11
1.10.xbeam-runners-flink-1.10
2.27.0 - 2.29.01.12.xbeam-runners-flink-1.12≥ 2.27.0
1.11.x beam-runners-flink-1.112.25.0 - 2.38.0
1.10.x beam-runners-flink-1.102.21.0 - 2.30.0
1.9.x beam-runners-flink-1.92.17.0 - 2.29.0
1.8.x beam-runners-flink-1.8
2.25.0 - 2.26.01.11.xbeam-runners-flink-1.11
1.10.xbeam-runners-flink-1.10
1.9.xbeam-runners-flink-1.9
1.8.xbeam-runners-flink-1.8
2.21.0 - 2.24.01.10.xbeam-runners-flink-1.10
1.9.xbeam-runners-flink-1.9
1.8.xbeam-runners-flink-1.8
2.17.0 - 2.20.01.9.xbeam-runners-flink-1.9
1.8.xbeam-runners-flink-1.8
1.7.xbeam-runners-flink-1.7
2.13.0 - 2.16.01.8.xbeam-runners-flink-1.82.13.0 - 2.29.0
1.7.x beam-runners-flink-1.72.10.0 - 2.20.0
1.6.x beam-runners-flink-1.62.10.0 - 2.16.0
1.5.x beam-runners-flink_2.112.6.0 - 2.16.0
2.10.0 - 2.16.01.7.xbeam-runners-flink-1.7
1.6.xbeam-runners-flink-1.6
1.5.x1.4.x with Scala 2.11 beam-runners-flink_2.112.3.0 - 2.5.0
2.9.01.5.xbeam-runners-flink_2.11
2.8.0
2.7.0
2.6.0
2.5.01.4.x with Scala 2.11beam-runners-flink_2.11
2.4.0
2.3.0
2.2.01.3.x with Scala 2.10beam-runners-flink_2.10
2.1.x1.3.x with Scala 2.10beam-runners-flink_2.102.1.x - 2.2.0
2.0.0 1.2.x with Scala 2.10 beam-runners-flink_2.102.0.0
diff --git a/website/www/site/content/en/documentation/runners/spark.md b/website/www/site/content/en/documentation/runners/spark.md index 0b3075061e44..dcc166873dc2 100644 --- a/website/www/site/content/en/documentation/runners/spark.md +++ b/website/www/site/content/en/documentation/runners/spark.md @@ -67,8 +67,8 @@ the portable Runner. For more information on portability, please visit the ## Spark Runner prerequisites and setup -The Spark runner currently supports Spark's 3.1.x branch. -> **Note:** Support for Spark 2.4.x was deprecated as of Beam 2.41.0 and finally dropped with the release of Beam 2.46.0. +The Spark runner currently supports Spark's 3.2.x branch. +> **Note:** Support for Spark 2.4.x was dropped with Beam 2.46.0. {{< paragraph class="language-java" >}} You can add a dependency on the latest version of the Spark runner by adding to your pom.xml the following: @@ -215,7 +215,7 @@ options = PipelineOptions([ "--job_endpoint=localhost:8099", "--environment_type=LOOPBACK" ]) -with beam.Pipeline(options=options) as p: +with beam.Pipeline(options) as p: ... {{< /highlight >}} @@ -243,7 +243,7 @@ See [here](/roadmap/portability/#sdk-harness-config) for details.) ### Running on Dataproc cluster (YARN backed) -To run Beam jobs written in Python, Go, and other supported languages, you can use the `SparkRunner` and `PortableRunner` as described on the Beam's [Spark Runner](/documentation/runners/spark/) page (also see [Portability Framework Roadmap](/roadmap/portability/)). +To run Beam jobs written in Python, Go, and other supported languages, you can use the `SparkRunner` and `PortableRunner` as described on the Beam's [Spark Runner](https://beam.apache.org/documentation/runners/spark/) page (also see [Portability Framework Roadmap](https://beam.apache.org/roadmap/portability/)). The following example runs a portable Beam job in Python from the Dataproc cluster's master node with Yarn backed. diff --git a/website/www/site/content/en/get-started/downloads.md b/website/www/site/content/en/get-started/downloads.md index 35acb147f110..a0762fd2ea93 100644 --- a/website/www/site/content/en/get-started/downloads.md +++ b/website/www/site/content/en/get-started/downloads.md @@ -96,10 +96,18 @@ versions denoted `0.x.y`. ## Releases +### 2.49.0 (2023-07-17) +Official [source code download](https://downloads.apache.org/beam/2.49.0/apache-beam-2.49.0-source-release.zip). +[SHA-512](https://downloads.apache.org/beam/2.49.0/apache-beam-2.49.0-source-release.zip.sha512). +[signature](https://downloads.apache.org/beam/2.49.0/apache-beam-2.49.0-source-release.zip.asc). + +[Release notes](https://github.com/apache/beam/releases/tag/v2.49.0) +[Blog post](/blog/beam-2.49.0). + ### 2.48.0 (2023-05-31) -Official [source code download](https://downloads.apache.org/beam/2.48.0/apache-beam-2.48.0-source-release.zip). -[SHA-512](https://downloads.apache.org/beam/2.48.0/apache-beam-2.48.0-source-release.zip.sha512). -[signature](https://downloads.apache.org/beam/2.48.0/apache-beam-2.48.0-source-release.zip.asc). +Official [source code download](https://archive.apache.org/dist/beam/2.48.0/apache-beam-2.48.0-source-release.zip). +[SHA-512](https://archive.apache.org/dist/beam/2.48.0/apache-beam-2.48.0-source-release.zip.sha512). +[signature](https://archive.apache.org/dist/beam/2.48.0/apache-beam-2.48.0-source-release.zip.asc). [Release notes](https://github.com/apache/beam/releases/tag/v2.48.0) [Blog post](/blog/beam-2.48.0). diff --git a/website/www/site/content/en/performance/_index.md b/website/www/site/content/en/performance/_index.md new file mode 100644 index 000000000000..f821b0f25084 --- /dev/null +++ b/website/www/site/content/en/performance/_index.md @@ -0,0 +1,40 @@ +--- +title: "Beam IO Performance" +--- + + + +# Beam IO Performance + +Various Beam pipelines measure characteristics of reading from and writing to +various IOs. + +# Available Metrics + +Various metrics were gathered using the Beam SDK +[Metrics API](/documentation/programming-guide/#metrics) +from a pipeline Job running on [Dataflow](/documentation/runners/dataflow/). + +See the [glossary](/performance/glossary) for a list of the metrics and their +definition. + +# Measured Beam IOs + +See the following pages for performance measures recorded when reading from and +writing to various Beam IOs. + +- [BigQuery](/performance/bigquery) +- [BigTable](/performance/bigtable) +- [TextIO](/performance/textio) \ No newline at end of file diff --git a/website/www/site/content/en/performance/bigquery/_index.md b/website/www/site/content/en/performance/bigquery/_index.md new file mode 100644 index 000000000000..5d46e1180970 --- /dev/null +++ b/website/www/site/content/en/performance/bigquery/_index.md @@ -0,0 +1,50 @@ +--- +title: "BigQuery Performance" +--- + + + +# BigQuery Performance + +The following graphs show various metrics when reading from and writing to +BigQuery. See the [glossary](/performance/glossary) for definitions. + +## Read + +### What is the estimated cost to read from BigQuery? + +{{< performance_looks io="bigquery" read_or_write="read" section="test_name" >}} + +### How has various metrics changed when reading from BigQuery for different Beam SDK versions? + +{{< performance_looks io="bigquery" read_or_write="read" section="version" >}} + +### How has various metrics changed over time when reading from BigQuery? + +{{< performance_looks io="bigquery" read_or_write="read" section="date" >}} + +## Write + +### What is the estimated cost to write to BigQuery? + +{{< performance_looks io="bigquery" read_or_write="write" section="test_name" >}} + +### How has various metrics changed when writing to BigQuery for different Beam SDK versions? + +{{< performance_looks io="bigquery" read_or_write="write" section="version" >}} + +### How has various metrics changed over time when writing to BigQuery? + +{{< performance_looks io="bigquery" read_or_write="write" section="date" >}} diff --git a/website/www/site/content/en/performance/bigtable/_index.md b/website/www/site/content/en/performance/bigtable/_index.md new file mode 100644 index 000000000000..d394528a05fa --- /dev/null +++ b/website/www/site/content/en/performance/bigtable/_index.md @@ -0,0 +1,50 @@ +--- +title: "BigTable Performance" +--- + + + +# BigTable Performance + +The following graphs show various metrics when reading from and writing to +BigTable. See the [glossary](/performance/glossary) for definitions. + +## Read + +### What is the estimated cost to read from BigTable? + +{{< performance_looks io="bigquery" read_or_write="read" section="test_name" >}} + +### How has various metrics changed when reading from BigTable for different Beam SDK versions? + +{{< performance_looks io="bigquery" read_or_write="read" section="version" >}} + +### How has various metrics changed over time when reading from BigTable? + +{{< performance_looks io="bigquery" read_or_write="read" section="date" >}} + +## Write + +### What is the estimated cost to write to BigTable? + +{{< performance_looks io="bigquery" read_or_write="write" section="test_name" >}} + +### How has various metrics changed when writing to BigTable for different Beam SDK versions? + +{{< performance_looks io="bigquery" read_or_write="write" section="version" >}} + +### How has various metrics changed over time when writing to BigTable? + +{{< performance_looks io="bigquery" read_or_write="write" section="date" >}} diff --git a/website/www/site/content/en/performance/glossary/_index.md b/website/www/site/content/en/performance/glossary/_index.md new file mode 100644 index 000000000000..6741d107ad33 --- /dev/null +++ b/website/www/site/content/en/performance/glossary/_index.md @@ -0,0 +1,47 @@ +--- +title: "Performance Glossary" +--- + + + +# Metrics Glossary + +The metrics glossary defines various metrics presented in the visualizations, +measured using the [Beam Metrics API](/documentation/programming-guide/#metrics) +from a pipeline Job running on [Dataflow](/documentation/runners/dataflow/). + +## AvgInputThroughputBytesPerSec + +The mean input throughput of the pipeline Job measured in bytes per second. + +## AvgInputThroughputElementsPerSec + +The mean elements input throughput per second of the pipeline Job. + +## AvgOutputThroughputBytesPerSec + +The mean output throughput of the pipeline Job measured in bytes per second. + +## AvgOutputThroughputElementsPerSec + +The mean elements output throughput per second of the pipeline Job. + +## EstimatedCost + +The estimated cost of the pipeline Job. + +## RunTime + +The time it took for the pipeline Job to run measured in seconds. diff --git a/website/www/site/content/en/performance/textio/_index.md b/website/www/site/content/en/performance/textio/_index.md new file mode 100644 index 000000000000..97c740ad057e --- /dev/null +++ b/website/www/site/content/en/performance/textio/_index.md @@ -0,0 +1,50 @@ +--- +title: "TextIO Performance" +--- + + + +# TextIO Performance + +The following graphs show various metrics when reading from or writing to Google Cloud Storage using +TextIO. See the [glossary](/performance/glossary) for definitions. + +## Read + +### What is the estimated cost of reading from Google Cloud Storage using TextIO? + +{{< performance_looks io="textio" read_or_write="read" section="test_name" >}} + +### How has various metrics changed when reading from Google Cloud Storage using TextIO for different Beam SDK versions? + +{{< performance_looks io="textio" read_or_write="read" section="version" >}} + +### How has various metrics changed over time when reading from Google Cloud Storage using TextIO? + +{{< performance_looks io="textio" read_or_write="read" section="date" >}} + +## Write + +### What is the estimated cost of writing to Google Cloud Storage using TextIO? + +{{< performance_looks io="textio" read_or_write="write" section="test_name" >}} + +### How has various metrics changed when writing to Google Cloud Storage using TextIO for different Beam SDK versions? + +{{< performance_looks io="textio" read_or_write="write" section="version" >}} + +### How has various metrics changed over time when writing to Google Cloud Storage using TextIO? + +{{< performance_looks io="textio" read_or_write="write" section="date" >}} diff --git a/website/www/site/data/authors.yml b/website/www/site/data/authors.yml index f9a2996ebe3f..0458bda2c963 100644 --- a/website/www/site/data/authors.yml +++ b/website/www/site/data/authors.yml @@ -265,4 +265,7 @@ sysede: linkedin: desyse riteshghorse: name: Ritesh Ghorse - email: riteshghorse@apache.org \ No newline at end of file + email: riteshghorse@apache.org +yhu: + name: Yi Hu + email: yhu@apache.org diff --git a/website/www/site/data/performance.yaml b/website/www/site/data/performance.yaml new file mode 100644 index 000000000000..dc375811c833 --- /dev/null +++ b/website/www/site/data/performance.yaml @@ -0,0 +1,108 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +host: https://storage.googleapis.com +path: public_looker_explores_us_a3853f40 +looks: + bigquery: + read: + folder: 30 + test_name: + - id: nwQxvsnQFdBPTk27pZYxjcGNm2rRfNJk + title: Runtime and Estimated Cost by Test Name + date: + - id: 7QKbMmgT5NgPH6RsDfVQmKMdjJysS37x + title: AvgOutputThroughputBytesPerSec by Date + - id: hSqCyCkHkh4whNZSpdVFwB8ksPWrbScb + title: AvgOutputThroughputElementsPerSec by Date + version: + - id: SghzxqjMb5QQwgZn7yHw4MnpGXBgXQft + title: AvgOutputThroughputBytesPerSec by Version + - id: QT6X4GTCnxxykSZ7tpXyBfBFCKMgdzs7 + title: AvgOutputThroughputElementsPerSec by Version + write: + folder: 31 + test_name: + - id: sHyRfwCfGCPqSwsTWXSsxfbHKH5hXjzc + title: Write BigQuery RunTime and EstimatedCost + date: + - id: p2qnxS58WdWdZTzJhXfSbkhxGyDXMYfY + title: AvgInputThroughputBytesPerSec by Date + - id: Bmg6sv8XPtcn3cGtgPtKP26Zj4cyd2XK + title: AvgInputThroughputElementsPerSec by Date + version: + - id: Q6QvktyfSFrSPq4TSDgQQyV9jttjwQJp + title: AvgInputThroughputBytesPerSec by Version + - id: ktfyZ7Th8yRFFwWgb4WSxngBpmSfz8xh + title: AvgInputThroughputElementsPerSec by Version + bigtable: + read: + folder: 32 + test_name: + - id: YQ2W3wdNnBXMgDgpzCRbmQWMHjyPvZny + title: Read BigTable RunTime and EstimatedCost + date: + - id: szvwNfPrwTtmRmMHWv3QFh6wTKxm26TF + title: AvgOutputThroughputBytesPerSec by Date + - id: ZRQw7np2mj35kQHcvgtNDwVGgP855QNF + title: AvgOutputThroughputElementsPerSec by Date + version: + - id: X7y2hQQkJnYY8ctnr2y7NnNx7xrMkjt2 + title: AvgOutputThroughputBytesPerSec by Version + - id: B3HwmTtn9nBfzwRC7s3WfQSjf7ckrfXx + title: AvgOutputThroughputElementsPerSec by Version + write: + folder: 33 + test_name: + - id: 2sC27RQwWy2MP9DXVHjbvTYSNFYpFvxj + title: Write BigTable RunTime and EstimatedCost + date: + - id: X22sDqD8krBQQ4mRXTFMmpKFvgkZ529g + title: AvgInputThroughputBytesPerSec by Date + - id: WqHVhrjQVrrjpczqgJB9s9k8fvt3Vc2d + title: AvgInputThroughputElementsPerSec by Date + version: + - id: j72shCBz6rJhQP8JwYB3JcVrhv9BWGpD + title: AvgInputThroughputBytesPerSec by Version + - id: 6DyhkfvcNfkJg53hwFnRzxqbJDdGb5t7 + title: AvgInputThroughputElementsPerSec by Version + textio: + read: + folder: 34 + test_name: + - id: SsfFgvwyMthnHjdRwBjWGNbvfNgym4wb + title: Read TextIO RunTime and EstimatedCost + date: + - id: Pr8MBG66JVgBmdQbDr6HXtRprjr3ybj6 + title: AvgOutputThroughputBytesPerSec by Date + - id: CTZYK4KVYM65Zjn6jWgDYYNBWckDYKRQ + title: AvgOutputThroughputElementsPerSec by Date + version: + - id: Dcvfh3XFZySrsmPY4Rm8NYyMg5QQRBF6 + title: AvgOutputThroughputBytesPerSec by Version + - id: dN8mTZsVZc7vGDYKJCT8S67BCXzVJT4s + title: AvgOutputThroughputElementsPerSec by Version + write: + folder: 35 + test_name: + - id: Mm2j2QPc2x4hZqYNSpZC4bDpH2fgvqvp + title: Write TextIO RunTime and EstimatedCost + date: + - id: J9VXhp3ry5zbPFFGsYNfDRypGNVNMbPV + title: AvgInputThroughputBytesPerSec by Date + - id: KwSCnyz75wMpXhGpZQ8FRp6cwt8pjhfD + title: AvgInputThroughputElementsPerSec by Date + version: + - id: VFXbPV9JGJxmNYnGsypQzH97RPDFjpPN + title: AvgInputThroughputBytesPerSec by Version + - id: fVVHhXCrHNgBG52TJsTjR8VbmWCCQnVN + title: AvgInputThroughputElementsPerSec by Version diff --git a/website/www/site/i18n/navbar/en.yaml b/website/www/site/i18n/navbar/en.yaml index d356d282671e..aeb1988183b0 100644 --- a/website/www/site/i18n/navbar/en.yaml +++ b/website/www/site/i18n/navbar/en.yaml @@ -48,3 +48,13 @@ translation: "About" - id: nav-connectors translation: "I/O Connectors" +- id: nav-performance-general + translation: "General" +- id: nav-performance-bigquery + translation: "BigQuery" +- id: nav-performance-bigtable + translation: "BigTable" +- id: nav-performance-textio + translation: "TextIO" +- id: nav-performance-glossary + translation: "Glossary" diff --git a/website/www/site/layouts/performance/baseof.html b/website/www/site/layouts/performance/baseof.html new file mode 100644 index 000000000000..06aa0031077f --- /dev/null +++ b/website/www/site/layouts/performance/baseof.html @@ -0,0 +1,40 @@ +{{/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. See accompanying LICENSE file. +*/}} + + + + + {{ partial "head.html" . }} + + +{{ partial "header.html" . }} +
+
+ + +
+ + + +
+ {{ .Content }} + {{ partial "feedback.html" . }} +
+
+{{ partial "footer.html" . }} + + \ No newline at end of file diff --git a/website/www/site/layouts/shortcodes/performance_looks.html b/website/www/site/layouts/shortcodes/performance_looks.html new file mode 100644 index 000000000000..1a4f1d0a578f --- /dev/null +++ b/website/www/site/layouts/shortcodes/performance_looks.html @@ -0,0 +1,30 @@ +{{/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. See accompanying LICENSE file. +*/}} + +
+ {{ $host := .Site.Data.performance.host }} + {{ $path := .Site.Data.performance.path }} + {{ $looks := .Site.Data.performance.looks }} + {{ $io := index $looks (.Get "io") }} + {{ $rw := index $io (.Get "read_or_write") }} + {{ $section := index $rw (.Get "section") }} + {{ range $section }} + {{ $folder := index $rw "folder" }} +

{{.title}}

+
+ {{.title}} +
+ {{ end }} +