diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d8130079..836e0136 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -@snowflakedb/snowpark-python-api +* @snowflakedb/snowcli diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 35c306fa..5e9823f2 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -24,164 +24,282 @@ jobs: name: Check linting runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.7' - - name: Display Python version - run: python -c "import sys; import os; print(\"\n\".join(os.environ[\"PATH\"].split(os.pathsep))); print(sys.version); print(sys.executable);" - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox + uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Upgrade and install tools + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default - name: Set PY - run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV - - uses: actions/cache@v1 + run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV + - uses: actions/cache@v4 with: path: ~/.cache/pre-commit key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} - - name: Run fix_lint - run: python -m tox -e fix_lint + - name: Run lint checks + run: hatch run check + + build-install: + name: Test package build and installation + runs-on: ubuntu-latest + needs: lint + strategy: + fail-fast: true + matrix: + hatch-env: [default, sa14] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Setup up Python + uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Upgrade and install tools + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + - name: Build package + run: | + python -m hatch -e ${{ matrix.hatch-env }} build --clean + - name: Install and check import + run: | + python -m uv pip install dist/snowflake_sqlalchemy-*.whl + python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" + + test-dialect: + name: Test dialect ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: [ lint, build-install ] + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and prepare environment + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run test for AWS + run: hatch run test-dialect-aws + if: matrix.cloud-provider == 'aws' + - name: Run tests + run: hatch run test-dialect + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-compatibility: + name: Test dialect compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: [ lint, build-install ] + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml - test: - name: Test ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint - runs-on: ${{ matrix.os.image_name }} - strategy: - fail-fast: false - matrix: - os: - - image_name: ubuntu-latest - download_name: manylinux_x86_64 - - image_name: macos-latest - download_name: macosx_x86_64 - - image_name: windows-2019 - download_name: win_amd64 - python-version: ["3.7"] - cloud-provider: [aws, azure, gcp] - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Setup parameters file - shell: bash - env: - PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} - run: | - gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ - .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: List installed packages - run: python -m pip freeze - - name: Run tests - run: python -m tox -e "py${PYTHON_VERSION/\./}" --skip-missing-interpreters false - env: - PYTHON_VERSION: ${{ matrix.python-version }} - PYTEST_ADDOPTS: -vvv --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 - - name: Combine coverages - run: python -m tox -e coverage --skip-missing-interpreters false - shell: bash - - uses: actions/upload-artifact@v2 - with: - name: coverage_${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - path: | - .tox/.coverage - .tox/coverage.xml + test-dialect-v14: + name: Test dialect v14 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: [ lint, build-install ] + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Run test for AWS + run: hatch run sa14:test-dialect-aws + if: matrix.cloud-provider == 'aws' + - name: Run tests + run: hatch run sa14:test-dialect + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v14-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml - test_connector_regression: - name: Connector Regression Test ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint - runs-on: ${{ matrix.os.image_name }} - strategy: - fail-fast: false - matrix: - os: - - image_name: ubuntu-latest - download_name: manylinux_x86_64 - python-version: ["3.7"] - cloud-provider: [aws] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Setup parameters file - shell: bash - env: - PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} - run: | - gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ - .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/connector_regression/test/parameters.py - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: List installed packages - run: python -m pip freeze - - name: Run tests - run: python -m tox -e connector_regression --skip-missing-interpreters false - env: - PYTEST_ADDOPTS: -vvv --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 + test-dialect-compatibility-v14: + name: Test dialect v14 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run sa14:test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v14-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml combine-coverage: - if: ${{ success() || failure() }} name: Combine coverage - needs: [test, test_connector_regression] + if: ${{ success() || failure() }} + needs: [test-dialect, test-dialect-compatibility, test-dialect-v14, test-dialect-compatibility-v14] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/download-artifact@v2 - with: - path: artifacts - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.7' - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Upgrade setuptools and pip - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: Collect all coverages to one dir + uses: actions/setup-python@v5 + with: + python-version: "3.8" + - name: Prepare environment + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + hatch env create default + - uses: actions/checkout@v4 + with: + persist-credentials: false + - uses: actions/download-artifact@v4 + with: + path: artifacts/ + - name: Combine coverage files run: | - python -c ' - from pathlib import Path - import shutil - src_dir = Path("artifacts") - dst_dir = Path(".") / ".tox" - dst_dir.mkdir() - for src_file in src_dir.glob("*/.coverage"): - dst_file = dst_dir / f".coverage.{src_file.parent.name[9:]}" - print(f"{src_file} copy to {dst_file}") - shutil.copy(str(src_file), str(dst_file))' - - name: Combine coverages - run: python -m tox -e coverage - - name: Publish html coverage - uses: actions/upload-artifact@v2 - with: - name: overall_cov_html - path: .tox/htmlcov - - name: Publish xml coverage - uses: actions/upload-artifact@v2 - with: - name: overall_cov_xml - path: .tox/coverage.xml - - uses: codecov/codecov-action@v1 - with: - file: .tox/coverage.xml + hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml + hatch run coverage report -m + - name: Store coverage reports + uses: actions/upload-artifact@v4 + with: + name: coverage.xml + path: coverage.xml + - name: Uplaod to codecov + uses: codecov/codecov-action@v4 + with: + file: coverage.xml + env_vars: OS,PYTHON + fail_ci_if_error: false + flags: unittests + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + url: https://snowflake.codecov.io/ diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 00000000..252405fd --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,21 @@ +name: Changelog Check + +on: + pull_request: + types: [opened, synchronize, labeled, unlabeled] + branches: + - main + +jobs: + check_change_log: + runs-on: ubuntu-latest + if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-CHANGELOG-UPDATES')}} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Ensure DESCRIPTION.md is updated + run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -wq "DESCRIPTION.md" diff --git a/.github/workflows/cla_bot.yml b/.github/workflows/cla_bot.yml index 5574667a..2c87fc92 100644 --- a/.github/workflows/cla_bot.yml +++ b/.github/workflows/cla_bot.yml @@ -8,6 +8,11 @@ on: jobs: CLAssistant: runs-on: ubuntu-latest + permissions: + actions: write + contents: write + pull-requests: write + statuses: write steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' @@ -17,7 +22,7 @@ jobs: PERSONAL_ACCESS_TOKEN : ${{ secrets.CLA_BOT_TOKEN }} with: path-to-signatures: 'signatures/version1.json' - path-to-document: 'https://github.com/Snowflake-Labs/CLA/blob/main/README.md' + path-to-document: 'https://github.com/snowflakedb/CLA/blob/main/README.md' branch: 'main' allowlist: 'dependabot[bot],github-actions, sfc-gh-snyk-sca-sa' remote-organization-name: 'snowflakedb' diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 20d5cede..2cb7a371 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -9,20 +9,22 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel + run: python -m pip install -U setuptools pip wheel uv - name: Install Snowflake SQLAlchemy shell: bash - run: python -m pip install . + run: python -m uv pip install . - name: Generate reqs file name shell: bash run: echo "requirements_file=temp_requirement/requirements_$(python -c 'from sys import version_info;print(str(version_info.major)+str(version_info.minor))').reqs" >> $GITHUB_ENV @@ -32,12 +34,12 @@ jobs: mkdir temp_requirement echo "# Generated on: $(python --version)" >${{ env.requirements_file }} python -m pip freeze | grep -v snowflake-sqlalchemy 1>>${{ env.requirements_file }} 2>/dev/null - echo "snowflake-sqlalchemy==$(python -m pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} + echo "snowflake-sqlalchemy==$(python -m uv pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} id: create-reqs-file - name: Show created req file shell: bash run: cat ${{ env.requirements_file }} - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: path: temp_requirement @@ -46,11 +48,12 @@ jobs: name: Commit and push files runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: + persist-credentials: false token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets - name: Download requirement files - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 with: name: artifact path: tested_requirements diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index dfcb8bc7..7862f483 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -9,14 +9,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/gh-actions ref: jira_v1 token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa..8533c14c 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} @@ -22,7 +22,7 @@ jobs: jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue - uses: atlassian/gajira-comment@master + uses: atlassian/gajira-comment@v3 if: startsWith(steps.extract.outputs.jira, 'SNOW-') with: issue: "${{ steps.extract.outputs.jira }}" diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 3683bbba..85c774ca 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -9,18 +9,21 @@ on: jobs: create-issue: runs-on: ubuntu-latest + permissions: + issues: write if: ((github.event_name == 'issue_comment' && github.event.comment.body == 'recreate jira' && github.event.comment.user.login == 'sfc-gh-mkeller') || (github.event_name == 'issues' && github.event.pull_request.user.login != 'whitesource-for-github-com[bot]')) steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/gh-actions ref: jira_v1 token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Login - uses: atlassian/gajira-login@v2.0.0 + uses: atlassian/gajira-login@v3 env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} @@ -28,7 +31,7 @@ jobs: - name: Create JIRA Ticket id: create - uses: atlassian/gajira-create@v2.0.1 + uses: atlassian/gajira-create@v3 with: project: SNOW issuetype: Bug diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index dd1e1ba6..52f43106 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,8 @@ on: types: [published] permissions: - contents: read + contents: write + id-token: write jobs: deploy: @@ -21,17 +22,59 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install build + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package - run: python -m build + run: python -m hatch build --clean + - name: List artifacts + run: ls ./dist + - name: Install sigstore + run: python -m pip install sigstore + - name: Signing + run: | + for dist in dist/*; do + dist_base="$(basename "${dist}")" + echo "dist: ${dist}" + echo "dist_base: ${dist_base}" + python -m \ + sigstore sign "${dist}" \ + --output-signature "${dist_base}.sig" \ + --output-certificate "${dist_base}.crt" \ + --bundle "${dist_base}.sigstore" + + # Verify using `.sig` `.crt` pair; + python -m \ + sigstore verify identity "${dist}" \ + --signature "${dist_base}.sig" \ + --cert "${dist_base}.crt" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} + + # Verify using `.sigstore` bundle; + python -m \ + sigstore verify identity "${dist}" \ + --bundle "${dist_base}.sigstore" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} + done + - name: List artifacts after sign + run: ls ./dist + - name: Copy files to release + run: | + gh release upload ${{ github.event.release.tag_name }} *.sigstore + gh release upload ${{ github.event.release.tag_name }} *.sig + gh release upload ${{ github.event.release.tag_name }} *.crt + env: + GITHUB_TOKEN: ${{ github.TOKEN }} - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/.github/workflows/snyk-issue.yml b/.github/workflows/snyk-issue.yml index 7098b01e..94dfeb53 100644 --- a/.github/workflows/snyk-issue.yml +++ b/.github/workflows/snyk-issue.yml @@ -4,6 +4,11 @@ on: schedule: - cron: '* */12 * * *' +permissions: + contents: read + issues: write + pull-requests: write + concurrency: snyk-issue jobs: @@ -11,13 +16,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Action - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/whitesource-actions token: ${{ secrets.whitesource_action_token }} path: whitesource-actions - name: Set Env - run: echo "repo=$(basename $github_repository)" >> $github_env + run: echo "repo=${{ github.event.repository.name }}" >> $GITHUB_ENV - name: Jira Creation uses: ./whitesource-actions/snyk-issue with: diff --git a/.github/workflows/snyk-pr.yml b/.github/workflows/snyk-pr.yml index 51e531f4..cc5e8644 100644 --- a/.github/workflows/snyk-pr.yml +++ b/.github/workflows/snyk-pr.yml @@ -3,20 +3,28 @@ on: pull_request: branches: - main + +permissions: + contents: read + issues: write + pull-requests: write + jobs: snyk: runs-on: ubuntu-latest if: ${{ github.event.pull_request.user.login == 'sfc-gh-snyk-sca-sa' }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - name: Checkout Action - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/whitesource-actions token: ${{ secrets.whitesource_action_token }} path: whitesource-actions diff --git a/.github/workflows/stale_issue_bot.yml b/.github/workflows/stale_issue_bot.yml new file mode 100644 index 00000000..4ee56ff8 --- /dev/null +++ b/.github/workflows/stale_issue_bot.yml @@ -0,0 +1,22 @@ +name: Close Stale Issues +on: + workflow_dispatch: + inputs: + staleDays: + required: true + + +jobs: + stale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v9 + with: + close-issue-message: 'To clean up and re-prioritize bugs and feature requests we are closing all issues older than 6 months as of Apr 1, 2023. If there are any issues or feature requests that you would like us to address, please re-create them. For urgent issues, opening a support case with this link [Snowflake Community](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge) is the fastest way to get a response' + days-before-issue-stale: ${{ inputs.staleDays }} + days-before-pr-stale: -1 + # Stale issues are closed immediately + days-before-issue-close: 0 + days-before-pr-close: -1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitmodules b/.gitmodules index 6f8702e6..e69de29b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "tests/connector_regression"] - path = tests/connector_regression - url = git@github.com:snowflakedb/snowflake-connector-python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83a5c0bd..b7370b74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,31 +1,32 @@ exclude: '^(.*egg.info.*|.*/parameters.py).*$' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: trailing-whitespace + exclude: '\.ambr$' - id: end-of-file-fixer - id: check-yaml exclude: .github/repo_meta.yaml - id: debug-statements - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.2.0 hooks: - id: black args: - --safe language_version: python3 - repo: https://github.com/Lucas-C/pre-commit-hooks.git - rev: v1.3.0 + rev: v1.5.5 hooks: - id: insert-license name: insert-py-license @@ -39,8 +40,15 @@ repos: - --license-filepath - license_header.txt - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: - flake8-bugbear +- repo: local + hooks: + - id: requirements-update + name: "Update dependencies from pyproject.toml to snyk/requirements.txt" + language: system + entry: python snyk/update_requirements.py + files: ^pyproject.toml$ diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 3ca790c1..909d52cf 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,60 @@ Source code is also available at: # Release Notes +- (Unreleased) + + - Add support for dynamic tables and required options + - Add support for hybrid tables + - Fixed SAWarning when registering functions with existing name in default namespace + - Update options to be defined in key arguments instead of arguments. + - Add support for refresh_mode option in DynamicTable + +- v1.6.1(July 9, 2024) + + - Update internal project workflow with pypi publishing + +- v1.6.0(July 8, 2024) + + - support for installing with SQLAlchemy 2.0.x + - use `hatch` & `uv` for managing project virtual environments + +- v1.5.4 + + - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY + +- v1.5.3(April 16, 2024) + + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + +- v1.5.2(April 11, 2024) + + - Bump min SQLAlchemy to 1.4.19 for outer lateral join + - Add support for sequence ordering in tests + +- v1.5.1(November 03, 2023) + + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check . + - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. + - This fixes other boolean parameter passing to driver as well. + +- v1.5.0(Aug 23, 2023) + + - Added option to create a temporary stage command. + - Added support for geometry type. + - Fixed a compatibility issue of regex expression with SQLAlchemy 1.4.49. + +- v1.4.7(Mar 22, 2023) + + - Re-applied the application name of driver connection `SnowflakeConnection` to `SnowflakeSQLAlchemy`. + - `SnowflakeDialect.get_columns` now throws a `NoSuchTableError` exception when the specified table doesn't exist, instead of the more vague `KeyError`. + - Fixed a bug that dialect can not be created with empty host name. + - Fixed a bug that `sqlalchemy.func.now` is not rendered correctly. + +- v1.4.6(Feb 8, 2023) + + - Bumped snowflake-connector-python dependency to newest version which supports Python 3.11. + - Reverted the change of application name introduced in v1.4.5 until support gets added. + - v1.4.5(Dec 7, 2022) - Updated the application name of driver connection `SnowflakeConnection` to `SnowflakeSQLAlchemy`. diff --git a/LICENSE.txt b/LICENSE.txt index 960fe780..f024f6ca 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright (c) 2012-2022 Snowflake Computing, Inc. + Copyright (c) 2012-2023 Snowflake Computing, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 0c75854e..c428353f 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ containing special characters need to be URL encoded to be parsed correctly. Thi characters could lead to authentication failure. The encoding for the password can be generated using `urllib.parse`: + ```python import urllib.parse urllib.parse.quote("kx@% jj5/g") @@ -111,6 +112,7 @@ urllib.parse.quote("kx@% jj5/g") To create an engine with the proper encodings, either manually constructing the url string by formatting or taking advantage of the `snowflake.sqlalchemy.URL` helper method: + ```python import urllib.parse from snowflake.sqlalchemy import URL @@ -191,14 +193,23 @@ engine = create_engine(...) engine.execute() engine.dispose() -# Do this. +# Better. engine = create_engine(...) connection = engine.connect() try: - connection.execute() + connection.execute(text()) finally: connection.close() engine.dispose() + +# Best +try: + with engine.connext() as connection: + connection.execute(text()) + # or + connection.exec_driver_sql() +finally: + engine.dispose() ``` ### Auto-increment Behavior @@ -242,14 +253,14 @@ engine = create_engine(URL( specific_date = np.datetime64('2016-03-04T12:03:05.123456789Z') -connection = engine.connect() -connection.execute( - "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") -connection.execute( - "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) -) -df = pd.read_sql_query("SELECT * FROM ts_tbl", engine) -assert df.c1.values[0] == specific_date +with engine.connect() as connection: + connection.exec_driver_sql( + "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") + connection.exec_driver_sql( + "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) + ) + df = pd.read_sql_query("SELECT * FROM ts_tbl", connection) + assert df.c1.values[0] == specific_date ``` The following `NumPy` data types are supported: diff --git a/ci/build.sh b/ci/build.sh index 4229506d..b63c8e01 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -3,7 +3,7 @@ # Build snowflake-sqlalchemy set -o pipefail -PYTHON="python3.7" +PYTHON="python3.8" THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" DIST_DIR="${SQLALCHEMY_DIR}/dist" @@ -11,14 +11,16 @@ DIST_DIR="${SQLALCHEMY_DIR}/dist" cd "$SQLALCHEMY_DIR" # Clean up previously built DIST_DIR if [ -d "${DIST_DIR}" ]; then - echo "[WARN] ${DIST_DIR} already existing, deleting it..." - rm -rf "${DIST_DIR}" + echo "[WARN] ${DIST_DIR} already existing, deleting it..." + rm -rf "${DIST_DIR}" fi # Constants and setup +export PATH=$PATH:$HOME/.local/bin echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -${PYTHON} -m pip install --upgrade pip setuptools wheel build -${PYTHON} -m build --outdir ${DIST_DIR} . +export UV_NO_CACHE=true +${PYTHON} -m pip install uv hatch +${PYTHON} -m hatch build diff --git a/ci/docker/sqlalchemy_build/Dockerfile b/ci/docker/sqlalchemy_build/Dockerfile index b31c4f50..745f4963 100644 --- a/ci/docker/sqlalchemy_build/Dockerfile +++ b/ci/docker/sqlalchemy_build/Dockerfile @@ -14,6 +14,6 @@ RUN chmod +x /usr/local/bin/entrypoint.sh WORKDIR /home/user RUN chmod 777 /home/user -ENV PATH="${PATH}:/opt/python/cp37-cp37m/bin:/opt/python/cp38-cp38/bin:/opt/python/cp39-cp39/bin:/opt/python/cp310-cp310/bin" +ENV PATH="${PATH}:/opt/python/cp37-cp37m/bin:/opt/python/cp38-cp38/bin:/opt/python/cp39-cp39/bin:/opt/python/cp310-cp310/bin:/opt/python/cp311-cp311/bin" ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] diff --git a/ci/test_linux.sh b/ci/test_linux.sh index 026e5c83..f5afc4fb 100755 --- a/ci/test_linux.sh +++ b/ci/test_linux.sh @@ -2,25 +2,24 @@ # # Test Snowflake SQLAlchemy in Linux # NOTES: -# - Versions to be tested should be passed in as the first argument, e.g: "3.7 3.8". If omitted 3.7-3.10 will be assumed. +# - Versions to be tested should be passed in as the first argument, e.g: "3.7 3.8". If omitted 3.7-3.11 will be assumed. # - This script assumes that ../dist/repaired_wheels has the wheel(s) built for all versions to be tested # - This is the script that test_docker.sh runs inside of the docker container -PYTHON_VERSIONS="${1:-3.7 3.8 3.9 3.10}" -THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SQLALCHEMY_DIR="$( dirname "${THIS_DIR}")" +PYTHON_VERSIONS="${1:-3.8 3.9 3.10 3.11}" +THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" # Install one copy of tox -python3 -m pip install -U tox-external-wheels -python3 -m pip install --pre tox +python3 -m pip install -U tox # Run tests cd $SQLALCHEMY_DIR for PYTHON_VERSION in ${PYTHON_VERSIONS}; do - echo "[Info] Testing with ${PYTHON_VERSION}" - SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") - SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py2.py3-none-any.whl | sort -r | head -n 1) - TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage - echo "[Info] Running tox for ${TEST_ENVLIST}" - python3 -m tox -e ${TEST_ENVLIST} --external_wheels ${SQLALCHEMY_WHL} + echo "[Info] Testing with ${PYTHON_VERSION}" + SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") + SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py3-none-any.whl | sort -r | head -n 1) + TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage + echo "[Info] Running tox for ${TEST_ENVLIST}" + python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} done diff --git a/license_header.txt b/license_header.txt index 7359dcd8..eea2d7b3 100644 --- a/license_header.txt +++ b/license_header.txt @@ -1,2 +1,2 @@ -Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..268d1ecd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,133 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "snowflake-sqlalchemy-decube" +dynamic = ["version"] +description = "Snowflake SQLAlchemy Dialect" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.8" +authors = [ + { name = "Snowflake Inc.", email = "triage-snowpark-python-api-dl@snowflake.com" }, +] +keywords = ["Snowflake", "analytics", "cloud", "database", "db", "warehouse"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Other Environment", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = ["SQLAlchemy>=1.4.19", "snowflake-connector-python<4.0.0"] + +[tool.hatch.version] +path = "src/snowflake/sqlalchemy/version.py" + +[project.optional-dependencies] +development = [ + "pre-commit", + "pytest", + "pytest-cov", + "pytest-timeout", + "pytest-rerunfailures", + "pytz", + "numpy", + "mock", + "syrupy==4.6.1", +] +pandas = ["snowflake-connector-python[pandas]"] + +[project.entry-points."sqlalchemy.dialects"] +snowflake = "snowflake.sqlalchemy:dialect" + +[project.urls] +Changelog = "https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/DESCRIPTION.md" +Documentation = "https://docs.snowflake.com/en/user-guide/sqlalchemy.html" +Homepage = "https://www.snowflake.com/" +Issues = "https://github.com/snowflakedb/snowflake-sqlalchemy/issues" +Source = "https://github.com/snowflakedb/snowflake-sqlalchemy" + +[tool.hatch.build.targets.sdist] +exclude = ["/.github"] + +[tool.hatch.build.targets.wheel] +packages = ["src/snowflake"] + +[tool.hatch.envs.default] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] +features = ["development", "pandas"] +python = "3.8" +installer = "uv" + +[tool.hatch.envs.sa14] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] +features = ["development", "pandas"] +python = "3.8" + +[tool.hatch.envs.default.env-vars] +COVERAGE_FILE = "coverage.xml" +SQLACHEMY_WARN_20 = "1" + +[tool.hatch.envs.default.scripts] +check = "pre-commit run --all-files" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" +check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" + +[[tool.hatch.envs.release.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +features = ["development", "pandas"] + +[tool.hatch.envs.release.scripts] +test-dialect = "pytest -ra -vvv --tb=short --ignore=tests/sqlalchemy_test_suite tests/" +test-compatibility = "pytest -ra -vvv --tb=short tests/sqlalchemy_test_suite tests/" + +[tool.ruff] +line-length = 88 + +[tool.black] +line-length = 88 + +[tool.pytest.ini_options] +addopts = "-m 'not feature_max_lob_size and not aws'" +markers = [ + # Optional dependency groups markers + "lambda: AWS lambda tests", + "pandas: tests for pandas integration", + "sso: tests for sso optional dependency integration", + # Cloud provider markers + "aws: tests for Amazon Cloud storage", + "azure: tests for Azure Cloud storage", + "gcp: tests for Google Cloud storage", + # Test type markers + "integ: integration tests", + "unit: unit tests", + "skipolddriver: skip for old driver tests", + # Other markers + "timeout: tests that need a timeout time", + "internal: tests that could but should only run on our internal CI", + "external: tests that could but should only run on our external CI", + "feature_max_lob_size: tests that could but should only run on our external CI", +] diff --git a/setup.cfg b/setup.cfg index 4f249666..7924cc57 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,80 +1,3 @@ -[bdist_wheel] -universal = 1 - -[metadata] -name = snowflake-sqlalchemy-decube -version = 1.4.8 -description = Snowflake SQLAlchemy Dialect -long_description = file: DESCRIPTION.md -long_description_content_type = text/markdown -url = https://www.snowflake.com/ -author = Snowflake, Inc -author_email = triage-snowpark-python-api-dl@snowflake.com -license = Apache-2.0 -license_files = LICENSE.txt -classifiers = - Development Status :: 5 - Production/Stable - Environment :: Console - Environment :: Other Environment - Intended Audience :: Developers - Intended Audience :: Education - Intended Audience :: Information Technology - Intended Audience :: System Administrators - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: SQL - Topic :: Database - Topic :: Scientific/Engineering :: Information Analysis - Topic :: Software Development - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Application Frameworks - Topic :: Software Development :: Libraries :: Python Modules -keywords = Snowflake db database cloud analytics warehouse -project_urls = - Documentation=https://docs.snowflake.com/en/user-guide/sqlalchemy.html - Source=https://github.com/snowflakedb/snowflake-sqlalchemy - Issues=https://github.com/snowflakedb/snowflake-sqlalchemy/issues - Changelog=https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/DESCRIPTION.md - -[options] -python_requires = >=3.7 -packages = find_namespace: -install_requires = - importlib-metadata;python_version<"3.8" - sqlalchemy<2.0.0,>=1.4.0 -; Keep in sync with extras dependency - snowflake-connector-python<3.0.0 -include_package_data = True -package_dir = - =src -zip_safe = False - -[options.packages.find] -where = src -include = snowflake.* - -[options.entry_points] -sqlalchemy.dialects = - snowflake=snowflake.sqlalchemy:dialect - -[options.extras_require] -development = - pytest - pytest-cov - pytest-rerunfailures - pytest-timeout - mock - pytz - numpy -pandas = - snowflake-connector-python[pandas]<3.0.0 - [sqla_testing] requirement_cls=snowflake.sqlalchemy.requirements:Requirements profile_file=tests/profiles.txt diff --git a/setup.py b/setup.py deleted file mode 100644 index 66a8ea8e..00000000 --- a/setup.py +++ /dev/null @@ -1,7 +0,0 @@ -# -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. -# - -from setuptools import setup - -setup() diff --git a/snyk/requirements.txt b/snyk/requirements.txt new file mode 100644 index 00000000..0166d751 --- /dev/null +++ b/snyk/requirements.txt @@ -0,0 +1,2 @@ +SQLAlchemy>=1.4.19 +snowflake-connector-python<4.0.0 diff --git a/snyk/requiremtnts.txt b/snyk/requiremtnts.txt new file mode 100644 index 00000000..a92c527e --- /dev/null +++ b/snyk/requiremtnts.txt @@ -0,0 +1,2 @@ +snowflake-connector-python<4.0.0 +SQLAlchemy>=1.4.19,<2.1.0 diff --git a/snyk/update_requirements.py b/snyk/update_requirements.py new file mode 100644 index 00000000..e0771fbd --- /dev/null +++ b/snyk/update_requirements.py @@ -0,0 +1,17 @@ +from pathlib import Path + +import tomlkit + + +def sync(): + pyproject = tomlkit.loads(Path("pyproject.toml").read_text()) + snyk_reqiurements = Path("snyk/requirements.txt") + dependencies = pyproject.get("project", {}).get("dependencies", []) + + with snyk_reqiurements.open("w") as fh: + fh.write("\n".join(dependencies)) + fh.write("\n") + + +if __name__ == "__main__": + sync() diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 2b2f11b8..9de05123 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import sys @@ -49,6 +49,7 @@ DOUBLE, FIXED, GEOGRAPHY, + GEOMETRY, NUMBER, OBJECT, STRING, @@ -60,6 +61,17 @@ VARBINARY, VARIANT, ) +from .sql.custom_schema import DynamicTable, HybridTable +from .sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + TimeUnit, +) from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -67,6 +79,7 @@ __version__ = importlib_metadata.version("snowflake-sqlalchemy-decube") __all__ = ( + # Custom Types "BIGINT", "BINARY", "BOOLEAN", @@ -90,6 +103,7 @@ "DOUBLE", "FIXED", "GEOGRAPHY", + "GEOMETRY", "OBJECT", "NUMBER", "STRING", @@ -100,6 +114,7 @@ "TINYINT", "VARBINARY", "VARIANT", + # Custom Commands "MergeInto", "CSVFormatter", "JSONFormatter", @@ -111,4 +126,17 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + # Custom Tables + "HybridTable", + "DynamicTable", + # Custom Table Options + "AsQueryOption", + "TargetLagOption", + "LiteralOption", + "IdentifierOption", + "KeywordOption", + # Enums + "TimeUnit", + "TableOptionKey", + "SnowflakeKeyword", ) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 2ad1ca94..839745ee 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -1,7 +1,7 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -import pkg_resources +from .version import VERSION # parameters needed for usage tracking PARAM_APPLICATION = "application" @@ -9,6 +9,5 @@ PARAM_INTERNAL_APPLICATION_VERSION = "internal_application_version" APPLICATION_NAME = "SnowflakeSQLAlchemy" -SNOWFLAKE_SQLALCHEMY_VERSION = pkg_resources.get_distribution( - "snowflake-sqlalchemy-decube" -).version +SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4456bbea..f46578ae 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1,20 +1,47 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import itertools import operator import re +from typing import List +from sqlalchemy import exc as sa_exc +from sqlalchemy import inspect, sql from sqlalchemy import util as sa_util from sqlalchemy.engine import default +from sqlalchemy.orm import context +from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import compiler, expression, functions +from sqlalchemy.sql.base import CompileState from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.selectable import Lateral, SelectState from sqlalchemy.util import text_type from sqlalchemy.util.compat import string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage -from .util import _set_connection_interpolate_empty_sequences +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + +from .exc import ( + CustomOptionsAreOnlySupportedOnSnowflakeTables, + UnexpectedOptionTypeError, +) +from .functions import flatten +from .sql.custom_schema.custom_table_base import CustomTableBase +from .sql.custom_schema.options.table_option import TableOption +from .util import ( + _find_left_clause_to_join_from, + _set_connection_interpolate_empty_sequences, + _Snowflake_ORMJoin, + _Snowflake_Selectable_Join, +) RESERVED_WORDS = frozenset( [ @@ -89,6 +116,327 @@ ) +""" +Overwrite methods to handle Snowflake BCR change: +https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 +- _join_determine_implicit_left_side +- _join_left_to_right +""" + + +# handle Snowflake BCR bcr-1057 +@CompileState.plugin_for("default", "select") +class SnowflakeSelectState(SelectState): + def _setup_joins(self, args, raw_columns): + for right, onclause, left, flags in args: + isouter = flags["isouter"] + full = flags["full"] + + if left is None: + ( + left, + replace_from_obj_index, + ) = self._join_determine_implicit_left_side( + raw_columns, left, right, onclause + ) + else: + (replace_from_obj_index) = self._join_place_explicit_left_side(left) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + ( + _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=isouter, + full=full, + ), + ) + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + self.from_clauses = self.from_clauses + ( + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join( + left, right, onclause, isouter=isouter, full=full + ), + ) + + @sa_util.preload_module("sqlalchemy.sql.util") + def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + replace_from_obj_index = None + + from_clauses = self.from_clauses + + if from_clauses: + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(from_clauses, right, onclause) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = from_clauses[replace_from_obj_index] + else: + potential = {} + statement = self.statement + + for from_clause in itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in statement._where_criteria] + ), + ): + potential[from_clause] = () + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(all_clauses, right, onclause) + + if len(indexes) == 1: + left = all_clauses[indexes[0]] + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." + ) + elif not indexes: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + return left, replace_from_obj_index + + +# handle Snowflake BCR bcr-1057 +@sql.base.CompileState.plugin_for("orm", "select") +class SnowflakeORMSelectCompileState(context.ORMSelectCompileState): + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + + elif entities_collection: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(entities_collection): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, context._MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + @args_reducer(positions_to_drop=(6, 7)) + def _join_left_to_right( + self, entities_collection, left, right, onclause, prop, outerjoin, full + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(entities_collection, left) + + if left is right: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) + + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance(entities_collection[use_entity_index], _MapperEntity) + left_clause = entities_collection[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} @@ -165,6 +513,9 @@ class SnowflakeCompiler(compiler.SQLCompiler): def visit_sequence(self, sequence, **kw): return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval" + def visit_now_func(self, now, **kw): + return "CURRENT_TIMESTAMP" + def visit_merge_into(self, merge_into, **kw): clauses = " ".join( clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses @@ -251,9 +602,11 @@ def visit_copy_into(self, copy_into, **kw): [ "{} = {}".format( n, - v._compiler_dispatch(self, **kw) - if getattr(v, "compiler_dispatch", False) - else str(v), + ( + v._compiler_dispatch(self, **kw) + if getattr(v, "compiler_dispatch", False) + else str(v) + ), ) for n, v in options_list ] @@ -276,20 +629,24 @@ def visit_copy_formatter(self, formatter, **kw): return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})" return "FILE_FORMAT=(TYPE={}{})".format( formatter.file_format, - " " - + " ".join( - [ - "{}={}".format( - name, - value._compiler_dispatch(self, **kw) - if hasattr(value, "_compiler_dispatch") - else formatter.value_repr(name, value), - ) - for name, value in options_list - ] - ) - if formatter.options - else "", + ( + " " + + " ".join( + [ + "{}={}".format( + name, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else formatter.value_repr(name, value) + ), + ) + for name, value in options_list + ] + ) + if formatter.options + else "" + ), ) def visit_aws_bucket(self, aws_bucket, **kw): @@ -384,7 +741,14 @@ def visit_regexp_match_op_binary(self, binary, operator, **kw): def visit_regexp_replace_op_binary(self, binary, operator, **kw): string, pattern, flags = self._get_regexp_args(binary, kw) - replacement = self.process(binary.modifiers["replacement"], **kw) + try: + replacement = self.process(binary.modifiers["replacement"], **kw) + except KeyError: + # in sqlalchemy 1.4.49, the internal structure of the expression is changed + # that binary.modifiers doesn't have "replacement": + # https://docs.sqlalchemy.org/en/20/changelog/changelog_14.html#change-1.4.49 + return f"REGEXP_REPLACE({string}, {pattern}{'' if flags is None else f', {flags}'})" + if flags is None: return f"REGEXP_REPLACE({string}, {pattern}, {replacement})" else: @@ -393,6 +757,45 @@ def visit_regexp_replace_op_binary(self, binary, operator, **kw): def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}" + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product(join.left._from_objects, join.right._from_objects) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + + join_statement = ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + ) + + if join.onclause is None and isinstance(join.right, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we do not append the on condition + return join_statement + + return ( + join_statement + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs) + ) + def render_literal_value(self, value, type_): # escape backslash return super().render_literal_value(value, type_).replace("\\", "\\\\") @@ -500,7 +903,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -530,7 +933,7 @@ def post_create_table(self, table): """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( @@ -538,15 +941,39 @@ def post_create_table(self, table): ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [] + invalid_options: List[str] = [] + + for key, option in table.dialect_options[DIALECT_NAME].items(): + if isinstance(option, TableOption): + options.append(option) + elif key not in ["clusterby", "*"]: + invalid_options.append(key) + + if len(invalid_options) > 0: + raise UnexpectedOptionTypeError(sorted(invalid_options)) + + if isinstance(table, CustomTableBase): + options.sort(key=lambda x: (x.priority.value, x.option_name), reverse=True) + for option in options: + text += "\t" + option.render_option(self) + elif len(options) > 0: + raise CustomOptionsAreOnlySupportedOnSnowflakeTables() + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. """ - return "CREATE {}STAGE {}{} URL={}".format( - "OR REPLACE " if create_stage.replace_if_exists else "", + return "CREATE {or_replace}{temporary}STAGE {}{} URL={}".format( create_stage.stage.namespace, create_stage.stage.name, repr(create_stage.container), + or_replace="OR REPLACE " if create_stage.replace_if_exists else "", + temporary="TEMPORARY " if create_stage.temporary else "", ) def visit_create_file_format(self, file_format, **kw): @@ -585,13 +1012,38 @@ def visit_drop_column_comment(self, drop, **kw): ) def visit_identity_column(self, identity, **kw): - text = " IDENTITY" + text = "IDENTITY" if identity.start is not None or identity.increment is not None: start = 1 if identity.start is None else identity.start increment = 1 if identity.increment is None else identity.increment text += f"({start},{increment})" + if identity.order is not None: + order = "ORDER" if identity.order else "NOORDER" + text += f" {order}" return text + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append("INCREMENT BY %d" % identity_options.increment) + if identity_options.start is not None: + text.append("START WITH %d" % identity_options.start) + if identity_options.minvalue is not None: + text.append("MINVALUE %d" % identity_options.minvalue) + if identity_options.maxvalue is not None: + text.append("MAXVALUE %d" % identity_options.maxvalue) + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append("CACHE %d" % identity_options.cache) + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NOORDER") + return " ".join(text) + class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): def visit_BYTEINT(self, type_, **kw): @@ -654,5 +1106,10 @@ def visit_TIMESTAMP(self, type_, **kw): def visit_GEOGRAPHY(self, type_, **kw): return "GEOGRAPHY" + def visit_GEOMETRY(self, type_, **kw): + return "GEOMETRY" + construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten, "snowflake") diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py new file mode 100644 index 00000000..9e97e574 --- /dev/null +++ b/src/snowflake/sqlalchemy/compat.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from __future__ import annotations + +import functools +from typing import Callable + +from sqlalchemy import __version__ as SA_VERSION +from sqlalchemy import util + +string_types = (str,) +returns_unicode = util.symbol("RETURNS_UNICODE") + +IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0) + + +def args_reducer(positions_to_drop: tuple): + """Removes args at positions provided in tuple positions_to_drop. + + For example tuple (3, 5) will remove items at third and fifth position. + Keep in mind that on class methods first postion is cls or self. + """ + + def fn_wrapper(fn: Callable): + @functools.wraps(fn) + def wrapper(*args): + reduced_args = args + if not IS_VERSION_20: + reduced_args = tuple( + arg for idx, arg in enumerate(args) if idx not in positions_to_drop + ) + fn(*reduced_args) + + return wrapper + + return fn_wrapper diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 51ea09f2..15585bd5 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from collections.abc import Sequence @@ -10,7 +10,8 @@ from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.roles import FromClauseRole -from sqlalchemy.util.compat import string_types + +from .compat import string_types NoneType = type(None) @@ -259,7 +260,8 @@ def field_delimiter(self, deli_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext @@ -386,7 +388,8 @@ def compression(self, comp_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext @@ -482,9 +485,10 @@ class CreateStage(DDLElement): __visit_name__ = "create_stage" - def __init__(self, container, stage, replace_if_exists=False): + def __init__(self, container, stage, replace_if_exists=False, *, temporary=False): super().__init__() self.container = container + self.temporary = temporary self.stage = stage self.replace_if_exists = replace_if_exists diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index bf3cc97b..802d1ce1 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import sqlalchemy.types as sqltypes @@ -61,6 +61,10 @@ class GEOGRAPHY(SnowflakeType): __visit_name__ = "GEOGRAPHY" +class GEOMETRY(SnowflakeType): + __visit_name__ = "GEOMETRY" + + class _CUSTOM_Date(SnowflakeType, sqltypes.Date): def literal_processor(self, dialect): def process(value): diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py new file mode 100644 index 00000000..898de279 --- /dev/null +++ b/src/snowflake/sqlalchemy/exc.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from typing import List + +from sqlalchemy.exc import ArgumentError + + +class NoPrimaryKeyError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Table {target} required primary key.") + + +class UnsupportedPrimaryKeysAndForeignKeysError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Primary key and foreign keys are not supported in {target}.") + + +class RequiredParametersNotProvidedError(ArgumentError): + def __init__(self, target: str, parameters: List[str]): + super().__init__( + f"{target} requires the following parameters: %s." % ", ".join(parameters) + ) + + +class UnexpectedTableOptionKeyError(ArgumentError): + def __init__(self, expected: str, actual: str): + super().__init__(f"Expected table option {expected} but got {actual}.") + + +class OptionKeyNotProvidedError(ArgumentError): + def __init__(self, target: str): + super().__init__( + f"Expected option key in {target} option but got NoneType instead." + ) + + +class UnexpectedOptionParameterTypeError(ArgumentError): + def __init__(self, parameter_name: str, target: str, types: List[str]): + super().__init__( + f"Parameter {parameter_name} of {target} requires to be one" + f" of following types: {', '.join(types)}." + ) + + +class CustomOptionsAreOnlySupportedOnSnowflakeTables(ArgumentError): + def __init__(self): + super().__init__( + "Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables." + ) + + +class UnexpectedOptionTypeError(ArgumentError): + def __init__(self, options: List[str]): + super().__init__( + f"The following options are either unsupported or should be defined using a Snowflake table: {', '.join(options)}." + ) + + +class InvalidTableParameterTypeError(ArgumentError): + def __init__(self, name: str, input_type: str, expected_types: List[str]): + expected_types_str = "', '".join(expected_types) + super().__init__( + f"Invalid parameter type '{input_type}' provided for '{name}'. " + f"Expected one of the following types: '{expected_types_str}'.\n" + ) + + +class MultipleErrors(ArgumentError): + def __init__(self, errors): + self.errors = errors + + def __str__(self): + return "".join(str(e) for e in self.errors) diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py new file mode 100644 index 00000000..c08aa734 --- /dev/null +++ b/src/snowflake/sqlalchemy/functions.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import warnings + +from sqlalchemy.sql import functions as sqlfunc + +FLATTEN_WARNING = "For backward compatibility params are not rendered." + + +class flatten(sqlfunc.GenericFunction): + name = "flatten" + + def __init__(self, *args, **kwargs): + warnings.warn(FLATTEN_WARNING, DeprecationWarning, stacklevel=2) + super().__init__(*args, **kwargs) diff --git a/src/snowflake/sqlalchemy/provision.py b/src/snowflake/sqlalchemy/provision.py index 6092e563..2c8368fd 100644 --- a/src/snowflake/sqlalchemy/provision.py +++ b/src/snowflake/sqlalchemy/provision.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from sqlalchemy.testing.provision import set_default_schema_on_connection diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index f9b57ea6..f2844804 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from sqlalchemy.testing import exclusions @@ -289,9 +289,25 @@ def datetime_implicit_bound(self): # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + @property + def date_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + + @property + def time_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + @property def timestamp_microseconds_implicit_bound(self): # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding # parameters in string forms of timestamp values. # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + + @property + def array_type(self): + return exclusions.closed() diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 5003f286..f2fb9b18 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -1,21 +1,21 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import operator from collections import defaultdict from functools import reduce +from typing import Any from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes from sqlalchemy import event as sa_vnt from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util -from sqlalchemy.engine import default, reflection +from sqlalchemy.engine import URL, default, reflection from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.sqltypes import String from sqlalchemy.types import ( BIGINT, BINARY, @@ -38,8 +38,11 @@ ) from snowflake.connector import errors as sf_errors +from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 +from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -51,6 +54,7 @@ _CUSTOM_DECIMAL, ARRAY, GEOGRAPHY, + GEOMETRY, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -61,6 +65,12 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -104,11 +114,14 @@ "OBJECT": OBJECT, "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, } +_ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True + class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True @@ -128,7 +141,7 @@ class SnowflakeDialect(default.DefaultDialect): # unicode strings supports_unicode_statements = True supports_unicode_binds = True - returns_unicode_strings = String.RETURNS_UNICODE + returns_unicode_strings = returns_unicode description_encoding = None # No lastrowid support. See SNOW-11155 @@ -189,11 +202,35 @@ class SnowflakeDialect(default.DefaultDialect): @classmethod def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def import_dbapi(cls): from snowflake import connector return connector - def create_connect_args(self, url): + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: name_spaces = [unquote_plus(e) for e in opts["database"].split("/")] @@ -206,7 +243,11 @@ def create_connect_args(self, url): raise sa_exc.ArgumentError( f"Invalid name space is specified: {opts['database']}" ) - if ".snowflakecomputing.com" not in opts["host"] and not opts.get("port"): + if ( + "host" in opts + and ".snowflakecomputing.com" not in opts["host"] + and not opts.get("port") + ): opts["account"] = opts["host"] if "." in opts["account"]: # remove region subdomain @@ -216,26 +257,34 @@ def create_connect_args(self, url): opts["host"] = opts["host"] + ".snowflakecomputing.com" opts["port"] = "443" opts["autocommit"] = False # autocommit is disabled by default - opts.update(url.query) + + query = dict(**url.query) # make mutable + cache_column_metadata = query.pop("cache_column_metadata", None) self._cache_column_metadata = ( - opts.get("cache_column_metadata", "false").lower() == "true" + parse_url_boolean(cache_column_metadata) if cache_column_metadata else False ) + + # URL sets the query parameter values as strings, we need to cast to expected types when necessary + for name, value in query.items(): + opts[name] = self.parse_query_param_type(name, value) + return ([], opts) - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): """ Checks if the table exists """ return self._has_object(connection, "TABLE", table_name, schema) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): """ Checks if the sequence exists """ return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( @@ -304,14 +353,6 @@ def _map_name_to_idx(result): name_to_idx[col[0]] = idx return name_to_idx - @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): - """ - Gets all indexes - """ - # no index is supported by Snowflake - return [] - @reflection.cache def get_check_constraints(self, connection, table_name, schema, **kw): # check constraints are not supported by Snowflake @@ -555,11 +596,13 @@ def _get_schema_columns(self, connection, schema, **kw): "autoincrement": is_identity == "YES", "comment": comment, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) if is_identity == "YES": @@ -648,13 +691,19 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw): "autoincrement": is_identity == "YES", "comment": comment if comment != "" else None, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) + + # If we didn't find any columns for the table, the table doesn't exist. + if len(ans) == 0: + raise sa_exc.NoSuchTableError() return ans def get_columns(self, connection, table_name, schema=None, **kw): @@ -669,7 +718,10 @@ def get_columns(self, connection, table_name, schema=None, **kw): if schema_columns is None: # Too many results, fall back to only query about single table return self._get_table_columns(connection, table_name, schema, **kw) - return schema_columns[self.normalize_name(table_name)] + normalized_table_name = self.normalize_name(table_name) + if normalized_table_name not in schema_columns: + raise sa_exc.NoSuchTableError() + return schema_columns[normalized_table_name] @reflection.cache def get_table_names(self, connection, schema=None, **kw): @@ -829,16 +881,159 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): result = self._get_view_comment(connection, table_name, schema) return { - "text": result._mapping["comment"] - if result and result._mapping["comment"] - else None + "text": ( + result._mapping["comment"] + if result and result._mapping["comment"] + else None + ) } + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + **kw, + ): + """ + Gets the indexes definition + """ + + table_prefixes = self.get_multi_prefixes( + connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + ) + if len(table_prefixes) == 0: + return [] + schema = schema or self.default_schema_name + if not schema: + result = connection.execute( + text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + indexes = {} + + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["table"]])) + if ( + row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' + or table not in filter_names + or (schema, table) not in table_prefixes + or ( + (schema, table) in table_prefixes + and CustomTablePrefix.HYBRID.name + not in table_prefixes[(schema, table)] + ) + ): + continue + index = { + "name": row[n2i["name"]], + "unique": row[n2i["is_unique"]] == "Y", + "column_names": row[n2i["columns"]], + "include_columns": row[n2i["included_columns"]], + "dialect_options": {}, + } + if (schema, table) in indexes: + indexes[(schema, table)] = indexes[(schema, table)].append(index) + else: + indexes[(schema, table)] = [index] + + return list(indexes.items()) + + def _value_or_default(self, data, table, schema): + table = self.normalize_name(str(table)) + dic_data = dict(data) + if (schema, table) in dic_data: + return dic_data[(schema, table)] + else: + return [] + + def get_prefixes_from_data(self, n2i, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in n2i and row[n2i[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + + @reflection.cache + def get_multi_prefixes( + self, connection, schema, table_name=None, filter_prefix=None, **kw + ): + """ + Gets all table prefixes + """ + schema = schema or self.default_schema_name + filter = f"LIKE '{table_name}'" if table_name else "" + if schema: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" + ) + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + tables_prefixes = {} + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["name"]])) + table_prefixes = self.get_prefixes_from_data(n2i, row) + if filter_prefix and filter_prefix not in table_prefixes: + continue + if (schema, table) in tables_prefixes: + tables_prefixes[(schema, table)].append(table_prefixes) + else: + tables_prefixes[(schema, table)] = table_prefixes + + return tables_prefixes + + @reflection.cache + def get_indexes(self, connection, tablename, schema, **kw): + """ + Gets the indexes definition + """ + table_name = self.normalize_name(str(tablename)) + data = self.get_multi_indexes( + connection=connection, schema=schema, filter_names=[table_name], **kw + ) + + return self._value_or_default(data, table_name, schema) + + def connect(self, *cargs, **cparams): + return ( + super().connect( + *cargs, + **( + _update_connection_application_name(**cparams) + if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME + else cparams + ), + ) + if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME + else super().connect(*cargs, **cparams) + ) + @sa_vnt.listens_for(Table, "before_create") def check_table(table, connection, _ddl_runner, **kw): + from .sql.custom_schema.hybrid_table import HybridTable + + if HybridTable.is_equal_type(table): # noqa + return True if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes: - raise NotImplementedError("Snowflake does not support indexes") + raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes") dialect = SnowflakeDialect diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..66b9270f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable +from .hybrid_table import HybridTable + +__all__ = ["DynamicTable", "HybridTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..671c6957 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,111 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, List + +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from ...exc import ( + MultipleErrors, + NoPrimaryKeyError, + RequiredParametersNotProvidedError, + UnsupportedPrimaryKeysAndForeignKeysError, +) +from .custom_table_prefix import CustomTablePrefix +from .options.invalid_table_option import InvalidTableOption +from .options.table_option import TableOption, TableOptionKey + + +class CustomTableBase(Table): + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True + _enforce_primary_keys: bool = False + _required_parameters: List[TableOptionKey] = [] + + @property + def table_prefixes(self) -> typing.List[str]: + return [prefix.name for prefix in self.__table_prefixes__] + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + self.table_prefixes + kw.update(prefixes=prefixes) + + if not IS_VERSION_20 and hasattr(super(), "_init"): + kw.pop("_no_init", True) + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + exceptions: List[Exception] = [] + + for _, option in self.dialect_options[DIALECT_NAME].items(): + if isinstance(option, InvalidTableOption): + exceptions.append(option.exception) + + if isinstance(self.key, NoneType) and self._enforce_primary_keys: + exceptions.append(NoPrimaryKeyError(self.__class__.__name__)) + missing_parameters: List[str] = [] + + for required_parameter in self._required_parameters: + if isinstance(self._get_dialect_option(required_parameter), NoneType): + missing_parameters.append(required_parameter.name.lower()) + if missing_parameters: + exceptions.append( + RequiredParametersNotProvidedError( + self.__class__.__name__, missing_parameters + ) + ) + + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + exceptions.append( + UnsupportedPrimaryKeysAndForeignKeysError(self.__class__.__name__) + ) + + if len(exceptions) > 1: + exceptions.sort(key=lambda e: str(e)) + raise MultipleErrors(exceptions) + elif len(exceptions) == 1: + raise exceptions[0] + + def _get_dialect_option( + self, option_name: TableOptionKey + ) -> typing.Optional[TableOption]: + if option_name.value in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name.value] + return None + + def _as_dialect_options( + self, table_options: List[TableOption] + ) -> typing.Dict[str, TableOption]: + result = {} + for table_option in table_options: + if isinstance(table_option, TableOption) and isinstance( + table_option.option_name, str + ): + result[DIALECT_NAME + "_" + table_option.option_name] = table_option + return result + + @classmethod + def is_equal_type(cls, table: Table) -> bool: + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..de7835d1 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = 1 + EVENT = 2 + HYBRID = 3 + ICEBERG = 4 + DYNAMIC = 5 diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..6db4312d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any, Union + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import ( + IdentifierOption, + IdentifierOptionType, + KeywordOptionType, + LiteralOption, + TableOptionKey, + TargetLagOption, + TargetLagOptionType, +) +from .options.keyword_option import KeywordOption +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using option values: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse='warehouse_name', + refresh_mode=SnowflakeKeyword.AUTO + as_query="SELECT id, name from test_table_1;" + ) + + Example using full options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=TargetLag(1, TimeUnit.HOURS), + warehouse=Identifier('warehouse_name'), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO) + as_query=AsQuery("SELECT id, name from test_table_1;") + ) + """ + + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] + _support_primary_and_foreign_keys = False + _required_parameters = [ + TableOptionKey.WAREHOUSE, + TableOptionKey.AS_QUERY, + TableOptionKey.TARGET_LAG, + ] + + @property + def warehouse(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.WAREHOUSE) + + @property + def target_lag(self) -> typing.Optional[TargetLagOption]: + return self._get_dialect_option(TableOptionKey.TARGET_LAG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + warehouse: IdentifierOptionType = None, + target_lag: Union[TargetLagOptionType, KeywordOptionType] = None, + refresh_mode: KeywordOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + IdentifierOption.create(TableOptionKey.WAREHOUSE, warehouse), + TargetLagOption.create(target_lag), + KeywordOption.create(TableOptionKey.REFRESH_MODE, refresh_mode), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py new file mode 100644 index 00000000..b7c29e78 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix + + +class HybridTable(CustomTableBase): + """ + A class representing a hybrid table with configurable options and settings. + + The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table + + Example usage: + HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String) + ) + """ + + __table_prefixes__ = [CustomTablePrefix.HYBRID] + _enforce_primary_keys: bool = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "HybridTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..11b54c1a --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query_option import AsQueryOption, AsQueryOptionType +from .identifier_option import IdentifierOption, IdentifierOptionType +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .literal_option import LiteralOption, LiteralOptionType +from .table_option import TableOptionKey +from .target_lag_option import TargetLagOption, TargetLagOptionType, TimeUnit + +__all__ = [ + # Options + "IdentifierOption", + "LiteralOption", + "KeywordOption", + "AsQueryOption", + "TargetLagOption", + # Enums + "TimeUnit", + "SnowflakeKeyword", + "TableOptionKey", + # Types + "IdentifierOptionType", + "LiteralOptionType", + "AsQueryOptionType", + "TargetLagOptionType", + "KeywordOptionType", +] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py new file mode 100644 index 00000000..93994abc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from sqlalchemy.sql import Selectable + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class AsQueryOption(TableOption): + """Class to represent an AS clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + Example: + as_query=AsQueryOption('select name, address from existing_table where name = "test"') + + is equivalent to: + + as select name, address from existing_table where name = "test" + """ + + def __init__(self, query: Union[str, Selectable]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.AS_QUERY + self.query = query + + @staticmethod + def create( + value: Optional[Union["AsQueryOption", str, Selectable]] + ) -> "TableOption": + if isinstance(value, (NoneType, AsQueryOption)): + return value + if isinstance(value, (str, Selectable)): + return AsQueryOption(value) + return TableOption._get_invalid_table_option( + TableOptionKey.AS_QUERY, + str(type(value).__name__), + [AsQueryOption.__name__, str.__name__, Selectable.__name__], + ) + + def template(self) -> str: + return "AS %s" + + @property + def priority(self) -> Priority: + return Priority.LOWEST + + def __get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "AsQueryOption(%s)" % self.__get_expression() + + +AsQueryOptionType = Union[AsQueryOption, str, Selectable] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py new file mode 100644 index 00000000..b296898b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class IdentifierOption(TableOption): + """Class to represent an identifier option in Snowflake Tables. + + Example: + warehouse = IdentifierOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = my_warehouse + """ + + def __init__(self, value: Union[str]) -> None: + super().__init__() + self.value: str = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, "IdentifierOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + + if isinstance(value, str): + value = IdentifierOption(value) + + if isinstance(value, IdentifierOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, str(type(value).__name__), [IdentifierOption.__name__, str.__name__] + ) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"IdentifierOption(value='{self.value}'{option_name})" + + +IdentifierOptionType = Union[IdentifierOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py new file mode 100644 index 00000000..2bdc9dd3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption, TableOptionKey + + +class InvalidTableOption(TableOption): + """Class to store errors and raise them after table initialization in order to avoid recursion error.""" + + def __init__(self, name: TableOptionKey, value: Exception) -> None: + super().__init__() + self.exception: Exception = value + self._name = name + + @staticmethod + def create(name: TableOptionKey, value: Exception) -> Optional[TableOption]: + return InvalidTableOption(name, value) + + def _render(self, compiler) -> str: + raise self.exception + + def __repr__(self) -> str: + return f"ErrorOption(value='{self.exception}')" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py new file mode 100644 index 00000000..ff6b444d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class KeywordOption(TableOption): + """Class to represent a keyword option in Snowflake Tables. + + Example: + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + """ + + def __init__(self, value: Union[SnowflakeKeyword]) -> None: + super().__init__() + self.value: str = value.value + + @property + def priority(self): + return Priority.HIGH + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value.upper() + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[SnowflakeKeyword, "KeywordOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + if isinstance(value, SnowflakeKeyword): + value = KeywordOption(value) + + if isinstance(value, KeywordOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [KeywordOption.__name__, SnowflakeKeyword.__name__], + ) + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if isinstance(self.option_name, NoneType) + else "" + ) + return f"KeywordOption(value='{self.value}'{option_name})" + + +KeywordOptionType = Union[KeywordOption, SnowflakeKeyword] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py new file mode 100644 index 00000000..f4ba87ea --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class SnowflakeKeyword(Enum): + # TARGET_LAG + DOWNSTREAM = "DOWNSTREAM" + + # REFRESH_MODE + AUTO = "AUTO" + FULL = "FULL" + INCREMENTAL = "INCREMENTAL" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py new file mode 100644 index 00000000..55dd7675 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any, Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class LiteralOption(TableOption): + """Class to represent a literal option in Snowflake Table. + + Example: + warehouse = LiteralOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = 'my_warehouse' + """ + + def __init__(self, value: Union[int, str]) -> None: + super().__init__() + self.value: Any = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, int, "LiteralOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + if isinstance(value, (str, int)): + value = LiteralOption(value) + + if isinstance(value, LiteralOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [LiteralOption.__name__, str.__name__, int.__name__], + ) + + def template(self) -> str: + if isinstance(self.value, int): + return f"{self.option_name.upper()} = %d" + else: + return f"{self.option_name.upper()} = '%s'" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"LiteralOption(value='{self.value}'{option_name})" + + +LiteralOptionType = Union[LiteralOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..14b91f2e --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum +from typing import List, Optional + +from snowflake.sqlalchemy import exc +from snowflake.sqlalchemy.custom_commands import NoneType + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOption: + + def __init__(self) -> None: + self._name: Optional[TableOptionKey] = None + + @property + def option_name(self) -> str: + if isinstance(self._name, NoneType): + return None + return str(self._name.value) + + def _set_option_name(self, name: Optional["TableOptionKey"]): + self._name = name + + @property + def priority(self) -> Priority: + return Priority.MEDIUM + + @staticmethod + def create(**kwargs) -> "TableOption": + raise NotImplementedError + + @staticmethod + def _get_invalid_table_option( + parameter_name: "TableOptionKey", input_type: str, expected_types: List[str] + ) -> "TableOption": + from .invalid_table_option import InvalidTableOption + + return InvalidTableOption( + parameter_name, + exc.InvalidTableParameterTypeError( + parameter_name.value, input_type, expected_types + ), + ) + + def _validate_option(self): + if isinstance(self.option_name, NoneType): + raise exc.OptionKeyNotProvidedError(self.__class__.__name__) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def render_option(self, compiler) -> str: + self._validate_option() + return self._render(compiler) + + def _render(self, compiler) -> str: + raise NotImplementedError + + +class TableOptionKey(Enum): + AS_QUERY = "as_query" + BASE_LOCATION = "base_location" + CATALOG = "catalog" + CATALOG_SYNC = "catalog_sync" + DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" + DEFAULT_DDL_COLLATION = "default_ddl_collation" + EXTERNAL_VOLUME = "external_volume" + MAX_DATA_EXTENSION_TIME_IN_DAYS = "max_data_extension_time_in_days" + REFRESH_MODE = "refresh_mode" + STORAGE_SERIALIZATION_POLICY = "storage_serialization_policy" + TARGET_LAG = "target_lag" + WAREHOUSE = "warehouse" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py new file mode 100644 index 00000000..7c1c0825 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional, Tuple, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + + +class TargetLagOption(TableOption): + """Class to represent the target lag clause in Dynamic Tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using the time and unit parameters: + + target_lag = TargetLagOption(10, TimeUnit.SECONDS) + + is equivalent to: + + TARGET_LAG = '10 SECONDS' + + Example using keyword parameter: + + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + + """ + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + ) -> None: + super().__init__() + self.time = time + self.unit = unit + self._name: TableOptionKey = TableOptionKey.TARGET_LAG + + @staticmethod + def create( + value: Union["TargetLagOption", Tuple[int, TimeUnit], KeywordOptionType] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + + if isinstance(value, Tuple): + time, unit = value + value = TargetLagOption(time, unit) + + if isinstance(value, TargetLagOption): + return value + + if isinstance(value, (KeywordOption, SnowflakeKeyword)): + return KeywordOption.create(TableOptionKey.TARGET_LAG, value) + + return TableOption._get_invalid_table_option( + TableOptionKey.TARGET_LAG, + str(type(value).__name__), + [ + TargetLagOption.__name__, + f"Tuple[int, {TimeUnit.__name__}])", + SnowflakeKeyword.__name__, + ], + ) + + def __get_expression(self): + return f"'{str(self.time)} {str(self.unit.value)}'" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "TargetLagOption(%s)" % self.__get_expression() + + +TargetLagOptionType = Union[TargetLagOption, Tuple[int, TimeUnit]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..fccc7a0b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .options.as_query_option import AsQueryOption, AsQueryOptionType +from .options.table_option import TableOptionKey + + +class TableFromQueryBase(CustomTableBase): + + @property + def as_query(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.AS_QUERY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + as_query: AsQueryOptionType = None, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query = AsQueryOption.create(as_query) # noqa + kw.update(self._as_dialect_options([as_query])) + if ( + isinstance(as_query, AsQueryOption) + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 54a20086..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -1,15 +1,25 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import re +from itertools import chain from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc +from sqlalchemy import exc, inspection, sql +from sqlalchemy.exc import NoForeignKeysError +from sqlalchemy.orm.interfaces import MapperProperty +from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin +from sqlalchemy.orm.util import attributes +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.base import _expand_cloned, _from_objects +from sqlalchemy.sql.elements import _find_columns +from sqlalchemy.sql.selectable import Join, Lateral, coercions, operators, roles from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection +from snowflake.sqlalchemy import compat from ._constants import ( APPLICATION_NAME, @@ -104,3 +114,231 @@ def _update_connection_application_name(**conn_kwargs: Any) -> Any: if PARAM_INTERNAL_APPLICATION_VERSION not in conn_kwargs: conn_kwargs[PARAM_INTERNAL_APPLICATION_VERSION] = SNOWFLAKE_SQLALCHEMY_VERSION return conn_kwargs + + +def parse_url_boolean(value: str) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise ValueError(f"Invalid boolean value detected: '{value}'") + + +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + +# handle Snowflake BCR bcr-1057 +# the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState +# which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that +# can not handle the BCR change, we implement it in a way that lateral join does not need onclause +def _find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presence of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif onclause is not None or Join._can_join(f, s): + idx.append(i) + break + elif onclause is None and isinstance(s, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we append the index indicating Lateral clause can be joined from with without onclause. + idx.append(i) + break + + if len(idx) > 1: + # this is the same "hide froms" logic from + # Selectable._get_display_froms + toremove = set(chain(*[_expand_cloned(f._hide_froms) for f in clauses])) + idx = [i for i in idx if clauses[i] not in toremove] + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) + else: + return idx + + +class _Snowflake_ORMJoin(sa_orm_util_ORMJoin): + def __init__( + self, + left, + right, + onclause=None, + isouter=False, + full=False, + _left_memo=None, + _right_memo=None, + _extra_criteria=(), + ): + left_info = inspection.inspect(left) + + right_info = inspection.inspect(right) + adapt_to = right_info.selectable + + # used by joined eager loader + self._left_memo = _left_memo + self._right_memo = _right_memo + + # legacy, for string attr name ON clause. if that's removed + # then the "_joined_from_info" concept can go + left_orm_info = getattr(left, "_joined_from_info", left_info) + self._joined_from_info = right_info + if isinstance(onclause, compat.string_types): + onclause = getattr(left_orm_info.entity, onclause) + # #### + + if isinstance(onclause, attributes.QueryableAttribute): + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + _extra_criteria += onclause._extra_criteria + elif isinstance(onclause, MapperProperty): + # used internally by joined eager loader...possibly not ideal + prop = onclause + on_selectable = prop.parent.selectable + else: + prop = None + + if prop: + left_selectable = left_info.selectable + + if sql_util.clause_is_present(on_selectable, left_selectable): + adapt_from = on_selectable + else: + adapt_from = left_selectable + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + of_type_entity=right_info, + alias_secondary=True, + extra_criteria=_extra_criteria, + ) + + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + + self._target_adapter = target_adapter + + # we don't use the normal coercions logic for _ORMJoin + # (probably should), so do some gymnastics to get the entity. + # logic here is for #8721, which was a major bug in 1.4 + # for almost two years, not reported/fixed until 1.4.43 (!) + if left_info.is_selectable: + parententity = left_selectable._annotations.get("parententity", None) + elif left_info.is_mapper or left_info.is_aliased_class: + parententity = left_info + else: + parententity = None + + if parententity is not None: + self._annotations = self._annotations.union( + {"parententity": parententity} + ) + + augment_onclause = onclause is None and _extra_criteria + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join.__init__(self, left, right, onclause, isouter, full) + + if augment_onclause: + self.onclause &= sql.and_(*_extra_criteria) + + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single + ): + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if single_crit is not None: + if right_info.is_aliased_class: + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit + + +class _Snowflake_Selectable_Join(Join): + def __init__(self, left, right, onclause=None, isouter=False, full=False): + """Construct a new :class:`_expression.Join`. + + The usual entrypoint here is the :func:`_expression.join` + function or the :meth:`_expression.FromClause.join` method of any + :class:`_expression.FromClause` object. + + """ + self.left = coercions.expect(roles.FromClauseRole, left, deannotate=True) + self.right = coercions.expect( + roles.FromClauseRole, right, deannotate=True + ).self_group() + + if onclause is None: + try: + self.onclause = self._match_primaries(self.left, self.right) + except NoForeignKeysError: + # handle Snowflake BCR bcr-1057 + if isinstance(self.right, Lateral): + self.onclause = None + else: + raise + else: + # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba + # not merged yet + self.onclause = coercions.expect(roles.OnClauseRole, onclause).self_group( + against=operators._asbool + ) + + self.isouter = isouter + self.full = full diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py new file mode 100644 index 00000000..d90f706b --- /dev/null +++ b/src/snowflake/sqlalchemy/version.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +# Update this for the versions +# Don't change the forth version number from None +VERSION = "1.6.1" diff --git a/tests/__init__.py b/tests/__init__.py index 0fc6abfa..ef416f64 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,3 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_orm.ambr b/tests/__snapshots__/test_orm.ambr new file mode 100644 index 00000000..2116e9e9 --- /dev/null +++ b/tests/__snapshots__/test_orm.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_orm_one_to_many_relationship_with_hybrid_table + ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.') +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/conftest.py b/tests/conftest.py index 84e29efd..d4dab3d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,29 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from __future__ import annotations import os import sys import time import uuid -from functools import partial from logging import getLogger import pytest from sqlalchemy import create_engine +from sqlalchemy.pool import NullPool import snowflake.connector +import snowflake.connector.connection from snowflake.connector.compat import IS_WINDOWS from snowflake.sqlalchemy import URL, dialect +from snowflake.sqlalchemy._constants import ( + APPLICATION_NAME, + PARAM_APPLICATION, + PARAM_INTERNAL_APPLICATION_NAME, + PARAM_INTERNAL_APPLICATION_VERSION, + SNOWFLAKE_SQLALCHEMY_VERSION, +) from .parameters import CONNECTION_PARAMETERS @@ -23,19 +32,18 @@ INTERNAL_SKIP_TAGS = {"external"} RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" -TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" - -create_engine_with_future_flag = create_engine +snowflake.connector.connection.DEFAULT_CONFIGURATION[PARAM_APPLICATION] = ( + APPLICATION_NAME, + (type(None), str), +) +snowflake.connector.connection.DEFAULT_CONFIGURATION[ + PARAM_INTERNAL_APPLICATION_NAME +] = (APPLICATION_NAME, (type(None), str)) +snowflake.connector.connection.DEFAULT_CONFIGURATION[ + PARAM_INTERNAL_APPLICATION_VERSION +] = (SNOWFLAKE_SQLALCHEMY_VERSION, (type(None), str)) - -def pytest_addoption(parser): - parser.addoption( - "--run_v20_sqlalchemy", - help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported." - "Turning on this option will set future flag to True on Engine and Session objects according to" - "the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html", - action="store_true", - ) +TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" @pytest.fixture(scope="session") @@ -83,10 +91,10 @@ def help(): @pytest.fixture(scope="session") def db_parameters(): - return get_db_parameters() + yield get_db_parameters() -def get_db_parameters(): +def get_db_parameters() -> dict: """ Sets the db connection parameters """ @@ -94,12 +102,9 @@ def get_db_parameters(): os.environ["TZ"] = "UTC" if not IS_WINDOWS: time.tzset() - for k, v in CONNECTION_PARAMETERS.items(): - ret[k] = v - for k, v in DEFAULT_PARAMETERS.items(): - if k not in ret: - ret[k] = v + ret.update(DEFAULT_PARAMETERS) + ret.update(CONNECTION_PARAMETERS) if "account" in ret and ret["account"] == DEFAULT_PARAMETERS["account"]: help() @@ -134,43 +139,50 @@ def get_db_parameters(): return ret -def get_engine(user=None, password=None, account=None, schema=None): - """ - Creates a connection using the parameters defined in JDBC connect string - """ - ret = get_db_parameters() - - if user is not None: - ret["user"] = user - if password is not None: - ret["password"] = password - if account is not None: - ret["account"] = account - - from sqlalchemy.pool import NullPool - - engine = create_engine_with_future_flag( - URL( - user=ret["user"], - password=ret["password"], - host=ret["host"], - port=ret["port"], - database=ret["database"], - schema=TEST_SCHEMA if not schema else schema, - account=ret["account"], - protocol=ret["protocol"], - ), - poolclass=NullPool, - ) +def url_factory(**kwargs) -> URL: + url_params = get_db_parameters() + url_params.update(kwargs) + return URL(**url_params) + - return engine, ret +def get_engine(url: URL, **engine_kwargs): + engine_params = { + "poolclass": NullPool, + "future": True, + "echo": True, + } + engine_params.update(engine_kwargs) + engine = create_engine(url, **engine_params) + return engine @pytest.fixture() def engine_testaccount(request): - engine, _ = get_engine() + url = url_factory() + engine = get_engine(url) request.addfinalizer(engine.dispose) - return engine + yield engine + + +@pytest.fixture() +def engine_testaccount_with_numpy(request): + url = url_factory(numpy=True) + engine = get_engine(url) + request.addfinalizer(engine.dispose) + yield engine + + +@pytest.fixture() +def engine_testaccount_with_qmark(request): + snowflake.connector.paramstyle = "qmark" + + url = url_factory() + engine = get_engine(url) + request.addfinalizer(engine.dispose) + + yield engine + + snowflake.connector.paramstyle = "pyformat" @pytest.fixture(scope="session", autouse=True) @@ -213,19 +225,6 @@ def sql_compiler(): ).replace("\n", "") -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - -def pytest_sessionstart(session): - # patch the create_engine with future flag - global create_engine_with_future_flag - create_engine_with_future_flag = partial( - create_engine, future=session.config.option.run_v20_sqlalchemy - ) - - def running_on_public_ci() -> bool: """Whether or not tests are currently running on one of our public CIs.""" return os.getenv("GITHUB_ACTIONS") == "true" diff --git a/tests/connector_regression b/tests/connector_regression deleted file mode 160000 index a3132409..00000000 --- a/tests/connector_regression +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a313240982f8182ddfa11d031b335ad7d3db3c6f diff --git a/tests/custom_tables/__init__.py b/tests/custom_tables/__init__.py new file mode 100644 index 00000000..d43f066c --- /dev/null +++ b/tests/custom_tables/__init__.py @@ -0,0 +1,2 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..66c8f98e --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,40 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE DYNAMIC TABLE "SCHEMA_DB".test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = \'10 seconds\'\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_multiple_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_one_wrong_option_types + ''' + Invalid parameter type 'LiteralOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_options_objects + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.AUTO] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.FULL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = FULL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.INCREMENTAL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = INCREMENTAL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr new file mode 100644 index 00000000..9412fb45 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_compile_hybrid_table + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_orm + 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr new file mode 100644 index 00000000..80201495 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_dynamic_table_without_dynamictable_and_defined_options + CustomOptionsAreOnlySupportedOnSnowflakeTables('Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables.') +# --- +# name: test_create_dynamic_table_without_dynamictable_class + UnexpectedOptionTypeError('The following options are either unsupported or should be defined using a Snowflake table: as_query, warehouse.') +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr new file mode 100644 index 00000000..696ff9c8 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_hybrid_table + "[(1, 'test')]" +# --- +# name: test_create_hybrid_table_with_multiple_index + ProgrammingError("(snowflake.connector.errors.ProgrammingError) 391480 (0A000): Another index is being built on table 'TEST_HYBRID_TABLE_WITH_MULTIPLE_INDEX'. Only one index can be built at a time. Either cancel the other index creation or wait until it is complete.") +# --- diff --git a/tests/custom_tables/__snapshots__/test_generic_options.ambr b/tests/custom_tables/__snapshots__/test_generic_options.ambr new file mode 100644 index 00000000..eef5e6fd --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_generic_options.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_identifier_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'.\n") +# --- +# name: test_identifier_option_without_name + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_invalid_as_query_option + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'.\n") +# --- +# name: test_literal_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'SnowflakeKeyword' provided for 'warehouse'. Expected one of the following types: 'LiteralOption', 'str', 'int'.\n") +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr new file mode 100644 index 00000000..6f6cd395 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_simple_reflection_hybrid_table_as_table + 'CREATE TABLE test_hybrid_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py new file mode 100644 index 00000000..935c61cd --- /dev/null +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.exc import MultipleErrors +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TargetLagOption, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.parametrize( + "refresh_mode_keyword", + [ + SnowflakeKeyword.AUTO, + SnowflakeKeyword.FULL, + SnowflakeKeyword.INCREMENTAL, + ], +) +def test_compile_dynamic_table_with_refresh_mode( + sql_compiler, snapshot, refresh_mode_keyword +): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + refresh_mode=refresh_mode_keyword, + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=IdentifierOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(ArgumentError) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=LiteralOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_with_multiple_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(MultipleErrors) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=IdentifierOption(SnowflakeKeyword.AUTO), + warehouse=KeywordOption(SnowflakeKeyword.AUTO), + as_query=KeywordOption(SnowflakeKeyword.AUTO), + refresh_mode=IdentifierOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DynamicTable requires the following parameters: warehouse, " + "as_query, target_lag.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DynamicTable.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DynamicTable.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + ForeignKeyConstraint(["id"], ["table.id"]), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "target_lag": (10, TimeUnit.SECONDS), + "warehouse": "warehouse", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py new file mode 100644 index 00000000..f1af6dc2 --- /dev/null +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, HybridTable + + +@pytest.mark.aws +def test_compile_hybrid_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.aws +def test_compile_hybrid_table_orm(sql_compiler, snapshot): + Base = declarative_base() + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestHybridTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py new file mode 100644 index 00000000..b583faad --- /dev/null +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable, exc +from snowflake.sqlalchemy.sql.custom_schema.options.as_query_option import AsQueryOption +from snowflake.sqlalchemy.sql.custom_schema.options.identifier_option import ( + IdentifierOption, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword +from snowflake.sqlalchemy.sql.custom_schema.options.table_option import TableOptionKey +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag_option import ( + TargetLagOption, + TimeUnit, +) + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse=warehouse, + as_query="SELECT id, name from test_table_1;", + refresh_mode=SnowflakeKeyword.FULL, + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_warehouse=warehouse, + snowflake_as_query="SELECT id, name from test_table_1;", + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.UnexpectedOptionTypeError) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot + + +def test_create_dynamic_table_without_dynamictable_and_defined_options( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_target_lag=TargetLagOption.create((1, TimeUnit.HOURS)), + snowflake_warehouse=IdentifierOption.create( + TableOptionKey.WAREHOUSE, warehouse + ), + snowflake_as_query=AsQueryOption.create("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.CustomOptionsAreOnlySupportedOnSnowflakeTables) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py new file mode 100644 index 00000000..43ae3ab6 --- /dev/null +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy.exc +from sqlalchemy import Column, Index, Integer, MetaData, String, select +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import HybridTable + + +@pytest.mark.aws +def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): + metadata = MetaData() + table_name = "test_create_hybrid_table" + + dynamic_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = dynamic_test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_multiple_index( + engine_testaccount, db_parameters, snapshot, sql_compiler +): + metadata = MetaData() + table_name = "test_hybrid_table_with_multiple_index" + + hybrid_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, index=True), + Column("name2", String), + Column("name3", String), + ) + + metadata.create_all(engine_testaccount) + + index = Index("idx_col34", hybrid_test_table_1.c.name2, hybrid_test_table_1.c.name3) + + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + index.create(engine_testaccount) + try: + assert exc_info.value == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_generic_options.py b/tests/custom_tables/test_generic_options.py new file mode 100644 index 00000000..916b94c6 --- /dev/null +++ b/tests/custom_tables/test_generic_options.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest + +from snowflake.sqlalchemy import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + exc, +) +from snowflake.sqlalchemy.sql.custom_schema.options.invalid_table_option import ( + InvalidTableOption, +) + + +def test_identifier_option(): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert identifier.render_option(None) == "WAREHOUSE = xsmall" + + +def test_literal_option(): + literal = LiteralOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert literal.render_option(None) == "WAREHOUSE = 'xsmall'" + + +def test_identifier_option_without_name(snapshot): + identifier = IdentifierOption("xsmall") + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_identifier_option_with_wrong_type(snapshot): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, 23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_literal_option_with_wrong_type(snapshot): + literal = LiteralOption.create( + TableOptionKey.WAREHOUSE, SnowflakeKeyword.DOWNSTREAM + ) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + literal.render_option(None) + assert exc_info.value == snapshot + + +def test_invalid_as_query_option(snapshot): + as_query = AsQueryOption.create(23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + as_query.render_option(None) + assert exc_info.value == snapshot + + +@pytest.mark.parametrize( + "table_option", + [ + IdentifierOption, + LiteralOption, + KeywordOption, + ], +) +def test_generic_option_with_wrong_type(table_option): + literal = table_option.create(TableOptionKey.WAREHOUSE, 0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" + + +@pytest.mark.parametrize( + "table_option", + [ + TargetLagOption, + AsQueryOption, + ], +) +def test_non_generic_option_with_wrong_type(table_option): + literal = table_option.create(0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py new file mode 100644 index 00000000..52eb4457 --- /dev/null +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert isinstance(dynamic_test_table.warehouse, NoneType) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_hybrid_table.py b/tests/custom_tables/test_reflect_hybrid_table.py new file mode 100644 index 00000000..4a777bf0 --- /dev/null +++ b/tests/custom_tables/test_reflect_hybrid_table.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_hybrid_table_reflection" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX index_name (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + hybrid_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + + constraint = hybrid_test_table.constraints.pop() + constraint.name = "demo_name" + hybrid_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(hybrid_test_table) + + actual = sql_compiler(value) + + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_reflect_hybrid_table_with_index( + engine_testaccount, db_parameters, sql_compiler +): + metadata = MetaData() + schema = db_parameters["schema"] + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + table = Table(table_name, metadata, schema=schema, autoload_with=engine_testaccount) + + try: + assert len(table.indexes) == 1 and table.indexes.pop().name == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/sqlalchemy_test_suite/__init__.py b/tests/sqlalchemy_test_suite/__init__.py index 0fc6abfa..ef416f64 100644 --- a/tests/sqlalchemy_test_suite/__init__.py +++ b/tests/sqlalchemy_test_suite/__init__.py @@ -1,3 +1,3 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index 50318d2a..f0464c7d 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import sqlalchemy.testing.config @@ -15,6 +15,7 @@ import snowflake.connector from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.compat import IS_VERSION_20 from ..conftest import get_db_parameters from ..util import random_string @@ -25,6 +26,12 @@ TEST_SCHEMA_2 = f"{TEST_SCHEMA}_2" +if IS_VERSION_20: + collect_ignore_glob = ["test_suite.py"] +else: + collect_ignore_glob = ["test_suite_20.py"] + + # patch sqlalchemy.testing.config.Confi.__init__ for schema name randomization # same schema name would result in conflict as we're running tests in parallel in the CI def config_patched__init__(self, db, db_opts, options, file_config): diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index fea83492..643d1559 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import pytest from sqlalchemy import Integer, testing @@ -69,6 +69,10 @@ def test_empty_insert(self, connection): def test_empty_insert_multiple(self, connection): pass + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + # 2. Patched Tests diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py new file mode 100644 index 00000000..1f79c4e9 --- /dev/null +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Integer, testing +from sqlalchemy.schema import Column, Sequence, Table +from sqlalchemy.testing import config +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, +) +from sqlalchemy.testing.suite import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite import DateTimeHistoricTest as _DateTimeHistoricTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest +from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest +from sqlalchemy.testing.suite import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest +from sqlalchemy.testing.suite import SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest +from sqlalchemy.testing.suite import TimeMicrosecondsTest as _TimeMicrosecondsTest +from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest +from sqlalchemy.testing.suite import * # noqa + +# 1. Unsupported by snowflake db + +del ComponentReflectionTest # require indexes not supported by snowflake +del HasIndexTest # require indexes not supported by snowflake +del QuotedNameArgumentTest # require indexes not supported by snowflake + + +class LongNameBlowoutTest(_LongNameBlowoutTest): + # The combination ("ix",) is removed due to Snowflake not supporting indexes + def ix(self, metadata, connection): + pytest.skip("ix required index feature not supported by Snowflake") + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_bound_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset_zero(self, connection): + pass + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "Snowflake does not support inserting empty values, the value may be a literal or an expression." + ) + def test_empty_insert(self, connection): + pass + + @pytest.mark.skip( + "Snowflake does not support inserting empty values, The value may be a literal or an expression." + ) + def test_empty_insert_multiple(self, connection): + pass + + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + + +# road to 2.0 +class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer_bound(self, connection): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer(self, connection, left, right, expected): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + +class TimeMicrosecondsTest(_TimeMicrosecondsTest): + def __init__(self): + super().__init__() + + +class DateTimeHistoricTest(_DateTimeHistoricTest): + def __init__(self): + super().__init__() + + +# 2. Patched Tests + + +class HasSequenceTest(_HasSequenceTest): + # Override the define_tables method as snowflake does not support 'nomaxvalue'/'nominvalue' + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + # Replace Sequence("other_seq") creation as in the original test suite, + # the Sequence created with 'nomaxvalue' and 'nominvalue' + # which snowflake does not support: + # Sequence("other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence("user_id_seq", schema=config.test_schema, metadata=metadata) + Sequence("schema_seq", schema=config.test_schema, metadata=metadata) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @testing.requires.regexp_match + @testing.combinations( + ("a.cde.*", {1, 5, 6, 9}), + ("abc.*", {1, 5, 6, 9, 10}), + ("^abc.*", {1, 5, 6, 9, 10}), + (".*9cde.*", {8}), + ("^a.*", set(range(1, 11))), + (".*(b|c).*", set(range(1, 11))), + ("^(b|c).*", set()), + ) + def test_regexp_match(self, text, expected): + super().test_regexp_match(text, expected) + + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde.*"), {2, 3, 4, 7, 8, 10}) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + # snowflake returns a row with numbers of rows updated and number of multi-joined rows updated + assert r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # snowflake returns a row with number of rows deleted + assert r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_fk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_fk_column_order() + + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_pk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_pk_column_order() + + +class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): + super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 43820623..40207b41 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,11 +1,13 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Integer, String, and_, select +from sqlalchemy import Integer, String, and_, func, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table -from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing.assertions import AssertsCompiledSQL + +from snowflake.sqlalchemy import snowdialect table1 = table( "table1", column("id", Integer), column("name", String), column("value", Integer) @@ -23,6 +25,14 @@ class TestSnowflakeCompiler(AssertsCompiledSQL): __dialect__ = "snowflake" + def test_now_func(self): + statement = select(func.now()) + self.assert_compile( + statement, + "SELECT CURRENT_TIMESTAMP AS now_1", + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile( @@ -99,3 +109,14 @@ def test_quoted_name_label(engine_testaccount): sel_from_tbl = select(col).group_by(col).select_from(table("abc")) compiled_result = sel_from_tbl.compile() assert str(compiled_result) == t["output"] + + +def test_outer_lateral_join(): + col = column("colname").label("label") + col2 = column("colname2").label("label2") + lateral_table = func.flatten(func.PARSE_JSON(col2), outer=True).lateral() + stmt = select(col).select_from(table("abc")).join(lateral_table).group_by(col) + assert ( + str(stmt.compile(dialect=snowdialect.dialect())) + == "SELECT colname AS label \nFROM abc JOIN LATERAL flatten(PARSE_JSON(colname2)) AS anon_1 GROUP BY colname" + ) diff --git a/tests/test_copy.py b/tests/test_copy.py index 0a7ab718..e0752d4f 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import pytest diff --git a/tests/test_core.py b/tests/test_core.py index 68a2a12e..980db1d2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import decimal import json @@ -28,14 +28,17 @@ String, Table, UniqueConstraint, + create_engine, dialects, + insert, inspect, text, ) -from sqlalchemy.exc import DBAPIError -from sqlalchemy.pool import NullPool +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select +import snowflake.connector.errors +import snowflake.sqlalchemy.snowdialect from snowflake.connector import Error, ProgrammingError, connect from snowflake.sqlalchemy import URL, MergeInto, dialect from snowflake.sqlalchemy._constants import ( @@ -44,8 +47,7 @@ ) from snowflake.sqlalchemy.snowdialect import SnowflakeDialect -from .conftest import create_engine_with_future_flag as create_engine -from .conftest import get_engine +from .conftest import get_engine, url_factory from .parameters import CONNECTION_PARAMETERS from .util import ischema_names_baseline, random_string @@ -99,35 +101,22 @@ def _create_users_addresses_tables_without_sequence(engine_testaccount, metadata return users, addresses -def verify_engine_connection(engine, verify_app_name): +def verify_engine_connection(engine): with engine.connect() as conn: results = conn.execute(text("select current_version()")).fetchone() - if verify_app_name: - assert conn.connection.driver_connection.application == APPLICATION_NAME - assert ( - conn.connection.driver_connection._internal_application_name - == APPLICATION_NAME - ) - assert ( - conn.connection.driver_connection._internal_application_version - == SNOWFLAKE_SQLALCHEMY_VERSION - ) + assert conn.connection.driver_connection.application == APPLICATION_NAME + assert ( + conn.connection.driver_connection._internal_application_name + == APPLICATION_NAME + ) + assert ( + conn.connection.driver_connection._internal_application_version + == SNOWFLAKE_SQLALCHEMY_VERSION + ) assert results is not None -@pytest.mark.parametrize( - "verify_app_name", - [ - False, - pytest.param( - True, - marks=pytest.mark.xfail( - reason="Pending backend service to recognize SnowflakeSQLAlchemy as a valid client app id" - ), - ), - ], -) -def test_connect_args(verify_app_name): +def test_connect_args(): """ Tests connect string @@ -148,7 +137,7 @@ def test_connect_args(verify_app_name): ) ) try: - verify_engine_connection(engine, verify_app_name) + verify_engine_connection(engine) finally: engine.dispose() @@ -163,7 +152,7 @@ def test_connect_args(verify_app_name): ) ) try: - verify_engine_connection(engine, verify_app_name) + verify_engine_connection(engine) finally: engine.dispose() @@ -179,7 +168,39 @@ def test_connect_args(verify_app_name): ) ) try: - verify_engine_connection(engine, verify_app_name) + verify_engine_connection(engine) + finally: + engine.dispose() + + +def test_boolean_query_argument_parsing(): + engine = create_engine( + URL( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + account=CONNECTION_PARAMETERS["account"], + host=CONNECTION_PARAMETERS["host"], + port=CONNECTION_PARAMETERS["port"], + protocol=CONNECTION_PARAMETERS["protocol"], + validate_default_parameters=True, + ) + ) + try: + verify_engine_connection(engine) + connection = engine.raw_connection() + assert connection.validate_default_parameters is True + finally: + connection.close() + engine.dispose() + + +def test_create_dialect(): + """ + Tests getting only dialect object through create_engine + """ + engine = create_engine("snowflake://") + try: + assert engine.dialect finally: engine.dispose() @@ -385,16 +406,6 @@ def test_insert_tables(engine_testaccount): str(users.join(addresses)) == "users JOIN addresses ON " "users.id = addresses.user_id" ) - assert ( - str( - users.join( - addresses, - addresses.c.email_address.like(users.c.name + "%"), - ) - ) - == "users JOIN addresses " - "ON addresses.email_address LIKE users.name || :name_1" - ) s = select(users.c.fullname).select_from( users.join( @@ -417,6 +428,15 @@ def test_insert_tables(engine_testaccount): users.drop(engine_testaccount) +def test_table_does_not_exist(engine_testaccount): + """ + Tests Correct Exception Thrown When Table Does Not Exist + """ + meta = MetaData() + with pytest.raises(NoSuchTableError): + Table("does_not_exist", meta, autoload_with=engine_testaccount) + + @pytest.mark.skip( """ Reflection is not implemented yet. @@ -440,9 +460,7 @@ def test_reflextion(engine_testaccount): ) try: meta = MetaData() - user_reflected = Table( - "user", meta, autoload=True, autoload_with=engine_testaccount - ) + user_reflected = Table("user", meta, autoload_with=engine_testaccount) assert user_reflected.c == ["user.id", "user.name", "user.fullname"] finally: conn.execute("DROP TABLE IF EXISTS user") @@ -484,19 +502,20 @@ def test_inspect_column(engine_testaccount): users.drop(engine_testaccount) -def test_get_indexes(engine_testaccount): +def test_get_indexes(engine_testaccount, db_parameters): """ Tests get indexes - NOTE: Snowflake doesn't support indexes + NOTE: Only Snowflake Hybrid Tables support indexes """ + schema = db_parameters["schema"] metadata = MetaData() users, addresses = _create_users_addresses_tables_without_sequence( engine_testaccount, metadata ) try: inspector = inspect(engine_testaccount) - assert inspector.get_indexes("users") == [] + assert inspector.get_indexes("users", schema) == [] finally: addresses.drop(engine_testaccount) @@ -906,37 +925,6 @@ class Appointment(Base): assert str(t.columns["real_data"].type) == "FLOAT" -def _get_engine_with_columm_metadata_cache( - db_parameters, user=None, password=None, account=None -): - """ - Creates a connection with column metadata cache - """ - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - cache_column_metadata=True, - ), - poolclass=NullPool, - ) - - return engine - - def test_many_table_column_metadta(db_parameters): """ Get dozens of table metadata with column metadata cache. @@ -944,7 +932,9 @@ def test_many_table_column_metadta(db_parameters): cache_column_metadata=True will cache all column metadata for all tables in the schema. """ - engine = _get_engine_with_columm_metadata_cache(db_parameters) + url = url_factory(cache_column_metadata=True) + engine = get_engine(url) + RE_SUFFIX_NUM = re.compile(r".*(\d+)$") metadata = MetaData() total_objects = 10 @@ -1070,28 +1060,16 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.skip(reason="Testaccount is not available, it returns 404 error.") +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1101,13 +1079,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect(): @@ -1313,7 +1291,8 @@ def test_comment_sqlalchemy(db_parameters, engine_testaccount, on_public_ci): column_comment1 = random_string(10, choices=string.ascii_uppercase) table_comment2 = random_string(10, choices=string.ascii_uppercase) column_comment2 = random_string(10, choices=string.ascii_uppercase) - engine2, _ = get_engine(schema=new_schema) + + engine2 = get_engine(url_factory(schema=new_schema)) con2 = None if not on_public_ci: con2 = engine2.connect() @@ -1394,47 +1373,51 @@ def test_special_schema_character(db_parameters, on_public_ci): def test_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ metadata = MetaData() users = Table( "users", metadata, - Column("uid", Integer, Sequence("id_seq"), primary_key=True), + Column("uid", Integer, Sequence("id_seq", order=True), primary_key=True), Column("name", String(39)), ) try: - users.create(engine_testaccount) - - with engine_testaccount.connect() as connection: - with connection.begin(): - connection.execute(users.insert(), [{"name": "sf1"}]) - assert connection.execute(select(users)).fetchall() == [(1, "sf1")] - connection.execute(users.insert(), [{"name": "sf2"}, {"name": "sf3"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - ] - connection.execute(users.insert(), {"name": "sf4"}) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - ] - - seq = Sequence("id_seq") - nextid = connection.execute(seq) - connection.execute(users.insert(), [{"uid": nextid, "name": "sf5"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - (5, "sf5"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as connection: + connection.execute(insert(users), [{"name": "sf1"}]) + assert connection.execute(select(users)).fetchall() == [(1, "sf1")] + connection.execute(insert(users), [{"name": "sf2"}, {"name": "sf3"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + ] + connection.execute(insert(users), {"name": "sf4"}) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + ] + + seq = Sequence("id_seq") + nextid = connection.execute(seq) + connection.execute(insert(users), [{"uid": nextid, "name": "sf5"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + (5, "sf5"), + ] finally: - users.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) @pytest.mark.skip( @@ -1529,11 +1512,16 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): metadata.create_all(engine_testaccount) inspector = inspect(engine_testaccount) # Do test - original_execute = inspector.bind.execute + connection = inspector.bind.connect() + original_execute = connection.execute + + too_many_columns_was_raised = False def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command: + if "_get_schema_columns" in command.text: # Creating exception exactly how SQLAlchemy does + nonlocal too_many_columns_was_raised + too_many_columns_was_raised = True raise DBAPIError.instance( """ SELECT /* sqlalchemy:_get_schema_columns */ @@ -1565,9 +1553,12 @@ def mock_helper(command, *args, **kwargs): else: return original_execute(command, *args, **kwargs) - with patch.object(inspector.bind, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) + with patch.object(engine_testaccount, "connect") as conn: + conn.return_value = connection + with patch.object(connection, "execute", side_effect=mock_helper): + column_metadata = inspector.get_columns("users", db_parameters["schema"]) assert len(column_metadata) == 4 + assert too_many_columns_was_raised # Clean up metadata.drop_all(engine_testaccount) @@ -1603,14 +1594,13 @@ def test_column_type_schema(engine_testaccount): C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC, C9 DECIMAL, C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, - C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY, + C32 GEOMETRY ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns assert ( len(columns) == len(ischema_names_baseline) - 1 @@ -1626,13 +1616,12 @@ def test_result_type_and_value(engine_testaccount): C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC(12,3), C9 DECIMAL(12,3), C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, - C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY, + C32 GEOMETRY ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) current_date = date.today() current_utctime = datetime.utcnow() current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( @@ -1652,6 +1641,8 @@ def test_result_type_and_value(engine_testaccount): CHAR_VALUE = "A" GEOGRAPHY_VALUE = "POINT(-122.35 37.55)" GEOGRAPHY_RESULT_VALUE = '{"coordinates": [-122.35,37.55],"type": "Point"}' + GEOMETRY_VALUE = "POINT(-94.58473 39.08985)" + GEOMETRY_RESULT_VALUE = '{"coordinates": [-94.58473,39.08985],"type": "Point"}' ins = table_reflected.insert().values( c1=MAX_INT_VALUE, # BIGINT @@ -1685,6 +1676,7 @@ def test_result_type_and_value(engine_testaccount): c29=None, # OBJECT, currently snowflake-sqlalchemy/connector does not support binding variant c30=None, # ARRAY, currently snowflake-sqlalchemy/connector does not support binding variant c31=GEOGRAPHY_VALUE, # GEOGRAPHY + c32=GEOMETRY_VALUE, # GEOMETRY ) conn.execute(ins) @@ -1723,6 +1715,7 @@ def test_result_type_and_value(engine_testaccount): and result[28] is None and result[29] is None and json.loads(result[30]) == json.loads(GEOGRAPHY_RESULT_VALUE) + and json.loads(result[31]) == json.loads(GEOMETRY_RESULT_VALUE) ) sql = f""" @@ -1785,3 +1778,90 @@ def test_normalize_and_denormalize_empty_string_column_name(engine_testaccount): and columns[0]["name"] == "col" and columns[1]["name"] == "" ) + + +def test_snowflake_sqlalchemy_as_valid_client_type(): + engine = create_engine( + URL( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + account=CONNECTION_PARAMETERS["account"], + host=CONNECTION_PARAMETERS["host"], + port=CONNECTION_PARAMETERS["port"], + protocol=CONNECTION_PARAMETERS["protocol"], + ), + connect_args={"internal_application_name": "UnknownClient"}, + ) + with engine.connect() as conn: + with pytest.raises(snowflake.connector.errors.NotSupportedError): + conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() + + engine = create_engine( + URL( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + account=CONNECTION_PARAMETERS["account"], + host=CONNECTION_PARAMETERS["host"], + port=CONNECTION_PARAMETERS["port"], + protocol=CONNECTION_PARAMETERS["protocol"], + ) + ) + with engine.connect() as conn: + conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() + + try: + snowflake.sqlalchemy.snowdialect._ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = False + origin_app = snowflake.connector.connection.DEFAULT_CONFIGURATION["application"] + origin_internal_app_name = snowflake.connector.connection.DEFAULT_CONFIGURATION[ + "internal_application_name" + ] + origin_internal_app_version = ( + snowflake.connector.connection.DEFAULT_CONFIGURATION[ + "internal_application_version" + ] + ) + snowflake.connector.connection.DEFAULT_CONFIGURATION["application"] = ( + None, + (type(None), str), + ) + snowflake.connector.connection.DEFAULT_CONFIGURATION[ + "internal_application_name" + ] = ( + "PythonConnector", + (type(None), str), + ) + snowflake.connector.connection.DEFAULT_CONFIGURATION[ + "internal_application_version" + ] = ( + "3.0.0", + (type(None), str), + ) + engine = create_engine( + URL( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + account=CONNECTION_PARAMETERS["account"], + host=CONNECTION_PARAMETERS["host"], + port=CONNECTION_PARAMETERS["port"], + protocol=CONNECTION_PARAMETERS["protocol"], + ) + ) + with engine.connect() as conn: + conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() + assert ( + conn.connection.driver_connection._internal_application_name + == "PythonConnector" + ) + assert ( + conn.connection.driver_connection._internal_application_version + == "3.0.0" + ) + finally: + snowflake.sqlalchemy.snowdialect._ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True + snowflake.connector.connection.DEFAULT_CONFIGURATION["application"] = origin_app + snowflake.connector.connection.DEFAULT_CONFIGURATION[ + "internal_application_name" + ] = origin_internal_app_name + snowflake.connector.connection.DEFAULT_CONFIGURATION[ + "internal_application_version" + ] = origin_internal_app_version diff --git a/tests/test_create.py b/tests/test_create.py index 2438b077..0b8b48fa 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from snowflake.sqlalchemy import ( @@ -54,6 +54,16 @@ def test_create_stage(sql_compiler): ) assert actual == expected + create_stage = CreateStage(stage=stage, container=container, temporary=True) + # validate that the resulting SQL is as expected + actual = sql_compiler(create_stage) + expected = ( + "CREATE TEMPORARY STAGE MY_DB.MY_SCHEMA.AZURE_STAGE " + "URL='azure://myaccount.blob.core.windows.net/my-container' " + "CREDENTIALS=(AZURE_SAS_TOKEN='saas_token')" + ) + assert actual == expected + def test_create_csv_format(sql_compiler): """ diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py new file mode 100644 index 00000000..2a1e1cb5 --- /dev/null +++ b/tests/test_custom_functions.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest +from sqlalchemy import func + +from snowflake.sqlalchemy import snowdialect + + +def test_flatten_does_not_render_params(): + """This behavior is for backward compatibility. + + In previous version params were not rendered. + In future this behavior will change. + """ + flat = func.flatten("[1, 2]", outer=True) + res = flat.compile(dialect=snowdialect.dialect()) + + assert str(res) == "flatten(%(flatten_1)s)" + + +def test_flatten_emits_warning(): + expected_warning = "For backward compatibility params are not rendered." + with pytest.warns(DeprecationWarning, match=expected_warning): + func.flatten().compile(dialect=snowdialect.dialect()) diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index b92b5e46..3961a5d3 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -1,8 +1,11 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from snowflake.sqlalchemy import custom_types +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from snowflake.sqlalchemy import TEXT, custom_types def test_string_conversions(): @@ -15,6 +18,7 @@ def test_string_conversions(): "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "GEOGRAPHY", + "GEOMETRY", ] sf_types = [ "TEXT", @@ -33,3 +37,31 @@ def test_string_conversions(): sample = getattr(custom_types, type_)() if type_ in sf_custom_types: assert type_ == str(sample) + + +@pytest.mark.feature_max_lob_size +def test_create_table_with_text_type(engine_testaccount): + metadata = MetaData() + table_name = "test_max_lob_size_0" + test_max_lob_size = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("full_name", TEXT(), server_default=text("id::varchar")), + ) + + metadata.create_all(engine_testaccount) + try: + assert test_max_lob_size is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{table_name}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + test_max_lob_size.drop(engine_testaccount) diff --git a/tests/test_geography.py b/tests/test_geography.py index fd738cf0..7168d2a3 100644 --- a/tests/test_geography.py +++ b/tests/test_geography.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from json import loads diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 00000000..742b518e --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from json import loads + +from sqlalchemy import Column, Integer, MetaData, Table +from sqlalchemy.sql import select + +from snowflake.sqlalchemy import GEOMETRY + + +def test_create_table_geometry_datatypes(engine_testaccount): + """ + Create table including geometry data types + """ + metadata = MetaData() + table_name = "test_geometry0" + test_geometry = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + metadata.create_all(engine_testaccount) + try: + assert test_geometry is not None + finally: + test_geometry.drop(engine_testaccount) + + +def test_inspect_geometry_datatypes(engine_testaccount): + """ + Create table including geometry data types + """ + metadata = MetaData() + table_name = "test_geometry0" + test_geometry = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom1", GEOMETRY), + Column("geom2", GEOMETRY), + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + test_point = "POINT(-94.58473 39.08985)" + test_point1 = '{"coordinates": [-94.58473, 39.08985],"type": "Point"}' + + ins = test_geometry.insert().values( + id=1, geom1=test_point, geom2=test_point1 + ) + + with conn.begin(): + results = conn.execute(ins) + results.close() + + s = select(test_geometry) + results = conn.execute(s) + rows = results.fetchone() + results.close() + assert rows[0] == 1 + assert rows[1] == rows[2] + assert loads(rows[2]) == loads(test_point1) + finally: + test_geometry.drop(engine_testaccount) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py new file mode 100644 index 00000000..09f5cfe7 --- /dev/null +++ b/tests/test_index_reflection.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData +from sqlalchemy.engine import reflection + + +@pytest.mark.aws +def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): + metadata = MetaData() + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + schema = db_parameters["schema"] + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + insp = reflection.Inspector.from_engine(engine_testaccount) + + try: + with engine_testaccount.connect(): + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + indexes = insp.get_indexes(table_name, schema) + assert len(indexes) == 1 + assert indexes[0].get("name") == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_multivalues_insert.py b/tests/test_multivalues_insert.py index 571b0844..5a91d3da 100644 --- a/tests/test_multivalues_insert.py +++ b/tests/test_multivalues_insert.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from sqlalchemy import Integer, Sequence, String diff --git a/tests/test_orm.py b/tests/test_orm.py index 7bf697b7..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -1,15 +1,30 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import enum +import logging import pytest -from sqlalchemy import Column, Enum, ForeignKey, Integer, Sequence, String, text +from sqlalchemy import ( + TEXT, + Column, + Enum, + ForeignKey, + Integer, + Sequence, + String, + exc, + func, + select, + text, +) from sqlalchemy.orm import Session, declarative_base, relationship +from snowflake.sqlalchemy import HybridTable + -def test_basic_orm(engine_testaccount, run_v20_sqlalchemy): +def test_basic_orm(engine_testaccount): """ Tests declarative """ @@ -35,7 +50,6 @@ def __repr__(self): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) our_user = session.query(User).filter_by(name="ed").first() @@ -45,14 +59,15 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount, run_v20_sqlalchemy): +def test_orm_one_to_many_relationship(engine_testaccount, db_parameters): """ Tests One to Many relationship """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -62,13 +77,13 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", backref="addresses") + user = relationship(User, backref="addresses") def __repr__(self): return f"" @@ -86,7 +101,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -113,14 +127,143 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_delete_cascade(engine_testaccount, run_v20_sqlalchemy): +@pytest.mark.aws +def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot): + """ + Tests One to Many relationship + """ + Base = declarative_base() + + class User(Base): + __tablename__ = "hb_tbl_user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = "hb_tbl_address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, backref="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + session.delete(jack) + + with pytest.raises(exc.ProgrammingError) as exc_info: + session.query(Address).all() + + assert exc_info.value == snapshot, "Iceberg Table enforce FK constraint" + + finally: + Base.metadata.drop_all(engine_testaccount) + + +def test_delete_cascade(engine_testaccount): """ Test delete cascade """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + addresses = relationship( + "Address", back_populates="user", cascade="all, delete, delete-orphan" + ) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = prefix + "address" + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, back_populates="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + got_jack = session.query(User).first() + assert got_jack == jack + + session.delete(jack) + got_addresses = session.query(Address).all() + assert len(got_addresses) == 0, "no address record" + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_delete_cascade_hybrid_table(engine_testaccount): + """ + Test delete cascade + """ + Base = declarative_base() + prefix = "hb_tbl_" + + class User(Base): + __tablename__ = prefix + "user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -134,13 +277,17 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", back_populates="addresses") + user = relationship(User, back_populates="addresses") def __repr__(self): return f"" @@ -158,7 +305,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -178,7 +324,7 @@ def __repr__(self): WIP """, ) -def test_orm_query(engine_testaccount, run_v20_sqlalchemy): +def test_orm_query(engine_testaccount): """ Tests ORM query """ @@ -199,7 +345,6 @@ def __repr__(self): # TODO: insert rows session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy # TODO: query.all() for name, fullname in session.query(User.name, User.fullname): @@ -209,7 +354,7 @@ def __repr__(self): # MultipleResultsFound if not one result -def test_schema_including_db(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_db(engine_testaccount, db_parameters): """ Test schema parameter including database separated by a dot. """ @@ -232,7 +377,6 @@ class User(Base): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) ret_user = session.query(User.id, User.name).first() @@ -244,7 +388,7 @@ class User(Base): Base.metadata.drop_all(engine_testaccount) -def test_schema_including_dot(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_dot(engine_testaccount, db_parameters): """ Tests pseudo schema name including dot. """ @@ -265,7 +409,6 @@ class User(Base): fullname = Column(String) session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy query = session.query(User.id) assert str(query).startswith( 'SELECT {db}."{schema}.{schema}".{db}.users.id'.format( @@ -274,9 +417,7 @@ class User(Base): ) -def test_schema_translate_map( - engine_testaccount, db_parameters, sql_compiler, run_v20_sqlalchemy -): +def test_schema_translate_map(engine_testaccount, db_parameters): """ Test schema translate map execution option works replaces schema correctly """ @@ -299,7 +440,6 @@ class User(Base): schema_translate_map={schema_map: db_parameters["schema"]} ) as con: session = Session(bind=con) - session.future = run_v20_sqlalchemy with con.begin(): Base.metadata.create_all(con) try: @@ -326,3 +466,119 @@ class User(Base): assert user.fullname == "test_user" finally: Base.metadata.drop_all(con) + + +def test_outer_lateral_join(engine_testaccount, caplog): + Base = declarative_base() + + class Employee(Base): + __tablename__ = "employees" + + employee_id = Column(Integer, primary_key=True) + last_name = Column(String) + + class Department(Base): + __tablename__ = "departments" + + department_id = Column(Integer, primary_key=True) + name = Column(String) + + Base.metadata.create_all(engine_testaccount) + session = Session(bind=engine_testaccount) + e1 = Employee(employee_id=101, last_name="Richards") + d1 = Department(department_id=1, name="Engineering") + session.add_all([e1, d1]) + session.commit() + + sub = select(Department).lateral() + query = ( + select(Employee.employee_id, Department.department_id) + .select_from(Employee) + .outerjoin(sub) + ) + compiled_stmts = ( + # v1.x + "SELECT employees.employee_id, departments.department_id " + "FROM departments, employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1", + # v2.x + "SELECT employees.employee_id, departments.department_id " + "FROM employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1, departments", + ) + compiled_stmt = str(query.compile(engine_testaccount)).replace("\n", "") + assert compiled_stmt in compiled_stmts + + with caplog.at_level(logging.DEBUG): + assert [res for res in session.execute(query)] + assert ( + "SELECT employees.employee_id, departments.department_id FROM departments" + in caplog.text + ) or ( + "SELECT employees.employee_id, departments.department_id FROM employees" + in caplog.text + ) + + +def test_lateral_join_without_condition(engine_testaccount, caplog): + Base = declarative_base() + + class Employee(Base): + __tablename__ = "Employee" + + pkey = Column(String, primary_key=True) + uid = Column(Integer) + content = Column(String) + + Base.metadata.create_all(engine_testaccount) + lateral_table = func.flatten( + func.PARSE_JSON(Employee.content), outer=False + ).lateral() + query = ( + select( + Employee.uid, + ) + .select_from(Employee) + .join(lateral_table) + .where(Employee.uid == "123") + ) + session = Session(bind=engine_testaccount) + with caplog.at_level(logging.DEBUG): + session.execute(query) + assert ( + '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' + in caplog.text + ) + + +@pytest.mark.feature_max_lob_size +def test_basic_table_with_large_lob_size_in_memory(engine_testaccount, sql_compiler): + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + full_name = Column(TEXT(), server_default=text("id::varchar")) + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + assert User.__table__ is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{User.__tablename__}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index f2b3f777..2a6b9f1b 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import operator @@ -24,13 +24,10 @@ select, text, ) -from sqlalchemy.pool import NullPool from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer -from snowflake.sqlalchemy import URL - -from .conftest import create_engine_with_future_flag as create_engine +from snowflake.sqlalchemy.compat import IS_VERSION_20 def _create_users_addresses_tables(engine_testaccount, metadata): @@ -113,40 +110,8 @@ def test_a_simple_read_sql(engine_testaccount): users.drop(engine_testaccount) -def get_engine_with_numpy(db_parameters, user=None, password=None, account=None): - """ - Creates a connection using the parameters defined in JDBC connect string - """ - from snowflake.sqlalchemy import URL - - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - numpy=True, - ), - poolclass=NullPool, - ) - - return engine - - -def test_numpy_datatypes(db_parameters): - engine = get_engine_with_numpy(db_parameters) - with engine.connect() as conn: +def test_numpy_datatypes(engine_testaccount_with_numpy, db_parameters): + with engine_testaccount_with_numpy.connect() as conn: try: specific_date = np.datetime64("2016-03-04T12:03:05.123456789") with conn.begin(): @@ -163,12 +128,11 @@ def test_numpy_datatypes(db_parameters): assert df.c1.values[0] == specific_date finally: conn.exec_driver_sql(f"DROP TABLE IF EXISTS {db_parameters['name']}") - engine.dispose() + engine_testaccount_with_numpy.dispose() -def test_to_sql(db_parameters): - engine = get_engine_with_numpy(db_parameters) - with engine.connect() as conn: +def test_to_sql(engine_testaccount_with_numpy, db_parameters): + with engine_testaccount_with_numpy.connect() as conn: total_rows = 10000 conn.exec_driver_sql( textwrap.dedent( @@ -182,7 +146,13 @@ def test_to_sql(db_parameters): conn.exec_driver_sql("create or replace table dst(c1 float)") tbl = pd.read_sql_query(text("select * from src"), conn) - tbl.to_sql("dst", engine, if_exists="append", chunksize=1000, index=False) + tbl.to_sql( + "dst", + engine_testaccount_with_numpy, + if_exists="append", + chunksize=1000, + index=False, + ) df = pd.read_sql_query(text("select count(*) as cnt from dst"), conn) assert df.cnt.values[0] == total_rows @@ -199,41 +169,15 @@ def test_no_indexes(engine_testaccount, db_parameters): con=conn, if_exists="replace", ) - assert str(exc.value) == "Snowflake does not support indexes" + assert str(exc.value) == "Only Snowflake Hybrid Tables supports indexes" -def test_timezone(db_parameters): +def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy): test_table_name = "".join([random.choice(string.ascii_letters) for _ in range(5)]) - sa_engine = create_engine( - URL( - account=db_parameters["account"], - password=db_parameters["password"], - database=db_parameters["database"], - port=db_parameters["port"], - user=db_parameters["user"], - host=db_parameters["host"], - protocol=db_parameters["protocol"], - schema=db_parameters["schema"], - numpy=True, - ) - ) - - sa_engine2_raw_conn = create_engine( - URL( - account=db_parameters["account"], - password=db_parameters["password"], - database=db_parameters["database"], - port=db_parameters["port"], - user=db_parameters["user"], - host=db_parameters["host"], - protocol=db_parameters["protocol"], - schema=db_parameters["schema"], - timezone="America/Los_Angeles", - numpy="", - ) - ).raw_connection() + sa_engine = engine_testaccount_with_numpy + sa_engine2_raw_conn = engine_testaccount.raw_connection() with sa_engine.connect() as conn: @@ -297,8 +241,8 @@ def test_timezone(db_parameters): conn.exec_driver_sql(f"DROP TABLE {test_table_name};") -def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_pandas_writeback(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -316,18 +260,13 @@ def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): sf_connector_version_df = pd.DataFrame( sf_connector_version_data, columns=["NAME", "NEWEST_VERSION"] ) - sf_connector_version_df.to_sql(table_name, conn, index=False, method=pd_writer) - - assert ( - ( - pd.read_sql_table(table_name, conn).rename( - columns={"newest_version": "NEWEST_VERSION", "name": "NAME"} - ) - == sf_connector_version_df - ) - .all() - .all() + sf_connector_version_df.to_sql( + table_name, conn, index=False, method=pd_writer, if_exists="replace" + ) + results = pd.read_sql_table(table_name, conn).rename( + columns={"newest_version": "NEWEST_VERSION", "name": "NAME"} ) + assert results.equals(sf_connector_version_df) @pytest.mark.parametrize("chunk_size", [5, 1]) @@ -414,8 +353,8 @@ def test_pandas_invalid_make_pd_writer(engine_testaccount): ) -def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_percent_signs(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -438,7 +377,7 @@ def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): not_like_sql = f"select * from {table_name} where c2 not like '%b%'" like_sql = f"select * from {table_name} where c2 like '%b%'" calculate_sql = "SELECT 1600 % 400 AS a, 1599 % 400 as b" - if run_v20_sqlalchemy: + if IS_VERSION_20: not_like_sql = sqlalchemy.text(not_like_sql) like_sql = sqlalchemy.text(like_sql) calculate_sql = sqlalchemy.text(calculate_sql) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index 68200034..3761181a 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -1,80 +1,39 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import os import sys +import pandas as pd import pytest from sqlalchemy import text -from snowflake.sqlalchemy import URL - -from .conftest import create_engine_with_future_flag as create_engine - THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def _get_engine_with_qmark(db_parameters, user=None, password=None, account=None): - """ - Creates a connection with column metadata cache - """ - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - ) - ) - return engine - - -def test_qmark_bulk_insert(db_parameters, run_v20_sqlalchemy): +def test_qmark_bulk_insert(engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ - if run_v20_sqlalchemy and sys.version_info < (3, 8): + if sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) - import snowflake.connector - - snowflake.connector.paramstyle = "qmark" - - engine = _get_engine_with_qmark(db_parameters) - import pandas as pd - - with engine.connect() as con: - try: - with con.begin(): - con.exec_driver_sql( - """ - create or replace table src(c1 int, c2 string) as select seq8(), - randstr(100, random()) from table(generator(rowcount=>100000)) - """ + with engine_testaccount_with_qmark.connect() as con: + with con.begin(): + con.exec_driver_sql( + """ + create or replace table src(c1 int, c2 string) as select seq8(), + randstr(100, random()) from table(generator(rowcount=>100000)) + """ + ) + con.exec_driver_sql("create or replace table dst like src") + + for data in pd.read_sql_query( + text("select * from src"), con, chunksize=16000 + ): + data.to_sql( + "dst", con, if_exists="append", index=False, index_label=None ) - con.exec_driver_sql("create or replace table dst like src") - - for data in pd.read_sql_query( - text("select * from src"), con, chunksize=16000 - ): - data.to_sql( - "dst", con, if_exists="append", index=False, index_label=None - ) - - finally: - engine.dispose() - snowflake.connector.paramstyle = "pyformat" diff --git a/tests/test_quote.py b/tests/test_quote.py index 9e7b19fc..ca6f36dd 100644 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, inspect diff --git a/tests/test_semi_structured_datatypes.py b/tests/test_semi_structured_datatypes.py index 6de7e11e..c6c15bd4 100644 --- a/tests/test_semi_structured_datatypes.py +++ b/tests/test_semi_structured_datatypes.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import json diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 8c8a2c58..32fc390e 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,90 +1,163 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, select +from sqlalchemy import ( + Column, + Identity, + Integer, + MetaData, + Sequence, + String, + Table, + insert, + select, +) +from sqlalchemy.sql import text +from sqlalchemy.sql.ddl import CreateTable def test_table_with_sequence(engine_testaccount, db_parameters): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" test_sequence_name = f"{test_table_name}_id_seq" + metadata = MetaData() + sequence_table = Table( test_table_name, - MetaData(), - Column("id", Integer, Sequence(test_sequence_name), primary_key=True), + metadata, + Column( + "id", Integer, Sequence(test_sequence_name, order=True), primary_key=True + ), Column("data", String(39)), ) - sequence_table.create(engine_testaccount) - seq = Sequence(test_sequence_name) + + autoload_metadata = MetaData() + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(sequence_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(sequence_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - nextid = conn.execute(seq) - conn.execute( - autoload_sequence_table.insert(), - [{"id": nextid, "data": "test_insert_seq"}], - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - (5, "test_insert_seq"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as conn: + conn.execute(insert(sequence_table), ({"data": "test_insert_1"})) + result = conn.execute(select(sequence_table)).fetchall() + assert result == [(1, "test_insert_1")], result + + autoload_sequence_table = Table( + test_table_name, + autoload_metadata, + autoload_with=engine_testaccount, + ) + seq = Sequence(test_sequence_name, order=True) + + conn.execute( + insert(autoload_sequence_table), + ( + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ), + ) + conn.execute( + insert(autoload_sequence_table), + ({"data": "test_insert_2"},), + ) + + nextid = conn.execute(seq) + conn.execute( + insert(autoload_sequence_table), + ({"id": nextid, "data": "test_insert_seq"}), + ) + + result = conn.execute(select(sequence_table)).fetchall() + + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + (5, "test_insert_seq"), + ], result + finally: - sequence_table.drop(engine_testaccount) - seq.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) -def test_table_with_autoincrement(engine_testaccount, db_parameters): +def test_table_with_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" + metadata = MetaData() autoincrement_table = Table( test_table_name, - MetaData(), + metadata, Column("id", Integer, autoincrement=True, primary_key=True), Column("data", String(39)), ) - autoincrement_table.create(engine_testaccount) + + select_stmt = select(autoincrement_table).order_by("id") + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(autoincrement_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(autoincrement_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - ] + with engine_testaccount.begin() as conn: + conn.execute(text("ALTER SESSION SET NOORDER_SEQUENCE_AS_DEFAULT = FALSE")) + metadata.create_all(conn) + + conn.execute(insert(autoincrement_table), ({"data": "test_insert_1"})) + result = conn.execute(select_stmt).fetchall() + assert result == [(1, "test_insert_1")] + + autoload_sequence_table = Table( + test_table_name, MetaData(), autoload_with=engine_testaccount + ) + conn.execute( + insert(autoload_sequence_table), + [ + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ], + ) + conn.execute( + insert(autoload_sequence_table), + [{"data": "test_insert_2"}], + ) + result = conn.execute(select_stmt).fetchall() + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + ], result + finally: - autoincrement_table.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) + + +def test_table_with_identity(sql_compiler): + test_table_name = "identity" + metadata = MetaData() + identity_autoincrement_table = Table( + test_table_name, + metadata, + Column( + "id", Integer, Identity(start=1, increment=1, order=True), primary_key=True + ), + Column("identity_col_unordered", Integer, Identity(order=False)), + Column("identity_col", Integer, Identity()), + ) + create_table = CreateTable(identity_autoincrement_table) + actual = sql_compiler(create_table) + expected = ( + "CREATE TABLE identity (" + "\tid INTEGER NOT NULL IDENTITY(1,1) ORDER, " + "\tidentity_col_unordered INTEGER NOT NULL IDENTITY NOORDER, " + "\tidentity_col INTEGER NOT NULL IDENTITY, " + "\tPRIMARY KEY (id))" + ) + assert actual == expected diff --git a/tests/test_timestamp.py b/tests/test_timestamp.py index 27379e1b..72bc9173 100644 --- a/tests/test_timestamp.py +++ b/tests/test_timestamp.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from datetime import datetime diff --git a/tests/test_unit_core.py b/tests/test_unit_core.py index b778941c..27c7cf4a 100644 --- a/tests/test_unit_core.py +++ b/tests/test_unit_core.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from sqlalchemy.engine.url import URL diff --git a/tests/test_unit_cte.py b/tests/test_unit_cte.py index 87e733f6..d3421653 100644 --- a/tests/test_unit_cte.py +++ b/tests/test_unit_cte.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # diff --git a/tests/test_unit_types.py b/tests/test_unit_types.py index 5ac77b87..fc6d3c23 100644 --- a/tests/test_unit_types.py +++ b/tests/test_unit_types.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import snowflake.sqlalchemy diff --git a/tests/test_unit_url.py b/tests/test_unit_url.py index 3e2db8fc..d2ca4e47 100644 --- a/tests/test_unit_url.py +++ b/tests/test_unit_url.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import urllib.parse diff --git a/tests/util.py b/tests/util.py index 423b595f..db0b0c9c 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from __future__ import annotations @@ -28,6 +28,7 @@ from snowflake.sqlalchemy.custom_types import ( ARRAY, GEOGRAPHY, + GEOMETRY, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -70,6 +71,7 @@ "OBJECT": OBJECT, "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, } diff --git a/tox.ini b/tox.ini index 69f440b3..102e2273 100644 --- a/tox.ini +++ b/tox.ini @@ -1,9 +1,8 @@ [tox] min_version = 4.0.0 envlist = fix_lint, - py{37,38,39,310}{,-pandas}, + py{37,38,39,310,311}{,-pandas}, coverage, - connector_regression skip_missing_interpreters = true [testenv] @@ -17,6 +16,7 @@ external_wheels = py38-ci: dist/*.whl py39-ci: dist/.whl py310-ci: dist/.whl + py311-ci: dist/.whl deps = pip passenv = AWS_ACCESS_KEY_ID @@ -34,29 +34,17 @@ passenv = setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} SQLALCHEMY_WARN_20 = 1 - ci: SNOWFLAKE_PYTEST_OPTS = -vvv + ci: SNOWFLAKE_PYTEST_OPTS = -vvv --tb=long commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ --junitxml {toxworkdir}/junit_{envname}.xml \ + --ignore=tests/sqlalchemy_test_suite \ {posargs:tests} pytest {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" --cov-append \ --junitxml {toxworkdir}/junit_{envname}.xml \ {posargs:tests/sqlalchemy_test_suite} - pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - --cov "snowflake.sqlalchemy" --cov-append \ - --junitxml {toxworkdir}/junit_{envname}.xml \ - --run_v20_sqlalchemy \ - {posargs:tests} - -[testenv:connector_regression] -deps = pendulum -commands = pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - -m "not gcp and not azure" \ - {posargs:tests/connector_regression/test} [testenv:.pkg_external] deps = build @@ -78,22 +66,23 @@ commands = coverage combine coverage xml -o {toxworkdir}/coverage.xml coverage html -d {toxworkdir}/htmlcov ;diff-cover --compare-branch {env:DIFF_AGAINST:origin/main} {toxworkdir}/coverage.xml -depends = py37, py38, py39, py310 +depends = py37, py38, py39, py310, py311 [testenv:fix_lint] description = format the code base to adhere to our styles, and complain about what we cannot do automatically -basepython = python3.7 +basepython = python3.8 passenv = PROGRAMDATA deps = {[testenv]deps} + tomlkit pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite --ignore=tests/connector_regression +addopts = -ra --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers =