diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000000..9dd627b01abed --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,13 @@ +FROM rust:bookworm + +RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ + # Remove imagemagick due to https://security-tracker.debian.org/tracker/CVE-2019-10131 + && apt-get purge -y imagemagick imagemagick-6-common + +# Add protoc +# https://datafusion.apache.org/contributor-guide/getting_started.html#protoc-installation +RUN curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protoc-25.1-linux-x86_64.zip \ + && unzip protoc-25.1-linux-x86_64.zip -d $HOME/.local \ + && rm protoc-25.1-linux-x86_64.zip + +ENV PATH="$PATH:$HOME/.local/bin" \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000..1af22306ed8c9 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,16 @@ +{ + "build": { + "dockerfile": "./Dockerfile", + "context": "." + }, + "customizations": { + "vscode": { + "extensions": [ + "rust-lang.rust-analyzer" + ] + } + }, + "features": { + "ghcr.io/devcontainers/features/rust:1": "latest" + } +} diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 5578517ec3594..22d2f2187dd07 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -28,16 +28,18 @@ runs: - name: Install Build Dependencies shell: bash run: | - apt-get update - apt-get install -y protobuf-compiler + RETRY=("ci/scripts/retry" timeout 120) + "${RETRY[@]}" apt-get update + "${RETRY[@]}" apt-get install -y protobuf-compiler - name: Setup Rust toolchain shell: bash # rustfmt is needed for the substrait build script run: | + RETRY=("ci/scripts/retry" timeout 120) echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} - rustup default ${{ inputs.rust-version }} - rustup component add rustfmt + "${RETRY[@]}" rustup toolchain install ${{ inputs.rust-version }} + "${RETRY[@]}" rustup default ${{ inputs.rust-version }} + "${RETRY[@]}" rustup component add rustfmt - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions diff --git a/.github/actions/setup-macos-aarch64-builder/action.yaml b/.github/actions/setup-macos-aarch64-builder/action.yaml index c4e14906ed108..288799a284b01 100644 --- a/.github/actions/setup-macos-aarch64-builder/action.yaml +++ b/.github/actions/setup-macos-aarch64-builder/action.yaml @@ -30,8 +30,8 @@ runs: run: | mkdir -p $HOME/d/protoc cd $HOME/d/protoc - export PROTO_ZIP="protoc-21.4-osx-aarch_64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP + export PROTO_ZIP="protoc-29.1-osx-aarch_64.zip" + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v29.1/$PROTO_ZIP unzip $PROTO_ZIP echo "$HOME/d/protoc/bin" >> $GITHUB_PATH export PATH=$PATH:$HOME/d/protoc/bin @@ -43,5 +43,7 @@ runs: rustup toolchain install stable rustup default stable rustup component add rustfmt + - name: Setup rust cache + uses: Swatinem/rust-cache@v2 - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime diff --git a/.github/actions/setup-macos-builder/action.yaml b/.github/actions/setup-macos-builder/action.yaml index 02419f6179429..fffdab160b043 100644 --- a/.github/actions/setup-macos-builder/action.yaml +++ b/.github/actions/setup-macos-builder/action.yaml @@ -30,8 +30,8 @@ runs: run: | mkdir -p $HOME/d/protoc cd $HOME/d/protoc - export PROTO_ZIP="protoc-21.4-osx-x86_64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP + export PROTO_ZIP="protoc-29.1-osx-x86_64.zip" + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v29.1/$PROTO_ZIP unzip $PROTO_ZIP echo "$HOME/d/protoc/bin" >> $GITHUB_PATH export PATH=$PATH:$HOME/d/protoc/bin diff --git a/.github/actions/setup-windows-builder/action.yaml b/.github/actions/setup-windows-builder/action.yaml index 5e937358c7d74..a0304168c744e 100644 --- a/.github/actions/setup-windows-builder/action.yaml +++ b/.github/actions/setup-windows-builder/action.yaml @@ -30,8 +30,8 @@ runs: run: | mkdir -p $HOME/d/protoc cd $HOME/d/protoc - export PROTO_ZIP="protoc-21.4-win64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/$PROTO_ZIP + export PROTO_ZIP="protoc-29.1-win64.zip" + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v29.1/$PROTO_ZIP unzip $PROTO_ZIP export PATH=$PATH:$HOME/d/protoc/bin protoc.exe --version diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index ebc5bcf91c94b..f87215565bb53 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -42,6 +42,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 19af21ec910be..cf204b2cd6c12 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -23,18 +23,12 @@ concurrency: cancel-in-progress: true jobs: - rat: - name: Release Audit Tool (RAT) + license-header-check: runs-on: ubuntu-latest + name: Check License Header steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - - name: Audit licenses - run: ./dev/release/run-rat.sh . + - uses: actions/checkout@v4 + - uses: korandoru/hawkeye@v5 prettier: name: Use prettier to check formatting of documents diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 44ca5aaf4eda1..0b43339f57a61 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -26,7 +26,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.12" - name: Install dependencies run: | @@ -61,4 +61,4 @@ jobs: git add --all git commit -m 'Publish built docs triggered by ${{ github.sha }}' git push || git push --force - fi \ No newline at end of file + fi diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index c2f3dd684a23e..3fad08643aa22 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -43,6 +43,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml new file mode 100644 index 0000000000000..b98e0a1740cbe --- /dev/null +++ b/.github/workflows/extended.yml @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Datafusion extended tests + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +# https://docs.github.com/en/actions/writing-workflows/choosing-when-your-workflow-runs/events-that-trigger-workflows#running-your-pull_request-workflow-when-a-pull-request-merges +# +# this job is intended to only run only on the main branch as it is time consuming +# and should not fail often. However, it is important coverage to ensure correctness +# in the (very rare) event of a hash failure or sqlite query failure. +on: + # Run on all commits to main + push: + branches: + - main + +jobs: + # Check answers are correct when hash values collide + hash-collisions: + name: cargo test hash collisions (amd64) + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Run tests + run: | + cd datafusion + cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --workspace --lib --tests --features=force_hash_collisions,avro + sqllogictest-sqlite: + name: "Run sqllogictests with the sqlite test suite" + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Run sqllogictest + run: cargo test --profile release-nonlto --test sqllogictests -- --include-sqlite diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4527d047e4c07..7ac0dfa78215c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -39,9 +39,17 @@ on: workflow_dispatch: jobs: - # Check crate compiles + # Check license header + license-header-check: + runs-on: ubuntu-20.04 + name: Check License Header + steps: + - uses: actions/checkout@v4 + - uses: korandoru/hawkeye@v5 + + # Check crate compiles and base cargo check passes linux-build-lib: - name: cargo check + name: linux build test runs-on: ubuntu-latest container: image: amd64/rust @@ -51,89 +59,119 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Prepare cargo build + run: cargo check --profile ci --all-targets - - name: Cache Cargo - uses: actions/cache@v4 + # cargo check common, functions and substrait with no default features + linux-cargo-check-no-default-features: + name: cargo check no default features + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder with: - path: | - ~/.cargo/bin/ - ~/.cargo/registry/index/ - ~/.cargo/registry/cache/ - ~/.cargo/git/db/ - ./target/ - ./datafusion-cli/target/ - key: cargo-cache-${{ hashFiles('**/Cargo.toml', '**/Cargo.lock') }} - + rust-version: stable - name: Check datafusion without default features # Some of the test binaries require the parquet feature still #run: cargo check --all-targets --no-default-features -p datafusion - run: cargo check --no-default-features -p datafusion + run: cargo check --profile ci --no-default-features -p datafusion - name: Check datafusion-common without default features - run: cargo check --all-targets --no-default-features -p datafusion-common + run: cargo check --profile ci --all-targets --no-default-features -p datafusion-common - - name: Check datafusion-functions - run: cargo check --all-targets --no-default-features -p datafusion-functions + - name: Check datafusion-functions without default features + run: cargo check --profile ci --all-targets --no-default-features -p datafusion-functions + + - name: Check datafusion-substrait without default features + run: cargo check --profile ci --all-targets --no-default-features -p datafusion-substrait - name: Check workspace in debug mode - run: cargo check --all-targets --workspace + run: cargo check --profile ci --all-targets --workspace - name: Check workspace with avro,json features - run: cargo check --workspace --benches --features avro,json + run: cargo check --profile ci --workspace --benches --features avro,json - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory # and check in the updated Cargo.lock file. - cargo check --manifest-path datafusion-cli/Cargo.toml --locked + cargo check --profile ci --manifest-path datafusion-cli/Cargo.toml --locked - # Ensure that the datafusion crate can be built with only a subset of the function - # packages enabled. + # cargo check datafusion to ensure that the datafusion crate can be built with only a + # subset of the function packages enabled. + linux-cargo-check-datafusion: + name: cargo check datafusion + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable - name: Check datafusion (nested_expressions) - run: cargo check --no-default-features --features=nested_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=nested_expressions -p datafusion - name: Check datafusion (crypto) - run: cargo check --no-default-features --features=crypto_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=crypto_expressions -p datafusion - name: Check datafusion (datetime_expressions) - run: cargo check --no-default-features --features=datetime_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=datetime_expressions -p datafusion - name: Check datafusion (encoding_expressions) - run: cargo check --no-default-features --features=encoding_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=encoding_expressions -p datafusion - name: Check datafusion (math_expressions) - run: cargo check --no-default-features --features=math_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=math_expressions -p datafusion - name: Check datafusion (regex_expressions) - run: cargo check --no-default-features --features=regex_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=regex_expressions -p datafusion - name: Check datafusion (string_expressions) - run: cargo check --no-default-features --features=string_expressions -p datafusion + run: cargo check --profile ci --no-default-features --features=string_expressions -p datafusion - # Ensure that the datafusion-functions crate can be built with only a subset of the function - # packages enabled. + # cargo check datafusion-functions to ensure that the datafusion-functions crate can be built with + # only a subset of the function packages enabled. + linux-cargo-check-datafusion-functions: + name: cargo check functions + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable - name: Check datafusion-functions (crypto) - run: cargo check --all-targets --no-default-features --features=crypto_expressions -p datafusion-functions + run: cargo check --profile ci --all-targets --no-default-features --features=crypto_expressions -p datafusion-functions - name: Check datafusion-functions (datetime_expressions) - run: cargo check --all-targets --no-default-features --features=datetime_expressions -p datafusion-functions + run: cargo check --profile ci --all-targets --no-default-features --features=datetime_expressions -p datafusion-functions - name: Check datafusion-functions (encoding_expressions) - run: cargo check --all-targets --no-default-features --features=encoding_expressions -p datafusion-functions + run: cargo check --profile ci --all-targets --no-default-features --features=encoding_expressions -p datafusion-functions - name: Check datafusion-functions (math_expressions) - run: cargo check --all-targets --no-default-features --features=math_expressions -p datafusion-functions + run: cargo check --profile ci --all-targets --no-default-features --features=math_expressions -p datafusion-functions - name: Check datafusion-functions (regex_expressions) - run: cargo check --all-targets --no-default-features --features=regex_expressions -p datafusion-functions + run: cargo check --profile ci --all-targets --no-default-features --features=regex_expressions -p datafusion-functions - name: Check datafusion-functions (string_expressions) - run: cargo check --all-targets --no-default-features --features=string_expressions -p datafusion-functions + run: cargo check --profile ci --all-targets --no-default-features --features=string_expressions -p datafusion-functions # Run tests linux-test: name: cargo test (amd64) - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -141,18 +179,19 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: - rust-version: stable + rust-version: stable - name: Run tests (excluding doctests) - run: cargo test --lib --tests --bins --features avro,json,backtrace + run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace - name: Verify Working Directory Clean run: git diff --exit-code linux-test-datafusion-cli: name: cargo test datafusion-cli (amd64) - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -160,6 +199,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -167,13 +207,13 @@ jobs: - name: Run tests (excluding doctests) run: | cd datafusion-cli - cargo test --lib --tests --bins --all-features + cargo test --profile ci --lib --tests --bins --all-features - name: Verify Working Directory Clean run: git diff --exit-code linux-test-example: name: cargo examples (amd64) - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -181,6 +221,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -188,18 +229,16 @@ jobs: - name: Run examples run: | # test datafusion-sql examples - cargo run --example sql + cargo run --profile ci --example sql # test datafusion-examples ci/scripts/rust_example.sh - name: Verify Working Directory Clean run: git diff --exit-code - - # Run `cargo test doc` (test documentation examples) linux-test-doc: name: cargo test doc (amd64) - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -207,22 +246,23 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Run doctests run: | - cargo test --doc --features avro,json + cargo test --profile ci --doc --features avro,json cd datafusion-cli - cargo test --doc --all-features + cargo test --profile ci --doc --all-features - name: Verify Working Directory Clean run: git diff --exit-code # Run `cargo doc` to ensure the rustdoc is clean linux-rustdoc: name: cargo doc - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -255,7 +295,7 @@ jobs: # verify that the benchmark queries return the correct results verify-benchmark-results: name: verify benchmark results (amd64) - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -263,6 +303,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -277,17 +318,20 @@ jobs: mv *.tbl ../datafusion/sqllogictest/test_files/tpch/data - name: Verify that benchmark queries return expected results run: | + # increase stack size to fix stack overflow + export RUST_MIN_STACK=20971520 export TPCH_DATA=`realpath datafusion/sqllogictest/test_files/tpch/data` - # use release build for plan verificaton because debug build causes stack overflow - cargo test plan_q --package datafusion-benchmarks --profile release-nonlto --features=ci -- --test-threads=1 - INCLUDE_TPCH=true cargo test --test sqllogictests + cargo test plan_q --package datafusion-benchmarks --profile ci --features=ci -- --test-threads=1 + INCLUDE_TPCH=true cargo test --profile ci --package datafusion-sqllogictest --test sqllogictests - name: Verify Working Directory Clean run: git diff --exit-code sqllogictest-postgres: name: "Run sqllogictest with Postgres runner" - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest + container: + image: amd64/rust services: postgres: image: postgres:15 @@ -296,7 +340,7 @@ jobs: POSTGRES_DB: db_test POSTGRES_INITDB_ARGS: --encoding=UTF-8 --lc-collate=C --lc-ctype=C ports: - - 5432/tcp + - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s @@ -306,47 +350,59 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Setup toolchain - run: | - rustup toolchain install stable - rustup default stable + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable - name: Run sqllogictest - run: PG_COMPAT=true PG_URI="postgresql://postgres:postgres@localhost:$POSTGRES_PORT/db_test" cargo test --features=postgres --test sqllogictests + run: | + cd datafusion/sqllogictest + PG_COMPAT=true PG_URI="postgresql://postgres:postgres@$POSTGRES_HOST:$POSTGRES_PORT/db_test" cargo test --profile ci --features=postgres --test sqllogictests env: + # use postgres for the host here because we have specified a container for the job + POSTGRES_HOST: postgres POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} - windows: - name: cargo test (win64) - runs-on: windows-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-windows-builder - - name: Run tests (excluding doctests) - shell: bash - run: | - export PATH=$PATH:$HOME/d/protoc/bin - cargo test --lib --tests --bins --features avro,json,backtrace - cd datafusion-cli - cargo test --lib --tests --bins --all-features - - macos: - name: cargo test (macos) - runs-on: macos-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-macos-builder - - name: Run tests (excluding doctests) - shell: bash - run: | - cargo test --lib --tests --bins --features avro,json,backtrace - cd datafusion-cli - cargo test --lib --tests --bins --all-features +# Temporarily commenting out the Windows flow, the reason is enormously slow running build +# Waiting for new Windows 2025 github runner +# Details: https://github.com/apache/datafusion/issues/13726 +# +# windows: +# name: cargo test (win64) +# runs-on: windows-latest +# steps: +# - uses: actions/checkout@v4 +# with: +# submodules: true +# - name: Setup Rust toolchain +# uses: ./.github/actions/setup-windows-builder +# - name: Run tests (excluding doctests) +# shell: bash +# run: | +# export PATH=$PATH:$HOME/d/protoc/bin +# cargo test --lib --tests --bins --features avro,json,backtrace +# cd datafusion-cli +# cargo test --lib --tests --bins --all-features + +# Commenting out intel mac build as so few users would ever use it +# Details: https://github.com/apache/datafusion/issues/13846 +# macos: +# name: cargo test (macos) +# runs-on: macos-latest +# steps: +# - uses: actions/checkout@v4 +# with: +# submodules: true +# fetch-depth: 1 +# - name: Setup Rust toolchain +# uses: ./.github/actions/setup-macos-builder +# - name: Run tests (excluding doctests) +# shell: bash +# run: | +# cargo test run --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace +# cd datafusion-cli +# cargo test run --profile ci --lib --tests --bins --all-features macos-aarch64: name: cargo test (macos-aarch64) @@ -355,18 +411,19 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-macos-aarch64-builder - name: Run tests (excluding doctests) shell: bash run: | - cargo test --lib --tests --bins --features avro,json,backtrace + cargo test --profile ci --lib --tests --bins --features avro,json,backtrace cd datafusion-cli - cargo test --lib --tests --bins --all-features + cargo test --profile ci --lib --tests --bins --all-features test-datafusion-pyarrow: name: cargo test pyarrow (amd64) - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-20.04 container: image: amd64/rust:bullseye # Workaround https://github.com/actions/setup-python/issues/721 @@ -374,6 +431,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - uses: actions/setup-python@v5 with: python-version: "3.8" @@ -386,7 +444,7 @@ jobs: with: rust-version: stable - name: Run datafusion-common tests - run: cargo test -p datafusion-common --features=pyarrow + run: cargo test --profile ci -p datafusion-common --features=pyarrow vendor: name: Verify Vendored Code @@ -397,6 +455,8 @@ jobs: - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder + with: + rust-version: stable - name: Run gen run: ./regen.sh working-directory: ./datafusion/proto @@ -463,7 +523,7 @@ jobs: clippy: name: clippy - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -471,6 +531,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -480,29 +541,9 @@ jobs: - name: Run clippy run: ci/scripts/rust_clippy.sh - # Check answers are correct when hash values collide - hash-collisions: - name: cargo test hash collisions (amd64) - needs: [ linux-build-lib ] - runs-on: ubuntu-latest - container: - image: amd64/rust - steps: - - uses: actions/checkout@v4 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run tests - run: | - cd datafusion - cargo test --lib --tests --features=force_hash_collisions,avro - cargo-toml-formatting-checks: name: check Cargo.toml formatting - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -510,6 +551,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -522,7 +564,7 @@ jobs: config-docs-check: name: check configs.md and ***_functions.md is up-to-date - needs: [ linux-build-lib ] + needs: linux-build-lib runs-on: ubuntu-latest container: image: amd64/rust @@ -530,6 +572,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true + fetch-depth: 1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -574,9 +617,9 @@ jobs: # # To reproduce: # 1. Install the version of Rust that is failing. Example: - # rustup install 1.78.0 + # rustup install 1.80.1 # 2. Run the command that failed with that version. Example: - # cargo +1.78.0 check -p datafusion + # cargo +1.80.1 check -p datafusion # # To resolve, either: # 1. Change your code to use older Rust features, @@ -595,4 +638,4 @@ jobs: run: cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-cli working-directory: datafusion-cli - run: cargo msrv --output-format json --log-target stdout verify \ No newline at end of file + run: cargo msrv --output-format json --log-target stdout verify diff --git a/.gitignore b/.gitignore index 05570eacf630c..1fa79249ff8e0 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,12 @@ datafusion/sqllogictests/test_files/tpch/data/* # Scratch temp dir for sqllogictests datafusion/sqllogictest/test_files/scratch* +# temp file for core +datafusion/core/*.parquet + +# Generated core benchmark data +datafusion/core/benches/data/* + # rat filtered_rat.txt rat.txt diff --git a/.gitmodules b/.gitmodules index ec5d6208b8ddb..037accdbe4241 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,7 @@ [submodule "testing"] path = testing url = https://github.com/apache/arrow-testing +[submodule "datafusion-testing"] + path = datafusion-testing + url = https://github.com/apache/datafusion-testing.git + branch = main diff --git a/CHANGELOG.md b/CHANGELOG.md index ea0c339ac4514..c481ce0b96a0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ under the License. --> -* [DataFusion CHANGELOG](./datafusion/CHANGELOG.md) +Change logs for each release can be found [here](dev/changelog). + For older versions, see [apache/arrow/CHANGELOG.md](https://github.com/apache/arrow/blob/master/CHANGELOG.md). diff --git a/Cargo.toml b/Cargo.toml index 448607257ca1e..aa412cba51087 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,9 +26,11 @@ members = [ "datafusion/expr", "datafusion/expr-common", "datafusion/execution", + "datafusion/ffi", "datafusion/functions", "datafusion/functions-aggregate", "datafusion/functions-aggregate-common", + "datafusion/functions-table", "datafusion/functions-nested", "datafusion/functions-window", "datafusion/functions-window-common", @@ -46,8 +48,13 @@ members = [ "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", + "datafusion-examples/examples/ffi/ffi_example_table_provider", + "datafusion-examples/examples/ffi/ffi_module_interface", + "datafusion-examples/examples/ffi/ffi_module_loader", "test-utils", "benchmarks", + "datafusion/macros", + "datafusion/doc", ] resolver = "2" @@ -58,75 +65,75 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.78" -version = "42.0.0" +rust-version = "1.80.1" +version = "44.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can # selectively turn them on if needed, since we can override default-features = true (from false) # for the inherited dependency but cannot do the reverse (override from true to false). # -# See for more detaiils: https://github.com/rust-lang/cargo/issues/11329 +# See for more details: https://github.com/rust-lang/cargo/issues/11329 ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "53.1.0", features = [ +arrow = { version = "54.0.0", features = [ "prettyprint", ] } -arrow-array = { version = "53.1.0", default-features = false, features = [ +arrow-array = { version = "54.0.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "53.1.0", default-features = false } -arrow-flight = { version = "53.1.0", features = [ +arrow-buffer = { version = "54.0.0", default-features = false } +arrow-flight = { version = "54.0.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "53.1.0", default-features = false, features = [ +arrow-ipc = { version = "54.0.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "53.1.0", default-features = false } -arrow-schema = { version = "53.1.0", default-features = false } -arrow-string = { version = "53.1.0", default-features = false } +arrow-ord = { version = "54.0.0", default-features = false } +arrow-schema = { version = "54.0.0", default-features = false } async-trait = "0.1.73" -bigdecimal = "=0.4.1" +bigdecimal = "0.4.7" bytes = "1.4" chrono = { version = "0.4.38", default-features = false } -ctor = "0.2.0" +ctor = "0.2.9" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "42.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "42.0.0" } -datafusion-common = { path = "datafusion/common", version = "42.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "42.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "42.0.0" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "42.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "42.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "42.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "42.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "42.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "42.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "42.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "42.0.0" } +datafusion = { path = "datafusion/core", version = "44.0.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "44.0.0" } +datafusion-common = { path = "datafusion/common", version = "44.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "44.0.0" } +datafusion-doc = { path = "datafusion/doc", version = "44.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "44.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "44.0.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "44.0.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "44.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "44.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "44.0.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "44.0.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "44.0.0" } +datafusion-functions-table = { path = "datafusion/functions-table", version = "44.0.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "44.0.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "44.0.0" } +datafusion-macros = { path = "datafusion/macros", version = "44.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "44.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "44.0.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "44.0.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "44.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "44.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "44.0.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "44.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "44.0.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" -itertools = "0.13" +itertools = "0.14" log = "^0.4" -num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "53.1.0", default-features = false, features = [ +parquet = { version = "54.0.0", default-features = false, features = [ "arrow", "async", "object_store", @@ -136,14 +143,14 @@ pbjson = { version = "0.7.0" } prost = "0.13.1" prost-derive = "0.13.1" rand = "0.8" +recursive = "0.1.1" regex = "1.8" -rstest = "0.23.0" +rstest = "0.24.0" serde_json = "1" -sqlparser = { version = "0.51.0", features = ["visitor"] } +sqlparser = { version = "0.53.0", features = ["visitor"] } tempfile = "3" -thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } -url = "2.2" +url = "2.5.4" [profile.release] codegen-units = 1 @@ -163,9 +170,21 @@ overflow-checks = false panic = 'unwind' rpath = false +[profile.ci] +inherits = "dev" +incremental = false + +# ci turns off debug info, etc for dependencies to allow for smaller binaries making caching more effective +[profile.ci.package."*"] +debug = false +debug-assertions = false +strip = "debuginfo" +incremental = false + [workspace.lints.clippy] # Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) large_futures = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unused_qualifications = "deny" diff --git a/README.md b/README.md index 5d0b096c1de11..2a20faa9e2fa5 100644 --- a/README.md +++ b/README.md @@ -38,14 +38,15 @@ [Chat](https://discord.com/channels/885562378132000778/885562378132000781) - logo + logo DataFusion is an extensible query engine written in [Rust] that uses [Apache Arrow] as its in-memory format. -The DataFusion libraries in this repository are used to build data-centric system software. DataFusion also provides the -following subprojects, which are packaged versions of DataFusion intended for end users. +This crate provides libraries and binaries for developers building fast and +feature rich database and analytic systems, customized to particular workloads. +See [use cases] for examples. The following related subprojects target end users: - [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame queries. @@ -54,13 +55,10 @@ following subprojects, which are packaged versions of DataFusion intended for en - [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on DataFusion. -The target audience for the DataFusion crates in this repository are -developers building fast and feature rich database and analytic systems, -customized to particular workloads. See [use cases] for examples. - -DataFusion offers [SQL] and [`Dataframe`] APIs, -excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, -extensive customization, and a great community. +"Out of the box," +DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], +built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and +a great community. DataFusion features a full query planner, a columnar, streaming, multi-threaded, vectorized execution engine, and partitioned data sources. You can @@ -114,7 +112,8 @@ Default features: - `parquet`: support for reading the [Apache Parquet] format - `regex_expressions`: regular expression functions, such as `regexp_match` - `unicode_expressions`: Include unicode aware functions such as `character_length` -- `unparser` : enables support to reverse LogicalPlans back into SQL +- `unparser`: enables support to reverse LogicalPlans back into SQL +- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. Optional features: @@ -128,11 +127,46 @@ Optional features: ## Rust Version Compatibility Policy -DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support stable [4 latest -Rust versions](https://releases.rs) OR the stable minor Rust version as of 4 months, whichever is lower. +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. -If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. + +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) + +## DataFusion API Evolution and Deprecation Guidelines + +Public methods in Apache DataFusion evolve over time: while we try to maintain a +stable API, we also improve the API over time. As a result, we typically +deprecate methods before removing them, according to the [deprecation guidelines]. + +[deprecation guidelines]: https://datafusion.apache.org/library-user-guide/api-health.html + +## Dependencies and a `Cargo.lock` + +`datafusion` is intended for use as a library and thus purposely does not have a +`Cargo.lock` file checked in. You can read more about the distinction in the +[Cargo book]. + +CI tests always run against the latest compatible versions of all dependencies +(the equivalent of doing `cargo update`), as suggested in the [Cargo CI guide] +and we rely on Dependabot for other upgrades. This strategy has two problems +that occasionally arise: + +1. CI failures when downstream libraries upgrade in some non compatible way +2. Local development builds that fail when DataFusion inadvertently relies on + a feature in a newer version of a dependency than declared in `Cargo.toml` + (e.g. a new method is added to a trait that we use). + +However, we think the current strategy is the best tradeoff between maintenance +overhead and user experience and ensures DataFusion always works with the latest +compatible versions of all dependencies. If you encounter either of these +problems, please open an issue or PR. -We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +[cargo book]: https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +[cargo ci guide]: https://doc.rust-lang.org/cargo/guide/continuous-integration.html#verifying-latest-dependencies diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 7f29f7471b6fc..ad8debaf2fa38 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -42,7 +42,6 @@ env_logger = { workspace = true } futures = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } -num_cpus = { workspace = true } parquet = { workspace = true, default-features = true } serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } diff --git a/benchmarks/README.md b/benchmarks/README.md index afaf28bb75769..332cac8459d75 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -32,7 +32,7 @@ DataFusion is included in the benchmark setups for several popular benchmarks that compare performance with other engines. For example: * [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) -* [H2o.ai `db-benchmark`] scripts are in [db-benchmark](db-benchmark) directory +* [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) [ClickBench]: https://github.com/ClickHouse/ClickBench/tree/main [H2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark @@ -330,6 +330,40 @@ steps. The tests sort the entire dataset using several different sort orders. +## Sort TPCH + +Test performance of end-to-end sort SQL queries. (While the `Sort` benchmark focuses on a single sort executor, this benchmark tests how sorting is executed across multiple CPU cores by benchmarking sorting the whole relational table.) + +Sort integration benchmark runs whole table sort queries on TPCH `lineitem` table, with different characteristics. For example, different number of sort keys, different sort key cardinality, different number of payload columns, etc. + +See [`sort_tpch.rs`](src/sort_tpch.rs) for more details. + +### Sort TPCH Benchmark Example Runs +1. Run all queries with default setting: +```bash + cargo run --release --bin dfbench -- sort-tpch -p '....../datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' +``` + +2. Run a specific query: +```bash + cargo run --release --bin dfbench -- sort-tpch -p '....../datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' --query 2 +``` + +3. Run all queries with `bench.sh` script: +```bash +./bench.sh run sort_tpch +``` + +## IMDB + +Run Join Order Benchmark (JOB) on IMDB dataset. + +The Internet Movie Database (IMDB) dataset contains real-world movie data. Unlike synthetic datasets like TPCH, which assume uniform data distribution and uncorrelated columns, the IMDB dataset includes skewed data and correlated columns (which are common for real dataset), making it more suitable for testing query optimizers, particularly for cardinality estimation. + +This benchmark is derived from [Join Order Benchmark](https://github.com/gregrahn/join-order-benchmark). + +See paper [How Good Are Query Optimizers, Really](http://www.vldb.org/pvldb/vol9/p204-leis.pdf) for more details. + ## TPCH Run the tpch benchmark. @@ -342,32 +376,79 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## External Aggregation + +Run the benchmark for aggregations with limited memory. -# Older Benchmarks +When the memory limit is exceeded, the aggregation intermediate results will be spilled to disk, and finally read back with sort-merge. -## h2o benchmarks +External aggregation benchmarks run several aggregation queries with different memory limits, on TPCH `lineitem` table. Queries can be found in [`external_aggr.rs`](src/bin/external_aggr.rs). +This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. + +### External Aggregation Example Runs +1. Run all queries with predefined memory limits: ```bash -cargo run --release --bin h2o group-by --query 1 --path /mnt/bigdata/h2oai/N_1e7_K_1e2_single.csv --mem-table --debug +# Under 'benchmarks/' directory +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' ``` -Example run: +2. Run a query with specific memory limit: +```bash +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M +``` +3. Run all queries with `bench.sh` script: +```bash +./bench.sh data external_aggr +./bench.sh run external_aggr ``` -Running benchmarks with the following options: GroupBy(GroupBy { query: 1, path: "/mnt/bigdata/h2oai/N_1e7_K_1e2_single.csv", debug: false }) -Executing select id1, sum(v1) as v1 from x group by id1 -+-------+--------+ -| id1 | v1 | -+-------+--------+ -| id063 | 199420 | -| id094 | 200127 | -| id044 | 198886 | -... -| id093 | 200132 | -| id003 | 199047 | -+-------+--------+ -h2o groupby query 1 took 1669 ms + +## h2o benchmarks for groupby + +### Generate data for h2o benchmarks +There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. + +1. Generate small data (1e7 rows) +```bash +./bench.sh data h2o_small +``` + + +2. Generate medium data (1e8 rows) +```bash +./bench.sh data h2o_medium +``` + + +3. Generate large data (1e9 rows) +```bash +./bench.sh data h2o_big +``` + +### Run h2o benchmarks +There are three options for running h2o benchmarks: `small`, `medium`, and `big`. +1. Run small data benchmark +```bash +./bench.sh run h2o_small +``` + +2. Run medium data benchmark +```bash +./bench.sh run h2o_medium +``` + +3. Run large data benchmark +```bash +./bench.sh run h2o_big +``` + +4. Run a specific query with a specific data path + +For example, to run query 1 with the small data generated above: +```bash +cargo run --release --bin dfbench -- h2o --path ./benchmarks/data/h2o/G1_1e7_1e7_100_0.csv --query 1 ``` [1]: http://www.tpc.org/tpch/ diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 70faa9ef2b737..20cb32722c879 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -75,9 +75,14 @@ tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory parquet: Benchmark of parquet reader's filtering speed sort: Benchmark of sorting speed +sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPCH dataset clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) +external_aggr: External aggregation benchmark +h2o_small: h2oai benchmark with small dataset (1e7 rows), default file format is csv +h2o_medium: h2oai benchmark with medium dataset (1e8 rows), default file format is csv +h2o_big: h2oai benchmark with large dataset (1e9 rows), default file format is csv ********** * Supported Configuration (Environment Variables) @@ -140,6 +145,9 @@ main() { all) data_tpch "1" data_tpch "10" + data_h2o "SMALL" + data_h2o "MEDIUM" + data_h2o "BIG" data_clickbench_1 data_clickbench_partitioned data_imdb @@ -170,6 +178,23 @@ main() { imdb) data_imdb ;; + h2o_small) + data_h2o "SMALL" "CSV" + ;; + h2o_medium) + data_h2o "MEDIUM" "CSV" + ;; + h2o_big) + data_h2o "BIG" "CSV" + ;; + external_aggr) + # same data as for tpch + data_tpch "1" + ;; + sort_tpch) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -211,7 +236,11 @@ main() { run_clickbench_1 run_clickbench_partitioned run_clickbench_extended + run_h2o "SMALL" "PARQUET" "groupby" + run_h2o "MEDIUM" "PARQUET" "groupby" + run_h2o "BIG" "PARQUET" "groupby" run_imdb + run_external_aggr ;; tpch) run_tpch "1" @@ -243,6 +272,21 @@ main() { imdb) run_imdb ;; + h2o_small) + run_h2o "SMALL" "CSV" "groupby" + ;; + h2o_medium) + run_h2o "MEDIUM" "CSV" "groupby" + ;; + h2o_big) + run_h2o "BIG" "CSV" "groupby" + ;; + external_aggr) + run_external_aggr + ;; + sort_tpch) + run_sort_tpch + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -357,7 +401,7 @@ run_parquet() { RESULTS_FILE="${RESULTS_DIR}/parquet.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } # Runs the sort benchmark @@ -365,7 +409,7 @@ run_sort() { RESULTS_FILE="${RESULTS_DIR}/sort.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } @@ -439,11 +483,11 @@ run_clickbench_extended() { } # Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) -# http://homepages.cwi.nl/~boncz/job/imdb.tgz +# https://event.cwi.nl/da/job/imdb.tgz data_imdb() { local imdb_dir="${DATA_DIR}/imdb" local imdb_temp_gz="${imdb_dir}/imdb.tgz" - local imdb_url="https://homepages.cwi.nl/~boncz/job/imdb.tgz" + local imdb_url="https://event.cwi.nl/da/job/imdb.tgz" # imdb has 21 files, we just separate them into 3 groups for better readability local first_required_files=( @@ -524,7 +568,150 @@ run_imdb() { $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" } +data_h2o() { + # Default values for size and data format + SIZE=${1:-"SMALL"} + DATA_FORMAT=${2:-"CSV"} + + # Function to compare Python versions + version_ge() { + [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] + } + + export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 + + # Find the highest available Python version (3.10 or higher) + REQUIRED_VERSION="3.10" + PYTHON_CMD=$(command -v python3 || true) + + if [ -n "$PYTHON_CMD" ]; then + PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then + echo "Found Python version $PYTHON_VERSION, which is suitable." + else + echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." + PYTHON_CMD="" + fi + fi + + # Search for suitable Python versions if the default is unsuitable + if [ -z "$PYTHON_CMD" ]; then + # Loop through all available Python3 commands on the system + for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do + if command -v "$CMD" &> /dev/null; then + PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then + PYTHON_CMD="$CMD" + echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" + break + fi + fi + done + fi + + # If no suitable Python version found, exit with an error + if [ -z "$PYTHON_CMD" ]; then + echo "Python 3.10 or higher is required. Please install it." + return 1 + fi + + echo "Using Python command: $PYTHON_CMD" + + # Install falsa and other dependencies + echo "Installing falsa..." + + # Set virtual environment directory + VIRTUAL_ENV="${PWD}/venv" + + # Create a virtual environment using the detected Python command + $PYTHON_CMD -m venv "$VIRTUAL_ENV" + + # Activate the virtual environment and install dependencies + source "$VIRTUAL_ENV/bin/activate" + + # Ensure 'falsa' is installed (avoid unnecessary reinstall) + pip install --quiet --upgrade falsa + + # Create directory if it doesn't exist + H2O_DIR="${DATA_DIR}/h2o" + mkdir -p "${H2O_DIR}" + + # Generate h2o test data + echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" + falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" + + # Deactivate virtual environment after completion + deactivate +} + +## todo now only support groupby, after https://github.com/mrpowers-io/falsa/issues/21 done, we can add support for join +run_h2o() { + # Default values for size and data format + SIZE=${1:-"SMALL"} + DATA_FORMAT=${2:-"CSV"} + DATA_FORMAT=$(echo "$DATA_FORMAT" | tr '[:upper:]' '[:lower:]') + RUN_Type=${3:-"groupby"} + + # Data directory and results file path + H2O_DIR="${DATA_DIR}/h2o" + RESULTS_FILE="${RESULTS_DIR}/h2o.json" + + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running h2o benchmark..." + # Set the file name based on the size + case "$SIZE" in + "SMALL") + FILE_NAME="G1_1e7_1e7_100_0.${DATA_FORMAT}" # For small dataset + ;; + "MEDIUM") + FILE_NAME="G1_1e8_1e8_100_0.${DATA_FORMAT}" # For medium dataset + ;; + "BIG") + FILE_NAME="G1_1e9_1e9_100_0.${DATA_FORMAT}" # For big dataset + ;; + *) + echo "Invalid size. Valid options are SMALL, MEDIUM, or BIG." + return 1 + ;; + esac + + # Set the query file name based on the RUN_Type + QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" + + # Run the benchmark using the dynamically constructed file path and query file + $CARGO_COMMAND --bin dfbench -- h2o \ + --iterations 3 \ + --path "${H2O_DIR}/${FILE_NAME}" \ + --queries-path "${QUERY_FILE}" \ + -o "${RESULTS_FILE}" +} + +# Runs the external aggregation benchmark +run_external_aggr() { + # Use TPC-H SF1 dataset + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/external_aggr.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running external aggregation benchmark..." + + # Only parquet is supported. + # Since per-operator memory limit is calculated as (total-memory-limit / + # number-of-partitions), and by default `--partitions` is set to number of + # CPU cores, we set a constant number of partitions to prevent this + # benchmark to fail on some machines. + $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" +} + +# Runs the sort integration benchmark +run_sort_tpch() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/sort_tpch.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sort tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" +} compare_benchmarks() { diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 2574c0735ca8d..4b609c744d503 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -1,21 +1,20 @@ #!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. from __future__ import annotations diff --git a/benchmarks/queries/h2o/groupby.sql b/benchmarks/queries/h2o/groupby.sql new file mode 100644 index 0000000000000..c2101ef8ada2d --- /dev/null +++ b/benchmarks/queries/h2o/groupby.sql @@ -0,0 +1,10 @@ +SELECT id1, SUM(v1) AS v1 FROM x GROUP BY id1; +SELECT id1, id2, SUM(v1) AS v1 FROM x GROUP BY id1, id2; +SELECT id3, SUM(v1) AS v1, AVG(v3) AS v3 FROM x GROUP BY id3; +SELECT id4, AVG(v1) AS v1, AVG(v2) AS v2, AVG(v3) AS v3 FROM x GROUP BY id4; +SELECT id6, SUM(v1) AS v1, SUM(v2) AS v2, SUM(v3) AS v3 FROM x GROUP BY id6; +SELECT id4, id5, MEDIAN(v3) AS median_v3, STDDEV(v3) AS sd_v3 FROM x GROUP BY id4, id5; +SELECT id3, MAX(v1) - MIN(v2) AS range_v1_v2 FROM x GROUP BY id3; +SELECT id6, largest2_v3 FROM (SELECT id6, v3 AS largest2_v3, ROW_NUMBER() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2; +SELECT id2, id4, POWER(CORR(v1, v2), 2) AS r2 FROM x GROUP BY id2, id4; +SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; diff --git a/benchmarks/queries/h2o/join.sql b/benchmarks/queries/h2o/join.sql new file mode 100644 index 0000000000000..8546b9292dbb4 --- /dev/null +++ b/benchmarks/queries/h2o/join.sql @@ -0,0 +1,5 @@ +SELECT x.id1, x.id2, x.id3, x.id4 as xid4, small.id4 as smallid4, x.id5, x.id6, x.v1, small.v2 FROM x INNER JOIN small ON x.id1 = small.id1; +SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x INNER JOIN medium ON x.id2 = medium.id2; +SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x LEFT JOIN medium ON x.id2 = medium.id2; +SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x JOIN medium ON x.id5 = medium.id5; +SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index f7b84116e793a..db6c29f4a46a6 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -33,7 +33,9 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -use datafusion_benchmarks::{clickbench, imdb, parquet_filter, sort, tpch}; +use datafusion_benchmarks::{ + clickbench, h2o, imdb, parquet_filter, sort, sort_tpch, tpch, +}; #[derive(Debug, StructOpt)] #[structopt(about = "benchmark command")] @@ -43,7 +45,9 @@ enum Options { Clickbench(clickbench::RunOpt), ParquetFilter(parquet_filter::RunOpt), Sort(sort::RunOpt), + SortTpch(sort_tpch::RunOpt), Imdb(imdb::RunOpt), + H2o(h2o::RunOpt), } // Main benchmark runner entrypoint @@ -57,6 +61,8 @@ pub async fn main() -> Result<()> { Options::Clickbench(opt) => opt.run().await, Options::ParquetFilter(opt) => opt.run().await, Options::Sort(opt) => opt.run().await, + Options::SortTpch(opt) => opt.run().await, Options::Imdb(opt) => opt.run().await, + Options::H2o(opt) => opt.run().await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs new file mode 100644 index 0000000000000..a2fb75dd19418 --- /dev/null +++ b/benchmarks/src/bin/external_aggr.rs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! external_aggr binary entrypoint + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::LazyLock; +use structopt::StructOpt; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::execution::memory_pool::FairSpillPool; +use datafusion::execution::memory_pool::{human_readable_size, units}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::execution::SessionStateBuilder; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; +use datafusion_common::{exec_datafusion_err, exec_err, DEFAULT_PARQUET_EXTENSION}; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "datafusion-external-aggregation", + about = "DataFusion external aggregation benchmark" +)] +enum ExternalAggrOpt { + Benchmark(ExternalAggrConfig), +} + +#[derive(Debug, StructOpt)] +struct ExternalAggrConfig { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query. + #[structopt(long)] + memory_limit: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files (lineitem). Only parquet format is supported + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to JSON benchmark result to be compare using `compare.py` + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +/// Query Memory Limits +/// Map query id to predefined memory limits +/// +/// Q1 requires 36MiB for aggregation +/// Memory limits to run: 64MiB, 32MiB, 16MiB +/// Q2 requires 250MiB for aggregation +/// Memory limits to run: 512MiB, 256MiB, 128MiB, 64MiB, 32MiB +static QUERY_MEMORY_LIMITS: LazyLock>> = LazyLock::new(|| { + use units::*; + let mut map = HashMap::new(); + map.insert(1, vec![64 * MB, 32 * MB, 16 * MB]); + map.insert(2, vec![512 * MB, 256 * MB, 128 * MB, 64 * MB, 32 * MB]); + map +}); + +impl ExternalAggrConfig { + const AGGR_TABLES: [&'static str; 1] = ["lineitem"]; + const AGGR_QUERIES: [&'static str; 2] = [ + // Q1: Output size is ~25% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey + FROM lineitem + ) + "#, + // Q2: Output size is ~99% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey, l_suppkey + FROM lineitem + ) + "#, + ]; + + /// If `--query` and `--memory-limit` is not speicified, run all queries + /// with pre-configured memory limits + /// If only `--query` is specified, run the query with all memory limits + /// for this query + /// If both `--query` and `--memory-limit` are specified, run the query + /// with the specified memory limit + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let memory_limit = match &self.memory_limit { + Some(limit) => Some(Self::parse_memory_limit(limit)?), + None => None, + }; + + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => 1..=Self::AGGR_QUERIES.len(), + }; + + // Each element is (query_id, memory_limit) + // e.g. [(1, 64_000), (1, 32_000)...] means first run Q1 with 64KiB + // memory limit, next run Q1 with 32KiB memory limit, etc. + let mut query_executions = vec![]; + // Setup `query_executions` + for query_id in query_range { + if query_id > Self::AGGR_QUERIES.len() { + return exec_err!( + "Invalid '--query'(query number) {} for external aggregation benchmark.", + query_id + ); + } + + match memory_limit { + Some(limit) => { + query_executions.push((query_id, limit)); + } + None => { + let memory_limits = QUERY_MEMORY_LIMITS.get(&query_id).unwrap(); + for limit in memory_limits { + query_executions.push((query_id, *limit)); + } + } + } + } + + for (query_id, mem_limit) in query_executions { + benchmark_run.start_new_case(&format!( + "{query_id}({})", + human_readable_size(mem_limit as usize) + )); + + let query_results = self.benchmark_query(query_id, mem_limit).await?; + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + + Ok(()) + } + + /// Benchmark query `query_id` in `AGGR_QUERIES` + async fn benchmark_query( + &self, + query_id: usize, + mem_limit: u64, + ) -> Result> { + let query_name = + format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); + let config = self.common.config(); + let runtime_env = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit as usize))) + .build_arc()?; + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime_env) + .with_default_features() + .build(); + let ctx = SessionContext::from(state); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_idx = query_id - 1; // 1-indexed -> 0-indexed + let sql = Self::AGGR_QUERIES[query_idx]; + + let result = self.execute_query(&ctx, sql).await?; + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "{query_name} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("{query_name} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::AGGR_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(table, Arc::new(memtable))?; + } else { + ctx.register_table(table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = config.infer_schema(&state).await?; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common + .partitions + .unwrap_or(get_available_parallelism()) + } + + /// Parse memory limit from string to number of bytes + /// e.g. '1.5G', '100M' -> 1572864 + fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + exec_datafusion_err!("Failed to parse number from memory limit '{}'", limit) + })?; + + match unit { + "K" => Ok((number * 1024.0) as u64), + "M" => Ok((number * 1024.0 * 1024.0) as u64), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as u64), + _ => exec_err!("Unsupported unit '{}' in memory limit '{}'", unit, limit), + } + } +} + +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + + match ExternalAggrOpt::from_args() { + ExternalAggrOpt::Benchmark(opt) => opt.run().await?, + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!( + ExternalAggrConfig::parse_memory_limit("100K").unwrap(), + 102400 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("1.5M").unwrap(), + 1572864 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("2G").unwrap(), + 2147483648 + ); + + // Test invalid unit + assert!(ExternalAggrConfig::parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(ExternalAggrConfig::parse_memory_limit("abcM").is_err()); + } +} diff --git a/benchmarks/src/bin/h2o.rs b/benchmarks/src/bin/h2o.rs deleted file mode 100644 index 1bb8cb9d43e4b..0000000000000 --- a/benchmarks/src/bin/h2o.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! DataFusion h2o benchmarks - -use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::config::ConfigOptions; -use datafusion::datasource::file_format::csv::CsvFormat; -use datafusion::datasource::listing::{ - ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, -}; -use datafusion::datasource::MemTable; -use datafusion::prelude::CsvReadOptions; -use datafusion::{arrow::util::pretty, error::Result, prelude::SessionContext}; -use datafusion_benchmarks::BenchmarkRun; -use std::path::PathBuf; -use std::sync::Arc; -use structopt::StructOpt; -use tokio::time::Instant; - -#[derive(Debug, StructOpt)] -#[structopt(name = "datafusion-h2o", about = "DataFusion h2o benchmarks")] -enum Opt { - GroupBy(GroupBy), //TODO add Join queries -} - -#[derive(Debug, StructOpt)] -struct GroupBy { - /// Query number - #[structopt(short, long)] - query: usize, - /// Path to data file - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] - path: PathBuf, - /// Activate debug mode to see query results - #[structopt(short, long)] - debug: bool, - /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] - mem_table: bool, - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} - -#[tokio::main] -async fn main() -> Result<()> { - let opt = Opt::from_args(); - println!("Running benchmarks with the following options: {opt:?}"); - match opt { - Opt::GroupBy(config) => group_by(&config).await, - } -} - -async fn group_by(opt: &GroupBy) -> Result<()> { - let mut rundata = BenchmarkRun::new(); - let path = opt.path.to_str().unwrap(); - let mut config = ConfigOptions::from_env()?; - config.execution.batch_size = 65535; - - let ctx = SessionContext::new_with_config(config.into()); - - let schema = Schema::new(vec![ - Field::new("id1", DataType::Utf8, false), - Field::new("id2", DataType::Utf8, false), - Field::new("id3", DataType::Utf8, false), - Field::new("id4", DataType::Int32, false), - Field::new("id5", DataType::Int32, false), - Field::new("id6", DataType::Int32, false), - Field::new("v1", DataType::Int32, false), - Field::new("v2", DataType::Int32, false), - Field::new("v3", DataType::Float64, false), - ]); - - if opt.mem_table { - let listing_config = ListingTableConfig::new(ListingTableUrl::parse(path)?) - .with_listing_options(ListingOptions::new(Arc::new(CsvFormat::default()))) - .with_schema(Arc::new(schema)); - let csv = ListingTable::try_new(listing_config)?; - let partition_size = num_cpus::get(); - let memtable = - MemTable::load(Arc::new(csv), Some(partition_size), &ctx.state()).await?; - ctx.register_table("x", Arc::new(memtable))?; - } else { - ctx.register_csv("x", path, CsvReadOptions::default().schema(&schema)) - .await?; - } - rundata.start_new_case(&opt.query.to_string()); - let sql = match opt.query { - 1 => "select id1, sum(v1) as v1 from x group by id1", - 2 => "select id1, id2, sum(v1) as v1 from x group by id1, id2", - 3 => "select id3, sum(v1) as v1, mean(v3) as v3 from x group by id3", - 4 => "select id4, mean(v1) as v1, mean(v2) as v2, mean(v3) as v3 from x group by id4", - 5 => "select id6, sum(v1) as v1, sum(v2) as v2, sum(v3) as v3 from x group by id6", - 6 => "select id4, id5, median(v3) as median_v3, stddev(v3) as sd_v3 from x group by id4, id5", - 7 => "select id3, max(v1)-min(v2) as range_v1_v2 from x group by id3", - 8 => "select id6, largest2_v3 from (select id6, v3 as largest2_v3, row_number() over (partition by id6 order by v3 desc) as order_v3 from x where v3 is not null) sub_query where order_v3 <= 2", - 9 => "select id2, id4, pow(corr(v1, v2), 2) as r2 from x group by id2, id4", - 10 => "select id1, id2, id3, id4, id5, id6, sum(v3) as v3, count(*) as count from x group by id1, id2, id3, id4, id5, id6", - _ => unimplemented!(), - }; - - println!("Executing {sql}"); - let start = Instant::now(); - let df = ctx.sql(sql).await?; - let batches = df.collect().await?; - let elapsed = start.elapsed(); - let numrows = batches.iter().map(|b| b.num_rows()).sum::(); - if opt.debug { - pretty::print_batches(&batches)?; - } - rundata.write_iter(elapsed, numrows); - println!( - "h2o groupby query {} took {} ms", - opt.query, - elapsed.as_secs_f64() * 1000.0 - ); - rundata.maybe_write_json(opt.output_path.as_ref())?; - Ok(()) -} diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 207da4020b588..6b7c75ed4babc 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -18,6 +18,7 @@ use std::path::Path; use std::path::PathBuf; +use crate::util::{BenchmarkRun, CommonOpt}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, @@ -26,8 +27,6 @@ use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; use structopt::StructOpt; -use crate::{BenchmarkRun, CommonOpt}; - /// Run the clickbench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and @@ -116,12 +115,14 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; + // configure parquet options let mut config = self.common.config(); - config - .options_mut() - .execution - .parquet - .schema_force_view_types = self.common.force_view_types; + { + let parquet_options = &mut config.options_mut().execution.parquet; + // The hits_partitioned dataset specifies string columns + // as binary due to how it was written. Force it to strings + parquet_options.binary_as_string = true; + } let ctx = SessionContext::new_with_config(config); self.register_hits(&ctx).await?; @@ -144,12 +145,15 @@ impl RunOpt { ); benchmark_run.write_iter(elapsed, row_count); } + if self.common.debug { + ctx.sql(sql).await?.explain(false, false)?.show().await?; + } } benchmark_run.maybe_write_json(self.output_path.as_ref())?; Ok(()) } - /// Registrs the `hits.parquet` as a table named `hits` + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs new file mode 100644 index 0000000000000..53a516ceb56d4 --- /dev/null +++ b/benchmarks/src/h2o.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt}; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::{exec_datafusion_err, instant::Instant, DataFusionError}; +use std::path::{Path, PathBuf}; +use structopt::StructOpt; + +/// Run the H2O benchmark +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to queries.sql (single file) + /// default value is the groupby.sql file in the h2o benchmark + #[structopt( + parse(from_os_str), + short = "r", + long = "queries-path", + default_value = "benchmarks/queries/h2o/groupby.sql" + )] + queries_path: PathBuf, + + /// Path to data file (parquet or csv) + /// Default value is the G1_1e7_1e7_100_0.csv file in the h2o benchmark + /// This is the small csv file with 10^7 rows + #[structopt( + parse(from_os_str), + short = "p", + long = "path", + default_value = "benchmarks/data/h2o/G1_1e7_1e7_100_0.csv" + )] + path: PathBuf, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let queries = AllQueries::try_new(&self.queries_path)?; + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => queries.min_query_id()..=queries.max_query_id(), + }; + + let config = self.common.config(); + let ctx = SessionContext::new_with_config(config); + + // Register data + self.register_data(&ctx).await?; + + let iterations = self.common.iterations; + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let sql = queries.get_query(query_id)?; + println!("Q{query_id}: {sql}"); + + for i in 1..=iterations { + let start = Instant::now(); + let results = ctx.sql(sql).await?.collect().await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + benchmark_run.write_iter(elapsed, row_count); + } + if self.common.debug { + ctx.sql(sql).await?.explain(false, false)?.show().await?; + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + } + + Ok(()) + } + + async fn register_data(&self, ctx: &SessionContext) -> Result<()> { + let csv_options = Default::default(); + let parquet_options = Default::default(); + let path = self.path.as_os_str().to_str().unwrap(); + + if self.path.extension().map(|s| s == "csv").unwrap_or(false) { + ctx.register_csv("x", path, csv_options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'table' as {path}"), + Box::new(e), + ) + }) + .expect("error registering csv"); + } + + if self + .path + .extension() + .map(|s| s == "parquet") + .unwrap_or(false) + { + ctx.register_parquet("x", path, parquet_options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'table' as {path}"), + Box::new(e), + ) + }) + .expect("error registering parquet"); + } + Ok(()) + } +} + +struct AllQueries { + queries: Vec, +} + +impl AllQueries { + fn try_new(path: &Path) -> Result { + let all_queries = std::fs::read_to_string(path) + .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; + + Ok(Self { + queries: all_queries.lines().map(|s| s.to_string()).collect(), + }) + } + + /// Returns the text of query `query_id` + fn get_query(&self, query_id: usize) -> Result<&str> { + self.queries + .get(query_id - 1) + .ok_or_else(|| { + let min_id = self.min_query_id(); + let max_id = self.max_query_id(); + exec_datafusion_err!( + "Invalid query id {query_id}. Must be between {min_id} and {max_id}" + ) + }) + .map(|s| s.as_str()) + } + + fn min_query_id(&self) -> usize { + 1 + } + + fn max_query_id(&self) -> usize { + self.queries.len() + } +} diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 697c79dc894a4..8d2317c62ef11 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -19,7 +19,7 @@ use std::path::PathBuf; use std::sync::Arc; use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; -use crate::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -35,6 +35,7 @@ use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; use log::info; @@ -305,11 +306,7 @@ impl RunOpt { .config() .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - config - .options_mut() - .execution - .parquet - .schema_force_view_types = self.common.force_view_types; + let ctx = SessionContext::new_with_config(config); // register tables @@ -472,7 +469,9 @@ impl RunOpt { } fn partitions(&self) -> usize { - self.common.partitions.unwrap_or(num_cpus::get()) + self.common + .partitions + .unwrap_or(get_available_parallelism()) } } @@ -489,6 +488,7 @@ mod tests { use super::*; + use crate::util::CommonOpt; use datafusion::common::exec_err; use datafusion::error::Result; use datafusion_proto::bytes::{ @@ -516,7 +516,6 @@ mod tests { partitions: Some(2), batch_size: 8192, debug: false, - force_view_types: false, }; let opt = RunOpt { query: Some(query), @@ -550,7 +549,6 @@ mod tests { partitions: Some(2), batch_size: 8192, debug: false, - force_view_types: false, }; let opt = RunOpt { query: Some(query), diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 52d81ca91816a..858a5b9df7f86 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -17,9 +17,10 @@ //! DataFusion benchmark runner pub mod clickbench; +pub mod h2o; pub mod imdb; pub mod parquet_filter; pub mod sort; +pub mod sort_tpch; pub mod tpch; -mod util; -pub use util::*; +pub mod util; diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs index 5c98a2f8be3de..34103af0ffd21 100644 --- a/benchmarks/src/parquet_filter.rs +++ b/benchmarks/src/parquet_filter.rs @@ -17,7 +17,7 @@ use std::path::PathBuf; -use crate::{AccessLogOpt, BenchmarkRun, CommonOpt}; +use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 19eec2949ef61..9cf09c57205a7 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -18,17 +18,17 @@ use std::path::PathBuf; use std::sync::Arc; -use crate::{AccessLogOpt, BenchmarkRun, CommonOpt}; +use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; -use datafusion::physical_expr::PhysicalSortExpr; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion::physical_plan::collect; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion::test_util::parquet::TestParquetFile; use datafusion_common::instant::Instant; - +use datafusion_common::utils::get_available_parallelism; use structopt::StructOpt; /// Test performance of sorting large datasets @@ -70,31 +70,28 @@ impl RunOpt { let sort_cases = vec![ ( "sort utf8", - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("request_method", &schema)?, options: Default::default(), - }], + }]), ), ( "sort int", - vec![PhysicalSortExpr { - expr: col("request_bytes", &schema)?, + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("response_bytes", &schema)?, options: Default::default(), - }], + }]), ), ( "sort decimal", - vec![ - // sort decimal - PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }, - ], + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("decimal_price", &schema)?, + options: Default::default(), + }]), ), ( "sort integer tuple", - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("request_bytes", &schema)?, options: Default::default(), @@ -103,11 +100,11 @@ impl RunOpt { expr: col("response_bytes", &schema)?, options: Default::default(), }, - ], + ]), ), ( "sort utf8 tuple", - vec![ + LexOrdering::new(vec![ // sort utf8 tuple PhysicalSortExpr { expr: col("service", &schema)?, @@ -125,11 +122,11 @@ impl RunOpt { expr: col("image", &schema)?, options: Default::default(), }, - ], + ]), ), ( "sort mixed tuple", - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("service", &schema)?, options: Default::default(), @@ -142,7 +139,7 @@ impl RunOpt { expr: col("decimal_price", &schema)?, options: Default::default(), }, - ], + ]), ), ]; for (title, expr) in sort_cases { @@ -150,7 +147,9 @@ impl RunOpt { rundata.start_new_case(title); for i in 0..self.common.iterations { let config = SessionConfig::new().with_target_partitions( - self.common.partitions.unwrap_or(num_cpus::get()), + self.common + .partitions + .unwrap_or(get_available_parallelism()), ); let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = @@ -170,13 +169,13 @@ impl RunOpt { async fn exec_sort( ctx: &SessionContext, - expr: &[PhysicalSortExpr], + expr: &LexOrdering, test_file: &TestParquetFile, debug: bool, ) -> Result<(usize, std::time::Duration)> { let start = Instant::now(); let scan = test_file.create_scan(ctx, None).await?; - let exec = Arc::new(SortExec::new(expr.to_owned(), scan)); + let exec = Arc::new(SortExec::new(expr.clone(), scan)); let task_ctx = ctx.task_ctx(); let result = collect(exec, task_ctx).await?; let elapsed = start.elapsed(); diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs new file mode 100644 index 0000000000000..566a5ea62c2d0 --- /dev/null +++ b/benchmarks/src/sort_tpch.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides integration benchmark for sort operation. +//! It will run different sort SQL queries on TPCH `lineitem` parquet dataset. +//! +//! Another `Sort` benchmark focus on single core execution. This benchmark +//! runs end-to-end sort queries and test the performance on multiple CPU cores. + +use futures::StreamExt; +use std::path::PathBuf; +use std::sync::Arc; +use structopt::StructOpt; + +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{displayable, execute_stream}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; + +use crate::util::{BenchmarkRun, CommonOpt}; + +#[derive(Debug, StructOpt)] +pub struct RunOpt { + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Sort query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Path to data files (lineitem). Only parquet format is supported + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Path to JSON benchmark result to be compare using `compare.py` + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +impl RunOpt { + const SORT_TABLES: [&'static str; 1] = ["lineitem"]; + + /// Sort queries with different characteristics: + /// - Sort key with fixed length or variable length (VARCHAR) + /// - Sort key with different cardinality + /// - Different number of sort keys + /// - Different number of payload columns (thin: 1 additional column other + /// than sort keys; wide: all columns except sort keys) + /// + /// DataSet is `lineitem` table in TPCH dataset (16 columns, 6M rows for + /// scale factor 1.0, cardinality is counted from SF1 dataset) + /// + /// Key Columns: + /// - Column `l_linenumber`, type: `INTEGER`, cardinality: 7 + /// - Column `l_suppkey`, type: `BIGINT`, cardinality: 10k + /// - Column `l_orderkey`, type: `BIGINT`, cardinality: 1.5M + /// - Column `l_comment`, type: `VARCHAR`, cardinality: 4.5M (len is ~26 chars) + /// + /// Payload Columns: + /// - Thin variant: `l_partkey` column with `BIGINT` type (1 column) + /// - Wide variant: all columns except for possible key columns (12 columns) + const SORT_QUERIES: [&'static str; 10] = [ + // Q1: 1 sort key (type: INTEGER, cardinality: 7) + 1 payload column + r#" + SELECT l_linenumber, l_partkey + FROM lineitem + ORDER BY l_linenumber + "#, + // Q2: 1 sort key (type: BIGINT, cardinality: 1.5M) + 1 payload column + r#" + SELECT l_orderkey, l_partkey + FROM lineitem + ORDER BY l_orderkey + "#, + // Q3: 1 sort key (type: VARCHAR, cardinality: 4.5M) + 1 payload column + r#" + SELECT l_comment, l_partkey + FROM lineitem + ORDER BY l_comment + "#, + // Q4: 2 sort keys {(BIGINT, 1.5M), (INTEGER, 7)} + 1 payload column + r#" + SELECT l_orderkey, l_linenumber, l_partkey + FROM lineitem + ORDER BY l_orderkey, l_linenumber + "#, + // Q5: 3 sort keys {(INTEGER, 7), (BIGINT, 10k), (BIGINT, 1.5M)} + no payload column + r#" + SELECT l_linenumber, l_suppkey, l_orderkey + FROM lineitem + ORDER BY l_linenumber, l_suppkey, l_orderkey + "#, + // Q6: 3 sort keys {(INTEGER, 7), (BIGINT, 10k), (BIGINT, 1.5M)} + 1 payload column + r#" + SELECT l_linenumber, l_suppkey, l_orderkey, l_partkey + FROM lineitem + ORDER BY l_linenumber, l_suppkey, l_orderkey + "#, + // Q7: 3 sort keys {(INTEGER, 7), (BIGINT, 10k), (BIGINT, 1.5M)} + 12 all other columns + r#" + SELECT l_linenumber, l_suppkey, l_orderkey, + l_partkey, l_quantity, l_extendedprice, l_discount, l_tax, + l_returnflag, l_linestatus, l_shipdate, l_commitdate, + l_receiptdate, l_shipinstruct, l_shipmode + FROM lineitem + ORDER BY l_linenumber, l_suppkey, l_orderkey + "#, + // Q8: 4 sort keys {(BIGINT, 1.5M), (BIGINT, 10k), (INTEGER, 7), (VARCHAR, 4.5M)} + no payload column + r#" + SELECT l_orderkey, l_suppkey, l_linenumber, l_comment + FROM lineitem + ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment + "#, + // Q9: 4 sort keys {(BIGINT, 1.5M), (BIGINT, 10k), (INTEGER, 7), (VARCHAR, 4.5M)} + 1 payload column + r#" + SELECT l_orderkey, l_suppkey, l_linenumber, l_comment, l_partkey + FROM lineitem + ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment + "#, + // Q10: 4 sort keys {(BIGINT, 1.5M), (BIGINT, 10k), (INTEGER, 7), (VARCHAR, 4.5M)} + 12 all other columns + r#" + SELECT l_orderkey, l_suppkey, l_linenumber, l_comment, + l_partkey, l_quantity, l_extendedprice, l_discount, l_tax, + l_returnflag, l_linestatus, l_shipdate, l_commitdate, + l_receiptdate, l_shipinstruct, l_shipmode + FROM lineitem + ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment + "#, + ]; + + /// If query is specified from command line, run only that query. + /// Otherwise, run all queries. + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => 1..=Self::SORT_QUERIES.len(), + }; + + for query_id in query_range { + benchmark_run.start_new_case(&format!("{query_id}")); + + let query_results = self.benchmark_query(query_id).await?; + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + + Ok(()) + } + + /// Benchmark query `query_id` in `SORT_QUERIES` + async fn benchmark_query(&self, query_id: usize) -> Result> { + let config = self.common.config(); + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .build(); + let ctx = SessionContext::from(state); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_idx = query_id - 1; // 1-indexed -> 0-indexed + let sql = Self::SORT_QUERIES[query_idx]; + + let row_count = self.execute_query(&ctx, sql).await?; + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + println!( + "Q{query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Q{query_id} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::SORT_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(table, Arc::new(memtable))?; + } else { + ctx.register_table(table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query(&self, ctx: &SessionContext, sql: &str) -> Result { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + + let mut row_count = 0; + + let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; + while let Some(batch) = stream.next().await { + row_count += batch.unwrap().num_rows(); + } + + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + } + + Ok(row_count) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = config.infer_schema(&state).await?; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common + .partitions + .unwrap_or(get_available_parallelism()) + } +} diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 1a1f51f700651..de3ee3d67db27 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::{ get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, }; -use crate::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -37,6 +37,7 @@ use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; use log::info; @@ -120,11 +121,6 @@ impl RunOpt { .config() .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - config - .options_mut() - .execution - .parquet - .schema_force_view_types = self.common.force_view_types; let ctx = SessionContext::new_with_config(config); // register tables @@ -301,7 +297,9 @@ impl RunOpt { } fn partitions(&self) -> usize { - self.common.partitions.unwrap_or(num_cpus::get()) + self.common + .partitions + .unwrap_or(get_available_parallelism()) } } @@ -345,7 +343,6 @@ mod tests { partitions: Some(2), batch_size: 8192, debug: false, - force_view_types: false, }; let opt = RunOpt { query: Some(query), @@ -379,7 +376,6 @@ mod tests { partitions: Some(2), batch_size: 8192, debug: false, - force_view_types: false, }; let opt = RunOpt { query: Some(query), diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index efdb074b2461e..b1570a1d1bc14 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::prelude::SessionConfig; +use datafusion_common::utils::get_available_parallelism; use structopt::StructOpt; // Common benchmark options (don't use doc comments otherwise this doc @@ -37,11 +38,6 @@ pub struct CommonOpt { /// Activate debug mode to see more details #[structopt(short, long)] pub debug: bool, - - /// If true, will use StringView/BinaryViewArray instead of String/BinaryArray - /// when reading ParquetFiles - #[structopt(long)] - pub force_view_types: bool, } impl CommonOpt { @@ -53,7 +49,9 @@ impl CommonOpt { /// Modify the existing config appropriately pub fn update_config(&self, config: SessionConfig) -> SessionConfig { config - .with_target_partitions(self.partitions.unwrap_or(num_cpus::get())) + .with_target_partitions( + self.partitions.unwrap_or(get_available_parallelism()), + ) .with_batch_size(self.batch_size) } } diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 5ee6691576b44..13969f4d39497 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::{error::Result, DATAFUSION_VERSION}; +use datafusion_common::utils::get_available_parallelism; use serde::{Serialize, Serializer}; use serde_json::Value; use std::{ @@ -68,7 +69,7 @@ impl RunContext { Self { benchmark_version: env!("CARGO_PKG_VERSION").to_owned(), datafusion_version: DATAFUSION_VERSION.to_owned(), - num_cpus: num_cpus::get(), + num_cpus: get_available_parallelism(), start_time: SystemTime::now(), arguments: std::env::args().skip(1).collect::>(), } diff --git a/ci/scripts/retry b/ci/scripts/retry new file mode 100755 index 0000000000000..411dc532ca52f --- /dev/null +++ b/ci/scripts/retry @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -euo pipefail + +x() { + echo "+ $*" >&2 + "$@" +} + +max_retry_time_seconds=$(( 5 * 60 )) +retry_delay_seconds=10 + +END=$(( $(date +%s) + ${max_retry_time_seconds} )) + +while (( $(date +%s) < $END )); do + x "$@" && exit 0 + sleep "${retry_delay_seconds}" +done + +echo "$0: retrying [$*] timed out" >&2 +exit 1 diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index 1bb97c88106f2..c3efcf2cf2e92 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -17,9 +17,13 @@ # specific language governing permissions and limitations # under the License. -set -ex +set -e + +export CARGO_PROFILE_CI_OPT_LEVEL="s" +export CARGO_PROFILE_CI_STRIP=true + cd datafusion-examples/examples/ -cargo check --examples +cargo build --profile ci --examples files=$(ls .) for filename in $files @@ -27,7 +31,6 @@ do example_name=`basename $filename ".rs"` # Skip tests that rely on external storage and flight if [ ! -d $filename ]; then - cargo run --example $example_name - cargo clean -p datafusion-examples + cargo run --profile ci --example $example_name fi done diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8a6ccacbb3807..8c7f2113eedb3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -63,9 +63,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android-tzdata" @@ -84,9 +84,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -99,48 +99,49 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] name = "apache-avro" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ceb7c683b2f8f40970b70e39ff8be514c95b96fcb9c4af87e1ed2cb2e10801a0" +checksum = "1aef82843a0ec9f8b19567445ad2421ceeb1d711514384bdd3d49fe37102ee13" dependencies = [ - "bzip2", + "bigdecimal", + "bzip2 0.4.4", "crc32fast", "digest", - "lazy_static", "libflate", "log", "num-bigint", @@ -148,15 +149,16 @@ dependencies = [ "rand", "regex-lite", "serde", + "serde_bytes", "serde_json", "snap", - "strum 0.25.0", - "strum_macros 0.25.3", - "thiserror", + "strum", + "strum_macros", + "thiserror 1.0.69", "typed-builder", "uuid", "xz2", - "zstd 0.12.4", + "zstd", ] [[package]] @@ -173,9 +175,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9ba0d7248932f4e2a12fb37f0a2e3ec82b3bdedbac2a1dce186e036843b8f8c" +checksum = "d2ccdcc8fb14508ca20aaec7076032e5c0b0751b906036d4496786e2f227a37a" dependencies = [ "arrow-arith", "arrow-array", @@ -194,24 +196,23 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d60afcdc004841a5c8d8da4f4fa22d64eb19c0c01ef4bcedd77f175a7cf6e38f" +checksum = "a1aad8e27f32e411a0fc0bf5a625a35f0bf9b9f871cf4542abe31f7cef4beea2" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "half", "num", ] [[package]] name = "arrow-array" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f16835e8599dbbb1659fd869d865254c4cf32c6c2bb60b6942ac9fc36bfa5da" +checksum = "bd6ed90c28c6f73a706c55799b8cc3a094e89257238e5b1d65ca7c70bd3ae23f" dependencies = [ "ahash", "arrow-buffer", @@ -220,15 +221,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "num", ] [[package]] name = "arrow-buffer" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a1f34f0faae77da6b142db61deba2cb6d60167592b178be317b341440acba80" +checksum = "fe4a40bdc1552ea10fbdeae4e5a945d8572c32f66bce457b96c13d9c46b80447" dependencies = [ "bytes", "half", @@ -237,9 +238,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "450e4abb5775bca0740bec0bcf1b1a5ae07eff43bd625661c4436d8e8e4540c4" +checksum = "430c0a21aa7f81bcf0f97c57216d7127795ea755f494d27bae2bd233be43c2cc" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,28 +259,25 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3a4e4d63830a341713e35d9a42452fbc6241d5f42fa5cf6a4681b8ad91370c4" +checksum = "b4444c8f8c57ac00e6a679ede67d1ae8872c170797dc45b46f75702437a77888" dependencies = [ "arrow-array", - "arrow-buffer", "arrow-cast", - "arrow-data", "arrow-schema", "chrono", "csv", "csv-core", "lazy_static", - "lexical-core", "regex", ] [[package]] name = "arrow-data" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b1e618bbf714c7a9e8d97203c806734f012ff71ae3adc8ad1b075689f540634" +checksum = "09af476cfbe9879937e50b1334c73189de6039186e025b1b1ac84b283b87b20e" dependencies = [ "arrow-buffer", "arrow-schema", @@ -289,13 +287,12 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98e983549259a2b97049af7edfb8f28b8911682040e99a94e4ceb1196bd65c2" +checksum = "136296e8824333a8a4c4a6e508e4aa65d5678b801246d0408825ae7b2523c628" dependencies = [ "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-data", "arrow-schema", "flatbuffers", @@ -304,9 +301,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b198b9c6fcf086501730efbbcb483317b39330a116125af7bb06467d04b352a3" +checksum = "e222ad0e419ab8276818c5605a5bb1e35ed86fa8c5e550726433cc63b09c3c78" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,26 +321,23 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2427f37b4459a4b9e533045abe87a5183a5e0995a3fc2c2fd45027ae2cc4ef3f" +checksum = "eddf14c5f03b679ec8ceac4dfac43f63cdc4ed54dab3cc120a4ef46af38481eb" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "half", - "num", ] [[package]] name = "arrow-row" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15959657d92e2261a7a323517640af87f5afd9fd8a6492e424ebee2203c567f6" +checksum = "e9acdc58da19f383f4ba381fa0e3583534ae2ceb31269aaf4a03f08ff13e8443" dependencies = [ - "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -353,15 +347,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf0388a18fd7f7f3fe3de01852d30f54ed5182f9004db700fbe3ba843ed2794" +checksum = "3a1822a1a952955637e85e8f9d6b0e04dd75d65492b87ec548dd593d3a1f772b" [[package]] name = "arrow-select" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b83e5723d307a38bf00ecd2972cd078d1339c7fd3eb044f609958a9a24463f3a" +checksum = "5c4172e9a12dfe15303d3926269f9ead471ea93bdd067d113abc65cb6c48e246" dependencies = [ "ahash", "arrow-array", @@ -373,9 +367,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab3db7c09dd826e74079661d84ed01ed06547cf75d52c2818ef776d0d852305" +checksum = "73683040445f4932342781926189901c9521bb1a787c35dbe628a3ce51372d3c" dependencies = [ "arrow-array", "arrow-buffer", @@ -406,27 +400,26 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.13" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e614738943d3f68c628ae3dbce7c3daffb196665f82f8c8ea6b65de73c79429" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" dependencies = [ - "bzip2", + "bzip2 0.4.4", "flate2", "futures-core", - "futures-io", "memchr", "pin-project-lite", "tokio", "xz2", - "zstd 0.13.2", - "zstd-safe 7.2.1", + "zstd", + "zstd-safe", ] [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" dependencies = [ "proc-macro2", "quote", @@ -456,9 +449,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.8" +version = "1.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7198e6f03240fdceba36656d8be440297b6b82270325908c7381f37d826a74f6" +checksum = "9b49afaa341e8dd8577e1a2200468f98956d6eda50bcf4a53246cc00174ba924" dependencies = [ "aws-credential-types", "aws-runtime", @@ -467,7 +460,7 @@ dependencies = [ "aws-sdk-sts", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -498,9 +491,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.4.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" +checksum = "bee7643696e7fdd74c10f9eb42848a87fe469d35eae9c3323f80aa98f350baac" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -523,15 +516,15 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.45.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" +checksum = "05ca43a4ef210894f93096039ef1d6fa4ad3edfabb3be92b80908b9f2e4b4eab" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.2", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -545,15 +538,15 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.46.0" +version = "1.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" +checksum = "abaf490c2e48eed0bb8e2da2fb08405647bd7f253996e0f93b981958ea0f73b0" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.2", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -567,15 +560,15 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.45.0" +version = "1.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" +checksum = "b68fde0d69c8bfdc1060ea7da21df3e39f6014da316783336deff0a9ec28f4bf" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.2", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -590,9 +583,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.4" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +checksum = "690118821e46967b3c4501d67d7d52dd75106a9c54cf36cefa1985cedbe94e05" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -603,7 +596,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.1.0", + "http 1.2.0", "once_cell", "percent-encoding", "sha2", @@ -613,9 +606,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.1" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62220bc6e97f946ddd51b5f1361f78996e704677afc518a4ff66b7a72ea1378c" +checksum = "fa59d1327d8b5053c54bf2eaae63bf629ba9e904434d0835a28ed3c0ed0a614e" dependencies = [ "futures-util", "pin-project-lite", @@ -624,9 +617,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.11" +version = "0.60.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8bc3e8fdc6b8d07d976e301c02fe553f72a39b7a9fea820e023268467d7ab6" +checksum = "7809c27ad8da6a6a68c454e651d4962479e81472aa19ae99e59f9aba1f9713cc" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -651,6 +644,15 @@ dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-json" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "623a51127f24c30776c8b374295f2df78d92517386f77ba30773f15a30ce1422" +dependencies = [ + "aws-smithy-types", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -663,9 +665,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "865f7050bbc7107a6c98a397a9fcd9413690c27fa718446967cf03b2d3ac517e" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -678,7 +680,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.32", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -690,15 +692,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.1.0", + "http 1.2.0", "pin-project-lite", "tokio", "tracing", @@ -707,16 +709,16 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.7" +version = "1.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" +checksum = "a28f6feb647fb5e0d5b50f0472c19a7db9462b74e2fec01bb0b44eedcc834e97" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.1.0", + "http 1.2.0", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -742,9 +744,9 @@ dependencies = [ [[package]] name = "aws-types" -version = "1.3.3" +version = "1.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5221b91b3e441e6675310829fd8984801b772cb1546ef6c0e54dec9f1ac13fef" +checksum = "b0df5a18c4f951c645300d365fec53a61418bcf4650f604f85fe2a665bfaa0c2" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -791,6 +793,20 @@ dependencies = [ "vsimd", ] +[[package]] +name = "bigdecimal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f31f3af01c5c65a07985c804d3366560e6fa7883d640a122819b14ec327482c" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -799,9 +815,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "blake2" @@ -814,9 +830,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.4" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" dependencies = [ "arrayref", "arrayvec", @@ -836,9 +852,9 @@ dependencies = [ [[package]] name = "brotli" -version = "6.0.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -857,9 +873,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" dependencies = [ "memchr", "regex-automata", @@ -880,9 +896,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "bytes-utils" @@ -904,6 +920,16 @@ dependencies = [ "libc", ] +[[package]] +name = "bzip2" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bafdbf26611df8c14810e268ddceda071c297570a5fb360ceddf617fe417ef58" +dependencies = [ + "bzip2-sys", + "libc", +] + [[package]] name = "bzip2-sys" version = "0.1.11+1.0.8" @@ -917,9 +943,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.28" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "jobserver", "libc", @@ -938,11 +964,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", @@ -974,9 +1006,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", "clap_derive", @@ -984,9 +1016,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstream", "anstyle", @@ -996,11 +1028,11 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn", @@ -1008,9 +1040,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "clipboard-win" @@ -1023,19 +1055,19 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ - "strum 0.26.3", - "strum_macros 0.26.4", - "unicode-width", + "strum", + "strum_macros", + "unicode-width 0.2.0", ] [[package]] @@ -1074,6 +1106,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1091,9 +1133,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -1109,9 +1151,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" @@ -1131,9 +1173,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -1152,9 +1194,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" +checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" dependencies = [ "quote", "syn", @@ -1162,9 +1204,9 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "dashmap" @@ -1182,9 +1224,8 @@ dependencies = [ [[package]] name = "datafusion" -version = "42.0.0" +version = "44.0.0" dependencies = [ - "ahash", "apache-avro", "arrow", "arrow-array", @@ -1193,7 +1234,7 @@ dependencies = [ "async-compression", "async-trait", "bytes", - "bzip2", + "bzip2 0.5.0", "chrono", "dashmap", "datafusion-catalog", @@ -1204,6 +1245,7 @@ dependencies = [ "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-nested", + "datafusion-functions-table", "datafusion-functions-window", "datafusion-optimizer", "datafusion-physical-expr", @@ -1214,19 +1256,14 @@ dependencies = [ "flate2", "futures", "glob", - "half", - "hashbrown 0.14.5", - "indexmap", - "itertools", + "itertools 0.14.0", "log", "num-traits", - "num_cpus", "object_store", "parking_lot", "parquet", - "paste", - "pin-project-lite", "rand", + "regex", "sqlparser", "tempfile", "tokio", @@ -1234,12 +1271,12 @@ dependencies = [ "url", "uuid", "xz2", - "zstd 0.13.2", + "zstd", ] [[package]] name = "datafusion-catalog" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow-schema", "async-trait", @@ -1252,7 +1289,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1265,9 +1302,11 @@ dependencies = [ "clap", "ctor", "datafusion", + "datafusion-catalog", "dirs", "env_logger", "futures", + "home", "mimalloc", "object_store", "parking_lot", @@ -1282,46 +1321,51 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "42.0.0" +version = "44.0.0" dependencies = [ "ahash", "apache-avro", "arrow", "arrow-array", "arrow-buffer", + "arrow-ipc", "arrow-schema", - "chrono", + "base64 0.22.1", "half", "hashbrown 0.14.5", - "instant", + "indexmap", "libc", - "num_cpus", + "log", "object_store", "parquet", "paste", + "recursive", "sqlparser", "tokio", + "web-time", ] [[package]] name = "datafusion-common-runtime" -version = "42.0.0" +version = "44.0.0" dependencies = [ "log", "tokio", ] +[[package]] +name = "datafusion-doc" +version = "44.0.0" + [[package]] name = "datafusion-execution" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", - "chrono", "dashmap", "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.5", "log", "object_store", "parking_lot", @@ -1332,38 +1376,35 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "42.0.0" +version = "44.0.0" dependencies = [ - "ahash", "arrow", - "arrow-array", - "arrow-buffer", "chrono", "datafusion-common", + "datafusion-doc", "datafusion-expr-common", "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", "indexmap", "paste", + "recursive", "serde_json", "sqlparser", - "strum 0.26.3", - "strum_macros 0.26.4", ] [[package]] name = "datafusion-expr-common" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", "datafusion-common", - "paste", + "itertools 0.14.0", ] [[package]] name = "datafusion-functions" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", "arrow-buffer", @@ -1372,11 +1413,14 @@ dependencies = [ "blake3", "chrono", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", + "datafusion-macros", "hashbrown 0.14.5", "hex", - "itertools", + "itertools 0.14.0", "log", "md-5", "rand", @@ -1388,38 +1432,38 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "42.0.0" +version = "44.0.0" dependencies = [ "ahash", "arrow", "arrow-schema", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-functions-aggregate-common", + "datafusion-macros", "datafusion-physical-expr", "datafusion-physical-expr-common", "half", - "indexmap", "log", "paste", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "42.0.0" +version = "44.0.0" dependencies = [ "ahash", "arrow", "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand", ] [[package]] name = "datafusion-functions-nested" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", "arrow-array", @@ -1427,24 +1471,42 @@ dependencies = [ "arrow-ord", "arrow-schema", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-macros", "datafusion-physical-expr-common", - "itertools", + "itertools 0.14.0", "log", "paste", - "rand", +] + +[[package]] +name = "datafusion-functions-table" +version = "44.0.0" +dependencies = [ + "arrow", + "async-trait", + "datafusion-catalog", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-plan", + "parking_lot", + "paste", ] [[package]] name = "datafusion-functions-window" -version = "42.0.0" +version = "44.0.0" dependencies = [ "datafusion-common", + "datafusion-doc", "datafusion-expr", "datafusion-functions-window-common", + "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", "log", "paste", @@ -1452,86 +1514,95 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "42.0.0" +version = "44.0.0" dependencies = [ "datafusion-common", + "datafusion-physical-expr-common", +] + +[[package]] +name = "datafusion-macros" +version = "44.0.0" +dependencies = [ + "datafusion-expr", + "quote", + "syn", ] [[package]] name = "datafusion-optimizer" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", - "async-trait", "chrono", "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.5", "indexmap", - "itertools", + "itertools 0.14.0", "log", - "paste", + "recursive", + "regex", "regex-syntax", ] [[package]] name = "datafusion-physical-expr" -version = "42.0.0" +version = "44.0.0" dependencies = [ "ahash", "arrow", "arrow-array", "arrow-buffer", - "arrow-ord", "arrow-schema", - "arrow-string", - "base64 0.22.1", - "chrono", "datafusion-common", - "datafusion-execution", "datafusion-expr", "datafusion-expr-common", "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "hex", "indexmap", - "itertools", + "itertools 0.14.0", "log", "paste", "petgraph", - "regex", ] [[package]] name = "datafusion-physical-expr-common" -version = "42.0.0" +version = "44.0.0" dependencies = [ "ahash", "arrow", "datafusion-common", "datafusion-expr-common", "hashbrown 0.14.5", - "rand", + "itertools 0.14.0", ] [[package]] name = "datafusion-physical-optimizer" -version = "42.0.0" +version = "44.0.0" dependencies = [ + "arrow", "arrow-schema", "datafusion-common", "datafusion-execution", + "datafusion-expr", + "datafusion-expr-common", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", - "itertools", + "futures", + "itertools 0.14.0", + "log", + "recursive", ] [[package]] name = "datafusion-physical-plan" -version = "42.0.0" +version = "44.0.0" dependencies = [ "ahash", "arrow", @@ -1545,8 +1616,6 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", - "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -1554,28 +1623,28 @@ dependencies = [ "half", "hashbrown 0.14.5", "indexmap", - "itertools", + "itertools 0.14.0", "log", - "once_cell", "parking_lot", "pin-project-lite", - "rand", "tokio", ] [[package]] name = "datafusion-sql" -version = "42.0.0" +version = "44.0.0" dependencies = [ "arrow", "arrow-array", "arrow-schema", + "bigdecimal", "datafusion-common", "datafusion-expr", + "indexmap", "log", + "recursive", "regex", "sqlparser", - "strum 0.26.3", ] [[package]] @@ -1625,6 +1694,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -1645,9 +1725,9 @@ checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] name = "env_filter" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -1655,9 +1735,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" dependencies = [ "anstream", "anstyle", @@ -1674,12 +1754,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1690,9 +1770,9 @@ checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" [[package]] name = "fastrand" -version = "2.1.1" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fd-lock" @@ -1707,15 +1787,15 @@ dependencies = [ [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flatbuffers" -version = "24.3.25" +version = "24.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" +checksum = "4f1baf0dbf96932ec9a3038d57900329c015b0bfb7b63d904f3bc27e2b02a096" dependencies = [ "bitflags 1.3.2", "rustc_version", @@ -1723,9 +1803,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -1733,9 +1813,9 @@ dependencies = [ [[package]] name = "float-cmp" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" dependencies = [ "num-traits", ] @@ -1867,8 +1947,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1879,9 +1961,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "h2" @@ -1904,16 +1986,16 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.1.0", + "http 1.2.0", "indexmap", "slab", "tokio", @@ -1944,15 +2026,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" - -[[package]] -name = "heck" -version = "0.4.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "heck" @@ -1960,12 +2036,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - [[package]] name = "hex" version = "0.4.3" @@ -2003,9 +2073,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" dependencies = [ "bytes", "fnv", @@ -2030,7 +2100,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.1.0", + "http 1.2.0", ] [[package]] @@ -2041,7 +2111,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "pin-project-lite", ] @@ -2066,9 +2136,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" dependencies = [ "bytes", "futures-channel", @@ -2090,15 +2160,15 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", - "http 1.1.0", + "h2 0.4.7", + "http 1.2.0", "http-body 1.0.1", "httparse", "itoa", @@ -2116,7 +2186,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.32", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2126,34 +2196,34 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.3" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.1.0", - "hyper 1.4.1", + "http 1.2.0", + "hyper 1.5.2", "hyper-util", - "rustls 0.23.14", - "rustls-native-certs 0.8.0", + "rustls 0.23.21", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tower-service", ] [[package]] name = "hyper-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.2", "pin-project-lite", "socket2", "tokio", @@ -2184,36 +2254,153 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", ] [[package]] -name = "indexmap" -version = "2.6.0" +name = "idna_adapter" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" dependencies = [ - "equivalent", - "hashbrown 0.15.0", + "icu_normalizer", + "icu_properties", ] [[package]] -name = "instant" -version = "0.1.13" +name = "indexmap" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ - "cfg-if", - "js-sys", - "wasm-bindgen", - "web-sys", + "equivalent", + "hashbrown 0.15.2", ] [[package]] @@ -2243,11 +2430,20 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "jobserver" @@ -2260,10 +2456,11 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -2275,9 +2472,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -2288,9 +2485,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -2299,9 +2496,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" dependencies = [ "lexical-util", "static_assertions", @@ -2309,18 +2506,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "1.0.3" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" dependencies = [ "lexical-util", "lexical-write-integer", @@ -2329,9 +2526,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" dependencies = [ "lexical-util", "static_assertions", @@ -2339,9 +2536,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libflate" @@ -2369,9 +2566,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libmimalloc-sys" @@ -2389,15 +2586,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "libc", ] [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "lock_api" @@ -2411,9 +2614,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "lz4_flex" @@ -2468,20 +2671,19 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi", "libc", "wasi", "windows-sys 0.52.0", @@ -2502,9 +2704,9 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", "libc", ] @@ -2536,6 +2738,7 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", + "serde", ] [[package]] @@ -2594,30 +2797,20 @@ dependencies = [ "libm", ] -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.11.0" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25a0c4b3a0e31f8b66f71ad8064521efa773910196e2cde791436f13409f3b45" +checksum = "3cfccb68961a56facde1163f9319e0d15743352344e7808a11795fb99698dcaf" dependencies = [ "async-trait", "base64 0.22.1", @@ -2625,8 +2818,8 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", - "itertools", + "hyper 1.5.2", + "itertools 0.13.0", "md-5", "parking_lot", "percent-encoding", @@ -2673,9 +2866,9 @@ dependencies = [ [[package]] name = "outref" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "parking_lot" @@ -2702,9 +2895,9 @@ dependencies = [ [[package]] name = "parquet" -version = "53.1.0" +version = "54.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "310c46a70a3ba90d98fec39fa2da6d9d731e544191da6fb56c9d199484d0dd3e" +checksum = "3334c50239d9f4951653d84fa6f636da86f53742e5e5849a30fbe852f3ff4383" dependencies = [ "ahash", "arrow-array", @@ -2721,7 +2914,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "lz4_flex", "num", "num-bigint", @@ -2732,7 +2925,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd 0.13.2", + "zstd", "zstd-sys", ] @@ -2759,9 +2952,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", "indexmap", @@ -2769,18 +2962,18 @@ dependencies = [ [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -2788,9 +2981,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", "rand", @@ -2798,18 +2991,18 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ "siphasher", ] [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -2840,9 +3033,9 @@ dependencies = [ [[package]] name = "predicates" -version = "3.1.2" +version = "3.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" dependencies = [ "anstyle", "difflib", @@ -2854,15 +3047,15 @@ dependencies = [ [[package]] name = "predicates-core" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" [[package]] name = "predicates-tree" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" dependencies = [ "predicates-core", "termtree", @@ -2879,24 +3072,33 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200b9ff220857e53e184257720a14553b2f4aa02577d2ed9842d45d4b9654810" +dependencies = [ + "cc", +] + [[package]] name = "quad-rand" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b76f1009795ca44bb5aaae8fd3f18953e209259c33d9b059b1f53d58ab7511db" +checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" [[package]] name = "quick-xml" -version = "0.36.2" +version = "0.37.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" +checksum = "165859e9e55f79d67b96c5d96f4e88b6f2695a1972849c15a6a3f5c59fc2c003" dependencies = [ "memchr", "serde", @@ -2904,45 +3106,49 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.14", + "rustls 0.23.21", "socket2", - "thiserror", + "thiserror 2.0.11", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash", - "rustls 0.23.14", + "rustls 0.23.21", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.11", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" dependencies = [ + "cfg_aliases 0.2.1", "libc", "once_cell", "socket2", @@ -2952,9 +3158,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -2999,13 +3205,33 @@ dependencies = [ "getrandom", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", ] [[package]] @@ -3016,14 +3242,14 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -3033,9 +3259,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -3062,20 +3288,20 @@ checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", - "h2 0.4.6", - "http 1.1.0", + "h2 0.4.7", + "http 1.2.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", - "hyper-rustls 0.27.3", + "hyper 1.5.2", + "hyper-rustls 0.27.5", "hyper-util", "ipnet", "js-sys", @@ -3085,8 +3311,8 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.14", - "rustls-native-certs 0.8.0", + "rustls 0.23.21", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -3094,8 +3320,9 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tokio-util", + "tower", "tower-service", "url", "wasm-bindgen", @@ -3164,9 +3391,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" [[package]] name = "rustc_version" @@ -3179,15 +3406,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3204,9 +3431,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.14" +version = "0.23.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" +checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" dependencies = [ "once_cell", "ring", @@ -3225,20 +3452,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile 1.0.4", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.2.0", ] [[package]] @@ -3261,9 +3487,12 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -3288,9 +3517,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "rustyline" @@ -3298,7 +3527,7 @@ version = "14.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7803e8936da37efd9b6d4478277f4b2b9bb5cdb37a113e8d63222e58da647e63" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", "clipboard-win", "fd-lock", @@ -3309,7 +3538,7 @@ dependencies = [ "nix", "radix_trie", "unicode-segmentation", - "unicode-width", + "unicode-width 0.1.14", "utf8parse", "windows-sys 0.52.0", ] @@ -3331,9 +3560,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -3360,8 +3589,21 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.6.0", - "core-foundation", + "bitflags 2.8.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.8.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -3369,9 +3611,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -3379,9 +3621,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.23" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" [[package]] name = "seq-macro" @@ -3391,18 +3633,27 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", @@ -3411,9 +3662,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -3461,9 +3712,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "slab" @@ -3495,7 +3746,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn", @@ -3509,9 +3760,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -3525,9 +3776,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.51.0" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fe11944a61da0da3f592e19a45ebe5ab92dc14a779907ff1f08fbb797bfefc7" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" dependencies = [ "log", "sqlparser_derive", @@ -3535,15 +3786,34 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", "syn", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "stacker" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799c883d55abdb5e98af1a7b3f23b9b6de8ecada0ecac058672d7635eb48ca7b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -3556,33 +3826,11 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" - [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" -dependencies = [ - "strum_macros 0.26.4", -] - -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn", -] [[package]] name = "strum_macros" @@ -3590,7 +3838,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -3605,9 +3853,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -3616,21 +3864,33 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tempfile" -version = "3.13.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -3638,24 +3898,44 @@ dependencies = [ [[package]] name = "termtree" -version = "0.4.1" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + +[[package]] +name = "thiserror" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] [[package]] name = "thiserror" -version = "1.0.64" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.11", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", @@ -3675,9 +3955,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", "num-conv", @@ -3695,9 +3975,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" dependencies = [ "num-conv", "time-core", @@ -3712,11 +3992,21 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -3729,9 +4019,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -3747,9 +4037,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", @@ -3768,20 +4058,19 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ - "rustls 0.23.14", - "rustls-pki-types", + "rustls 0.23.21", "tokio", ] [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", @@ -3807,6 +4096,27 @@ dependencies = [ "winnow", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -3815,9 +4125,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -3826,9 +4136,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -3837,9 +4147,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", ] @@ -3862,18 +4172,18 @@ dependencies = [ [[package]] name = "typed-builder" -version = "0.16.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34085c17941e36627a879208083e25d357243812c30e7d7387c3b954f30ade16" +checksum = "a06fbd5b8de54c5f7c91f6fe4cebb949be2125d7758e630bb58b1d831dbce600" dependencies = [ "typed-builder-macro", ] [[package]] name = "typed-builder-macro" -version = "0.16.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" +checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", @@ -3886,26 +4196,11 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unicode-bidi" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" - [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" - -[[package]] -name = "unicode-normalization" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" -dependencies = [ - "tinyvec", -] +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-segmentation" @@ -3919,6 +4214,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "untrusted" version = "0.9.0" @@ -3927,9 +4228,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.2" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -3942,6 +4243,18 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -3950,9 +4263,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" dependencies = [ "getrandom", "serde", @@ -4006,24 +4319,24 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn", @@ -4032,21 +4345,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4054,9 +4368,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", @@ -4067,15 +4381,18 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-streams" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -4086,9 +4403,19 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", "wasm-bindgen", @@ -4292,13 +4619,25 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" dependencies = [ "memchr", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "xmlparser" version = "0.13.6" @@ -4314,6 +4653,30 @@ dependencies = [ "lzma-sys", ] +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -4335,6 +4698,27 @@ dependencies = [ "syn", ] +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" @@ -4342,31 +4726,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" [[package]] -name = "zstd" -version = "0.12.4" +name = "zerovec" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" dependencies = [ - "zstd-safe 6.0.6", + "yoke", + "zerofrom", + "zerovec-derive", ] [[package]] -name = "zstd" -version = "0.13.2" +name = "zerovec-derive" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ - "zstd-safe 7.2.1", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "zstd-safe" -version = "6.0.6" +name = "zstd" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "libc", - "zstd-sys", + "zstd-safe", ] [[package]] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index b86dbd2a38027..b9d190ac07cc5 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,51 +18,69 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "42.0.0" +version = "44.0.0" authors = ["Apache DataFusion "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" -# Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.80.1" readme = "README.md" [dependencies] -arrow = { version = "53.0.0" } +arrow = { version = "54.0.0" } async-trait = "0.1.73" -aws-config = "1.5.5" -aws-sdk-sso = "1.43.0" -aws-sdk-ssooidc = "1.44.0" -aws-sdk-sts = "1.43.0" +## 1.5.13 requires a hiher MSRV 1.81 so lock until DataFusion MSRV catches up +aws-config = "=1.5.10" +## 1.53.0 requires a higher MSRV 1.81 so lock until DataFusion MSRV catches up +aws-sdk-sso = "=1.50.0" +## 1.54.0 requires a higher MSRV 1.81 so lock until DataFusion MSRV catches up +aws-sdk-ssooidc = "=1.51.0" +## 1.54.1 requires a higher MSRV 1.81 so lock until DataFusion MSRV catches up +aws-sdk-sts = "=1.51.0" # end pin aws-sdk crates aws-credential-types = "1.2.0" clap = { version = "4.5.16", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "42.0.0", features = [ +datafusion = { path = "../datafusion/core", version = "44.0.0", features = [ "avro", "crypto_expressions", "datetime_expressions", "encoding_expressions", "parquet", + "recursive_protection", "regex_expressions", "unicode_expressions", "compression", ] } +datafusion-catalog = { path = "../datafusion/catalog", version = "44.0.0" } dirs = "5.0.1" env_logger = "0.11" futures = "0.3" +# pin as home 0.5.11 has MSRV 1.81. Can remove this once we bump MSRV to 1.81 +home = "=0.5.9" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.11.0", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "53.0.0", default-features = false } +parquet = { version = "54.0.0", default-features = false } regex = "1.8" rustyline = "14.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } -url = "2.2" +url = "2.5.4" [dev-dependencies] assert_cmd = "2.0" -ctor = "0.2.0" +ctor = "0.2.9" predicates = "3.0" rstest = "0.22" + +[profile.ci] +inherits = "dev" +incremental = false + +# ci turns off debug info, etc for dependencies to allow for smaller binaries making caching more effective +[profile.ci.package."*"] +debug = false +debug-assertions = false +strip = "debuginfo" +incremental = false diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index 7adead64db57c..faf345660dbea 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.78-bookworm AS builder +FROM rust:1.80-bookworm AS builder COPY . /usr/src/datafusion COPY ./datafusion /usr/src/datafusion/datafusion diff --git a/datafusion-cli/README.md b/datafusion-cli/README.md index 73a2eb01b76ff..ce09c3b345b9b 100644 --- a/datafusion-cli/README.md +++ b/datafusion-cli/README.md @@ -41,6 +41,8 @@ The reason `datafusion-cli` is not part of the main workspace in checked in `Cargo.lock` file to ensure reproducible builds. However, the `datafusion` and sub crates are intended for use as libraries and -thus do not have a `Cargo.lock` file checked in. +thus do not have a `Cargo.lock` file checked in, as described in the [main +README] file. [`datafusion cargo.toml`]: https://github.com/apache/datafusion/blob/main/Cargo.toml +[main readme]: ../README.md diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index db4242d971758..a4f154b2de92d 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -33,11 +33,12 @@ use crate::{ }; use datafusion::common::instant::Instant; -use datafusion::common::plan_datafusion_err; +use datafusion::common::{plan_datafusion_err, plan_err}; use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; +use datafusion::physical_plan::execution_plan::EmissionType; use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; @@ -234,10 +235,19 @@ pub(super) async fn exec_and_print( let df = ctx.execute_logical_plan(plan).await?; let physical_plan = df.create_physical_plan().await?; - if physical_plan.execution_mode().is_unbounded() { + if physical_plan.boundedness().is_unbounded() { + if physical_plan.pipeline_behavior() == EmissionType::Final { + return plan_err!( + "The given query can generate a valid result only once \ + the source finishes, but the source is unbounded" + ); + } + // As the input stream comes, we can generate results. + // However, memory safety is not guaranteed. let stream = execute_stream(physical_plan, task_ctx.clone())?; print_options.print_stream(stream, now).await?; } else { + // Bounded stream; collected results are printed after all input consumed. let schema = physical_plan.schema(); let results = collect(physical_plan, task_ctx.clone()).await?; adjusted.into_inner().print_batches(schema, &results, now)?; @@ -383,7 +393,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx.register_table_options_extension_from_scheme(scheme); // Clone and modify the default table options based on the provided options - let mut table_options = ctx.session_state().default_table_options().clone(); + let mut table_options = ctx.session_state().default_table_options(); if let Some(format) = format { table_options.set_config_format(format); } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 3b91abf8f3dcf..25d9b1681e516 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -24,13 +24,13 @@ use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::common::{plan_err, Column}; -use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::logical_expr::{Expr, Scalar}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::scalar::ScalarValue; +use datafusion_catalog::TableFunctionImpl; use parquet::basic::ConvertedType; use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::file::reader::FileReader; @@ -363,7 +363,7 @@ impl TableFunctionImpl for ParquetMetadataFunc { Field::new("total_uncompressed_size", DataType::Int64, true), ])); - // construct recordbatch from metadata + // construct record batch from metadata let mut filename_arr = vec![]; let mut row_group_id_arr = vec![]; let mut row_group_num_rows_arr = vec![]; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 4c6c352ff3395..52665df3751ea 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -19,12 +19,12 @@ use std::collections::HashMap; use std::env; use std::path::Path; use std::process::ExitCode; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; -use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; use datafusion_cli::functions::ParquetMetadataFunc; @@ -156,27 +156,22 @@ async fn main_inner() -> Result<()> { session_config = session_config.with_batch_size(batch_size); }; - let rt_config = RuntimeConfig::new(); - let rt_config = - // set memory pool size - if let Some(memory_limit) = args.memory_limit { - // set memory pool type - match args.mem_pool_type { - PoolType::Fair => rt_config - .with_memory_pool(Arc::new(FairSpillPool::new(memory_limit))), - PoolType::Greedy => rt_config - .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))) - } - } else { - rt_config + let mut rt_builder = RuntimeEnvBuilder::new(); + // set memory pool size + if let Some(memory_limit) = args.memory_limit { + // set memory pool type + let pool: Arc = match args.mem_pool_type { + PoolType::Fair => Arc::new(FairSpillPool::new(memory_limit)), + PoolType::Greedy => Arc::new(GreedyMemoryPool::new(memory_limit)), }; + rt_builder = rt_builder.with_memory_pool(pool) + } - let runtime_env = create_runtime_env(rt_config.clone())?; + let runtime_env = rt_builder.build_arc()?; // enable dynamic file query - let ctx = - SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)) - .enable_url_table(); + let ctx = SessionContext::new_with_config_rt(session_config, runtime_env) + .enable_url_table(); ctx.refresh_catalogs().await?; // install dynamic catalog provider that can register required object stores ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new( @@ -231,10 +226,6 @@ async fn main_inner() -> Result<()> { Ok(()) } -fn create_runtime_env(rn_config: RuntimeConfig) -> Result { - RuntimeEnv::try_new(rn_config) -} - fn parse_valid_file(dir: &str) -> Result { if Path::new(dir).is_file() { Ok(dir.to_string()) @@ -288,9 +279,8 @@ impl ByteUnit { } fn extract_memory_pool_size(size: &str) -> Result { - fn byte_suffixes() -> &'static HashMap<&'static str, ByteUnit> { - static BYTE_SUFFIXES: OnceLock> = OnceLock::new(); - BYTE_SUFFIXES.get_or_init(|| { + static BYTE_SUFFIXES: LazyLock> = + LazyLock::new(|| { let mut m = HashMap::new(); m.insert("b", ByteUnit::Byte); m.insert("k", ByteUnit::KiB); @@ -302,23 +292,20 @@ fn extract_memory_pool_size(size: &str) -> Result { m.insert("t", ByteUnit::TiB); m.insert("tb", ByteUnit::TiB); m - }) - } + }); - fn suffix_re() -> &'static regex::Regex { - static SUFFIX_REGEX: OnceLock = OnceLock::new(); - SUFFIX_REGEX.get_or_init(|| regex::Regex::new(r"^(-?[0-9]+)([a-z]+)?$").unwrap()) - } + static SUFFIX_REGEX: LazyLock = + LazyLock::new(|| regex::Regex::new(r"^(-?[0-9]+)([a-z]+)?$").unwrap()); let lower = size.to_lowercase(); - if let Some(caps) = suffix_re().captures(&lower) { + if let Some(caps) = SUFFIX_REGEX.captures(&lower) { let num_str = caps.get(1).unwrap().as_str(); let num = num_str.parse::().map_err(|_| { format!("Invalid numeric value in memory pool size '{}'", size) })?; let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b"); - let unit = byte_suffixes() + let unit = &BYTE_SUFFIXES .get(suffix) .ok_or_else(|| format!("Invalid memory pool size '{}'", size))?; let memory_pool_size = usize::try_from(unit.multiplier()) diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index e8d60e4f0926c..045c924e50370 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -32,7 +32,7 @@ use aws_credential_types::provider::ProvideCredentials; use object_store::aws::{AmazonS3Builder, AwsCredential}; use object_store::gcp::GoogleCloudStorageBuilder; use object_store::http::HttpBuilder; -use object_store::{CredentialProvider, ObjectStore}; +use object_store::{ClientOptions, CredentialProvider, ObjectStore}; use url::Url; pub async fn get_s3_object_store_builder( @@ -437,6 +437,7 @@ pub(crate) async fn get_object_store( } "http" | "https" => Arc::new( HttpBuilder::new() + .with_client_options(ClientOptions::new().with_allow_http(true)) .with_url(url.origin().ascii_serialization()) .build()?, ), @@ -471,12 +472,13 @@ mod tests { #[tokio::test] async fn s3_object_store_builder() -> Result<()> { - let access_key_id = "fake_access_key_id"; - let secret_access_key = "fake_secret_access_key"; + // "fake" is uppercase to ensure the values are not lowercased when parsed + let access_key_id = "FAKE_access_key_id"; + let secret_access_key = "FAKE_secret_access_key"; let region = "fake_us-east-2"; let endpoint = "endpoint33"; - let session_token = "fake_session_token"; - let location = "s3://bucket/path/file.parquet"; + let session_token = "FAKE_session_token"; + let location = "s3://bucket/path/FAKE/file.parquet"; let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); @@ -495,7 +497,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let builder = @@ -540,7 +542,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) @@ -566,7 +568,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); // ensure this isn't an error @@ -594,7 +596,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; @@ -631,7 +633,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let gcp_options = table_options.extensions.get::().unwrap(); let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 92cb106d622bf..1fc949593512b 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -26,7 +26,7 @@ use arrow::datatypes::SchemaRef; use arrow::json::{ArrayWriter, LineDelimitedWriter}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; use datafusion::error::Result; /// Allow records to be printed in different formats @@ -133,7 +133,7 @@ fn format_batches_with_maxrows( let formatted = pretty_format_batches_with_options( &filtered_batches, - &DEFAULT_FORMAT_OPTIONS, + &DEFAULT_CLI_FORMAT_OPTIONS, )?; if over_limit { let mut formatted_str = format!("{}", formatted); @@ -145,7 +145,7 @@ fn format_batches_with_maxrows( } MaxRows::Unlimited => { let formatted = - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?; + pretty_format_batches_with_options(batches, &DEFAULT_CLI_FORMAT_OPTIONS)?; writeln!(writer, "{}", formatted)?; } } @@ -201,7 +201,7 @@ impl PrintFormat { let empty_batch = RecordBatch::new_empty(schema); let formatted = pretty_format_batches_with_options( &[empty_batch], - &DEFAULT_FORMAT_OPTIONS, + &DEFAULT_CLI_FORMAT_OPTIONS, )?; writeln!(writer, "{}", formatted)?; } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index f430a87e190db..d8aaad801e5c0 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -60,8 +60,10 @@ async-trait = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } datafusion = { workspace = true, default-features = true, features = ["avro"] } +datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-window-common = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-proto = { workspace = true } @@ -70,12 +72,8 @@ env_logger = { workspace = true } futures = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", default-features = false } -num_cpus = { workspace = true } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } -prost-derive = { workspace = true } -serde = { version = "1.0.136", features = ["derive"] } -serde_json = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 5f032c3e9cfff..b5f82b4d5140e 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -22,7 +22,7 @@ This crate includes end to end, highly commented examples of how to use various DataFusion APIs to help you get started. -## Prerequisites: +## Prerequisites Run `git submodule update --init` to init test files. @@ -54,22 +54,19 @@ cargo run --example dataframe - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file +- [`csv_json_opener.rs`](examples/csv_json_opener.rs): Use low level `FileOpener` APIs to read CSV/JSON into Arrow `RecordBatch`es - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file -- [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory -- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame -- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde -- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and analyze `Expr`s +- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. +- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results (Arrow ArrayRefs) into Rust structs +- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify, analyze and coerce `Expr`s - [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function -- [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries -- [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files - [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution - [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. - [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from DataFusion `Expr` and `LogicalPlan` @@ -78,12 +75,14 @@ cargo run --example dataframe - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions +- [`remote_catalog.rs`](examples/regexp.rs): Examples of interfacing with a remote catalog (e.g. over a network) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) - [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with DataFusion structures - [`sql_frontend.rs`](examples/sql_frontend.rs): Create LogicalPlans (only) from sql strings - [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` +- [`sql_query.rs`](examples/memtable.rs): Query data using SQL (in memory `RecordBatch`es, local Parquet files)q - [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function - [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index f6860bb5b87a5..28a3a2f1de09e 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -82,7 +82,7 @@ use url::Url; /// Specifically, this example illustrates how to: /// 1. Use [`ParquetFileReaderFactory`] to avoid re-reading parquet metadata on each query /// 2. Use [`PruningPredicate`] for predicate analysis -/// 3. Pass a row group selection to [`ParuetExec`] +/// 3. Pass a row group selection to [`ParquetExec`] /// 4. Pass a row selection (within a row group) to [`ParquetExec`] /// /// Note this is a *VERY* low level example for people who want to build their @@ -211,7 +211,7 @@ async fn main() -> Result<()> { // // Note: in order to prune pages, the Page Index must be loaded and the // ParquetExec will load it on demand if not present. To avoid a second IO - // during query, this example loaded the Page Index pre-emptively by setting + // during query, this example loaded the Page Index preemptively by setting // `ArrowReader::with_page_index` in `IndexedFile::try_new` provider.set_use_row_selection(true); println!("** Select data, predicate `id = 950`"); @@ -229,9 +229,9 @@ async fn main() -> Result<()> { /// `file1.parquet` contains values `0..1000` #[derive(Debug)] pub struct IndexTableProvider { - /// Where the file is stored (cleanup on drop) - #[allow(dead_code)] - tmpdir: TempDir, + /// Pointer to temporary file storage. Keeping it in scope to prevent temporary folder + /// to be deleted prematurely + _tmpdir: TempDir, /// The file that is being read. indexed_file: IndexedFile, /// The underlying object store @@ -250,7 +250,7 @@ impl IndexTableProvider { Ok(Self { indexed_file, - tmpdir, + _tmpdir: tmpdir, object_store, use_row_selections: AtomicBool::new(false), }) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 1259f90d64496..a914cea4a928a 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,7 +31,9 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::{AccumulatorArgs, StateFieldsArgs}, + expr::AggregateFunction, + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + simplify::SimplifyInfo, Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; @@ -193,44 +195,10 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } -// create local session context with an in-memory table -fn create_context() -> Result { - use datafusion::datasource::MemTable; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float32, false), - ])); - - // define data in two partitions - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), - Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), - ], - )?; - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![64.0])), - Arc::new(Float32Array::from(vec![2.0])), - ], - )?; - - // declare a new context. In spark API, this corresponds to a new spark SQLsession - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - Ok(ctx) -} - // Define a `GroupsAccumulator` for GeometricMean /// which handles accumulator state for multiple groups at once. /// This API is significantly more complicated than `Accumulator`, which manages @@ -394,40 +362,151 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.prods.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + + self.prods.capacity() * size_of::() + } +} + +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. +#[derive(Debug, Clone)] +struct SimplifiedGeoMeanUdaf { + signature: Signature, +} + +impl SimplifiedGeoMeanUdaf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "simplified_geo_mean" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unimplemented!("should not get here"); + } + + /// Optionally replaces a UDAF with another expression during query optimization. + fn simplify(&self) -> Option { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method. + // In real-world scenarios, you might create UDFs from built-in expressions. + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(AggregateUDF::from(GeoMeanUdaf::new())), + aggregate_function.args, + aggregate_function.distinct, + aggregate_function.filter, + aggregate_function.order_by, + aggregate_function.null_treatment, + ))) + }; + Some(Box::new(simplify)) } } +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![64.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + #[tokio::main] async fn main() -> Result<()> { let ctx = create_context()?; - // create the AggregateUDF - let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new()); - ctx.register_udaf(geometric_mean.clone()); + let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new()); + let simplified_geo_mean_udf = AggregateUDF::from(SimplifiedGeoMeanUdaf::new()); + + for (udf, udf_name) in [ + (geo_mean_udf, "geo_mean"), + (simplified_geo_mean_udf, "simplified_geo_mean"), + ] { + ctx.register_udaf(udf.clone()); - let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?; - sql_df.show().await?; + let sql_df = ctx + .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name)) + .await?; + sql_df.show().await?; - // get a DataFrame from the context - // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. - let df = ctx.table("t").await?; + // get a DataFrame from the context + // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. + let df = ctx.table("t").await?; - // perform the aggregation - let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; + // perform the aggregation + let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?; - // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. + // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. - // execute the query - let results = df.collect().await?; + // execute the query + let results = df.collect().await?; - // downcast the array to the expected type - let result = as_float64_array(results[0].column(0))?; + // downcast the array to the expected type + let result = as_float64_array(results[0].column(0))?; - // verify that the calculation is correct - assert!((result.value(0) - 8.0).abs() < f64::EPSILON); - println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + // verify that the calculation is correct + assert!((result.value(0) - 8.0).abs() < f64::EPSILON); + println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + } Ok(()) } diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 22d37043e4731..0aa2b3f370e97 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -27,9 +27,11 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::logical_expr::Volatility; use datafusion::prelude::*; -use datafusion_common::{internal_err, ScalarValue}; +use datafusion_common::{exec_err, internal_err, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, +}; /// This example shows how to use the full ScalarUDFImpl API to implement a user /// defined function. As in the `simple_udf.rs` example, this struct implements @@ -83,21 +85,29 @@ impl ScalarUDFImpl for PowUdf { Ok(DataType::Float64) } - /// This is the function that actually calculates the results. + /// This function actually calculates the results of the scalar function. + /// + /// This is the same way that functions provided with DataFusion are invoked, + /// which permits important special cases: /// - /// This is the same way that functions built into DataFusion are invoked, - /// which permits important special cases when one or both of the arguments - /// are single values (constants). For example `pow(a, 2)` + ///1. When one or both of the arguments are single values (constants). + /// For example `pow(a, 2)` + /// 2. When the input arrays can be reused (avoid allocating a new output array) /// /// However, it also means the implementation is more complex than when /// using `create_udf`. - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // The other fields of the `args` struct are used for more specialized + // uses, and are not needed in this example + let ScalarFunctionArgs { mut args, .. } = args; // DataFusion has arranged for the correct inputs to be passed to this // function, but we check again to make sure assert_eq!(args.len(), 2); - let (base, exp) = (&args[0], &args[1]); - assert_eq!(base.data_type(), &DataType::Float64); - assert_eq!(exp.data_type(), &DataType::Float64); + // take ownership of arguments by popping in reverse order + let exp = args.pop().unwrap(); + let base = args.pop().unwrap(); + assert_eq!(*base.data_type(), DataType::Float64); + assert_eq!(*exp.data_type(), DataType::Float64); match (base, exp) { // For demonstration purposes we also implement the scalar / scalar @@ -109,21 +119,17 @@ impl ScalarUDFImpl for PowUdf { // this path once during planning, and simply use the result during // execution. (ColumnarValue::Scalar(base), ColumnarValue::Scalar(exp)) => { - match (base.value(), exp.value()) { - (ScalarValue::Float64(base), ScalarValue::Float64(exp)) => { - // compute the output. Note DataFusion treats `None` as NULL. - let res = match (base, exp) { - (Some(base), Some(exp)) => Some(base.powf(*exp)), - // one or both arguments were NULL - _ => None, - }; - Ok(ColumnarValue::from(ScalarValue::from(res))) - } - _ => { - internal_err!("Invalid argument types to pow function") - } - } + let res = match (base.value(), exp.value()) { + // compute the output. Note DataFusion treats `None` as NULL. + ( + ScalarValue::Float64(Some(base)), + ScalarValue::Float64(Some(exp)), + ) => Some(base.powf(*exp)), + _ => None, // one or both arguments were NULL + }; + Ok(ColumnarValue::from(ScalarValue::from(res))) } + // special case if the exponent is a constant (ColumnarValue::Array(base_array), ColumnarValue::Scalar(exp)) => { let result_array = match exp.value() { @@ -147,24 +153,28 @@ impl ScalarUDFImpl for PowUdf { Ok(ColumnarValue::Array(result_array)) } - // special case if the base is a constant (note this code is quite - // similar to the previous case, so we omit comments) + // special case if the base is a constant. + // + // Note this case is very similar to the previous case, so we could + // use the same pattern. However, for this case we demonstrate an + // even more advanced pattern to potentially avoid allocating a new array (ColumnarValue::Scalar(base), ColumnarValue::Array(exp_array)) => { let res = match base.value() { ScalarValue::Float64(None) => { new_null_array(exp_array.data_type(), exp_array.len()) } ScalarValue::Float64(Some(base)) => { - let exp_array = exp_array.as_primitive::(); - let res: Float64Array = - compute::unary(exp_array, |exp| base.powf(exp)); - Arc::new(res) + maybe_pow_in_place(*base, exp_array)? } - _ => return internal_err!("Invalid argument types to pow function"), + _ => return internal_err!("Invalid scalar argument to pow function"), }; Ok(ColumnarValue::Array(res)) } - // Both arguments are arrays so we have to perform the calculation for every row + // Both arguments are arrays so we have to perform the calculation + // for every row + // + // Note this could also be done in place using `binary_mut` as + // is done in `maybe_pow_in_place` but here we use binary for simplicity (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { let res: Float64Array = compute::binary( base_array.as_primitive::(), @@ -187,6 +197,52 @@ impl ScalarUDFImpl for PowUdf { } } +/// Evaluate `base ^ exp` *without* allocating a new array, if possible +fn maybe_pow_in_place(base: f64, exp_array: ArrayRef) -> Result { + // Calling `unary` creates a new array for the results. Avoiding + // allocations is a common optimization in performance critical code. + // arrow-rs allows this optimization via the `unary_mut` + // and `binary_mut` kernels in certain cases + // + // These kernels can only be used if there are no other references to + // the arrays (exp_array has to be the last remaining reference). + let owned_array = exp_array + // as in the previous example, we first downcast to &Float64Array + .as_primitive::() + // non-obviously, we call clone here to get an owned `Float64Array`. + // Calling clone() is relatively inexpensive as it increments + // some ref counts but doesn't clone the data) + // + // Once we have the owned Float64Array we can drop the original + // exp_array (untyped) reference + .clone(); + + // We *MUST* drop the reference to `exp_array` explicitly so that + // owned_array is the only reference remaining in this function. + // + // Note that depending on the query there may still be other references + // to the underlying buffers, which would prevent reuse. The only way to + // know for sure is the result of `compute::unary_mut` + drop(exp_array); + + // If we have the only reference, compute the result directly into the same + // allocation as was used for the input array + match compute::unary_mut(owned_array, |exp| base.powf(exp)) { + Err(_orig_array) => { + // unary_mut will return the original array if there are other + // references into the underling buffer (and thus reuse is + // impossible) + // + // In a real implementation, this case should fall back to + // calling `unary` and allocate a new array; In this example + // we will return an error for demonstration purposes + exec_err!("Could not reuse array for maybe_pow_in_place") + } + // a result of OK means the operation was run successfully + Ok(res) => Ok(Arc::new(res)), + } +} + /// In this example we register `PowUdf` as a user defined function /// and invoke it via the DataFrame API and SQL #[tokio::main] @@ -211,9 +267,29 @@ async fn main() -> Result<()> { // print the results df.show().await?; - // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL - let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; - sql_df.show().await?; + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t") + .await? + .show() + .await?; + + // You can also invoke pow_in_place by passing a constant base and a + // column `a` as the exponent . If there is only a single + // reference to `a` the code works well + ctx.sql("SELECT pow(2, a) FROM t").await?.show().await?; + + // However, if there are multiple references to `a` in the evaluation + // the array storage can not be reused + let err = ctx + .sql("SELECT pow(2, a), pow(3, a) FROM t") + .await? + .show() + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Execution error: Could not reuse array for maybe_pow_in_place" + ); Ok(()) } diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index fd1b84070cf68..49e890467d21e 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -24,12 +24,16 @@ use arrow::{ }; use arrow_schema::Field; use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg_udaf; use datafusion::prelude::*; use datafusion_common::ScalarValue; -use datafusion_expr::function::WindowUDFFieldArgs; +use datafusion_expr::expr::WindowFunction; +use datafusion_expr::function::{WindowFunctionSimplification, WindowUDFFieldArgs}; +use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, + Expr, PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -74,7 +78,10 @@ impl WindowUDFImpl for SmoothItUdf { /// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) } @@ -138,6 +145,67 @@ impl PartitionEvaluator for MyPartitionEvaluator { } } +/// This UDWF will show how to use the WindowUDFImpl::simplify() API +#[derive(Debug, Clone)] +struct SimplifySmoothItUdf { + signature: Signature, +} + +impl SimplifySmoothItUdf { + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} +impl WindowUDFImpl for SimplifySmoothItUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "simplify_smooth_it" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + todo!() + } + + /// this function will simplify `SimplifySmoothItUdf` to `AggregateUDF` for `Avg` + /// default implementation will not be called (left as `todo!()`) + fn simplify(&self) -> Option { + let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { + Ok(Expr::WindowFunction(WindowFunction { + fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), + args: window_function.args, + partition_by: window_function.partition_by, + order_by: window_function.order_by, + window_frame: window_function.window_frame, + null_treatment: window_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true)) + } +} + // create local execution context with `cars.csv` registered as a table named `cars` async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session @@ -158,12 +226,15 @@ async fn main() -> Result<()> { let smooth_it = WindowUDF::from(SmoothItUdf::new()); ctx.register_udwf(smooth_it.clone()); - // Use SQL to run the new window function + let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new()); + ctx.register_udwf(simplify_smooth_it.clone()); + + // Use SQL to retrieve entire table let df = ctx.sql("SELECT * from cars").await?; // print the results df.show().await?; - // Use SQL to run the new window function: + // Use SQL to run smooth_it: // // `PARTITION BY car`:each distinct value of car (red, and green) // should be treated as a separate partition (and will result in @@ -197,7 +268,7 @@ async fn main() -> Result<()> { // print the results df.show().await?; - // this time, call the new widow function with an explicit + // this time, call the function with an explicit // window so evaluate will be invoked with each window. // // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation @@ -228,5 +299,22 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // Use SQL to run simplify_smooth_it + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + + // print the results + df.show().await?; + Ok(()) } diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index bd067be97b8b3..aded64ed4105d 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -138,7 +138,7 @@ impl AnalyzerRule for RowLevelAccessControl { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { // use the TreeNode API to recursively walk the LogicalPlan tree // and all of its children (inputs) - let transfomed_plan = plan.transform(|plan| { + let transformed_plan = plan.transform(|plan| { // This closure is called for each LogicalPlan node // if it is a Scan node, add a filter to remove all managers if is_employee_table_scan(&plan) { @@ -166,7 +166,7 @@ impl AnalyzerRule for RowLevelAccessControl { // // This example does not need the value of either flag, so simply // extract the LogicalPlan "data" - Ok(transfomed_plan.data) + Ok(transformed_plan.data) } fn name(&self) -> &str { diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index f40f1dfb5a159..655438b78b9fa 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -46,11 +46,11 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); let state = ctx.state(); - let cataloglist = Arc::new(CustomCatalogProviderList::new()); + let catalog_list = Arc::new(CustomCatalogProviderList::new()); // use our custom catalog list for context. each context has a single catalog list. // context will by default have [`MemoryCatalogProviderList`] - ctx.register_catalog_list(cataloglist.clone()); + ctx.register_catalog_list(catalog_list.clone()); // initialize our catalog and schemas let catalog = DirCatalog::new(); @@ -81,7 +81,7 @@ async fn main() -> Result<()> { ctx.register_catalog("dircat", Arc::new(catalog)); { // catalog was passed down into our custom catalog list since we override the ctx's default - let catalogs = cataloglist.catalogs.read().unwrap(); + let catalogs = catalog_list.catalogs.read().unwrap(); assert!(catalogs.contains_key("dircat")); }; @@ -144,8 +144,8 @@ impl DirSchema { async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result> { let DirSchemaOpts { ext, dir, format } = opts; let mut tables = HashMap::new(); - let direntries = std::fs::read_dir(dir).unwrap(); - for res in direntries { + let dir_entries = std::fs::read_dir(dir).unwrap(); + for res in dir_entries { let entry = res.unwrap(); let filename = entry.file_name().to_str().unwrap().to_string(); if !filename.ends_with(ext) { diff --git a/datafusion-examples/examples/config_extension.rs b/datafusion-examples/examples/config_extension.rs deleted file mode 100644 index b9f83f91ce564..0000000000000 --- a/datafusion-examples/examples/config_extension.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This example demonstrates how to extend the DataFusion configs with custom extensions. - -use datafusion::{ - common::{config::ConfigExtension, extensions_options}, - config::ConfigOptions, -}; - -extensions_options! { - /// My own config options. - pub struct MyConfig { - /// Should "foo" be replaced by "bar"? - pub foo_to_bar: bool, default = true - - /// How many "baz" should be created? - pub baz_count: usize, default = 1337 - } -} - -impl ConfigExtension for MyConfig { - const PREFIX: &'static str = "my_config"; -} - -fn main() { - // set up config struct and register extension - let mut config = ConfigOptions::default(); - config.extensions.insert(MyConfig::default()); - - // overwrite config default - config.set("my_config.baz_count", "42").unwrap(); - - // check config state - let my_config = config.extensions.get::().unwrap(); - assert!(my_config.foo_to_bar,); - assert_eq!(my_config.baz_count, 42,); -} diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_json_opener.rs similarity index 50% rename from datafusion-examples/examples/csv_opener.rs rename to datafusion-examples/examples/csv_json_opener.rs index e7b7ead109bc0..334e4c83404ff 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_json_opener.rs @@ -15,28 +15,36 @@ // specific language governing permissions and limitations // under the License. -use std::{sync::Arc, vec}; +use std::sync::Arc; +use arrow_schema::{DataType, Field, Schema}; use datafusion::{ assert_batches_eq, datasource::{ file_format::file_compression_type::FileCompressionType, listing::PartitionedFile, object_store::ObjectStoreUrl, - physical_plan::{CsvConfig, CsvOpener, FileScanConfig, FileStream}, + physical_plan::{CsvConfig, CsvOpener, FileScanConfig, FileStream, JsonOpener}, }, error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, test_util::aggr_test_schema, }; - use futures::StreamExt; -use object_store::local::LocalFileSystem; +use object_store::{local::LocalFileSystem, memory::InMemory, ObjectStore}; -/// This example demonstrates a scanning against an Arrow data source (CSV) and -/// fetching results +/// This example demonstrates using the low level [`FileStream`] / [`FileOpener`] APIs to directly +/// read data from (CSV/JSON) into Arrow RecordBatches. +/// +/// If you want to query data in CSV or JSON files, see the [`dataframe.rs`] and [`sql_query.rs`] examples #[tokio::main] async fn main() -> Result<()> { + csv_opener().await?; + json_opener().await?; + Ok(()) +} + +async fn csv_opener() -> Result<()> { let object_store = Arc::new(LocalFileSystem::new()); let schema = aggr_test_schema(); @@ -59,18 +67,17 @@ async fn main() -> Result<()> { let path = std::path::Path::new(&path).canonicalize()?; - let scan_config = - FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema.clone()) - .with_projection(Some(vec![12, 0])) - .with_limit(Some(5)) - .with_file(PartitionedFile::new(path.display().to_string(), 10)); - - let result = - FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new()) - .unwrap() - .map(|b| b.unwrap()) - .collect::>() - .await; + let scan_config = FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema) + .with_projection(Some(vec![12, 0])) + .with_limit(Some(5)) + .with_file(PartitionedFile::new(path.display().to_string(), 10)); + + let mut result = vec![]; + let mut stream = + FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new())?; + while let Some(batch) = stream.next().await.transpose()? { + result.push(batch); + } assert_batches_eq!( &[ "+--------------------------------+----+", @@ -87,3 +94,54 @@ async fn main() -> Result<()> { ); Ok(()) } + +async fn json_opener() -> Result<()> { + let object_store = InMemory::new(); + let path = object_store::path::Path::from("demo.json"); + let data = bytes::Bytes::from( + r#"{"num":5,"str":"test"} + {"num":2,"str":"hello"} + {"num":4,"str":"foo"}"#, + ); + + object_store.put(&path, data.into()).await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("num", DataType::Int64, false), + Field::new("str", DataType::Utf8, false), + ])); + + let projected = Arc::new(schema.clone().project(&[1, 0])?); + + let opener = JsonOpener::new( + 8192, + projected, + FileCompressionType::UNCOMPRESSED, + Arc::new(object_store), + ); + + let scan_config = FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema) + .with_projection(Some(vec![1, 0])) + .with_limit(Some(5)) + .with_file(PartitionedFile::new(path.to_string(), 10)); + + let mut stream = + FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new())?; + let mut result = vec![]; + while let Some(batch) = stream.next().await.transpose()? { + result.push(batch); + } + assert_batches_eq!( + &[ + "+-------+-----+", + "| str | num |", + "+-------+-----+", + "| test | 5 |", + "| hello | 2 |", + "| foo | 4 |", + "+-------+-----+", + ], + &result + ); + Ok(()) +} diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 0f7748b133650..bc865fac5a338 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -21,22 +21,23 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::{Arc, Mutex}; use std::time::Duration; +use async_trait::async_trait; use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::{provider_as_source, TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::TaskContext; +use datafusion::logical_expr::LogicalPlanBuilder; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ - project_schema, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, - Partitioning, PlanProperties, SendableRecordBatchStream, + project_schema, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PlanProperties, SendableRecordBatchStream, }; use datafusion::prelude::*; -use datafusion_expr::LogicalPlanBuilder; -use datafusion_physical_expr::EquivalenceProperties; -use async_trait::async_trait; use datafusion::catalog::Session; use tokio::time::timeout; @@ -110,7 +111,7 @@ struct CustomDataSourceInner { } impl Debug for CustomDataSource { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str("custom_db") } } @@ -214,13 +215,14 @@ impl CustomExec { PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } impl DisplayAs for CustomExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { write!(f, "CustomExec") } } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 1d9b587f15b93..95168597ebaaf 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -74,10 +74,7 @@ impl FileFormat for TSVFileFormat { "tsv".to_string() } - fn get_ext_with_compression( - &self, - c: &FileCompressionType, - ) -> datafusion::error::Result { + fn get_ext_with_compression(&self, c: &FileCompressionType) -> Result { if c == &FileCompressionType::UNCOMPRESSED { Ok("tsv".to_string()) } else { @@ -154,7 +151,7 @@ impl FileFormatFactory for TSVFileFactory { &self, state: &SessionState, format_options: &std::collections::HashMap, - ) -> Result> { + ) -> Result> { let mut new_options = format_options.clone(); new_options.insert("format.delimiter".to_string(), "\t".to_string()); @@ -164,7 +161,7 @@ impl FileFormatFactory for TSVFileFactory { Ok(tsv_file_format) } - fn default(&self) -> std::sync::Arc { + fn default(&self) -> Arc { todo!() } diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index d7e0068ef88f4..91d62135b9135 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -15,90 +15,116 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg; +use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; +use datafusion_common::config::CsvOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; use std::fs::File; use std::io::Write; +use std::sync::Arc; use tempfile::tempdir; -/// This example demonstrates executing a simple query against an Arrow data source (Parquet) and -/// fetching results, using the DataFrame trait +/// This example demonstrates using DataFusion's DataFrame API +/// +/// # Reading from different formats +/// +/// * [read_parquet]: execute queries against parquet files +/// * [read_csv]: execute queries against csv files +/// * [read_memory]: execute queries against in-memory arrow data +/// +/// # Writing out to local storage +/// +/// The following examples demonstrate how to write a DataFrame to local +/// storage. See `external_dependency/dataframe-to-s3.rs` for an example writing +/// to a remote object store. +/// +/// * [write_out]: write out a DataFrame to a table, parquet file, csv file, or json file +/// +/// # Executing subqueries +/// +/// * [where_scalar_subquery]: execute a scalar subquery +/// * [where_in_subquery]: execute a subquery with an IN clause +/// * [where_exist_subquery]: execute a subquery with an EXISTS clause +/// +/// # Querying data +/// +/// * [query_to_date]: execute queries against parquet files #[tokio::main] async fn main() -> Result<()> { - // create local execution context + // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); + read_parquet(&ctx).await?; + read_csv(&ctx).await?; + read_memory(&ctx).await?; + write_out(&ctx).await?; + query_to_date().await?; + register_aggregate_test_data("t1", &ctx).await?; + register_aggregate_test_data("t2", &ctx).await?; + where_scalar_subquery(&ctx).await?; + where_in_subquery(&ctx).await?; + where_exist_subquery(&ctx).await?; + Ok(()) +} +/// Use DataFrame API to +/// 1. Read parquet files, +/// 2. Show the schema +/// 3. Select columns and rows +async fn read_parquet(ctx: &SessionContext) -> Result<()> { + // Find the local path of "alltypes_plain.parquet" let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); - // define the query using the DataFrame trait - let df = ctx - .read_parquet(filename, ParquetReadOptions::default()) - .await? - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(1)))?; - - // print the results - df.show().await?; - - // create a csv file waiting to be written - let dir = tempdir()?; - let file_path = dir.path().join("example.csv"); - let file = File::create(&file_path)?; - write_csv_file(file); - - // Reading CSV file with inferred schema example - let csv_df = - example_read_csv_file_with_inferred_schema(file_path.to_str().unwrap()).await; - csv_df.show().await?; - - // Reading CSV file with defined schema - let csv_df = example_read_csv_file_with_schema(file_path.to_str().unwrap()).await; - csv_df.show().await?; - - // Reading PARQUET file and print describe + // Read the parquet files and show its schema using 'describe' let parquet_df = ctx .read_parquet(filename, ParquetReadOptions::default()) .await?; - parquet_df.describe().await.unwrap().show().await?; - let dyn_ctx = ctx.enable_url_table(); - let df = dyn_ctx - .sql(&format!("SELECT * FROM '{}'", file_path.to_str().unwrap())) + // show its schema using 'describe' + parquet_df.clone().describe().await?.show().await?; + + // Select three columns and filter the results + // so that only rows where id > 1 are returned + parquet_df + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(1)))? + .show() .await?; - df.show().await?; Ok(()) } -// Function to create an test CSV file -fn write_csv_file(mut file: File) { - // Create the data to put into the csv file with headers - let content = r#"id,time,vote,unixtime,rating -a1,"10 6, 2013",3,1381017600,5.0 -a2,"08 9, 2013",2,1376006400,4.5"#; - // write the data - file.write_all(content.as_ref()) - .expect("Problem with writing file!"); -} +/// Use the DataFrame API to +/// 1. Read CSV files +/// 2. Optionally specify schema +async fn read_csv(ctx: &SessionContext) -> Result<()> { + // create example.csv file in a temporary directory + let dir = tempdir()?; + let file_path = dir.path().join("example.csv"); + { + let mut file = File::create(&file_path)?; + // write CSV data + file.write_all( + r#"id,time,vote,unixtime,rating + a1,"10 6, 2013",3,1381017600,5.0 + a2,"08 9, 2013",2,1376006400,4.5"# + .as_bytes(), + )?; + } // scope closes the file + let file_path = file_path.to_str().unwrap(); -// Example to read data from a csv file with inferred schema -async fn example_read_csv_file_with_inferred_schema(file_path: &str) -> DataFrame { - // Create a session context - let ctx = SessionContext::new(); - // Register a lazy DataFrame using the context - ctx.read_csv(file_path, CsvReadOptions::default()) - .await - .unwrap() -} + // You can read a CSV file and DataFusion will infer the schema automatically + let csv_df = ctx.read_csv(file_path, CsvReadOptions::default()).await?; + csv_df.show().await?; -// Example to read csv file with a defined schema for the csv file -async fn example_read_csv_file_with_schema(file_path: &str) -> DataFrame { - // Create a session context - let ctx = SessionContext::new(); - // Define the schema + // If you know the types of your data you can specify them explicitly let schema = Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("time", DataType::Utf8, false), @@ -112,6 +138,206 @@ async fn example_read_csv_file_with_schema(file_path: &str) -> DataFrame { schema: Some(&schema), ..Default::default() }; - // Register a lazy DataFrame by using the context and option provider - ctx.read_csv(file_path, csv_read_option).await.unwrap() + let csv_df = ctx.read_csv(file_path, csv_read_option).await?; + csv_df.show().await?; + + // You can also create DataFrames from the result of sql queries + // and using the `enable_url_table` refer to local files directly + let dyn_ctx = ctx.clone().enable_url_table(); + let csv_df = dyn_ctx + .sql(&format!("SELECT rating, unixtime FROM '{}'", file_path)) + .await?; + csv_df.show().await?; + + Ok(()) +} + +/// Use the DataFrame API to: +/// 1. Read in-memory data. +async fn read_memory(ctx: &SessionContext) -> Result<()> { + // define data in memory + let a: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let b: ArrayRef = Arc::new(Int32Array::from(vec![1, 10, 10, 100])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a table in memory. In Apache Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL + let filter = col("b").eq(lit(10)); + let df = df.select_columns(&["a", "b"])?.filter(filter)?; + + // print the results + df.show().await?; + + Ok(()) +} + +/// Use the DataFrame API to: +/// 1. Write out a DataFrame to a table +/// 2. Write out a DataFrame to a parquet file +/// 3. Write out a DataFrame to a csv file +/// 4. Write out a DataFrame to a json file +async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { + let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); + + // Ensure the column names and types match the target table + df = df.with_column_renamed("column1", "tablecol1").unwrap(); + + ctx.sql( + "create external table + test(tablecol1 varchar) + stored as parquet + location './datafusion-examples/test_table/'", + ) + .await? + .collect() + .await?; + + // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). + // The behavior of write_table depends on the TableProvider's implementation + // of the insert_into method. + df.clone() + .write_table("test", DataFrameWriteOptions::new()) + .await?; + + df.clone() + .write_parquet( + "./datafusion-examples/test_parquet/", + DataFrameWriteOptions::new(), + None, + ) + .await?; + + df.clone() + .write_csv( + "./datafusion-examples/test_csv/", + // DataFrameWriteOptions contains options which control how data is written + // such as compression codec + DataFrameWriteOptions::new(), + Some(CsvOptions::default().with_compression(CompressionTypeVariant::GZIP)), + ) + .await?; + + df.clone() + .write_json( + "./datafusion-examples/test_json/", + DataFrameWriteOptions::new(), + None, + ) + .await?; + + Ok(()) +} + +/// This example demonstrates how to use the to_date series +/// of functions in the DataFrame API as well as via sql. +async fn query_to_date() -> Result<()> { + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(StringArray::from(vec![ + "2020-09-08T13:42:29Z", + "2020-09-08T13:42:29.190855-05:00", + "2020-08-09 12:13:29", + "2020-01-02", + ]))], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + // use to_date function to convert col 'a' to timestamp type using the default parsing + let df = df.with_column("a", to_date(vec![col("a")]))?; + + let df = df.select_columns(&["a"])?; + + // print the results + df.show().await?; + + Ok(()) +} + +/// Use the DataFrame API to execute the following subquery: +/// select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; +async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { + ctx.table("t1") + .await? + .filter( + scalar_subquery(Arc::new( + ctx.table("t2") + .await? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? + .aggregate(vec![], vec![avg(col("t2.c2"))])? + .select(vec![avg(col("t2.c2"))])? + .into_unoptimized_plan(), + )) + .gt(lit(0u8)), + )? + .select(vec![col("t1.c1"), col("t1.c2")])? + .limit(0, Some(3))? + .show() + .await?; + Ok(()) +} + +/// Use the DataFrame API to execute the following subquery: +/// select t1.c1, t1.c2 from t1 where t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; +async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { + ctx.table("t1") + .await? + .filter(in_subquery( + col("t1.c2"), + Arc::new( + ctx.table("t2") + .await? + .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? + .aggregate(vec![], vec![max(col("t2.c2"))])? + .select(vec![max(col("t2.c2"))])? + .into_unoptimized_plan(), + ), + ))? + .select(vec![col("t1.c1"), col("t1.c2")])? + .limit(0, Some(3))? + .show() + .await?; + Ok(()) +} + +/// Use the DataFrame API to execute the following subquery: +/// select t1.c1, t1.c2 from t1 where exists (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; +async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { + ctx.table("t1") + .await? + .filter(exists(Arc::new( + ctx.table("t2") + .await? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? + .select(vec![col("t2.c2")])? + .into_unoptimized_plan(), + )))? + .select(vec![col("t1.c1"), col("t1.c2")])? + .limit(0, Some(3))? + .show() + .await?; + Ok(()) +} + +async fn register_aggregate_test_data(name: &str, ctx: &SessionContext) -> Result<()> { + let testdata = datafusion::test_util::arrow_test_data(); + ctx.register_csv( + name, + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::default(), + ) + .await?; + Ok(()) } diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs deleted file mode 100644 index c57c38870a7e4..0000000000000 --- a/datafusion-examples/examples/dataframe_in_memory.rs +++ /dev/null @@ -1,60 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use datafusion::arrow::array::{Int32Array, StringArray}; -use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates how to use the DataFrame API against in-memory data. -#[tokio::main] -async fn main() -> Result<()> { - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Int32, false), - ])); - - // define data. - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - ], - )?; - - // declare a new context. In spark API, this corresponds to a new spark SQLsession - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?; - - // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL - let filter = col("b").eq(lit(10)); - - let df = df.select_columns(&["a", "b"])?.filter(filter)?; - - // print the results - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/dataframe_output.rs b/datafusion-examples/examples/dataframe_output.rs deleted file mode 100644 index 60ca090d722d6..0000000000000 --- a/datafusion-examples/examples/dataframe_output.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::{dataframe::DataFrameWriteOptions, prelude::*}; -use datafusion_common::config::CsvOptions; -use datafusion_common::{parsers::CompressionTypeVariant, DataFusionError}; - -/// This example demonstrates the various methods to write out a DataFrame to local storage. -/// See datafusion-examples/examples/external_dependency/dataframe-to-s3.rs for an example -/// using a remote object store. -#[tokio::main] -async fn main() -> Result<(), DataFusionError> { - let ctx = SessionContext::new(); - - let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); - - // Ensure the column names and types match the target table - df = df.with_column_renamed("column1", "tablecol1").unwrap(); - - ctx.sql( - "create external table - test(tablecol1 varchar) - stored as parquet - location './datafusion-examples/test_table/'", - ) - .await? - .collect() - .await?; - - // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). - // The behavior of write_table depends on the TableProvider's implementation - // of the insert_into method. - df.clone() - .write_table("test", DataFrameWriteOptions::new()) - .await?; - - df.clone() - .write_parquet( - "./datafusion-examples/test_parquet/", - DataFrameWriteOptions::new(), - None, - ) - .await?; - - df.clone() - .write_csv( - "./datafusion-examples/test_csv/", - // DataFrameWriteOptions contains options which control how data is written - // such as compression codec - DataFrameWriteOptions::new(), - Some(CsvOptions::default().with_compression(CompressionTypeVariant::GZIP)), - ) - .await?; - - df.clone() - .write_json( - "./datafusion-examples/test_json/", - DataFrameWriteOptions::new(), - None, - ) - .await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs deleted file mode 100644 index 3e3d0c1b5a84b..0000000000000 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ /dev/null @@ -1,118 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_schema::DataType; -use std::sync::Arc; - -use datafusion::error::Result; -use datafusion::functions_aggregate::average::avg; -use datafusion::functions_aggregate::min_max::max; -use datafusion::prelude::*; -use datafusion::test_util::arrow_test_data; -use datafusion_common::ScalarValue; - -/// This example demonstrates how to use the DataFrame API to create a subquery. -#[tokio::main] -async fn main() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_test_data("t1", &ctx).await?; - register_aggregate_test_data("t2", &ctx).await?; - - where_scalar_subquery(&ctx).await?; - - where_in_subquery(&ctx).await?; - - where_exist_subquery(&ctx).await?; - - Ok(()) -} - -//select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; -async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { - ctx.table("t1") - .await? - .filter( - scalar_subquery(Arc::new( - ctx.table("t2") - .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .aggregate(vec![], vec![avg(col("t2.c2"))])? - .select(vec![avg(col("t2.c2"))])? - .into_unoptimized_plan(), - )) - .gt(lit(0u8)), - )? - .select(vec![col("t1.c1"), col("t1.c2")])? - .limit(0, Some(3))? - .show() - .await?; - Ok(()) -} - -//SELECT t1.c1, t1.c2 FROM t1 WHERE t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; -async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { - ctx.table("t1") - .await? - .filter(in_subquery( - col("t1.c2"), - Arc::new( - ctx.table("t2") - .await? - .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? - .aggregate(vec![], vec![max(col("t2.c2"))])? - .select(vec![max(col("t2.c2"))])? - .into_unoptimized_plan(), - ), - ))? - .select(vec![col("t1.c1"), col("t1.c2")])? - .limit(0, Some(3))? - .show() - .await?; - Ok(()) -} - -//SELECT t1.c1, t1.c2 FROM t1 WHERE EXISTS (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; -async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { - ctx.table("t1") - .await? - .filter(exists(Arc::new( - ctx.table("t2") - .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .select(vec![col("t2.c2")])? - .into_unoptimized_plan(), - )))? - .select(vec![col("t1.c1"), col("t1.c2")])? - .limit(0, Some(3))? - .show() - .await?; - Ok(()) -} - -pub async fn register_aggregate_test_data( - name: &str, - ctx: &SessionContext, -) -> Result<()> { - let testdata = arrow_test_data(); - ctx.register_csv( - name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::default(), - ) - .await?; - Ok(()) -} diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs index 985cab703a5cb..5ac3ee6187d11 100644 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ b/datafusion-examples/examples/deserialize_to_struct.rs @@ -15,62 +15,136 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::AsArray; +use arrow::array::{AsArray, PrimitiveArray}; use arrow::datatypes::{Float64Type, Int32Type}; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_common::assert_batches_eq; use futures::StreamExt; -/// This example shows that it is possible to convert query results into Rust structs . +/// This example shows how to convert query results into Rust structs by using +/// the Arrow APIs to convert the results into Rust native types. +/// +/// This is a bit tricky initially as the results are returned as columns stored +/// as [ArrayRef] +/// +/// [ArrayRef]: arrow::array::ArrayRef #[tokio::main] async fn main() -> Result<()> { - let data_list = Data::new().await?; - println!("{data_list:#?}"); - Ok(()) -} + // Run a query that returns two columns of data + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + let df = ctx + .sql("SELECT int_col, double_col FROM alltypes_plain") + .await?; -#[derive(Debug)] -struct Data { - #[allow(dead_code)] - int_col: i32, - #[allow(dead_code)] - double_col: f64, -} + // print out the results showing we have an int32 and a float64 column + let results = df.clone().collect().await?; + assert_batches_eq!( + [ + "+---------+------------+", + "| int_col | double_col |", + "+---------+------------+", + "| 0 | 0.0 |", + "| 1 | 10.1 |", + "| 0 | 0.0 |", + "| 1 | 10.1 |", + "| 0 | 0.0 |", + "| 1 | 10.1 |", + "| 0 | 0.0 |", + "| 1 | 10.1 |", + "+---------+------------+", + ], + &results + ); -impl Data { - pub async fn new() -> Result> { - // this group is almost the same as the one you find it in parquet_sql.rs - let ctx = SessionContext::new(); + // We will now convert the query results into a Rust struct + let mut stream = df.execute_stream().await?; + let mut list = vec![]; - let testdata = datafusion::test_util::parquet_test_data(); + // DataFusion produces data in chunks called `RecordBatch`es which are + // typically 8000 rows each. This loop processes each `RecordBatch` as it is + // produced by the query plan and adds it to the list + while let Some(b) = stream.next().await.transpose()? { + // Each `RecordBatch` has one or more columns. Each column is stored as + // an `ArrayRef`. To interact with data using Rust native types we need to + // convert these `ArrayRef`s into concrete array types using APIs from + // the arrow crate. - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; + // In this case, we know that each batch has two columns of the Arrow + // types Int32 and Float64, so first we cast the two columns to the + // appropriate Arrow PrimitiveArray (this is a fast / zero-copy cast).: + let int_col: &PrimitiveArray = b.column(0).as_primitive(); + let float_col: &PrimitiveArray = b.column(1).as_primitive(); - let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") - .await?; + // With PrimitiveArrays, we can access to the values as native Rust + // types i32 and f64, and forming the desired `Data` structs + for (i, f) in int_col.values().iter().zip(float_col.values()) { + list.push(Data { + int_col: *i, + double_col: *f, + }) + } + } - df.clone().show().await?; + // Finally, we have the results in the list of Rust structs + let res = format!("{list:#?}"); + assert_eq!( + res, + r#"[ + Data { + int_col: 0, + double_col: 0.0, + }, + Data { + int_col: 1, + double_col: 10.1, + }, + Data { + int_col: 0, + double_col: 0.0, + }, + Data { + int_col: 1, + double_col: 10.1, + }, + Data { + int_col: 0, + double_col: 0.0, + }, + Data { + int_col: 1, + double_col: 10.1, + }, + Data { + int_col: 0, + double_col: 0.0, + }, + Data { + int_col: 1, + double_col: 10.1, + }, +]"# + ); - let mut stream = df.execute_stream().await?; - let mut list = vec![]; - while let Some(b) = stream.next().await.transpose()? { - let int_col = b.column(0).as_primitive::(); - let float_col = b.column(1).as_primitive::(); + // Use the fields in the struct to avoid clippy complaints + let int_sum = list.iter().fold(0, |acc, x| acc + x.int_col); + let double_sum = list.iter().fold(0.0, |acc, x| acc + x.double_col); + assert_eq!(int_sum, 4); + assert_eq!(double_sum, 40.4); - for (i, f) in int_col.values().iter().zip(float_col.values()) { - list.push(Data { - int_col: *i, - double_col: *f, - }) - } - } + Ok(()) +} - Ok(list) - } +/// This is target struct where we want the query results. +#[derive(Debug)] +struct Data { + int_col: i32, + double_col: f64, } diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 85a79a1a56048..2f9b5697c243d 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{BooleanArray, Int32Array}; +use arrow::array::{BooleanArray, Int32Array, Int8Array}; use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; @@ -28,12 +28,14 @@ use datafusion::functions_aggregate::first_last::first_value_udaf; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; use datafusion::prelude::*; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; +use datafusion_optimizer::analyzer::type_coercion::TypeCoercionRewriter; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -51,6 +53,7 @@ use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; /// 4. Simplify expressions: [`simplify_demo`] /// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] /// 6. Get the types of the expressions: [`expression_type_demo`] +/// 7. Apply type coercion to expressions: [`type_coercion_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the @@ -80,6 +83,9 @@ async fn main() -> Result<()> { // See how to determine the data types of expressions expression_type_demo()?; + // See how to type coerce expressions. + type_coercion_demo()?; + Ok(()) } @@ -316,3 +322,103 @@ fn expression_type_demo() -> Result<()> { Ok(()) } + +/// This function demonstrates how to apply type coercion to expressions, such as binary expressions. +/// +/// In most cases, manual type coercion is not required since DataFusion handles it implicitly. +/// However, certain projects may construct `ExecutionPlan`s directly from DataFusion logical expressions, +/// bypassing the construction of DataFusion logical plans. +/// Since constructing `ExecutionPlan`s from logical expressions does not automatically apply type coercion, +/// you may need to handle type coercion manually in these cases. +/// +/// The codes in this function shows various ways to perform type coercion on expressions: +/// 1. Using `SessionContext::create_physical_expr` +/// 2. Using `ExprSimplifier::coerce` +/// 3. Using `TreeNodeRewriter::rewrite` based on `TypeCoercionRewriter` +/// 4. Using `TreeNode::transform` +/// +/// Note, this list may not be complete and there may be other methods to apply type coercion to expressions. +fn type_coercion_demo() -> Result<()> { + // Creates a record batch for demo. + let df_schema = DFSchema::from_unqualified_fields( + vec![Field::new("a", DataType::Int8, false)].into(), + HashMap::new(), + )?; + let i8_array = Int8Array::from_iter_values(vec![0, 1, 2]); + let batch = RecordBatch::try_new( + Arc::new(df_schema.as_arrow().to_owned()), + vec![Arc::new(i8_array) as _], + )?; + + // Constructs a binary expression for demo. + // By default, the literal `1` is translated into the Int32 type and cannot be directly compared with the Int8 type. + let expr = col("a").gt(lit(1)); + + // Evaluation with an expression that has not been type coerced cannot succeed. + let props = ExecutionProps::default(); + let physical_expr = + datafusion_physical_expr::create_physical_expr(&expr, &df_schema, &props)?; + let e = physical_expr.evaluate(&batch).unwrap_err(); + assert!(e + .find_root() + .to_string() + .contains("Invalid comparison operation: Int8 > Int32")); + + // 1. Type coercion with `SessionContext::create_physical_expr` which implicitly applies type coercion before constructing the physical expr. + let physical_expr = + SessionContext::new().create_physical_expr(expr.clone(), &df_schema)?; + assert!(physical_expr.evaluate(&batch).is_ok()); + + // 2. Type coercion with `ExprSimplifier::coerce`. + let context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema.clone())); + let simplifier = ExprSimplifier::new(context); + let coerced_expr = simplifier.coerce(expr.clone(), &df_schema)?; + let physical_expr = datafusion_physical_expr::create_physical_expr( + &coerced_expr, + &df_schema, + &props, + )?; + assert!(physical_expr.evaluate(&batch).is_ok()); + + // 3. Type coercion with `TypeCoercionRewriter`. + let coerced_expr = expr + .clone() + .rewrite(&mut TypeCoercionRewriter::new(&df_schema))? + .data; + let physical_expr = datafusion_physical_expr::create_physical_expr( + &coerced_expr, + &df_schema, + &props, + )?; + assert!(physical_expr.evaluate(&batch).is_ok()); + + // 4. Apply explicit type coercion by manually rewriting the expression + let coerced_expr = expr + .transform(|e| { + // Only type coerces binary expressions. + let Expr::BinaryExpr(e) = e else { + return Ok(Transformed::no(e)); + }; + if let Expr::Column(ref col_expr) = *e.left { + let field = df_schema.field_with_name(None, col_expr.name())?; + let cast_to_type = field.data_type(); + let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( + e.left, + e.op, + Box::new(coerced_right), + )))) + } else { + Ok(Transformed::no(Expr::BinaryExpr(e))) + } + })? + .data; + let physical_expr = datafusion_physical_expr::create_physical_expr( + &coerced_expr, + &df_schema, + &props, + )?; + assert!(physical_expr.evaluate(&batch).is_ok()); + + Ok(()) +} diff --git a/datafusion-examples/examples/ffi/README.md b/datafusion-examples/examples/ffi/README.md new file mode 100644 index 0000000000000..f29e0012f3180 --- /dev/null +++ b/datafusion-examples/examples/ffi/README.md @@ -0,0 +1,48 @@ + + +# Example FFI Usage + +The purpose of these crates is to provide an example of how one can use the +DataFusion Foreign Function Interface (FFI). See [API Docs] for detailed +usage. + +This example is broken into three crates. + +- `ffi_module_interface` is a common library to be shared by both the module + to be loaded and the program that will load it. It defines how the module + is to be structured. +- `ffi_example_table_provider` creates a library to exposes the module. +- `ffi_module_loader` is an example program that loads the module, gets data + from it, and displays this data to the user. + +## Building and running + +In order for the program to run successfully, the module to be loaded must be +built first. This example expects both the module and the program to be +built using the same build mode (debug or release). + +```shell +cd ffi_example_table_provider +cargo build +cd ../ffi_module_loader +cargo run +``` + +[api docs]: http://docs.rs/datafusion-ffi/latest diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml new file mode 100644 index 0000000000000..52efdb7461abe --- /dev/null +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "ffi_example_table_provider" +version = "0.1.0" +edition = { workspace = true } +publish = false + +[dependencies] +abi_stable = "0.11.3" +arrow = { workspace = true } +arrow-array = { workspace = true } +arrow-schema = { workspace = true } +datafusion = { workspace = true } +datafusion-ffi = { workspace = true } +ffi_module_interface = { path = "../ffi_module_interface" } + +[lib] +name = "ffi_example_table_provider" +crate-type = ["cdylib", 'rlib'] diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs new file mode 100644 index 0000000000000..c7eea8a8070b1 --- /dev/null +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; +use arrow_array::RecordBatch; +use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + common::record_batch, + datasource::MemTable, +}; +use datafusion_ffi::table_provider::FFI_TableProvider; +use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; + +fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { + let end_value = start_value + num_values as i32; + let a_vals: Vec = (start_value..end_value).collect(); + let b_vals: Vec = a_vals.iter().map(|v| *v as f64).collect(); + + record_batch!(("a", Int32, a_vals), ("b", Float64, b_vals)).unwrap() +} + +/// Here we only wish to create a simple table provider as an example. +/// We create an in-memory table and convert it to it's FFI counterpart. +extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true), + ])); + + // It is useful to create these as multiple record batches + // so that we can demonstrate the FFI stream. + let batches = vec![ + create_record_batch(1, 5), + create_record_batch(6, 1), + create_record_batch(7, 5), + ]; + + let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); + + FFI_TableProvider::new(Arc::new(table_provider), true) +} + +#[export_root_module] +/// This defines the entry point for using the module. +pub fn get_simple_memory_table() -> TableProviderModuleRef { + TableProviderModule { + create_table: construct_simple_table_provider, + } + .leak_into_prefix() +} diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml new file mode 100644 index 0000000000000..612a219324763 --- /dev/null +++ b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "ffi_module_interface" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +abi_stable = "0.11.3" +datafusion-ffi = { workspace = true } diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs new file mode 100644 index 0000000000000..88690e9297135 --- /dev/null +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::{ + declare_root_module_statics, + library::{LibraryError, RootModule}, + package_version_strings, + sabi_types::VersionStrings, + StableAbi, +}; +use datafusion_ffi::table_provider::FFI_TableProvider; + +#[repr(C)] +#[derive(StableAbi)] +#[sabi(kind(Prefix(prefix_ref = TableProviderModuleRef)))] +/// This struct defines the module interfaces. It is to be shared by +/// both the module loading program and library that implements the +/// module. It is possible to move this definition into the loading +/// program and reference it in the modules, but this example shows +/// how a user may wish to separate these concerns. +pub struct TableProviderModule { + /// Constructs the table provider + pub create_table: extern "C" fn() -> FFI_TableProvider, +} + +impl RootModule for TableProviderModuleRef { + declare_root_module_statics! {TableProviderModuleRef} + const BASE_NAME: &'static str = "ffi_example_table_provider"; + const NAME: &'static str = "ffi_example_table_provider"; + const VERSION_STRINGS: VersionStrings = package_version_strings!(); + + fn initialization(self) -> Result { + Ok(self) + } +} diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml new file mode 100644 index 0000000000000..028a366aab1c0 --- /dev/null +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "ffi_module_loader" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +abi_stable = "0.11.3" +datafusion = { workspace = true } +datafusion-ffi = { workspace = true } +ffi_module_interface = { path = "../ffi_module_interface" } +tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs new file mode 100644 index 0000000000000..6e376ca866e8f --- /dev/null +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::{ + error::{DataFusionError, Result}, + prelude::SessionContext, +}; + +use abi_stable::library::{development_utils::compute_library_path, RootModule}; +use datafusion_ffi::table_provider::ForeignTableProvider; +use ffi_module_interface::TableProviderModuleRef; + +#[tokio::main] +async fn main() -> Result<()> { + // Find the location of the library. This is specific to the build environment, + // so you will need to change the approach here based on your use case. + let target: &std::path::Path = "../../../../target/".as_ref(); + let library_path = compute_library_path::(target) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + // Load the module + let table_provider_module = + TableProviderModuleRef::load_from_directory(&library_path) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + // By calling the code below, the table provided will be created within + // the module's code. + let ffi_table_provider = + table_provider_module + .create_table() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_table".to_string(), + ))?(); + + // In order to access the table provider within this executable, we need to + // turn it into a `ForeignTableProvider`. + let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); + + let ctx = SessionContext::new(); + + // Display the data to show the full cycle works. + ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + let df = ctx.table("external_table").await?; + df.show().await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index f9d1b8029f04b..cc5f43746ddfb 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -105,7 +105,7 @@ impl FlightService for FlightServiceImpl { } // add an initial FlightData message that sends schema - let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default(); + let options = arrow::ipc::writer::IpcWriteOptions::default(); let schema_flight_data = SchemaAsIpc::new(&schema, &options); let mut flights = vec![FlightData::from(schema_flight_data)]; diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index f57b3bf604048..58ffa060ebaad 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -26,7 +26,9 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{exec_err, internal_err, DataFusionError}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, +}; /// This example shows how to utilize [FunctionFactory] to implement simple /// SQL-macro like functions using a `CREATE FUNCTION` statement. The same @@ -34,7 +36,7 @@ use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature} /// /// Apart from [FunctionFactory], this example covers /// [ScalarUDFImpl::simplify()] which is often used at the same time, to replace -/// a function call with another expression at rutime. +/// a function call with another expression at runtime. /// /// This example is rather simple and does not cover all cases required for a /// real implementation. @@ -121,7 +123,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } @@ -132,9 +134,9 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(self.return_type.clone()) } - fn invoke( + fn invoke_with_args( &self, - _args: &[datafusion_expr::ColumnarValue], + _args: ScalarFunctionArgs, ) -> Result { // Since this function is always simplified to another expression, it // should never actually be invoked diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs deleted file mode 100644 index 7bc431c5c5eef..0000000000000 --- a/datafusion-examples/examples/json_opener.rs +++ /dev/null @@ -1,88 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{sync::Arc, vec}; - -use arrow_schema::{DataType, Field, Schema}; -use datafusion::{ - assert_batches_eq, - datasource::{ - file_format::file_compression_type::FileCompressionType, - listing::PartitionedFile, - object_store::ObjectStoreUrl, - physical_plan::{FileScanConfig, FileStream, JsonOpener}, - }, - error::Result, - physical_plan::metrics::ExecutionPlanMetricsSet, -}; - -use futures::StreamExt; -use object_store::ObjectStore; - -/// This example demonstrates a scanning against an Arrow data source (JSON) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - let object_store = object_store::memory::InMemory::new(); - let path = object_store::path::Path::from("demo.json"); - let data = bytes::Bytes::from( - r#"{"num":5,"str":"test"} - {"num":2,"str":"hello"} - {"num":4,"str":"foo"}"#, - ); - object_store.put(&path, data.into()).await.unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("num", DataType::Int64, false), - Field::new("str", DataType::Utf8, false), - ])); - - let projected = Arc::new(schema.clone().project(&[1, 0])?); - - let opener = JsonOpener::new( - 8192, - projected, - FileCompressionType::UNCOMPRESSED, - Arc::new(object_store), - ); - - let scan_config = - FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema.clone()) - .with_projection(Some(vec![1, 0])) - .with_limit(Some(5)) - .with_file(PartitionedFile::new(path.to_string(), 10)); - - let result = - FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new()) - .unwrap() - .map(|b| b.unwrap()) - .collect::>() - .await; - assert_batches_eq!( - &[ - "+-------+-----+", - "| str | num |", - "+-------+-----+", - "| test | 5 |", - "| hello | 2 |", - "| foo | 4 |", - "+-------+-----+", - ], - &result - ); - Ok(()) -} diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs deleted file mode 100644 index 5cce578039e74..0000000000000 --- a/datafusion-examples/examples/memtable.rs +++ /dev/null @@ -1,74 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::array::{UInt64Array, UInt8Array}; -use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::MemTable; -use datafusion::error::Result; -use datafusion::prelude::SessionContext; -use std::sync::Arc; -use std::time::Duration; -use tokio::time::timeout; - -/// This example demonstrates executing a simple query against a Memtable -#[tokio::main] -async fn main() -> Result<()> { - let mem_table = create_memtable()?; - - // create local execution context - let ctx = SessionContext::new(); - - // Register the in-memory table containing the data - ctx.register_table("users", Arc::new(mem_table))?; - - let dataframe = ctx.sql("SELECT * FROM users;").await?; - - timeout(Duration::from_secs(10), async move { - let result = dataframe.collect().await.unwrap(); - let record_batch = result.first().unwrap(); - - assert_eq!(1, record_batch.column(0).len()); - dbg!(record_batch.columns()); - }) - .await - .unwrap(); - - Ok(()) -} - -fn create_memtable() -> Result { - MemTable::try_new(get_schema(), vec![vec![create_record_batch()?]]) -} - -fn create_record_batch() -> Result { - let id_array = UInt8Array::from(vec![1]); - let account_array = UInt64Array::from(vec![9000]); - - Ok(RecordBatch::try_new( - get_schema(), - vec![Arc::new(id_array), Arc::new(account_array)], - ) - .unwrap()) -} - -fn get_schema() -> SchemaRef { - SchemaRef::new(Schema::new(vec![ - Field::new("id", DataType::UInt8, false), - Field::new("bank_account", DataType::UInt64, true), - ])) -} diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 5f18bfe244449..9fd8b0133481a 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -146,7 +146,7 @@ impl MyOptimizerRule { // Closure called for each sub tree match expr { Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => { - // destruture the expression + // destructure the expression let BinaryExpr { left, op: _, right } = binary_expr; // rewrite to `my_eq(left, right)` let udf = ScalarUDF::new_from_impl(MyEq::new()); @@ -205,7 +205,11 @@ impl ScalarUDFImpl for MyEq { Ok(DataType::Boolean) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { // this example simply returns "true" which is not what a real // implementation would do. Ok(ColumnarValue::from(ScalarValue::from(true))) diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs deleted file mode 100644 index b0d3922a32789..0000000000000 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ /dev/null @@ -1,112 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::path::Path; -use std::sync::Arc; - -use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::listing::ListingOptions; -use datafusion::prelude::*; - -use object_store::local::LocalFileSystem; - -/// This example demonstrates executing a simple query against an Arrow data source (a directory -/// with multiple Parquet files) and fetching results. The query is run twice, once showing -/// how to used `register_listing_table` with an absolute path, and once registering an -/// ObjectStore to use a relative path. -#[tokio::main] -async fn main() -> Result<(), Box> { - // create local execution context - let ctx = SessionContext::new(); - - let test_data = datafusion::test_util::parquet_test_data(); - - // Configure listing options - let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - // This is a workaround for this example since `test_data` contains - // many different parquet different files, - // in practice use FileType::PARQUET.get_ext(). - .with_file_extension("alltypes_plain.parquet"); - - // First example were we use an absolute path, which requires no additional setup. - ctx.register_listing_table( - "my_table", - &format!("file://{test_data}/"), - listing_options.clone(), - None, - None, - ) - .await - .unwrap(); - - // execute the query - let df = ctx - .sql( - "SELECT * \ - FROM my_table \ - LIMIT 1", - ) - .await?; - - // print the results - df.show().await?; - - // Second example were we temporarily move into the test data's parent directory and - // simulate a relative path, this requires registering an ObjectStore. - let cur_dir = std::env::current_dir()?; - - let test_data_path = Path::new(&test_data); - let test_data_path_parent = test_data_path - .parent() - .ok_or("test_data path needs a parent")?; - - std::env::set_current_dir(test_data_path_parent)?; - - let local_fs = Arc::new(LocalFileSystem::default()); - - let u = url::Url::parse("file://./")?; - ctx.register_object_store(&u, local_fs); - - // Register a listing table - this will use all files in the directory as data sources - // for the query - ctx.register_listing_table( - "relative_table", - "./data", - listing_options.clone(), - None, - None, - ) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT * \ - FROM relative_table \ - LIMIT 1", - ) - .await?; - - // print the results - df.show().await?; - - // Reset the current directory - std::env::set_current_dir(cur_dir)?; - - Ok(()) -} diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index e23e5accae397..d8f0778e19e36 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> { assert_batches_eq!( &[ - "+------------+----------------------+", - "| double_col | sum(?table?.int_col) |", - "+------------+----------------------+", - "| 10.1 | 4 |", - "+------------+----------------------+", + "+------------+-------------+", + "| double_col | sum_int_col |", + "+------------+-------------+", + "| 10.1 | 4 |", + "+------------+-------------+", ], &result ); diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index 8ea7c2951223d..cf1202498416a 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -16,11 +16,25 @@ // under the License. use datafusion::error::Result; - +use datafusion::logical_expr::sqlparser::ast::Statement; use datafusion::prelude::*; use datafusion::sql::unparser::expr_to_sql; +use datafusion_common::DFSchemaRef; +use datafusion_expr::{ + Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, + UserDefinedLogicalNodeCore, +}; +use datafusion_sql::unparser::ast::{ + DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, +}; use datafusion_sql::unparser::dialect::CustomDialectBuilder; +use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; +use datafusion_sql::unparser::extension_unparser::{ + UnparseToStatementResult, UnparseWithinStatementResult, +}; use datafusion_sql::unparser::{plan_to_sql, Unparser}; +use std::fmt; +use std::sync::Arc; /// This example demonstrates the programmatic construction of SQL strings using /// the DataFusion Expr [`Expr`] and LogicalPlan [`LogicalPlan`] API. @@ -44,6 +58,10 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; /// /// 5. [`round_trip_plan_to_sql_demo`]: Create a logical plan from a SQL string, modify it using the /// DataFrames API and convert it back to a sql string. +/// +/// 6. [`unparse_my_logical_plan_as_statement`]: Create a custom logical plan and unparse it as a statement. +/// +/// 7. [`unparse_my_logical_plan_as_subquery`]: Create a custom logical plan and unparse it as a subquery. #[tokio::main] async fn main() -> Result<()> { @@ -53,6 +71,8 @@ async fn main() -> Result<()> { simple_expr_to_sql_demo_escape_mysql_style()?; simple_plan_to_sql_demo().await?; round_trip_plan_to_sql_demo().await?; + unparse_my_logical_plan_as_statement().await?; + unparse_my_logical_plan_as_subquery().await?; Ok(()) } @@ -65,7 +85,7 @@ fn simple_expr_to_sql_demo() -> Result<()> { Ok(()) } -/// DataFusioon can remove parentheses when converting an expression to SQL. +/// DataFusion can remove parentheses when converting an expression to SQL. /// Note that output is intended for humans, not for other SQL engines, /// as difference in precedence rules can cause expressions to be parsed differently. fn simple_expr_to_pretty_sql_demo() -> Result<()> { @@ -152,3 +172,144 @@ async fn round_trip_plan_to_sql_demo() -> Result<()> { Ok(()) } + +#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)] +struct MyLogicalPlan { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for MyLogicalPlan { + fn name(&self) -> &str { + "MyLogicalPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MyLogicalPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(MyLogicalPlan { + input: inputs.into_iter().next().unwrap(), + }) + } +} + +struct PlanToStatement {} +impl UserDefinedLogicalNodeUnparser for PlanToStatement { + fn unparse_to_statement( + &self, + node: &dyn UserDefinedLogicalNode, + unparser: &Unparser, + ) -> Result { + if let Some(plan) = node.as_any().downcast_ref::() { + let input = unparser.plan_to_sql(&plan.input)?; + Ok(UnparseToStatementResult::Modified(input)) + } else { + Ok(UnparseToStatementResult::Unmodified) + } + } +} + +/// This example demonstrates how to unparse a custom logical plan as a statement. +/// The custom logical plan is a simple extension of the logical plan that reads from a parquet file. +/// It can be unparse as a statement that reads from the same parquet file. +async fn unparse_my_logical_plan_as_statement() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let inner_plan = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await? + .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .into_unoptimized_plan(); + + let node = Arc::new(MyLogicalPlan { input: inner_plan }); + + let my_plan = LogicalPlan::Extension(Extension { node }); + let unparser = + Unparser::default().with_extension_unparsers(vec![Arc::new(PlanToStatement {})]); + let sql = unparser.plan_to_sql(&my_plan)?.to_string(); + assert_eq!( + sql, + r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + ); + Ok(()) +} + +struct PlanToSubquery {} +impl UserDefinedLogicalNodeUnparser for PlanToSubquery { + fn unparse( + &self, + node: &dyn UserDefinedLogicalNode, + unparser: &Unparser, + _query: &mut Option<&mut QueryBuilder>, + _select: &mut Option<&mut SelectBuilder>, + relation: &mut Option<&mut RelationBuilder>, + ) -> Result { + if let Some(plan) = node.as_any().downcast_ref::() { + let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else { + return Ok(UnparseWithinStatementResult::Unmodified); + }; + let mut derived_builder = DerivedRelationBuilder::default(); + derived_builder.subquery(input); + derived_builder.lateral(false); + if let Some(rel) = relation { + rel.derived(derived_builder); + } + } + Ok(UnparseWithinStatementResult::Modified) + } +} + +/// This example demonstrates how to unparse a custom logical plan as a subquery. +/// The custom logical plan is a simple extension of the logical plan that reads from a parquet file. +/// It can be unparse as a subquery that reads from the same parquet file, with some columns projected. +async fn unparse_my_logical_plan_as_subquery() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let inner_plan = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await? + .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .into_unoptimized_plan(); + + let node = Arc::new(MyLogicalPlan { input: inner_plan }); + + let my_plan = LogicalPlan::Extension(Extension { node }); + let plan = LogicalPlanBuilder::from(my_plan) + .project(vec![ + col("id").alias("my_id"), + col("int_col").alias("my_int"), + ])? + .build()?; + let unparser = + Unparser::default().with_extension_unparsers(vec![Arc::new(PlanToSubquery {})]); + let sql = unparser.plan_to_sql(&plan)?.to_string(); + assert_eq!( + sql, + "SELECT \"?table?\".id AS my_id, \"?table?\".int_col AS my_int FROM \ + (SELECT \"?table?\".id, \"?table?\".int_col, \"?table?\".double_col, \"?table?\".date_string_col FROM \"?table?\")", + ); + Ok(()) +} diff --git a/datafusion-examples/examples/regexp.rs b/datafusion-examples/examples/regexp.rs index 02e74bae22af7..5419efd2faea2 100644 --- a/datafusion-examples/examples/regexp.rs +++ b/datafusion-examples/examples/regexp.rs @@ -148,7 +148,7 @@ async fn main() -> Result<()> { // invalid flags will result in an error let result = ctx - .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', 4010, 'g')") + .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', '4010', 'g')") .await? .collect() .await; diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/remote_catalog.rs new file mode 100644 index 0000000000000..38629328d71c4 --- /dev/null +++ b/datafusion-examples/examples/remote_catalog.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// This example shows how to implement the DataFusion [`CatalogProvider`] API +/// for catalogs that are remote (require network access) and/or offer only +/// asynchronous APIs such as [Polaris], [Unity], and [Hive]. +/// +/// Integrating with this catalogs is a bit more complex than with local +/// catalogs because calls like `ctx.sql("SELECT * FROM db.schm.tbl")` may need +/// to perform remote network requests, but many Catalog APIs are synchronous. +/// See the documentation on [`CatalogProvider`] for more details. +/// +/// [`CatalogProvider`]: datafusion_catalog::CatalogProvider +/// +/// [Polaris]: https://github.com/apache/polaris +/// [Unity]: https://github.com/unitycatalog/unitycatalog +/// [Hive]: https://hive.apache.org/ +use arrow::array::record_batch; +use arrow_schema::{Field, Fields, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::catalog::TableProvider; +use datafusion::common::Result; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{DataFrame, SessionContext}; +use datafusion_catalog::{AsyncSchemaProvider, Session}; +use datafusion_common::{assert_batches_eq, internal_datafusion_err, plan_err}; +use datafusion_expr::{Expr, TableType}; +use futures::TryStreamExt; +use std::any::Any; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + // As always, we create a session context to interact with DataFusion + let ctx = SessionContext::new(); + + // Make a connection to the remote catalog, asynchronously, and configure it + let remote_catalog_interface = Arc::new(RemoteCatalogInterface::connect().await?); + + // Create an adapter to provide the AsyncSchemaProvider interface to DataFusion + // based on our remote catalog interface + let remote_catalog_adapter = RemoteCatalogDatafusionAdapter(remote_catalog_interface); + + // Here is a query that selects data from a table in the remote catalog. + let sql = "SELECT * from remote_schema.remote_table"; + + // The `SessionContext::sql` interface is async, but it does not + // support asynchronous access to catalogs, so we cannot register our schema provider + // directly and the following query fails to find our table. + let results = ctx.sql(sql).await; + assert_eq!( + results.unwrap_err().to_string(), + "Error during planning: table 'datafusion.remote_schema.remote_table' not found" + ); + + // Instead, to use a remote catalog, we must use lower level APIs on + // SessionState (what `SessionContext::sql` does internally). + let state = ctx.state(); + + // First, parse the SQL (but don't plan it / resolve any table references) + let dialect = state.config().options().sql_parser.dialect.as_str(); + let statement = state.sql_to_statement(sql, dialect)?; + + // Find all `TableReferences` in the parsed queries. These correspond to the + // tables referred to by the query (in this case + // `remote_schema.remote_table`) + let references = state.resolve_table_references(&statement)?; + + // Now we can asynchronously resolve the table references to get a cached catalog + // that we can use for our query + let resolved_catalog = remote_catalog_adapter + .resolve(&references, state.config(), "datafusion", "remote_schema") + .await?; + + // This resolved catalog only makes sense for this query and so we create a clone + // of the session context with the resolved catalog + let query_ctx = ctx.clone(); + + query_ctx + .catalog("datafusion") + .ok_or_else(|| internal_datafusion_err!("default catalog was not installed"))? + .register_schema("remote_schema", resolved_catalog)?; + + // We can now continue planning the query with this new query-specific context that + // contains our cached catalog + let query_state = query_ctx.state(); + + let plan = query_state.statement_to_plan(statement).await?; + let results = DataFrame::new(state, plan).collect().await?; + assert_batches_eq!( + [ + "+----+-------+", + "| id | name |", + "+----+-------+", + "| 1 | alpha |", + "| 2 | beta |", + "| 3 | gamma |", + "+----+-------+", + ], + &results + ); + + Ok(()) +} + +/// This is an example of an API that interacts with a remote catalog. +/// +/// Specifically, its APIs are all `async` and thus can not be used by +/// [`SchemaProvider`] or [`TableProvider`] directly. +#[derive(Debug)] +struct RemoteCatalogInterface {} + +impl RemoteCatalogInterface { + /// Establish a connection to the remote catalog + pub async fn connect() -> Result { + // In a real implementation this method might connect to a remote + // catalog, validate credentials, cache basic information, etc + Ok(Self {}) + } + + /// Fetches information for a specific table + pub async fn table_info(&self, name: &str) -> Result> { + if name != "remote_table" { + return Ok(None); + } + + // In this example, we'll model a remote table with columns "id" and + // "name" + // + // A real remote catalog would make a network call to fetch this + // information from a remote source. + let schema = Schema::new(Fields::from(vec![ + Field::new("id", arrow::datatypes::DataType::Int32, false), + Field::new("name", arrow::datatypes::DataType::Utf8, false), + ])); + Ok(Some(Arc::new(schema))) + } + + /// Fetches data for a table from a remote data source + pub async fn read_data(&self, name: &str) -> Result { + if name != "remote_table" { + return plan_err!("Remote table not found: {}", name); + } + + // In a real remote catalog this call would likely perform network IO to + // open and begin reading from a remote datasource, prefetching + // information, etc. + // + // In this example we are just demonstrating how the API works so simply + // return back some static data as a stream. + let batch = record_batch!( + ("id", Int32, [1, 2, 3]), + ("name", Utf8, ["alpha", "beta", "gamma"]) + ) + .unwrap(); + let schema = batch.schema(); + + let stream = futures::stream::iter([Ok(batch)]); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } +} + +/// Implements an async version of the DataFusion SchemaProvider API for tables +/// stored in a remote catalog. +struct RemoteCatalogDatafusionAdapter(Arc); + +#[async_trait] +impl AsyncSchemaProvider for RemoteCatalogDatafusionAdapter { + async fn table(&self, name: &str) -> Result>> { + // Fetch information about the table from the remote catalog + // + // Note that a real remote catalog interface could return more + // information, but at the minimum, DataFusion requires the + // table's schema for planing. + Ok(self.0.table_info(name).await?.map(|schema| { + Arc::new(RemoteTable::new(Arc::clone(&self.0), name, schema)) + as Arc + })) + } +} + +/// Represents the information about a table retrieved from the remote catalog +#[derive(Debug)] +struct RemoteTable { + /// connection to the remote catalog + remote_catalog_interface: Arc, + name: String, + schema: SchemaRef, +} + +impl RemoteTable { + pub fn new( + remote_catalog_interface: Arc, + name: impl Into, + schema: SchemaRef, + ) -> Self { + Self { + remote_catalog_interface, + name: name.into(), + schema, + } + } +} + +/// Implement the DataFusion Catalog API for [`RemoteTable`] +#[async_trait] +impl TableProvider for RemoteTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + // Note that `scan` is called once the plan begin execution, and thus is + // async. When interacting with remote data sources, this is the place + // to begin establishing the remote connections and interacting with the + // remote storage system. + // + // As this example is just modeling the catalog API interface, we buffer + // the results locally in memory for simplicity. + let batches = self + .remote_catalog_interface + .read_data(&self.name) + .await? + .try_collect() + .await?; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + self.schema.clone(), + projection.cloned(), + )?)) + } +} diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 140fc0d3572da..ef97bf9763b0f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -131,7 +131,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index 8cc4309d4d31b..5412e5da78a41 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -21,13 +21,13 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::Session; -use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_catalog::TableFunctionImpl; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{Expr, Scalar, TableType}; @@ -144,7 +144,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { let limit = exprs .get(1) .map(|expr| { - // try to simpify the expression, so 1+2 becomes 3, for example + // try to simplify the expression, so 1+2 becomes 3, for example let execution_props = ExecutionProps::new(); let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; @@ -181,8 +181,8 @@ fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec Self { - Self { - signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), - } - } -} - -impl AggregateUDFImpl for BetterAvgUdaf { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "better_avg" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - unimplemented!("should not be invoked") - } - - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { - unimplemented!("should not be invoked") - } - - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> Result> { - unimplemented!("should not get here"); - } - - // we override method, to return new expression which would substitute - // user defined function call - fn simplify(&self) -> Option { - // as an example for this functionality we replace UDF function - // with build-in aggregate function to illustrate the use - let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, - _: &dyn SimplifyInfo| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - avg_udaf(), - // yes it is the same Avg, `BetterAvgUdaf` was just a - // marketing pitch :) - aggregate_function.args, - aggregate_function.distinct, - aggregate_function.filter, - aggregate_function.order_by, - aggregate_function.null_treatment, - ))) - }; - - Some(Box::new(simplify)) - } -} - -// create local session context with an in-memory table -fn create_context() -> Result { - use datafusion::datasource::MemTable; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float32, false), - ])); - - // define data in two partitions - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), - Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), - ], - )?; - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![16.0])), - Arc::new(Float32Array::from(vec![2.0])), - ], - )?; - - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - Ok(ctx) -} - -#[tokio::main] -async fn main() -> Result<()> { - let ctx = create_context()?; - - let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); - ctx.register_udaf(better_avg.clone()); - - let result = ctx - .sql("SELECT better_avg(a) FROM t group by b") - .await? - .collect() - .await?; - - let expected = [ - "+-----------------+", - "| better_avg(t.a) |", - "+-----------------+", - "| 7.5 |", - "+-----------------+", - ]; - - assert_batches_eq!(expected, &result); - - let df = ctx.table("t").await?; - let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; - - let results = df.collect().await?; - let result = as_float64_array(results[0].column(0))?; - - assert!((result.value(0) - 7.5).abs() < f64::EPSILON); - println!("The average of [2,4,8,16] is {}", result.value(0)); - - Ok(()) -} diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs deleted file mode 100644 index 1ff629eef1966..0000000000000 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ /dev/null @@ -1,130 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::any::Any; - -use arrow_schema::{DataType, Field}; - -use datafusion::execution::context::SessionContext; -use datafusion::functions_aggregate::average::avg_udaf; -use datafusion::{error::Result, execution::options::CsvReadOptions}; -use datafusion_expr::function::{WindowFunctionSimplification, WindowUDFFieldArgs}; -use datafusion_expr::{ - expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature, - Volatility, WindowUDF, WindowUDFImpl, -}; - -/// This UDWF will show how to use the WindowUDFImpl::simplify() API -#[derive(Debug, Clone)] -struct SimplifySmoothItUdf { - signature: Signature, -} - -impl SimplifySmoothItUdf { - fn new() -> Self { - Self { - signature: Signature::exact( - // this function will always take one arguments of type f64 - vec![DataType::Float64], - // this function is deterministic and will always return the same - // result for the same input - Volatility::Immutable, - ), - } - } -} -impl WindowUDFImpl for SimplifySmoothItUdf { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "simplify_smooth_it" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn partition_evaluator(&self) -> Result> { - todo!() - } - - /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. - fn simplify(&self) -> Option { - let simplify = |window_function: datafusion_expr::expr::WindowFunction, - _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { - fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), - args: window_function.args, - partition_by: window_function.partition_by, - order_by: window_function.order_by, - window_frame: window_function.window_frame, - null_treatment: window_function.null_treatment, - })) - }; - - Some(Box::new(simplify)) - } - - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) - } -} - -// create local execution context with `cars.csv` registered as a table named `cars` -async fn create_context() -> Result { - // declare a new context. In spark API, this corresponds to a new spark SQL session - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); - - ctx.register_csv("cars", &csv_path, read_options).await?; - Ok(ctx) -} - -#[tokio::main] -async fn main() -> Result<()> { - let ctx = create_context().await?; - let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new()); - ctx.register_udwf(simplify_smooth_it.clone()); - - // Use SQL to run the new window function - let df = ctx.sql("SELECT * from cars").await?; - // print the results - df.show().await?; - - let df = ctx - .sql( - "SELECT \ - car, \ - speed, \ - simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ - time \ - from cars \ - ORDER BY \ - car", - ) - .await?; - // print the results - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 9a2aabaa79c2e..2158b8e4b016e 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -39,7 +39,7 @@ fn total_join_count(plan: &LogicalPlan) -> usize { // We can use the TreeNode API to walk over a LogicalPlan. plan.apply(|node| { // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { total += 1; } Ok(TreeNodeRecursion::Continue) @@ -89,7 +89,7 @@ fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { while let Some(node) = to_visit.pop() { // if we encounter a join, we know were at the root of the tree // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { let (group_count, inputs) = count_tree(node); total += group_count; groups.push(group_count); @@ -151,7 +151,7 @@ fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { } // any join we count - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { total += 1; Ok(TreeNodeRecursion::Continue) } else { diff --git a/datafusion-examples/examples/sql_query.rs b/datafusion-examples/examples/sql_query.rs new file mode 100644 index 0000000000000..a6e7fe91dda52 --- /dev/null +++ b/datafusion-examples/examples/sql_query.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::array::{UInt64Array, UInt8Array}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::ListingOptions; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, exec_datafusion_err}; +use object_store::local::LocalFileSystem; +use std::path::Path; +use std::sync::Arc; + +/// Examples of various ways to execute queries using SQL +/// +/// [`query_memtable`]: a simple query against a [`MemTable`] +/// [`query_parquet`]: a simple query against a directory with multiple Parquet files +/// +#[tokio::main] +async fn main() -> Result<()> { + query_memtable().await?; + query_parquet().await?; + Ok(()) +} + +/// Run a simple query against a [`MemTable`] +pub async fn query_memtable() -> Result<()> { + let mem_table = create_memtable()?; + + // create local execution context + let ctx = SessionContext::new(); + + // Register the in-memory table containing the data + ctx.register_table("users", Arc::new(mem_table))?; + + // running a SQL query results in a "DataFrame", which can be used + // to execute the query and collect the results + let dataframe = ctx.sql("SELECT * FROM users;").await?; + + // Calling 'show' on the dataframe will execute the query and + // print the results + dataframe.clone().show().await?; + + // calling 'collect' on the dataframe will execute the query and + // buffer the results into a vector of RecordBatch. There are other + // APIs on DataFrame for incrementally generating results (e.g. streaming) + let result = dataframe.collect().await?; + + // Use the assert_batches_eq macro to compare the results + assert_batches_eq!( + [ + "+----+--------------+", + "| id | bank_account |", + "+----+--------------+", + "| 1 | 9000 |", + "+----+--------------+", + ], + &result + ); + + Ok(()) +} + +fn create_memtable() -> Result { + MemTable::try_new(get_schema(), vec![vec![create_record_batch()?]]) +} + +fn create_record_batch() -> Result { + let id_array = UInt8Array::from(vec![1]); + let account_array = UInt64Array::from(vec![9000]); + + Ok(RecordBatch::try_new( + get_schema(), + vec![Arc::new(id_array), Arc::new(account_array)], + ) + .unwrap()) +} + +fn get_schema() -> SchemaRef { + SchemaRef::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, true), + ])) +} + +/// The simplest way to query parquet files is to use the +/// [`SessionContext::read_parquet`] API +/// +/// For more control, you can use the lower level [`ListingOptions`] and +/// [`ListingTable`] APIS +/// +/// This example shows how to use relative and absolute paths. +/// +/// [`ListingTable`]: datafusion::datasource::listing::ListingTable +async fn query_parquet() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + let test_data = datafusion::test_util::parquet_test_data(); + + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(true); + let listing_options = ListingOptions::new(Arc::new(file_format)) + // This is a workaround for this example since `test_data` contains + // many different parquet different files, + // in practice use FileType::PARQUET.get_ext(). + .with_file_extension("alltypes_plain.parquet"); + + // First example were we use an absolute path, which requires no additional setup. + ctx.register_listing_table( + "my_table", + &format!("file://{test_data}/"), + listing_options.clone(), + None, + None, + ) + .await + .unwrap(); + + // execute the query + let df = ctx + .sql( + "SELECT * \ + FROM my_table \ + LIMIT 1", + ) + .await?; + + // print the results + let results = df.collect().await?; + assert_batches_eq!( + [ + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + ], + &results); + + // Second example were we temporarily move into the test data's parent directory and + // simulate a relative path, this requires registering an ObjectStore. + let cur_dir = std::env::current_dir()?; + + let test_data_path = Path::new(&test_data); + let test_data_path_parent = test_data_path + .parent() + .ok_or(exec_datafusion_err!("test_data path needs a parent"))?; + + std::env::set_current_dir(test_data_path_parent)?; + + let local_fs = Arc::new(LocalFileSystem::default()); + + let u = url::Url::parse("file://./") + .map_err(|e| DataFusionError::External(Box::new(e)))?; + ctx.register_object_store(&u, local_fs); + + // Register a listing table - this will use all files in the directory as data sources + // for the query + ctx.register_listing_table( + "relative_table", + "./data", + listing_options.clone(), + None, + None, + ) + .await?; + + // execute the query + let df = ctx + .sql( + "SELECT * \ + FROM relative_table \ + LIMIT 1", + ) + .await?; + + // print the results + let results = df.collect().await?; + assert_batches_eq!( + [ + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + ], + &results); + + // Reset the current directory + std::env::set_current_dir(cur_dir)?; + + Ok(()) +} diff --git a/datafusion-examples/examples/to_date.rs b/datafusion-examples/examples/to_date.rs deleted file mode 100644 index 99ee555ffc17e..0000000000000 --- a/datafusion-examples/examples/to_date.rs +++ /dev/null @@ -1,60 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use datafusion::arrow::array::StringArray; -use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates how to use the to_date series -/// of functions in the DataFrame API as well as via sql. -#[tokio::main] -async fn main() -> Result<()> { - // define a schema. - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); - - // define data. - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(StringArray::from(vec![ - "2020-09-08T13:42:29Z", - "2020-09-08T13:42:29.190855-05:00", - "2020-08-09 12:13:29", - "2020-01-02", - ]))], - )?; - - // declare a new context. In spark API, this corresponds to a new spark SQLsession - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?; - - // use to_date function to convert col 'a' to timestamp type using the default parsing - let df = df.with_column("a", to_date(vec![col("a")]))?; - - let df = df.select_columns(&["a"])?; - - // print the results - df.show().await?; - - Ok(()) -} diff --git a/datafusion-testing b/datafusion-testing new file mode 160000 index 0000000000000..36283d195c728 --- /dev/null +++ b/datafusion-testing @@ -0,0 +1 @@ +Subproject commit 36283d195c728f26b16b517ba999fd62509b6649 diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index f9801352087d8..32a87cc7611cf 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -36,5 +36,8 @@ datafusion-expr = { workspace = true } datafusion-physical-plan = { workspace = true } parking_lot = { workspace = true } +[dev-dependencies] +tokio = { workspace = true } + [lints] workspace = true diff --git a/datafusion/catalog/LICENSE.txt b/datafusion/catalog/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/catalog/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/catalog/NOTICE.txt b/datafusion/catalog/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/catalog/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md new file mode 100644 index 0000000000000..5b201e736fdc4 --- /dev/null +++ b/datafusion/catalog/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Catalog + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog/src/async.rs b/datafusion/catalog/src/async.rs new file mode 100644 index 0000000000000..a244261b91e27 --- /dev/null +++ b/datafusion/catalog/src/async.rs @@ -0,0 +1,753 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion_common::{error::Result, not_impl_err, HashMap, TableReference}; +use datafusion_execution::config::SessionConfig; + +use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; + +/// A schema provider that looks up tables in a cache +/// +/// Instances are created by the [`AsyncSchemaProvider::resolve`] method +#[derive(Debug)] +struct ResolvedSchemaProvider { + owner_name: Option, + cached_tables: HashMap>, +} +#[async_trait] +impl SchemaProvider for ResolvedSchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn table_names(&self) -> Vec { + self.cached_tables.keys().cloned().collect() + } + + async fn table(&self, name: &str) -> Result>> { + Ok(self.cached_tables.get(name).cloned()) + } + + fn register_table( + &self, + name: String, + _table: Arc, + ) -> Result>> { + not_impl_err!( + "Attempt to register table '{name}' with ResolvedSchemaProvider which is not supported" + ) + } + + fn deregister_table(&self, name: &str) -> Result>> { + not_impl_err!("Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported") + } + + fn table_exist(&self, name: &str) -> bool { + self.cached_tables.contains_key(name) + } +} + +/// Helper class for building a [`ResolvedSchemaProvider`] +struct ResolvedSchemaProviderBuilder { + owner_name: String, + async_provider: Arc, + cached_tables: HashMap>>, +} +impl ResolvedSchemaProviderBuilder { + fn new(owner_name: String, async_provider: Arc) -> Self { + Self { + owner_name, + async_provider, + cached_tables: HashMap::new(), + } + } + + async fn resolve_table(&mut self, table_name: &str) -> Result<()> { + if !self.cached_tables.contains_key(table_name) { + let resolved_table = self.async_provider.table(table_name).await?; + self.cached_tables + .insert(table_name.to_string(), resolved_table); + } + Ok(()) + } + + fn finish(self) -> Arc { + let cached_tables = self + .cached_tables + .into_iter() + .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value))) + .collect(); + Arc::new(ResolvedSchemaProvider { + owner_name: Some(self.owner_name), + cached_tables, + }) + } +} + +/// A catalog provider that looks up schemas in a cache +/// +/// Instances are created by the [`AsyncCatalogProvider::resolve`] method +#[derive(Debug)] +struct ResolvedCatalogProvider { + cached_schemas: HashMap>, +} +impl CatalogProvider for ResolvedCatalogProvider { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema_names(&self) -> Vec { + self.cached_schemas.keys().cloned().collect() + } + + fn schema(&self, name: &str) -> Option> { + self.cached_schemas.get(name).cloned() + } +} + +/// Helper class for building a [`ResolvedCatalogProvider`] +struct ResolvedCatalogProviderBuilder { + cached_schemas: HashMap>, + async_provider: Arc, +} +impl ResolvedCatalogProviderBuilder { + fn new(async_provider: Arc) -> Self { + Self { + cached_schemas: HashMap::new(), + async_provider, + } + } + fn finish(self) -> Arc { + let cached_schemas = self + .cached_schemas + .into_iter() + .filter_map(|(key, maybe_value)| { + maybe_value.map(|value| (key, value.finish())) + }) + .collect(); + Arc::new(ResolvedCatalogProvider { cached_schemas }) + } +} + +/// A catalog provider list that looks up catalogs in a cache +/// +/// Instances are created by the [`AsyncCatalogProviderList::resolve`] method +#[derive(Debug)] +struct ResolvedCatalogProviderList { + cached_catalogs: HashMap>, +} +impl CatalogProviderList for ResolvedCatalogProviderList { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn register_catalog( + &self, + _name: String, + _catalog: Arc, + ) -> Option> { + unimplemented!("resolved providers cannot handle registration APIs") + } + + fn catalog_names(&self) -> Vec { + self.cached_catalogs.keys().cloned().collect() + } + + fn catalog(&self, name: &str) -> Option> { + self.cached_catalogs.get(name).cloned() + } +} + +/// A trait for schema providers that must resolve tables asynchronously +/// +/// The [`SchemaProvider::table`] method _is_ asynchronous. However, this is primarily for convenience and +/// it is not a good idea for this method to be slow as this will cause poor planning performance. +/// +/// It is a better idea to resolve the tables once and cache them in memory for the duration of +/// planning. This trait helps implement that pattern. +/// +/// After implementing this trait you can call the [`AsyncSchemaProvider::resolve`] method to get an +/// `Arc` that contains a cached copy of the referenced tables. The `resolve` +/// method can be slow and asynchronous as it is only called once, before planning. +/// +/// See the [remote_catalog.rs] for an end to end example +/// +/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +#[async_trait] +pub trait AsyncSchemaProvider: Send + Sync { + /// Lookup a table in the schema provider + async fn table(&self, name: &str) -> Result>>; + /// Creates a cached provider that can be used to execute a query containing given references + /// + /// This method will walk through the references and look them up once, creating a cache of table + /// providers. This cache will be returned as a synchronous TableProvider that can be used to plan + /// and execute a query containing the given references. + /// + /// This cache is intended to be short-lived for the execution of a single query. There is no mechanism + /// for refresh or eviction of stale entries. + /// + /// See the [`AsyncSchemaProvider`] documentation for additional details + async fn resolve( + &self, + references: &[TableReference], + config: &SessionConfig, + catalog_name: &str, + schema_name: &str, + ) -> Result> { + let mut cached_tables = HashMap::>>::new(); + + for reference in references { + let ref_catalog_name = reference + .catalog() + .unwrap_or(&config.options().catalog.default_catalog); + + // Maybe this is a reference to some other catalog provided in another way + if ref_catalog_name != catalog_name { + continue; + } + + let ref_schema_name = reference + .schema() + .unwrap_or(&config.options().catalog.default_schema); + + if ref_schema_name != schema_name { + continue; + } + + if !cached_tables.contains_key(reference.table()) { + let resolved_table = self.table(reference.table()).await?; + cached_tables.insert(reference.table().to_string(), resolved_table); + } + } + + let cached_tables = cached_tables + .into_iter() + .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value))) + .collect(); + + Ok(Arc::new(ResolvedSchemaProvider { + cached_tables, + owner_name: Some(catalog_name.to_string()), + })) + } +} + +/// A trait for catalog providers that must resolve schemas asynchronously +/// +/// The [`CatalogProvider::schema`] method is synchronous because asynchronous operations should +/// not be used during planning. This trait makes it easy to lookup schema references once and cache +/// them for future planning use. See [`AsyncSchemaProvider`] for more details on motivation. + +#[async_trait] +pub trait AsyncCatalogProvider: Send + Sync { + /// Lookup a schema in the provider + async fn schema(&self, name: &str) -> Result>>; + + /// Creates a cached provider that can be used to execute a query containing given references + /// + /// This method will walk through the references and look them up once, creating a cache of schema + /// providers (each with their own cache of table providers). This cache will be returned as a + /// synchronous CatalogProvider that can be used to plan and execute a query containing the given + /// references. + /// + /// This cache is intended to be short-lived for the execution of a single query. There is no mechanism + /// for refresh or eviction of stale entries. + async fn resolve( + &self, + references: &[TableReference], + config: &SessionConfig, + catalog_name: &str, + ) -> Result> { + let mut cached_schemas = + HashMap::>::new(); + + for reference in references { + let ref_catalog_name = reference + .catalog() + .unwrap_or(&config.options().catalog.default_catalog); + + // Maybe this is a reference to some other catalog provided in another way + if ref_catalog_name != catalog_name { + continue; + } + + let schema_name = reference + .schema() + .unwrap_or(&config.options().catalog.default_schema); + + let schema = if let Some(schema) = cached_schemas.get_mut(schema_name) { + schema + } else { + let resolved_schema = self.schema(schema_name).await?; + let resolved_schema = resolved_schema.map(|resolved_schema| { + ResolvedSchemaProviderBuilder::new( + catalog_name.to_string(), + resolved_schema, + ) + }); + cached_schemas.insert(schema_name.to_string(), resolved_schema); + cached_schemas.get_mut(schema_name).unwrap() + }; + + // If we can't find the catalog don't bother checking the table + let Some(schema) = schema else { continue }; + + schema.resolve_table(reference.table()).await?; + } + + let cached_schemas = cached_schemas + .into_iter() + .filter_map(|(key, maybe_builder)| { + maybe_builder.map(|schema_builder| (key, schema_builder.finish())) + }) + .collect::>(); + + Ok(Arc::new(ResolvedCatalogProvider { cached_schemas })) + } +} + +/// A trait for catalog provider lists that must resolve catalogs asynchronously +/// +/// The [`CatalogProviderList::catalog`] method is synchronous because asynchronous operations should +/// not be used during planning. This trait makes it easy to lookup catalog references once and cache +/// them for future planning use. See [`AsyncSchemaProvider`] for more details on motivation. +#[async_trait] +pub trait AsyncCatalogProviderList: Send + Sync { + /// Lookup a catalog in the provider + async fn catalog(&self, name: &str) -> Result>>; + + /// Creates a cached provider that can be used to execute a query containing given references + /// + /// This method will walk through the references and look them up once, creating a cache of catalog + /// providers, schema providers, and table providers. This cache will be returned as a + /// synchronous CatalogProvider that can be used to plan and execute a query containing the given + /// references. + /// + /// This cache is intended to be short-lived for the execution of a single query. There is no mechanism + /// for refresh or eviction of stale entries. + async fn resolve( + &self, + references: &[TableReference], + config: &SessionConfig, + ) -> Result> { + let mut cached_catalogs = + HashMap::>::new(); + + for reference in references { + let catalog_name = reference + .catalog() + .unwrap_or(&config.options().catalog.default_catalog); + + // We will do three lookups here, one for the catalog, one for the schema, and one for the table + // We cache the result (both found results and not-found results) to speed up future lookups + // + // Note that a cache-miss is not an error at this point. We allow for the possibility that + // other providers may supply the reference. + // + // If this is the only provider then a not-found error will be raised during planning when it can't + // find the reference in the cache. + + let catalog = if let Some(catalog) = cached_catalogs.get_mut(catalog_name) { + catalog + } else { + let resolved_catalog = self.catalog(catalog_name).await?; + let resolved_catalog = + resolved_catalog.map(ResolvedCatalogProviderBuilder::new); + cached_catalogs.insert(catalog_name.to_string(), resolved_catalog); + cached_catalogs.get_mut(catalog_name).unwrap() + }; + + // If we can't find the catalog don't bother checking the schema / table + let Some(catalog) = catalog else { continue }; + + let schema_name = reference + .schema() + .unwrap_or(&config.options().catalog.default_schema); + + let schema = if let Some(schema) = catalog.cached_schemas.get_mut(schema_name) + { + schema + } else { + let resolved_schema = catalog.async_provider.schema(schema_name).await?; + let resolved_schema = resolved_schema.map(|async_schema| { + ResolvedSchemaProviderBuilder::new( + catalog_name.to_string(), + async_schema, + ) + }); + catalog + .cached_schemas + .insert(schema_name.to_string(), resolved_schema); + catalog.cached_schemas.get_mut(schema_name).unwrap() + }; + + // If we can't find the catalog don't bother checking the table + let Some(schema) = schema else { continue }; + + schema.resolve_table(reference.table()).await?; + } + + // Build the cached catalog provider list + let cached_catalogs = cached_catalogs + .into_iter() + .filter_map(|(key, maybe_builder)| { + maybe_builder.map(|catalog_builder| (key, catalog_builder.finish())) + }) + .collect::>(); + + Ok(Arc::new(ResolvedCatalogProviderList { cached_catalogs })) + } +} + +#[cfg(test)] +mod tests { + use std::{ + any::Any, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + }; + + use arrow_schema::SchemaRef; + use async_trait::async_trait; + use datafusion_common::{error::Result, Statistics, TableReference}; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::{Expr, TableType}; + use datafusion_physical_plan::ExecutionPlan; + + use crate::{Session, TableProvider}; + + use super::{AsyncCatalogProvider, AsyncCatalogProviderList, AsyncSchemaProvider}; + + #[derive(Debug)] + struct MockTableProvider {} + #[async_trait] + impl TableProvider for MockTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + /// Get a reference to the schema for this table + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + fn table_type(&self) -> TableType { + unimplemented!() + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!() + } + + fn statistics(&self) -> Option { + unimplemented!() + } + } + + #[derive(Default)] + struct MockAsyncSchemaProvider { + lookup_count: AtomicU32, + } + + const MOCK_CATALOG: &str = "mock_catalog"; + const MOCK_SCHEMA: &str = "mock_schema"; + const MOCK_TABLE: &str = "mock_table"; + + #[async_trait] + impl AsyncSchemaProvider for MockAsyncSchemaProvider { + async fn table(&self, name: &str) -> Result>> { + self.lookup_count.fetch_add(1, Ordering::Release); + if name == MOCK_TABLE { + Ok(Some(Arc::new(MockTableProvider {}))) + } else { + Ok(None) + } + } + } + + fn test_config() -> SessionConfig { + let mut config = SessionConfig::default(); + config.options_mut().catalog.default_catalog = MOCK_CATALOG.to_string(); + config.options_mut().catalog.default_schema = MOCK_SCHEMA.to_string(); + config + } + + #[tokio::test] + async fn test_async_schema_provider_resolve() { + async fn check( + refs: Vec, + expected_lookup_count: u32, + found_tables: &[&str], + not_found_tables: &[&str], + ) { + let async_provider = MockAsyncSchemaProvider::default(); + let cached_provider = async_provider + .resolve(&refs, &test_config(), MOCK_CATALOG, MOCK_SCHEMA) + .await + .unwrap(); + + assert_eq!( + async_provider.lookup_count.load(Ordering::Acquire), + expected_lookup_count + ); + + for table_ref in found_tables { + let table = cached_provider.table(table_ref).await.unwrap(); + assert!(table.is_some()); + } + + for table_ref in not_found_tables { + assert!(cached_provider.table(table_ref).await.unwrap().is_none()); + } + } + + // Basic full lookups + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"), + ], + 2, + &[MOCK_TABLE], + &["not_exists"], + ) + .await; + + // Catalog / schema mismatch doesn't even search + check( + vec![ + TableReference::full(MOCK_CATALOG, "foo", MOCK_TABLE), + TableReference::full("foo", MOCK_SCHEMA, MOCK_TABLE), + ], + 0, + &[], + &[MOCK_TABLE], + ) + .await; + + // Both hits and misses cached + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"), + ], + 2, + &[MOCK_TABLE], + &["not_exists"], + ) + .await; + } + + #[derive(Default)] + struct MockAsyncCatalogProvider { + lookup_count: AtomicU32, + } + + #[async_trait] + impl AsyncCatalogProvider for MockAsyncCatalogProvider { + async fn schema( + &self, + name: &str, + ) -> Result>> { + self.lookup_count.fetch_add(1, Ordering::Release); + if name == MOCK_SCHEMA { + Ok(Some(Arc::new(MockAsyncSchemaProvider::default()))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_async_catalog_provider_resolve() { + async fn check( + refs: Vec, + expected_lookup_count: u32, + found_schemas: &[&str], + not_found_schemas: &[&str], + ) { + let async_provider = MockAsyncCatalogProvider::default(); + let cached_provider = async_provider + .resolve(&refs, &test_config(), MOCK_CATALOG) + .await + .unwrap(); + + assert_eq!( + async_provider.lookup_count.load(Ordering::Acquire), + expected_lookup_count + ); + + for schema_ref in found_schemas { + let schema = cached_provider.schema(schema_ref); + assert!(schema.is_some()); + } + + for schema_ref in not_found_schemas { + assert!(cached_provider.schema(schema_ref).is_none()); + } + } + + // Basic full lookups + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"), + TableReference::full(MOCK_CATALOG, "not_exists", "x"), + ], + 2, + &[MOCK_SCHEMA], + &["not_exists"], + ) + .await; + + // Catalog mismatch doesn't even search + check( + vec![TableReference::full("foo", MOCK_SCHEMA, "x")], + 0, + &[], + &[MOCK_SCHEMA], + ) + .await; + + // Both hits and misses cached + check( + vec![ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"), + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"), + TableReference::full(MOCK_CATALOG, "not_exists", "x"), + TableReference::full(MOCK_CATALOG, "not_exists", "x"), + ], + 2, + &[MOCK_SCHEMA], + &["not_exists"], + ) + .await; + } + + #[derive(Default)] + struct MockAsyncCatalogProviderList { + lookup_count: AtomicU32, + } + + #[async_trait] + impl AsyncCatalogProviderList for MockAsyncCatalogProviderList { + async fn catalog( + &self, + name: &str, + ) -> Result>> { + self.lookup_count.fetch_add(1, Ordering::Release); + if name == MOCK_CATALOG { + Ok(Some(Arc::new(MockAsyncCatalogProvider::default()))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_async_catalog_provider_list_resolve() { + async fn check( + refs: Vec, + expected_lookup_count: u32, + found_catalogs: &[&str], + not_found_catalogs: &[&str], + ) { + let async_provider = MockAsyncCatalogProviderList::default(); + let cached_provider = + async_provider.resolve(&refs, &test_config()).await.unwrap(); + + assert_eq!( + async_provider.lookup_count.load(Ordering::Acquire), + expected_lookup_count + ); + + for catalog_ref in found_catalogs { + let catalog = cached_provider.catalog(catalog_ref); + assert!(catalog.is_some()); + } + + for catalog_ref in not_found_catalogs { + assert!(cached_provider.catalog(catalog_ref).is_none()); + } + } + + // Basic full lookups + check( + vec![ + TableReference::full(MOCK_CATALOG, "x", "x"), + TableReference::full("not_exists", "x", "x"), + ], + 2, + &[MOCK_CATALOG], + &["not_exists"], + ) + .await; + + // Both hits and misses cached + check( + vec![ + TableReference::full(MOCK_CATALOG, "x", "x"), + TableReference::full(MOCK_CATALOG, "x", "x"), + TableReference::full("not_exists", "x", "x"), + TableReference::full("not_exists", "x", "x"), + ], + 2, + &[MOCK_CATALOG], + &["not_exists"], + ) + .await; + } + + #[tokio::test] + async fn test_defaults() { + for table_ref in &[ + TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE), + TableReference::partial(MOCK_SCHEMA, MOCK_TABLE), + TableReference::bare(MOCK_TABLE), + ] { + let async_provider = MockAsyncCatalogProviderList::default(); + let cached_provider = async_provider + .resolve(&[table_ref.clone()], &test_config()) + .await + .unwrap(); + + let catalog = cached_provider + .catalog(table_ref.catalog().unwrap_or(MOCK_CATALOG)) + .unwrap(); + let schema = catalog + .schema(table_ref.schema().unwrap_or(MOCK_SCHEMA)) + .unwrap(); + assert!(schema.table(table_ref.table()).await.unwrap().is_some()); + } + } +} diff --git a/datafusion/catalog/src/catalog.rs b/datafusion/catalog/src/catalog.rs index 048a7f14ed378..71b9eccf9d657 100644 --- a/datafusion/catalog/src/catalog.rs +++ b/datafusion/catalog/src/catalog.rs @@ -52,12 +52,16 @@ use datafusion_common::Result; /// /// # Implementing "Remote" catalogs /// +/// See [`remote_catalog`] for an end to end example of how to implement a +/// remote catalog. +/// /// Sometimes catalog information is stored remotely and requires a network call /// to retrieve. For example, the [Delta Lake] table format stores table /// metadata in files on S3 that must be first downloaded to discover what /// schemas and tables exist. /// /// [Delta Lake]: https://delta.io/ +/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs /// /// The [`CatalogProvider`] can support this use case, but it takes some care. /// The planning APIs in DataFusion are not `async` and thus network IO can not @@ -72,15 +76,15 @@ use datafusion_common::Result; /// batch access to the remote catalog to retrieve multiple schemas and tables /// in a single network call. /// -/// Note that [`SchemaProvider::table`] is an `async` function in order to +/// Note that [`SchemaProvider::table`] **is** an `async` function in order to /// simplify implementing simple [`SchemaProvider`]s. For many table formats it /// is easy to list all available tables but there is additional non trivial /// access required to read table details (e.g. statistics). /// /// The pattern that DataFusion itself uses to plan SQL queries is to walk over -/// the query to find all table references, -/// performing required remote catalog in parallel, and then plans the query -/// using that snapshot. +/// the query to find all table references, performing required remote catalog +/// lookups in parallel, storing the results in a cached snapshot, and then plans +/// the query using that snapshot. /// /// # Example Catalog Implementations /// @@ -101,7 +105,6 @@ use datafusion_common::Result; /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// /// [`TableProvider`]: crate::TableProvider - pub trait CatalogProvider: Debug + Sync + Send { /// Returns the catalog provider as [`Any`] /// so that it can be downcast to a specific implementation. @@ -151,7 +154,7 @@ pub trait CatalogProvider: Debug + Sync + Send { /// Represent a list of named [`CatalogProvider`]s. /// -/// Please see the documentation on `CatalogProvider` for details of +/// Please see the documentation on [`CatalogProvider`] for details of /// implementing a custom catalog. pub trait CatalogProviderList: Debug + Sync + Send { /// Returns the catalog list as [`Any`] diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 21630f267d2c7..3cf2a3b3cd332 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod r#async; mod catalog; mod dynamic_file; mod schema; @@ -23,6 +24,7 @@ mod table; pub use catalog::*; pub use dynamic_file::catalog::*; +pub use r#async::*; pub use schema::*; pub use session::*; pub use table::*; diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index ca3a2bef882e2..3c89604955882 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -25,13 +25,27 @@ use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_common::Result; use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_expr::Expr; + use datafusion_expr::dml::InsertOp; use datafusion_expr::{ - CreateExternalTable, Expr, LogicalPlan, TableProviderFilterPushDown, TableType, + CreateExternalTable, LogicalPlan, TableProviderFilterPushDown, TableType, }; use datafusion_physical_plan::ExecutionPlan; -/// Source table +/// A named table which can be queried. +/// +/// Please see [`CatalogProvider`] for details of implementing a custom catalog. +/// +/// [`TableProvider`] represents a source of data which can provide data as +/// Apache Arrow `RecordBatch`es. Implementations of this trait provide +/// important information for planning such as: +/// +/// 1. [`Self::schema`]: The schema (columns and their types) of the table +/// 2. [`Self::supports_filters_pushdown`]: Should filters be pushed into this scan +/// 2. [`Self::scan`]: An [`ExecutionPlan`] that can read data +/// +/// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { /// Returns the table provider as [`Any`](std::any::Any) so that it can be @@ -247,6 +261,9 @@ pub trait TableProvider: Debug + Sync + Send { } /// Get statistics for this table, if available + /// Although not presently used in mainline DataFusion, this allows implementation specific + /// behavior for downstream repositories, in conjunction with specialized optimizer rules to + /// perform operations such as re-ordering of joins. fn statistics(&self) -> Option { None } @@ -294,3 +311,40 @@ pub trait TableProviderFactory: Debug + Sync + Send { cmd: &CreateExternalTable, ) -> Result>; } + +/// A trait for table function implementations +pub trait TableFunctionImpl: Debug + Sync + Send { + /// Create a table provider + fn call(&self, args: &[Expr]) -> Result>; +} + +/// A table that uses a function to generate data +#[derive(Debug)] +pub struct TableFunction { + /// Name of the table function + name: String, + /// Function implementation + fun: Arc, +} + +impl TableFunction { + /// Create a new table function + pub fn new(name: String, fun: Arc) -> Self { + Self { name, fun } + } + + /// Get the name of the table function + pub fn name(&self) -> &str { + &self.name + } + + /// Get the implementation of the table function + pub fn function(&self) -> &Arc { + &self.fun + } + + /// Get the function implementation and generate a table + pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + self.fun.call(args) + } +} diff --git a/datafusion/common-runtime/LICENSE.txt b/datafusion/common-runtime/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/common-runtime/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/common-runtime/NOTICE.txt b/datafusion/common-runtime/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/common-runtime/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index 8145bb110464e..51cb988ea06a3 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 1ac27b40c2194..fe6d652be700a 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -40,10 +40,11 @@ avro = ["apache-avro"] backtrace = [] pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] force_hash_collisions = [] +recursive_protection = ["dep:recursive"] [dependencies] ahash = { workspace = true } -apache-avro = { version = "0.16", default-features = false, features = [ +apache-avro = { version = "0.17", default-features = false, features = [ "bzip", "snappy", "xz", @@ -52,21 +53,25 @@ apache-avro = { version = "0.16", default-features = false, features = [ arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ipc = { workspace = true } arrow-schema = { workspace = true } -chrono = { workspace = true } +base64 = "0.22.1" half = { workspace = true } hashbrown = { workspace = true } +indexmap = { workspace = true } libc = "0.2.140" -num_cpus = { workspace = true } +log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" -pyo3 = { version = "0.22.0", optional = true } +pyo3 = { version = "0.23.3", optional = true } +recursive = { workspace = true, optional = true } sqlparser = { workspace = true } tokio = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] -instant = { version = "0.1", features = ["wasm-bindgen"] } +web-time = "1.1.0" [dev-dependencies] +chrono = { workspace = true } rand = { workspace = true } diff --git a/datafusion/common/LICENSE.txt b/datafusion/common/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/common/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/common/NOTICE.txt b/datafusion/common/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/common/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 0586fcf5e2ae9..bb9d809965710 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -36,33 +36,66 @@ use arrow::{ }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; -use arrow_array::{BinaryViewArray, StringViewArray}; +use arrow_array::{ + BinaryViewArray, Float16Array, Int16Array, Int8Array, LargeBinaryArray, + LargeStringArray, StringViewArray, UInt16Array, +}; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { Ok(downcast_value!(array, Date32Array)) } +// Downcast ArrayRef to Date64Array +pub fn as_date64_array(array: &dyn Array) -> Result<&Date64Array> { + Ok(downcast_value!(array, Date64Array)) +} + // Downcast ArrayRef to StructArray pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray> { Ok(downcast_value!(array, StructArray)) } +// Downcast ArrayRef to Int8Array +pub fn as_int8_array(array: &dyn Array) -> Result<&Int8Array> { + Ok(downcast_value!(array, Int8Array)) +} + // Downcast ArrayRef to UInt8Array pub fn as_uint8_array(array: &dyn Array) -> Result<&UInt8Array> { Ok(downcast_value!(array, UInt8Array)) } +// Downcast ArrayRef to Int16Array +pub fn as_int16_array(array: &dyn Array) -> Result<&Int16Array> { + Ok(downcast_value!(array, Int16Array)) +} + +// Downcast ArrayRef to UInt16Array +pub fn as_uint16_array(array: &dyn Array) -> Result<&UInt16Array> { + Ok(downcast_value!(array, UInt16Array)) +} + // Downcast ArrayRef to Int32Array pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array> { Ok(downcast_value!(array, Int32Array)) } +// Downcast ArrayRef to UInt32Array +pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> { + Ok(downcast_value!(array, UInt32Array)) +} + // Downcast ArrayRef to Int64Array pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array> { Ok(downcast_value!(array, Int64Array)) } +// Downcast ArrayRef to UInt64Array +pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> { + Ok(downcast_value!(array, UInt64Array)) +} + // Downcast ArrayRef to Decimal128Array pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) @@ -73,6 +106,11 @@ pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> { Ok(downcast_value!(array, Decimal256Array)) } +// Downcast ArrayRef to Float16Array +pub fn as_float16_array(array: &dyn Array) -> Result<&Float16Array> { + Ok(downcast_value!(array, Float16Array)) +} + // Downcast ArrayRef to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> { Ok(downcast_value!(array, Float32Array)) @@ -93,14 +131,9 @@ pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> { Ok(downcast_value!(array, StringViewArray)) } -// Downcast ArrayRef to UInt32Array -pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> { - Ok(downcast_value!(array, UInt32Array)) -} - -// Downcast ArrayRef to UInt64Array -pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> { - Ok(downcast_value!(array, UInt64Array)) +// Downcast ArrayRef to LargeStringArray +pub fn as_large_string_array(array: &dyn Array) -> Result<&LargeStringArray> { + Ok(downcast_value!(array, LargeStringArray)) } // Downcast ArrayRef to BooleanArray @@ -232,6 +265,11 @@ pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> { Ok(downcast_value!(array, BinaryViewArray)) } +// Downcast ArrayRef to LargeBinaryArray +pub fn as_large_binary_array(array: &dyn Array) -> Result<&LargeBinaryArray> { + Ok(downcast_value!(array, LargeBinaryArray)) +} + // Downcast ArrayRef to FixedSizeListArray pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> { Ok(downcast_value!(array, FixedSizeListArray)) @@ -242,11 +280,6 @@ pub fn as_fixed_size_binary_array(array: &dyn Array) -> Result<&FixedSizeBinaryA Ok(downcast_value!(array, FixedSizeBinaryArray)) } -// Downcast ArrayRef to Date64Array -pub fn as_date64_array(array: &dyn Array) -> Result<&Date64Array> { - Ok(downcast_value!(array, Date64Array)) -} - // Downcast ArrayRef to GenericBinaryArray pub fn as_generic_string_array( array: &dyn Array, diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index d855198fa7c6b..fdde3d69eddba 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -21,7 +21,7 @@ use arrow_schema::{Field, FieldRef}; use crate::error::_schema_err; use crate::utils::{parse_identifiers_normalized, quote_identifier}; -use crate::{DFSchema, DataFusionError, Result, SchemaError, TableReference}; +use crate::{DFSchema, Result, SchemaError, TableReference}; use std::collections::HashSet; use std::convert::Infallible; use std::fmt; @@ -109,21 +109,23 @@ impl Column { /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` pub fn from_qualified_name(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(&mut parse_identifiers_normalized(&flat_name, false)) - .unwrap_or_else(|| Self { + Self::from_idents(&mut parse_identifiers_normalized(&flat_name, false)).unwrap_or( + Self { relation: None, name: flat_name, - }) + }, + ) } /// Deserialize a fully qualified name string into a column preserving column text case pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(&mut parse_identifiers_normalized(&flat_name, true)) - .unwrap_or_else(|| Self { + Self::from_idents(&mut parse_identifiers_normalized(&flat_name, true)).unwrap_or( + Self { relation: None, name: flat_name, - }) + }, + ) } /// return the column's name. @@ -228,7 +230,7 @@ impl Column { .collect::>(); for using_col in using_columns { let all_matched = columns.iter().all(|c| using_col.contains(c)); - // All matched fields belong to the same using column set, in orther words + // All matched fields belong to the same using column set, in other words // the same join clause. We simply pick the qualifier from the first match. if all_matched { return Ok(columns[0].clone()); @@ -366,7 +368,8 @@ mod tests { &[], ) .expect_err("should've failed to find field"); - let expected = r#"Schema error: No field named z. Valid fields are t1.a, t1.b, t2.c, t2.d, t3.a, t3.b, t3.c, t3.d, t3.e."#; + let expected = "Schema error: No field named z. \ + Valid fields are t1.a, t1.b, t2.c, t2.d, t3.a, t3.b, t3.c, t3.d, t3.e."; assert_eq!(err.strip_backtrace(), expected); // ambiguous column reference diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1e1c5d5424b08..33a90017bd7e8 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -19,16 +19,20 @@ use std::any::Any; use std::collections::{BTreeMap, HashMap}; +use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; +use crate::utils::get_available_parallelism; use crate::{DataFusionError, Result}; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used -/// in the [`ConfigOptions`] configuration tree +/// in the [`ConfigOptions`] configuration tree. +/// +/// `transform` is used to normalize values before parsing. /// /// For example, /// @@ -37,7 +41,7 @@ use crate::{DataFusionError, Result}; /// /// Amazing config /// pub struct MyConfig { /// /// Field 1 doc -/// field1: String, default = "".to_string() +/// field1: String, transform = str::to_lowercase, default = "".to_string() /// /// /// Field 2 doc /// field2: usize, default = 232 @@ -66,9 +70,12 @@ use crate::{DataFusionError, Result}; /// fn set(&mut self, key: &str, value: &str) -> Result<()> { /// let (key, rem) = key.split_once('.').unwrap_or((key, "")); /// match key { -/// "field1" => self.field1.set(rem, value), -/// "field2" => self.field2.set(rem, value), -/// "field3" => self.field3.set(rem, value), +/// "field1" => { +/// let value = str::to_lowercase(value); +/// self.field1.set(rem, value.as_ref()) +/// }, +/// "field2" => self.field2.set(rem, value.as_ref()), +/// "field3" => self.field3.set(rem, value.as_ref()), /// _ => _internal_err!( /// "Config value \"{}\" not found on MyConfig", /// key @@ -101,7 +108,6 @@ use crate::{DataFusionError, Result}; /// ``` /// /// NB: Misplaced commas may result in nonsensical errors -/// #[macro_export] macro_rules! config_namespace { ( @@ -109,7 +115,7 @@ macro_rules! config_namespace { $vis:vis struct $struct_name:ident { $( $(#[doc = $d:tt])* - $field_vis:vis $field_name:ident : $field_type:ty, default = $default:expr + $field_vis:vis $field_name:ident : $field_type:ty, $(warn = $warn: expr,)? $(transform = $transform:expr,)? default = $default:expr )*$(,)* } ) => { @@ -126,9 +132,14 @@ macro_rules! config_namespace { impl ConfigField for $struct_name { fn set(&mut self, key: &str, value: &str) -> Result<()> { let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { $( - stringify!($field_name) => self.$field_name.set(rem, value), + stringify!($field_name) => { + $(let value = $transform(value);)? + $(log::warn!($warn);)? + self.$field_name.set(rem, value.as_ref()) + }, )* _ => return _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) @@ -210,12 +221,15 @@ config_namespace! { /// When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) pub enable_ident_normalization: bool, default = true - /// When set to true, SQL parser will normalize options value (convert value to lowercase) - pub enable_options_value_normalization: bool, default = true + /// When set to true, SQL parser will normalize options value (convert value to lowercase). + /// Note that this option is ignored and will be removed in the future. All case-insensitive values + /// are normalized automatically. + pub enable_options_value_normalization: bool, warn = "`enable_options_value_normalization` is deprecated and ignored", default = false /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. pub dialect: String, default = "generic".to_string() + // no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but /// ignore the length. If false, error if a `VARCHAR` with a length is @@ -250,7 +264,7 @@ config_namespace! { /// concurrency. /// /// Defaults to the number of CPU cores on the system - pub target_partitions: usize, default = num_cpus::get() + pub target_partitions: usize, default = get_available_parallelism() /// The default time zone /// @@ -266,7 +280,18 @@ config_namespace! { /// This is mostly use to plan `UNION` children in parallel. /// /// Defaults to the number of CPU cores on the system - pub planning_concurrency: usize, default = num_cpus::get() + pub planning_concurrency: usize, default = get_available_parallelism() + + /// When set to true, skips verifying that the schema produced by + /// planning the input of `LogicalPlan::Aggregate` exactly matches the + /// schema of the input plan. + /// + /// When set to false, if the schema does not match exactly + /// (including nullability and metadata), a planning error will be raised. + /// + /// This is used to workaround bugs in the planner that are now caught by + /// the new schema verification step. + pub skip_physical_aggregate_schema_check: bool, default = false /// Specifies the reserved memory for each spillable sort operation to /// facilitate an in-memory merge. @@ -338,6 +363,12 @@ config_namespace! { /// if the source of statistics is accurate. /// We plan to make this the default in the future. pub use_row_number_estimates_to_optimize_partitioning: bool, default = false + + /// Should DataFusion enforce batch size in joins or not. By default, + /// DataFusion will not enforce batch size in joins. Enforcing batch size + /// in joins can reduce memory usage when joining large + /// tables with a highly-selective join filter, but is also slightly slower. + pub enforce_batch_size_in_joins: bool, default = false } } @@ -382,7 +413,15 @@ config_namespace! { /// (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, /// and `Binary/BinaryLarge` with `BinaryView`. - pub schema_force_view_types: bool, default = false + pub schema_force_view_types: bool, default = true + + /// (reading) If true, parquet reader will read columns of + /// `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. + /// + /// Parquet files generated by some legacy writers do not correctly set + /// the UTF8 flag for strings, causing string columns to be loaded as + /// BLOB instead. + pub binary_as_string: bool, default = false // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties @@ -397,6 +436,12 @@ config_namespace! { /// valid values are "1.0" and "2.0" pub writer_version: String, default = "1.0".to_string() + /// (writing) Skip encoding the embedded arrow metadata in the KV_meta + /// + /// This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. + /// Refer to + pub skip_arrow_metadata: bool, default = false + /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. @@ -405,7 +450,7 @@ config_namespace! { /// /// Note that this default setting is not the same as /// the default parquet writer setting. - pub compression: Option, default = Some("zstd(3)".into()) + pub compression: Option, transform = str::to_lowercase, default = Some("zstd(3)".into()) /// (writing) Sets if dictionary encoding is enabled. If NULL, uses /// default parquet writer setting @@ -418,7 +463,7 @@ config_namespace! { /// Valid values are: "none", "chunk", and "page" /// These values are not case sensitive. If NULL, uses /// default parquet writer setting - pub statistics_enabled: Option, default = Some("page".into()) + pub statistics_enabled: Option, transform = str::to_lowercase, default = Some("page".into()) /// (writing) Sets max statistics size for any column. If NULL, uses /// default parquet writer setting @@ -444,7 +489,7 @@ config_namespace! { /// delta_byte_array, rle_dictionary, and byte_stream_split. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting - pub encoding: Option, default = None + pub encoding: Option, transform = str::to_lowercase, default = None /// (writing) Use any available bloom filters when reading parquet files pub bloom_filter_on_read: bool, default = true @@ -850,8 +895,48 @@ impl ConfigOptions { } } -/// [`ConfigExtension`] provides a mechanism to store third-party configuration within DataFusion +/// [`ConfigExtension`] provides a mechanism to store third-party configuration +/// within DataFusion [`ConfigOptions`] +/// +/// This mechanism can be used to pass configuration to user defined functions +/// or optimizer passes +/// +/// # Example +/// ``` +/// use datafusion_common::{ +/// config::ConfigExtension, extensions_options, +/// config::ConfigOptions, +/// }; +/// // Define a new configuration struct using the `extensions_options` macro +/// extensions_options! { +/// /// My own config options. +/// pub struct MyConfig { +/// /// Should "foo" be replaced by "bar"? +/// pub foo_to_bar: bool, default = true +/// +/// /// How many "baz" should be created? +/// pub baz_count: usize, default = 1337 +/// } +/// } +/// +/// impl ConfigExtension for MyConfig { +/// const PREFIX: &'static str = "my_config"; +/// } +/// +/// // set up config struct and register extension +/// let mut config = ConfigOptions::default(); +/// config.extensions.insert(MyConfig::default()); +/// +/// // overwrite config default +/// config.set("my_config.baz_count", "42").unwrap(); /// +/// // check config state +/// let my_config = config.extensions.get::().unwrap(); +/// assert!(my_config.foo_to_bar,); +/// assert_eq!(my_config.baz_count, 42,); +/// ``` +/// +/// # Note: /// Unfortunately associated constants are not currently object-safe, and so this /// extends the object-safe [`ExtensionOptions`] pub trait ConfigExtension: ExtensionOptions { @@ -861,16 +946,18 @@ pub trait ConfigExtension: ExtensionOptions { const PREFIX: &'static str; } -/// An object-safe API for storing arbitrary configuration -pub trait ExtensionOptions: Send + Sync + std::fmt::Debug + 'static { +/// An object-safe API for storing arbitrary configuration. +/// +/// See [`ConfigExtension`] for user defined configuration +pub trait ExtensionOptions: Send + Sync + fmt::Debug + 'static { /// Return `self` as [`Any`] /// - /// This is needed until trait upcasting is stabilised + /// This is needed until trait upcasting is stabilized fn as_any(&self) -> &dyn Any; /// Return `self` as [`Any`] /// - /// This is needed until trait upcasting is stabilised + /// This is needed until trait upcasting is stabilized fn as_any_mut(&mut self) -> &mut dyn Any; /// Return a deep clone of this [`ExtensionOptions`] @@ -945,21 +1032,37 @@ impl ConfigField for Option { } } +fn default_transform(input: &str) -> Result +where + T: FromStr, + ::Err: Sync + Send + Error + 'static, +{ + input.parse().map_err(|e| { + DataFusionError::Context( + format!( + "Error parsing '{}' as {}", + input, + std::any::type_name::() + ), + Box::new(DataFusionError::External(Box::new(e))), + ) + }) +} + #[macro_export] macro_rules! config_field { ($t:ty) => { + config_field!($t, value => default_transform(value)?); + }; + + ($t:ty, $arg:ident => $transform:expr) => { impl ConfigField for $t { fn visit(&self, v: &mut V, key: &str, description: &'static str) { v.some(key, self, description) } - fn set(&mut self, _: &str, value: &str) -> Result<()> { - *self = value.parse().map_err(|e| { - DataFusionError::Context( - format!(concat!("Error parsing {} as ", stringify!($t),), value), - Box::new(DataFusionError::External(Box::new(e))), - ) - })?; + fn set(&mut self, _: &str, $arg: &str) -> Result<()> { + *self = $transform; Ok(()) } } @@ -967,7 +1070,7 @@ macro_rules! config_field { } config_field!(String); -config_field!(bool); +config_field!(bool, value => default_transform(value.to_lowercase().as_str())?); config_field!(usize); config_field!(f64); config_field!(u64); @@ -1053,6 +1156,8 @@ pub trait Visit { /// - ``: Default value matching the field type like `42`. /// /// # Example +/// See also a full example on the [`ConfigExtension`] documentation +/// /// ``` /// use datafusion_common::extensions_options; /// @@ -1222,16 +1327,18 @@ impl ConfigField for TableOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); - let Some(format) = &self.current_format else { - return _config_err!("Specify a format for TableOptions"); - }; match key { - "format" => match format { - #[cfg(feature = "parquet")] - ConfigFileType::PARQUET => self.parquet.set(rem, value), - ConfigFileType::CSV => self.csv.set(rem, value), - ConfigFileType::JSON => self.json.set(rem, value), - }, + "format" => { + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; + match format { + #[cfg(feature = "parquet")] + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), + } + } _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } @@ -1258,8 +1365,7 @@ impl TableOptions { /// A new `TableOptions` instance with settings applied from the session config. pub fn default_from_session_config(config: &ConfigOptions) -> Self { let initial = TableOptions::default(); - initial.combine_with_session_config(config); - initial + initial.combine_with_session_config(config) } /// Updates the current `TableOptions` with settings from a given session config. @@ -1271,6 +1377,7 @@ impl TableOptions { /// # Returns /// /// A new `TableOptions` instance with updated settings from the session config. + #[must_use = "this method returns a new instance"] pub fn combine_with_session_config(&self, config: &ConfigOptions) -> Self { let mut clone = self.clone(); clone.parquet.global = config.execution.parquet.clone(); @@ -1439,6 +1546,20 @@ impl TableParquetOptions { pub fn new() -> Self { Self::default() } + + /// Set whether the encoding of the arrow metadata should occur + /// during the writing of parquet. + /// + /// Default is to encode the arrow schema in the file kv_metadata. + pub fn with_skip_arrow_metadata(self, skip: bool) -> Self { + Self { + global: ParquetOptions { + skip_arrow_metadata: skip, + ..self.global + }, + ..self + } + } } impl ConfigField for TableParquetOptions { @@ -1480,7 +1601,7 @@ macro_rules! config_namespace_with_hashmap { $vis:vis struct $struct_name:ident { $( $(#[doc = $d:tt])* - $field_vis:vis $field_name:ident : $field_type:ty, default = $default:expr + $field_vis:vis $field_name:ident : $field_type:ty, $(transform = $transform:expr,)? default = $default:expr )*$(,)* } ) => { @@ -1499,7 +1620,10 @@ macro_rules! config_namespace_with_hashmap { let (key, rem) = key.split_once('.').unwrap_or((key, "")); match key { $( - stringify!($field_name) => self.$field_name.set(rem, value), + stringify!($field_name) => { + $(let value = $transform(value);)? + self.$field_name.set(rem, value.as_ref()) + }, )* _ => _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) @@ -1578,7 +1702,7 @@ config_namespace_with_hashmap! { /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case-sensitive. If NULL, uses /// default parquet options - pub compression: Option, default = None + pub compression: Option, transform = str::to_lowercase, default = None /// Sets if statistics are enabled for the column /// Valid values are: "none", "chunk", and "page" @@ -1621,13 +1745,16 @@ config_namespace! { /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED - pub schema_infer_max_rec: usize, default = 100 + pub schema_infer_max_rec: Option, default = None pub date_format: Option, default = None pub datetime_format: Option, default = None pub timestamp_format: Option, default = None pub timestamp_tz_format: Option, default = None pub time_format: Option, default = None + // The output format for Nulls in the CSV writer. pub null_value: Option, default = None + // The input regex for Nulls when loading CSVs. + pub null_regex: Option, default = None pub comment: Option, default = None } } @@ -1646,7 +1773,7 @@ impl CsvOptions { /// Set a limit in terms of records to scan to infer the schema /// - default to `DEFAULT_SCHEMA_INFER_MAX_RECORD` pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self { - self.schema_infer_max_rec = max_rec; + self.schema_infer_max_rec = Some(max_rec); self } @@ -1746,7 +1873,7 @@ config_namespace! { /// Options controlling JSON format pub struct JsonOptions { pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED - pub schema_infer_max_rec: usize, default = 100 + pub schema_infer_max_rec: Option, default = None } } diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs new file mode 100644 index 0000000000000..674d3386171f8 --- /dev/null +++ b/datafusion/common/src/cse.rs @@ -0,0 +1,912 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with +//! a [`CSEController`], that defines how to eliminate common subtrees from a particular +//! [`TreeNode`] tree. + +use crate::hash_utils::combine_hashes; +use crate::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, +}; +use crate::Result; +use indexmap::IndexMap; +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Hashes the direct content of an [`TreeNode`] without recursing into its children. +/// +/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds +/// a deep hash of a node and its descendants during the bottom-up phase of the first +/// traversal and so avoid computing the hash of the node and then the hash of its +/// descendants separately. +/// +/// If a node doesn't have any children then the value returned by `hash_node()` is +/// similar to '.hash()`, but not necessarily returns the same value. +pub trait HashNode { + fn hash_node(&self, state: &mut H); +} + +impl HashNode for Arc { + fn hash_node(&self, state: &mut H) { + (**self).hash_node(state); + } +} + +/// The `Normalizeable` trait defines a method to determine whether a node can be normalized. +/// +/// Normalization is the process of converting a node into a canonical form that can be used +/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE), +/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal. +pub trait Normalizeable { + fn can_normalize(&self) -> bool; +} + +/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing +/// normalized nodes in optimizations like Common Subexpression Elimination (CSE). +/// +/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization) +/// are considered equal in CSE optimization, even if their original forms differ. +/// +/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their +/// internal representations. +pub trait NormalizeEq: Eq + Normalizeable { + fn normalize_eq(&self, other: &Self) -> bool; +} + +/// Identifier that represents a [`TreeNode`] tree. +/// +/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and +/// "have no collision (as low as possible)" +#[derive(Debug, Eq)] +struct Identifier<'n, N: NormalizeEq> { + // Hash of `node` built up incrementally during the first, visiting traversal. + // Its value is not necessarily equal to default hash of the node. E.g. it is not + // equal to `expr.hash()` if the node is `Expr`. + hash: u64, + node: &'n N, +} + +impl Clone for Identifier<'_, N> { + fn clone(&self) -> Self { + *self + } +} +impl Copy for Identifier<'_, N> {} + +impl Hash for Identifier<'_, N> { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} + +impl PartialEq for Identifier<'_, N> { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.node.normalize_eq(other.node) + } +} + +impl<'n, N> Identifier<'n, N> +where + N: HashNode + NormalizeEq, +{ + fn new(node: &'n N, random_state: &RandomState) -> Self { + let mut hasher = random_state.build_hasher(); + node.hash_node(&mut hasher); + let hash = hasher.finish(); + Self { hash, node } + } + + fn combine(mut self, other: Option) -> Self { + other.map_or(self, |other_id| { + self.hash = combine_hashes(self.hash, other_id.hash); + self + }) + } +} + +/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the +/// preorder index of the nodes. +/// +/// This cache is filled by [`CSEVisitor`] during the first traversal and is +/// used by [`CSERewriter`] during the second traversal. +/// +/// The purpose of this cache is to quickly find the identifier of a node during the +/// second traversal. +/// +/// Elements in this array are added during `f_down` so the indexes represent the preorder +/// index of nodes and thus element 0 belongs to the root of the tree. +/// +/// The elements of the array are tuples that contain: +/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start +/// from 0. +/// - The optional [`Identifier`] of the node. If none the node should not be considered +/// for CSE. +/// +/// # Example +/// An expression tree like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))), +/// (1, Some(Identifier(hash_of("a"), &"a"))), +/// (0, Some(Identifier(hash_of("b"), &"b"))) +/// ] +/// ``` +type IdArray<'n, N> = Vec<(usize, Option>)>; + +#[derive(PartialEq, Eq)] +/// How many times a node is evaluated. A node can be considered common if evaluated +/// surely at least 2 times or surely only once but also conditionally. +enum NodeEvaluation { + SurelyOnce, + ConditionallyAtLeastOnce, + Common, +} + +/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers. +type NodeStats<'n, N> = HashMap, NodeEvaluation>; + +/// A map that contains the common [`TreeNode`]s and their alias by their identifiers, +/// extracted during the second, rewriting traversal. +type CommonNodes<'n, N> = IndexMap, (N, String)>; + +type ChildrenList = (Vec, Vec); + +/// The [`TreeNode`] specific definition of elimination. +pub trait CSEController { + /// The type of the tree nodes. + type Node; + + /// Splits the children to normal and conditionally evaluated ones or returns `None` + /// if all are always evaluated. + fn conditional_children(node: &Self::Node) -> Option>; + + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. + // Validity is propagated up which means no subtree can be eliminated that contains + // an invalid node. + // (E.g. volatile expressions are not valid and subtrees containing such a node can't + // be extracted.) + fn is_valid(node: &Self::Node) -> bool; + + // Returns true if a node should be ignored during CSE. Contrary to validity of a node, + // it is not propagated up. + fn is_ignored(&self, node: &Self::Node) -> bool; + + // Generates a new name for the extracted subtree. + fn generate_alias(&self) -> String; + + // Replaces a node to the generated alias. + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + + // A helper method called on each node during top-down traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node during bottom-up traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common +/// subtrees. +#[derive(Debug)] +pub enum FoundCommonNodes { + /// No common [`TreeNode`]s were found + No { original_nodes_list: Vec> }, + + /// Common [`TreeNode`]s were found + Yes { + /// extracted common [`TreeNode`] + common_nodes: Vec<(N, String)>, + + /// new [`TreeNode`]s with common subtrees replaced + new_nodes_list: Vec>, + + /// original [`TreeNode`]s + original_nodes_list: Vec>, + }, +} + +/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees. +/// +/// An identifier contains information of the [`TreeNode`] itself and its subtrees. +/// This visitor implementation use a stack `visit_stack` to track traversal, which +/// lets us know when a subtree's visiting is finished. When `pre_visit` is called +/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack. +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem` +/// before the first `EnterMark` is considered to be sub-tree of the leaving node. +/// +/// This visitor also records identifier in `id_array`. Makes the following traverse +/// pass can get the identifier of a node without recalculate it. We assign each node +/// in the tree a series number, start from 1, maintained by `series_number`. +/// Series number represents the order we left (`f_up()`) a node. Has the property +/// that child node's series number always smaller than parent's. While `id_array` is +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to +/// get the index of `id_array` for each node. +/// +/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier +/// because they should not be recognized as common subtree. +struct CSEVisitor<'a, 'n, N, C> +where + N: NormalizeEq, + C: CSEController, +{ + /// statistics of [`TreeNode`]s + node_stats: &'a mut NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a mut IdArray<'n, N>, + + /// inner states + visit_stack: Vec>, + + /// preorder index, start from 0. + down_index: usize, + + /// postorder index, start from 0. + up_index: usize, + + /// a [`RandomState`] to generate hashes during the first traversal + random_state: &'a RandomState, + + /// a flag to indicate that common [`TreeNode`]s found + found_common: bool, + + /// if we are in a conditional branch. A conditional branch means that the [`TreeNode`] + /// might not be executed depending on the runtime values of other [`TreeNode`]s, and + /// thus can not be extracted as a common [`TreeNode`]. + conditional: bool, + + controller: &'a C, +} + +/// Record item that used when traversing a [`TreeNode`] tree. +enum VisitRecord<'n, N> +where + N: NormalizeEq, +{ + /// Marks the beginning of [`TreeNode`]. It contains: + /// - The post-order index assigned during the first, visiting traversal. + EnterMark(usize), + + /// Marks an accumulated subtree. It contains: + /// - The accumulated identifier of a subtree. + /// - A accumulated boolean flag if the subtree is valid for CSE. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + NodeItem(Identifier<'n, N>, bool), +} + +impl<'n, N, C> CSEVisitor<'_, 'n, N, C> +where + N: TreeNode + HashNode + NormalizeEq, + C: CSEController, +{ + /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before + /// it. Returns a tuple that contains: + /// - The pre-order index of the [`TreeNode`] we marked. + /// - The accumulated identifier of the children of the marked [`TreeNode`]. + /// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all + /// children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a + /// common [`TreeNode`] from its children POV). + /// (E.g. if any of the children of the marked expression is not valid (e.g. is + /// volatile) then the expression is also not valid, so we can propagate this + /// information up from children to parents via `visit_stack` during the first, + /// visiting traversal and no need to test the expression's validity beforehand with + /// an extra traversal). + fn pop_enter_mark( + &mut self, + can_normalize: bool, + ) -> (usize, Option>, bool) { + let mut node_ids: Vec> = vec![]; + let mut is_valid = true; + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::EnterMark(down_index) => { + if can_normalize { + node_ids.sort_by_key(|i| i.hash); + } + let node_id = node_ids + .into_iter() + .fold(None, |accum, item| Some(item.combine(accum))); + return (down_index, node_id, is_valid); + } + VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => { + node_ids.push(sub_node_id); + is_valid &= sub_node_is_valid; + } + } + } + unreachable!("EnterMark should paired with NodeItem"); + } +} + +impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C> +where + N: TreeNode + HashNode + NormalizeEq, + C: CSEController, +{ + type Node = N; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + self.id_array.push((0, None)); + self.visit_stack + .push(VisitRecord::EnterMark(self.down_index)); + self.down_index += 1; + + // If a node can short-circuit then some of its children might not be executed so + // count the occurrence either normal or conditional. + Ok(if self.conditional { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + TreeNodeRecursion::Continue + } else { + // If we are already in a node that can short-circuit then start new + // traversals on its normal conditional children. + match C::conditional_children(node) { + Some((normal, conditional)) => { + normal + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = true; + conditional + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = false; + + TreeNodeRecursion::Jump + } + + // In case of non-short-circuit node continue the traversal. + _ => TreeNodeRecursion::Continue, + } + }) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + let (down_index, sub_node_id, sub_node_is_valid) = + self.pop_enter_mark(node.can_normalize()); + + let node_id = Identifier::new(node, self.random_state).combine(sub_node_id); + let is_valid = C::is_valid(node) && sub_node_is_valid; + + self.id_array[down_index].0 = self.up_index; + if is_valid && !self.controller.is_ignored(node) { + self.id_array[down_index].1 = Some(node_id); + self.node_stats + .entry(node_id) + .and_modify(|evaluation| { + if *evaluation == NodeEvaluation::SurelyOnce + || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce + && !self.conditional + { + *evaluation = NodeEvaluation::Common; + self.found_common = true; + } + }) + .or_insert_with(|| { + if self.conditional { + NodeEvaluation::ConditionallyAtLeastOnce + } else { + NodeEvaluation::SurelyOnce + } + }); + } + self.visit_stack + .push(VisitRecord::NodeItem(node_id, is_valid)); + self.up_index += 1; + + Ok(TreeNodeRecursion::Continue) + } +} + +/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the +/// corresponding temporary [`TreeNode`], that column contains the evaluate result of +/// replaced [`TreeNode`] tree. +struct CSERewriter<'a, 'n, N, C> +where + N: NormalizeEq, + C: CSEController, +{ + /// statistics of [`TreeNode`]s + node_stats: &'a NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a IdArray<'n, N>, + + /// common [`TreeNode`]s, that are replaced during the second traversal, are collected + /// to this map + common_nodes: &'a mut CommonNodes<'n, N>, + + // preorder index, starts from 0. + down_index: usize, + + controller: &'a mut C, +} + +impl TreeNodeRewriter for CSERewriter<'_, '_, N, C> +where + N: TreeNode + NormalizeEq, + C: CSEController, +{ + type Node = N; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_down(&node); + + let (up_index, node_id) = self.id_array[self.down_index]; + self.down_index += 1; + + // Handle nodes with identifiers only + if let Some(node_id) = node_id { + let evaluation = self.node_stats.get(&node_id).unwrap(); + if *evaluation == NodeEvaluation::Common { + // step index to skip all sub-node (which has smaller series number). + while self.down_index < self.id_array.len() + && self.id_array[self.down_index].0 < up_index + { + self.down_index += 1; + } + + // We *must* replace all original nodes with same `node_id`, not just the first + // node which is inserted into the common_nodes. This is because nodes with the same + // `node_id` are semantically equivalent, but not exactly the same. + // + // For example, `a + 1` and `1 + a` are semantically equivalent but not identical. + // In this case, we should replace the common expression `1 + a` with a new variable + // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by + // `__common_cse_1`. + // + // The final result would be: + // - `__common_cse_1 as a + 1` + // - `__common_cse_1 as 1 + a` + // + // This way, we can efficiently handle semantically equivalent expressions without + // incorrectly treating them as identical. + let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id) + { + self.controller.rewrite(&node, alias) + } else { + let node_alias = self.controller.generate_alias(); + let rewritten = self.controller.rewrite(&node, &node_alias); + self.common_nodes.insert(node_id, (node, node_alias)); + rewritten + }; + + return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_up(&node); + + Ok(Transformed::no(node)) + } +} + +/// The main entry point of Common Subexpression Elimination. +/// +/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular +/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the +/// [`CSE::extract_common_nodes()`] method. +pub struct CSE> { + random_state: RandomState, + phantom_data: PhantomData, + controller: C, +} + +impl CSE +where + N: TreeNode + HashNode + Clone + NormalizeEq, + C: CSEController, +{ + pub fn new(controller: C) -> Self { + Self { + random_state: RandomState::new(), + phantom_data: PhantomData, + controller, + } + } + + /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. + fn node_to_id_array<'n>( + &self, + node: &'n N, + node_stats: &mut NodeStats<'n, N>, + id_array: &mut IdArray<'n, N>, + ) -> Result { + let mut visitor = CSEVisitor { + node_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + random_state: &self.random_state, + found_common: false, + conditional: false, + controller: &self.controller, + }; + node.visit(&mut visitor)?; + + Ok(visitor.found_common) + } + + /// Returns the identifier list for each element in `nodes` and a flag to indicate if + /// rewrite phase of CSE make sense. + /// + /// Returns and array with 1 element for each input node in `nodes` + /// + /// Each element is itself the result of [`CSE::node_to_id_array`] for that node + /// (e.g. the identifiers for each node in the tree) + fn to_arrays<'n>( + &self, + nodes: &'n [N], + node_stats: &mut NodeStats<'n, N>, + ) -> Result<(bool, Vec>)> { + let mut found_common = false; + nodes + .iter() + .map(|n| { + let mut id_array = vec![]; + self.node_to_id_array(n, node_stats, &mut id_array) + .map(|fc| { + found_common |= fc; + + id_array + }) + }) + .collect::>>() + .map(|id_arrays| (found_common, id_arrays)) + } + + /// Replace common subtrees in `node` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`] + fn replace_common_node<'n>( + &mut self, + node: N, + id_array: &IdArray<'n, N>, + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result { + if id_array.is_empty() { + Ok(Transformed::no(node)) + } else { + node.rewrite(&mut CSERewriter { + node_stats, + id_array, + common_nodes, + down_index: 0, + controller: &mut self.controller, + }) + } + .data() + } + + /// Replace common subtrees in `nodes_list` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]. + fn rewrite_nodes_list<'n>( + &mut self, + nodes_list: Vec>, + arrays_list: &[Vec>], + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result>> { + nodes_list + .into_iter() + .zip(arrays_list.iter()) + .map(|(nodes, arrays)| { + nodes + .into_iter() + .zip(arrays.iter()) + .map(|(node, id_array)| { + self.replace_common_node(node, id_array, node_stats, common_nodes) + }) + .collect::>>() + }) + .collect::>>() + } + + /// Extracts common [`TreeNode`]s and rewrites `nodes_list`. + /// + /// Returns [`FoundCommonNodes`] recording the result of the extraction. + pub fn extract_common_nodes( + &mut self, + nodes_list: Vec>, + ) -> Result> { + let mut found_common = false; + let mut node_stats = NodeStats::new(); + + let id_arrays_list = nodes_list + .iter() + .map(|nodes| { + self.to_arrays(nodes, &mut node_stats) + .map(|(fc, id_arrays)| { + found_common |= fc; + + id_arrays + }) + }) + .collect::>>()?; + if found_common { + let mut common_nodes = CommonNodes::new(); + let new_nodes_list = self.rewrite_nodes_list( + // Must clone the list of nodes as Identifiers use references to original + // nodes so we have to keep them intact. + nodes_list.clone(), + &id_arrays_list, + &node_stats, + &mut common_nodes, + )?; + assert!(!common_nodes.is_empty()); + + Ok(FoundCommonNodes::Yes { + common_nodes: common_nodes.into_values().collect(), + new_nodes_list, + original_nodes_list: nodes_list, + }) + } else { + Ok(FoundCommonNodes::No { + original_nodes_list: nodes_list, + }) + } + } +} + +#[cfg(test)] +mod test { + use crate::alias::AliasGenerator; + use crate::cse::{ + CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, + Normalizeable, CSE, + }; + use crate::tree_node::tests::TestTreeNode; + use crate::Result; + use std::collections::HashSet; + use std::hash::{Hash, Hasher}; + + const CSE_PREFIX: &str = "__common_node"; + + #[derive(Clone, Copy)] + pub enum TestTreeNodeMask { + Normal, + NormalAndAggregates, + } + + pub struct TestTreeNodeCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: TestTreeNodeMask, + } + + impl<'a> TestTreeNodeCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self { + Self { + alias_generator, + mask, + } + } + } + + impl CSEController for TestTreeNodeCSEController<'_> { + type Node = TestTreeNode; + + fn conditional_children( + _: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + None + } + + fn is_valid(_node: &Self::Node) -> bool { + true + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + let is_leaf = node.is_leaf(); + let is_aggr = node.data == "avg" || node.data == "sum"; + + match self.mask { + TestTreeNodeMask::Normal => is_leaf || is_aggr, + TestTreeNodeMask::NormalAndAggregates => is_leaf, + } + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) + } + } + + impl HashNode for TestTreeNode { + fn hash_node(&self, state: &mut H) { + self.data.hash(state); + } + } + + impl Normalizeable for TestTreeNode { + fn can_normalize(&self) -> bool { + false + } + } + + impl NormalizeEq for TestTreeNode { + fn normalize_eq(&self, other: &Self) -> bool { + self == other + } + } + + #[test] + fn id_array_visitor() -> Result<()> { + let alias_generator = AliasGenerator::new(); + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::Normal, + )); + + let a_plus_1 = TestTreeNode::new( + vec![ + TestTreeNode::new_leaf("a".to_string()), + TestTreeNode::new_leaf("1".to_string()), + ], + "+".to_string(), + ); + let avg_c = TestTreeNode::new( + vec![TestTreeNode::new_leaf("c".to_string())], + "avg".to_string(), + ); + let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string()); + let sum_a_plus_1_minus_avg_c = + TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string()); + let root = TestTreeNode::new( + vec![ + sum_a_plus_1_minus_avg_c, + TestTreeNode::new_leaf("2".to_string()), + ], + "*".to_string(), + ); + + let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [a_plus_1] = sum_a_plus_1.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + + // skip aggregates + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + // Collect distinct hashes and set them to 0 in `id_array` + fn collect_hashes( + id_array: &mut IdArray<'_, TestTreeNode>, + ) -> HashSet { + id_array + .iter_mut() + .flat_map(|(_, id_option)| { + id_option.as_mut().map(|node_id| { + let hash = node_id.hash; + node_id.hash = 0; + hash + }) + }) + .collect::>() + } + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 3); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + (3, None), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + (5, None), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + // include aggregates + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::NormalAndAggregates, + )); + + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 5); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + ( + 3, + Some(Identifier { + hash: 0, + node: sum_a_plus_1, + }), + ), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + ( + 5, + Some(Identifier { + hash: 0, + node: avg_c, + }), + ), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + Ok(()) + } +} diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 69cdf866cf981..302d515e027ee 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -315,7 +315,6 @@ impl DFSchema { None => self_unqualified_names.contains(field.name().as_str()), }; if !duplicated_field { - // self.inner.fields.push(field.clone()); schema_builder.push(Arc::clone(field)); qualifiers.push(qualifier.cloned()); } @@ -406,33 +405,6 @@ impl DFSchema { } } - /// Check whether the column reference is ambiguous - pub fn check_ambiguous_name( - &self, - qualifier: Option<&TableReference>, - name: &str, - ) -> Result<()> { - let count = self - .iter() - .filter(|(field_q, f)| match (field_q, qualifier) { - (Some(q1), Some(q2)) => q1.resolved_eq(q2) && f.name() == name, - (None, None) => f.name() == name, - _ => false, - }) - .take(2) - .count(); - if count > 1 { - _schema_err!(SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, - }) - } else { - Ok(()) - } - } - /// Find the qualified field with the given name pub fn qualified_field_with_name( &self, @@ -684,9 +656,26 @@ impl DFSchema { (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) - | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) - | (DataType::Map(f1, _), DataType::Map(f2, _)) => { - Self::field_is_logically_equal(f1, f2) + | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { + // Don't compare the names of the technical inner field + // Usually "item" but that's not mandated + Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) + } + (DataType::Map(f1, _), DataType::Map(f2, _)) => { + // Don't compare the names of the technical inner fields + // Usually "entries", "key", "value" but that's not mandated + match (f1.data_type(), f2.data_type()) { + (DataType::Struct(f1_inner), DataType::Struct(f2_inner)) => { + f1_inner.len() == f2_inner.len() + && f1_inner.iter().zip(f2_inner.iter()).all(|(f1, f2)| { + Self::datatype_is_logically_equal( + f1.data_type(), + f2.data_type(), + ) + }) + } + _ => panic!("Map type should have an inner struct field"), + } } (DataType::Struct(fields1), DataType::Struct(fields2)) => { let iter1 = fields1.iter(); @@ -714,7 +703,7 @@ impl DFSchema { /// name and type), ignoring both metadata and nullability. /// /// request to upstream: - fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool { + pub fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool { // check nested fields match (dt1, dt2) { (DataType::Dictionary(k1, v1), DataType::Dictionary(k2, v2)) => { @@ -723,9 +712,26 @@ impl DFSchema { } (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) - | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) - | (DataType::Map(f1, _), DataType::Map(f2, _)) => { - Self::field_is_semantically_equal(f1, f2) + | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { + // Don't compare the names of the technical inner field + // Usually "item" but that's not mandated + Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) + } + (DataType::Map(f1, _), DataType::Map(f2, _)) => { + // Don't compare the names of the technical inner fields + // Usually "entries", "key", "value" but that's not mandated + match (f1.data_type(), f2.data_type()) { + (DataType::Struct(f1_inner), DataType::Struct(f2_inner)) => { + f1_inner.len() == f2_inner.len() + && f1_inner.iter().zip(f2_inner.iter()).all(|(f1, f2)| { + Self::datatype_is_semantically_equal( + f1.data_type(), + f2.data_type(), + ) + }) + } + _ => panic!("Map type should have an inner struct field"), + } } (DataType::Struct(fields1), DataType::Struct(fields2)) => { let iter1 = fields1.iter(); @@ -949,7 +955,7 @@ pub trait ExprSchema: std::fmt::Debug { /// Returns the column's optional metadata. fn metadata(&self, col: &Column) -> Result<&HashMap>; - /// Return the coulmn's datatype and nullability + /// Return the column's datatype and nullability fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; } @@ -1062,10 +1068,12 @@ mod tests { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; // lookup with unqualified name "t1.c0" let err = schema.index_of_column(&col).unwrap_err(); - assert_eq!( - err.strip_backtrace(), - "Schema error: No field named \"t1.c0\". Valid fields are t1.c0, t1.c1." - ); + let expected = "Schema error: No field named \"t1.c0\". \ + Column names are case sensitive. \ + You can use double quotes to refer to the \"\"t1.c0\"\" column \ + or set the datafusion.sql_parser.enable_ident_normalization configuration. \ + Valid fields are t1.c0, t1.c1."; + assert_eq!(err.strip_backtrace(), expected); Ok(()) } @@ -1082,10 +1090,9 @@ mod tests { // lookup with unqualified name "t1.c0" let err = schema.index_of_column(&col).unwrap_err(); - assert_eq!( - err.strip_backtrace(), - "Schema error: No field named \"t1.c0\". Valid fields are t1.\"CapitalColumn\", t1.\"field.with.period\"." - ); + let expected = "Schema error: No field named \"t1.c0\". \ + Valid fields are t1.\"CapitalColumn\", t1.\"field.with.period\"."; + assert_eq!(err.strip_backtrace(), expected); Ok(()) } @@ -1255,12 +1262,14 @@ mod tests { let col = Column::from_qualified_name("t1.c0"); let err = schema.index_of_column(&col).unwrap_err(); - assert_eq!(err.strip_backtrace(), "Schema error: No field named t1.c0."); + let expected = "Schema error: No field named t1.c0."; + assert_eq!(err.strip_backtrace(), expected); // the same check without qualifier let col = Column::from_name("c0"); let err = schema.index_of_column(&col).err().unwrap(); - assert_eq!(err.strip_backtrace(), "Schema error: No field named c0."); + let expected = "Schema error: No field named c0."; + assert_eq!(err.strip_backtrace(), expected); } #[test] @@ -1360,6 +1369,286 @@ mod tests { Ok(()) } + #[test] + fn test_datatype_is_logically_equal() { + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Int8, + &DataType::Int8 + )); + + assert!(!DFSchema::datatype_is_logically_equal( + &DataType::Int8, + &DataType::Int16 + )); + + // Test lists + + // Succeeds if both have the same element type, disregards names and nullability + assert!(DFSchema::datatype_is_logically_equal( + &DataType::List(Field::new_list_field(DataType::Int8, true).into()), + &DataType::List(Field::new("element", DataType::Int8, false).into()) + )); + + // Fails if element type is different + assert!(!DFSchema::datatype_is_logically_equal( + &DataType::List(Field::new_list_field(DataType::Int8, true).into()), + &DataType::List(Field::new_list_field(DataType::Int16, true).into()) + )); + + // Test maps + let map_field = DataType::Map( + Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int8, false), + Field::new("value", DataType::Int8, true), + ])), + true, + ) + .into(), + true, + ); + + // Succeeds if both maps have the same key and value types, disregards names and nullability + assert!(DFSchema::datatype_is_logically_equal( + &map_field, + &DataType::Map( + Field::new( + "pairs", + DataType::Struct(Fields::from(vec![ + Field::new("one", DataType::Int8, false), + Field::new("two", DataType::Int8, false) + ])), + true + ) + .into(), + true + ) + )); + // Fails if value type is different + assert!(!DFSchema::datatype_is_logically_equal( + &map_field, + &DataType::Map( + Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int8, false), + Field::new("value", DataType::Int16, true) + ])), + true + ) + .into(), + true + ) + )); + + // Fails if key type is different + assert!(!DFSchema::datatype_is_logically_equal( + &map_field, + &DataType::Map( + Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int16, false), + Field::new("value", DataType::Int8, true) + ])), + true + ) + .into(), + true + ) + )); + + // Test structs + + let struct_field = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int8, true), + Field::new("b", DataType::Int8, true), + ])); + + // Succeeds if both have same names and datatypes, ignores nullability + assert!(DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int8, false), + Field::new("b", DataType::Int8, true), + ])) + )); + + // Fails if field names are different + assert!(!DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![ + Field::new("x", DataType::Int8, true), + Field::new("y", DataType::Int8, true), + ])) + )); + + // Fails if types are different + assert!(!DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Int8, true), + ])) + )); + + // Fails if more or less fields + assert!(!DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int8, true),])) + )); + } + + #[test] + fn test_datatype_is_logically_equivalent_to_dictionary() { + // Dictionary is logically equal to its value type + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Utf8, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + } + + #[test] + fn test_datatype_is_semantically_equal() { + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Int8, + &DataType::Int8 + )); + + assert!(!DFSchema::datatype_is_semantically_equal( + &DataType::Int8, + &DataType::Int16 + )); + + // Test lists + + // Succeeds if both have the same element type, disregards names and nullability + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::List(Field::new_list_field(DataType::Int8, true).into()), + &DataType::List(Field::new("element", DataType::Int8, false).into()) + )); + + // Fails if element type is different + assert!(!DFSchema::datatype_is_semantically_equal( + &DataType::List(Field::new_list_field(DataType::Int8, true).into()), + &DataType::List(Field::new_list_field(DataType::Int16, true).into()) + )); + + // Test maps + let map_field = DataType::Map( + Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int8, false), + Field::new("value", DataType::Int8, true), + ])), + true, + ) + .into(), + true, + ); + + // Succeeds if both maps have the same key and value types, disregards names and nullability + assert!(DFSchema::datatype_is_semantically_equal( + &map_field, + &DataType::Map( + Field::new( + "pairs", + DataType::Struct(Fields::from(vec![ + Field::new("one", DataType::Int8, false), + Field::new("two", DataType::Int8, false) + ])), + true + ) + .into(), + true + ) + )); + // Fails if value type is different + assert!(!DFSchema::datatype_is_semantically_equal( + &map_field, + &DataType::Map( + Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int8, false), + Field::new("value", DataType::Int16, true) + ])), + true + ) + .into(), + true + ) + )); + + // Fails if key type is different + assert!(!DFSchema::datatype_is_semantically_equal( + &map_field, + &DataType::Map( + Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int16, false), + Field::new("value", DataType::Int8, true) + ])), + true + ) + .into(), + true + ) + )); + + // Test structs + + let struct_field = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int8, true), + Field::new("b", DataType::Int8, true), + ])); + + // Succeeds if both have same names and datatypes, ignores nullability + assert!(DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int8, false), + Field::new("b", DataType::Int8, true), + ])) + )); + + // Fails if field names are different + assert!(!DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![ + Field::new("x", DataType::Int8, true), + Field::new("y", DataType::Int8, true), + ])) + )); + + // Fails if types are different + assert!(!DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Int8, true), + ])) + )); + + // Fails if more or less fields + assert!(!DFSchema::datatype_is_logically_equal( + &struct_field, + &DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int8, true),])) + )); + } + + #[test] + fn test_datatype_is_not_semantically_equivalent_to_dictionary() { + // Dictionary is not semantically equal to its value type + assert!(!DFSchema::datatype_is_semantically_equal( + &DataType::Utf8, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + } + fn test_schema_2() -> Schema { Schema::new(vec![ Field::new("c100", DataType::Boolean, true), diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index c12e7419e4b6b..bad51c45f8ee8 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -62,6 +62,8 @@ pub enum PlanType { FinalPhysicalPlanWithStats, /// The final with schema, fully optimized physical plan which would be executed FinalPhysicalPlanWithSchema, + /// An error creating the physical plan + PhysicalPlanError, } impl Display for PlanType { @@ -91,6 +93,7 @@ impl Display for PlanType { PlanType::FinalPhysicalPlanWithSchema => { write!(f, "physical_plan_with_schema") } + PlanType::PhysicalPlanError => write!(f, "physical_plan_error"), } } } @@ -118,7 +121,9 @@ impl StringifiedPlan { /// `verbose_mode = true` will display all available plans pub fn should_display(&self, verbose_mode: bool) -> bool { match self.plan_type { - PlanType::FinalLogicalPlan | PlanType::FinalPhysicalPlan => true, + PlanType::FinalLogicalPlan + | PlanType::FinalPhysicalPlan + | PlanType::PhysicalPlanError => true, _ => verbose_mode, } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 05988d6c6da4c..f7c247aaf288c 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -115,7 +115,7 @@ pub enum DataFusionError { Execution(String), /// [`JoinError`] during execution of the query. /// - /// This error can unoccur for unjoined tasks, such as execution shutdown. + /// This error can't occur for unjoined tasks, such as execution shutdown. ExecutionJoin(JoinError), /// Error when resources (such as memory of scratch disk space) are exhausted. /// @@ -167,6 +167,18 @@ impl Display for SchemaError { valid_fields, } => { write!(f, "No field named {}", field.quoted_flat_name())?; + let lower_valid_fields = valid_fields + .iter() + .map(|column| column.flat_name().to_lowercase()) + .collect::>(); + if lower_valid_fields.contains(&field.flat_name().to_lowercase()) { + write!( + f, + ". Column names are case sensitive. You can use double quotes to refer to the \"{}\" column \ + or set the datafusion.sql_parser.enable_ident_normalization configuration", + field.quoted_flat_name() + )?; + } if !valid_fields.is_empty() { write!( f, @@ -598,9 +610,9 @@ macro_rules! arrow_err { #[macro_export] macro_rules! schema_datafusion_err { ($ERR:expr) => { - DataFusionError::SchemaError( + $crate::error::DataFusionError::SchemaError( $ERR, - Box::new(Some(DataFusionError::get_back_trace())), + Box::new(Some($crate::error::DataFusionError::get_back_trace())), ) }; } @@ -609,9 +621,9 @@ macro_rules! schema_datafusion_err { #[macro_export] macro_rules! schema_err { ($ERR:expr) => { - Err(DataFusionError::SchemaError( + Err($crate::error::DataFusionError::SchemaError( $ERR, - Box::new(Some(DataFusionError::get_back_trace())), + Box::new(Some($crate::error::DataFusionError::get_back_trace())), )) }; } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 77781457d0d2d..02667e0165717 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -30,7 +30,6 @@ pub mod parquet_writer; mod tests { use std::collections::HashMap; - use super::parquet_writer::ParquetWriterOptions; use crate::{ config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, @@ -40,7 +39,7 @@ mod tests { use parquet::{ basic::{Compression, Encoding, ZstdLevel}, - file::properties::{EnabledStatistics, WriterVersion}, + file::properties::{EnabledStatistics, WriterPropertiesBuilder, WriterVersion}, schema::types::ColumnPath, }; @@ -79,8 +78,10 @@ mod tests { table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; - let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; - let properties = parquet_options.writer_options(); + let properties = WriterPropertiesBuilder::try_from( + &table_config.parquet.with_skip_arrow_metadata(true), + )? + .build(); // Verify the expected options propagated down to parquet crate WriterProperties struct assert_eq!(properties.max_row_group_size(), 123); @@ -184,8 +185,10 @@ mod tests { table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; - let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; - let properties = parquet_options.writer_options(); + let properties = WriterPropertiesBuilder::try_from( + &table_config.parquet.with_skip_arrow_metadata(true), + )? + .build(); let col1 = ColumnPath::from(vec!["col1".to_owned()]); let col2_nested = ColumnPath::from(vec!["col2".to_owned(), "nested".to_owned()]); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 5d553d59da4ec..3f06e11bb3767 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -17,18 +17,26 @@ //! Options related to how parquet files should be written +use base64::Engine; +use std::sync::Arc; + use crate::{ config::{ParquetOptions, TableParquetOptions}, - DataFusionError, Result, + DataFusionError, Result, _internal_datafusion_err, }; +use arrow_schema::Schema; +#[allow(deprecated)] use parquet::{ + arrow::ARROW_SCHEMA_META_KEY, basic::{BrotliLevel, GzipLevel, ZstdLevel}, - file::properties::{ - EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, - DEFAULT_MAX_STATISTICS_SIZE, DEFAULT_STATISTICS_ENABLED, + file::{ + metadata::KeyValue, + properties::{ + EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, + DEFAULT_MAX_STATISTICS_SIZE, DEFAULT_STATISTICS_ENABLED, + }, }, - format::KeyValue, schema::types::ColumnPath, }; @@ -51,6 +59,17 @@ impl ParquetWriterOptions { } } +impl TableParquetOptions { + /// Add the arrow schema to the parquet kv_metadata. + /// If already exists, then overwrites. + pub fn arrow_schema(&mut self, schema: &Arc) { + self.key_value_metadata.insert( + ARROW_SCHEMA_META_KEY.into(), + Some(encode_arrow_schema(schema)), + ); + } +} + impl TryFrom<&TableParquetOptions> for ParquetWriterOptions { type Error = DataFusionError; @@ -79,6 +98,14 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { let mut builder = global.into_writer_properties_builder()?; + // check that the arrow schema is present in the kv_metadata, if configured to do so + if !global.skip_arrow_metadata + && !key_value_metadata.contains_key(ARROW_SCHEMA_META_KEY) + { + return Err(_internal_datafusion_err!("arrow schema was not added to the kv_metadata, even though it is required by configuration settings")); + } + + // add kv_meta, if any if !key_value_metadata.is_empty() { builder = builder.set_key_value_metadata(Some( key_value_metadata @@ -131,8 +158,10 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { } if let Some(max_statistics_size) = options.max_statistics_size { - builder = - builder.set_column_max_statistics_size(path, max_statistics_size); + builder = { + #[allow(deprecated)] + builder.set_column_max_statistics_size(path, max_statistics_size) + } } } @@ -140,11 +169,38 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { } } +/// Encodes the Arrow schema into the IPC format, and base64 encodes it +/// +/// TODO: use extern parquet's private method, once publicly available. +/// Refer to +fn encode_arrow_schema(schema: &Arc) -> String { + let options = arrow_ipc::writer::IpcWriteOptions::default(); + let mut dictionary_tracker = arrow_ipc::writer::DictionaryTracker::new(true); + let data_gen = arrow_ipc::writer::IpcDataGenerator::default(); + let mut serialized_schema = data_gen.schema_to_bytes_with_dictionary_tracker( + schema, + &mut dictionary_tracker, + &options, + ); + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.ipc_message.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.append(&mut vec![255u8, 255, 255, 255]); + len_prefix_schema.append((schema_len as u32).to_le_bytes().to_vec().as_mut()); + len_prefix_schema.append(&mut serialized_schema.ipc_message); + + base64::prelude::BASE64_STANDARD.encode(&len_prefix_schema) +} + impl ParquetOptions { /// Convert the global session options, [`ParquetOptions`], into a single write action's [`WriterPropertiesBuilder`]. /// /// The returned [`WriterPropertiesBuilder`] can then be further modified with additional options /// applied per column; a customization which is not applicable for [`ParquetOptions`]. + /// + /// Note that this method does not include the key_value_metadata from [`TableParquetOptions`]. pub fn into_writer_properties_builder(&self) -> Result { let ParquetOptions { data_pagesize_limit, @@ -176,6 +232,8 @@ impl ParquetOptions { maximum_buffered_record_batches_per_stream: _, bloom_filter_on_read: _, // reads not used for writer props schema_force_view_types: _, + binary_as_string: _, // not used for writer props + skip_arrow_metadata: _, } = self; let mut builder = WriterProperties::builder() @@ -189,15 +247,19 @@ impl ParquetOptions { .and_then(|s| parse_statistics_string(s).ok()) .unwrap_or(DEFAULT_STATISTICS_ENABLED), ) - .set_max_statistics_size( - max_statistics_size.unwrap_or(DEFAULT_MAX_STATISTICS_SIZE), - ) .set_max_row_group_size(*max_row_group_size) .set_created_by(created_by.clone()) .set_column_index_truncate_length(*column_index_truncate_length) .set_data_page_row_count_limit(*data_page_row_count_limit) .set_bloom_filter_enabled(*bloom_filter_on_write); + builder = { + #[allow(deprecated)] + builder.set_max_statistics_size( + max_statistics_size.unwrap_or(DEFAULT_MAX_STATISTICS_SIZE), + ) + }; + if let Some(bloom_filter_fpp) = bloom_filter_fpp { builder = builder.set_bloom_filter_fpp(*bloom_filter_fpp); }; @@ -442,6 +504,8 @@ mod tests { .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: defaults.bloom_filter_on_read, schema_force_view_types: defaults.schema_force_view_types, + binary_as_string: defaults.binary_as_string, + skip_arrow_metadata: defaults.skip_arrow_metadata, } } @@ -471,6 +535,7 @@ mod tests { ), bloom_filter_fpp: bloom_filter_default_props.map(|p| p.fpp), bloom_filter_ndv: bloom_filter_default_props.map(|p| p.ndv), + #[allow(deprecated)] max_statistics_size: Some(props.max_statistics_size(&col)), } } @@ -543,19 +608,56 @@ mod tests { .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, schema_force_view_types: global_options_defaults.schema_force_view_types, + binary_as_string: global_options_defaults.binary_as_string, + skip_arrow_metadata: global_options_defaults.skip_arrow_metadata, }, column_specific_options, key_value_metadata, } } + #[test] + fn table_parquet_opts_to_writer_props_skip_arrow_metadata() { + // TableParquetOptions, all props set to default + let mut table_parquet_opts = TableParquetOptions::default(); + assert!( + !table_parquet_opts.global.skip_arrow_metadata, + "default false, to not skip the arrow schema requirement" + ); + + // see errors without the schema added, using default settings + let should_error = WriterPropertiesBuilder::try_from(&table_parquet_opts); + assert!( + should_error.is_err(), + "should error without the required arrow schema in kv_metadata", + ); + + // succeeds if we permit skipping the arrow schema + table_parquet_opts = table_parquet_opts.with_skip_arrow_metadata(true); + let should_succeed = WriterPropertiesBuilder::try_from(&table_parquet_opts); + assert!( + should_succeed.is_ok(), + "should work with the arrow schema skipped by config", + ); + + // Set the arrow schema back to required + table_parquet_opts = table_parquet_opts.with_skip_arrow_metadata(false); + // add the arrow schema to the kv_meta + table_parquet_opts.arrow_schema(&Arc::new(Schema::empty())); + let should_succeed = WriterPropertiesBuilder::try_from(&table_parquet_opts); + assert!( + should_succeed.is_ok(), + "should work with the arrow schema included in TableParquetOptions", + ); + } + #[test] fn table_parquet_opts_to_writer_props() { // ParquetOptions, all props set to non-default let parquet_options = parquet_options_with_non_defaults(); // TableParquetOptions, using ParquetOptions for global settings - let key = "foo".to_string(); + let key = ARROW_SCHEMA_META_KEY.to_string(); let value = Some("bar".into()); let table_parquet_opts = TableParquetOptions { global: parquet_options.clone(), @@ -582,7 +684,7 @@ mod tests { #[test] fn test_defaults_match() { // ensure the global settings are the same - let default_table_writer_opts = TableParquetOptions::default(); + let mut default_table_writer_opts = TableParquetOptions::default(); let default_parquet_opts = ParquetOptions::default(); assert_eq!( default_table_writer_opts.global, @@ -590,6 +692,10 @@ mod tests { "should have matching defaults for TableParquetOptions.global and ParquetOptions", ); + // selectively skip the arrow_schema metadata, since the WriterProperties default has an empty kv_meta (no arrow schema) + default_table_writer_opts = + default_table_writer_opts.with_skip_arrow_metadata(true); + // WriterProperties::default, a.k.a. using extern parquet's defaults let default_writer_props = WriterProperties::new(); @@ -637,6 +743,7 @@ mod tests { session_config_from_writer_props(&default_writer_props); from_extern_parquet.global.created_by = same_created_by; from_extern_parquet.global.compression = Some("zstd(3)".into()); + from_extern_parquet.global.skip_arrow_metadata = true; assert_eq!( default_table_writer_opts, @@ -650,6 +757,7 @@ mod tests { // the TableParquetOptions::default, with only the bloom filter turned on let mut default_table_writer_opts = TableParquetOptions::default(); default_table_writer_opts.global.bloom_filter_on_write = true; + default_table_writer_opts.arrow_schema(&Arc::new(Schema::empty())); // add the required arrow schema let from_datafusion_defaults = WriterPropertiesBuilder::try_from(&default_table_writer_opts) .unwrap() @@ -678,6 +786,7 @@ mod tests { let mut default_table_writer_opts = TableParquetOptions::default(); default_table_writer_opts.global.bloom_filter_on_write = true; default_table_writer_opts.global.bloom_filter_fpp = Some(0.42); + default_table_writer_opts.arrow_schema(&Arc::new(Schema::empty())); // add the required arrow schema let from_datafusion_defaults = WriterPropertiesBuilder::try_from(&default_table_writer_opts) .unwrap() @@ -710,6 +819,7 @@ mod tests { let mut default_table_writer_opts = TableParquetOptions::default(); default_table_writer_opts.global.bloom_filter_on_write = true; default_table_writer_opts.global.bloom_filter_ndv = Some(42); + default_table_writer_opts.arrow_schema(&Arc::new(Schema::empty())); // add the required arrow schema let from_datafusion_defaults = WriterPropertiesBuilder::try_from(&default_table_writer_opts) .unwrap() diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index 484a7f2388f56..23cfb72314a3c 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -27,3 +27,7 @@ pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { safe: false, format_options: DEFAULT_FORMAT_OPTIONS, }; + +pub const DEFAULT_CLI_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new() + .with_duration_format(DurationFormat::Pretty) + .with_null("NULL"); diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 90f4e6e7e3d1e..5f262d634af37 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -18,16 +18,12 @@ //! FunctionalDependencies keeps track of functional dependencies //! inside DFSchema. -use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::vec::IntoIter; -use crate::error::_plan_err; use crate::utils::{merge_and_order_indices, set_difference}; -use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; - -use sqlparser::ast::TableConstraint; +use crate::{DFSchema, HashSet, JoinType}; /// This object defines a constraint on a table. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -60,77 +56,44 @@ impl Constraints { Self { inner: constraints } } - /// Convert each `TableConstraint` to corresponding `Constraint` - pub fn new_from_table_constraints( - constraints: &[TableConstraint], - df_schema: &DFSchemaRef, - ) -> Result { - let constraints = constraints + /// Check whether constraints is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Projects constraints using the given projection indices. + /// Returns None if any of the constraint columns are not included in the projection. + pub fn project(&self, proj_indices: &[usize]) -> Option { + let projected = self + .inner .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let field_names = df_schema.field_names(); - // Get unique constraint indices in the schema: - let indices = columns - .iter() - .map(|u| { - let idx = field_names - .iter() - .position(|item| *item == u.value) - .ok_or_else(|| { - let name = name - .as_ref() - .map(|name| format!("with name '{name}' ")) - .unwrap_or("".to_string()); - DataFusionError::Execution( - format!("Column for unique constraint {}not found in schema: {}", name,u.value) - ) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::Unique(indices)) - } - TableConstraint::PrimaryKey { columns, .. } => { - let field_names = df_schema.field_names(); - // Get primary key indices in the schema: - let indices = columns - .iter() - .map(|pk| { - let idx = field_names - .iter() - .position(|item| *item == pk.value) - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Column for primary key not found in schema: {}", - pk.value - )) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::PrimaryKey(indices)) - } - TableConstraint::ForeignKey { .. } => { - _plan_err!("Foreign key constraints are not currently supported") - } - TableConstraint::Check { .. } => { - _plan_err!("Check constraints are not currently supported") - } - TableConstraint::Index { .. } => { - _plan_err!("Indexes are not currently supported") - } - TableConstraint::FulltextOrSpatial { .. } => { - _plan_err!("Indexes are not currently supported") + .filter_map(|constraint| { + match constraint { + Constraint::PrimaryKey(indices) => { + let new_indices = + update_elements_with_matching_indices(indices, proj_indices); + // Only keep constraint if all columns are preserved + (new_indices.len() == indices.len()) + .then_some(Constraint::PrimaryKey(new_indices)) + } + Constraint::Unique(indices) => { + let new_indices = + update_elements_with_matching_indices(indices, proj_indices); + // Only keep constraint if all columns are preserved + (new_indices.len() == indices.len()) + .then_some(Constraint::Unique(new_indices)) + } } }) - .collect::>>()?; - Ok(Constraints::new_unverified(constraints)) + .collect::>(); + + (!projected.is_empty()).then_some(Constraints::new_unverified(projected)) } +} - /// Check whether constraints is empty - pub fn is_empty(&self) -> bool { - self.inner.is_empty() +impl Default for Constraints { + fn default() -> Self { + Constraints::empty() } } @@ -145,13 +108,13 @@ impl IntoIterator for Constraints { impl Display for Constraints { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let pk: Vec = self.inner.iter().map(|c| format!("{:?}", c)).collect(); + let pk = self + .inner + .iter() + .map(|c| format!("{:?}", c)) + .collect::>(); let pk = pk.join(", "); - if !pk.is_empty() { - write!(f, " constraints=[{pk}]") - } else { - write!(f, "") - } + write!(f, "constraints=[{pk}]") } } @@ -405,7 +368,7 @@ impl FunctionalDependencies { left_func_dependencies.extend(right_func_dependencies); left_func_dependencies } - JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // These joins preserve functional dependencies of the left side: left_func_dependencies } @@ -671,6 +634,24 @@ mod tests { assert_eq!(iter.next(), None); } + #[test] + fn test_project_constraints() { + let constraints = Constraints::new_unverified(vec![ + Constraint::PrimaryKey(vec![1, 2]), + Constraint::Unique(vec![0, 3]), + ]); + + // Project keeping columns 1,2,3 + let projected = constraints.project(&[1, 2, 3]).unwrap(); + assert_eq!( + projected, + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0, 1])]) + ); + + // Project keeping only column 0 - should return None as no constraints are preserved + assert!(constraints.project(&[0]).is_none()); + } + #[test] fn test_get_updated_id_keys() { let fund_dependencies = diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 72cfeafd0bfec..0d1d93acf1fce 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -32,7 +32,7 @@ use arrow_buffer::IntervalMonthDayNano; use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_primitive_array, as_string_array, as_string_view_array, as_struct_array, + as_string_array, as_string_view_array, as_struct_array, }; use crate::error::Result; #[cfg(not(feature = "force_hash_collisions"))] @@ -63,7 +63,7 @@ pub trait HashValue { fn hash_one(&self, state: &RandomState) -> u64; } -impl<'a, T: HashValue + ?Sized> HashValue for &'a T { +impl HashValue for &T { fn hash_one(&self, state: &RandomState) -> u64 { T::hash_one(self, state) } @@ -102,8 +102,7 @@ fn hash_array_primitive( hashes_buffer: &mut [u64], rehash: bool, ) where - T: ArrowPrimitiveType, - ::Native: HashValue, + T: ArrowPrimitiveType, { assert_eq!( hashes_buffer.len(), @@ -322,8 +321,7 @@ fn hash_fixed_list_array( hashes_buffer: &mut [u64], ) -> Result<()> { let values = Arc::clone(array.values()); - let value_len = array.value_length(); - let offset_size = value_len as usize / array.len(); + let value_length = array.value_length() as usize; let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; create_hashes(&[values], random_state, &mut values_hashes)?; @@ -331,7 +329,8 @@ fn hash_fixed_list_array( for i in 0..array.len() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] + for values_hash in + &values_hashes[i * value_length..(i + 1) * value_length] { *hash = combine_hashes(*hash, *values_hash); } @@ -340,7 +339,7 @@ fn hash_fixed_list_array( } else { for i in 0..array.len() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] { + for values_hash in &values_hashes[i * value_length..(i + 1) * value_length] { *hash = combine_hashes(*hash, *values_hash); } } @@ -393,14 +392,6 @@ pub fn create_hashes<'a>( let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); hash_array(array, random_state, hashes_buffer, rehash) } - DataType::Decimal128(_, _) => { - let array = as_primitive_array::(array)?; - hash_array_primitive(array, random_state, hashes_buffer, rehash) - } - DataType::Decimal256(_, _) => { - let array = as_primitive_array::(array)?; - hash_array_primitive(array, random_state, hashes_buffer, rehash) - } DataType::Dictionary(_, _) => downcast_dictionary_array! { array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() @@ -463,6 +454,16 @@ mod tests { Ok(()) } + #[test] + fn create_hashes_for_empty_fixed_size_lit() -> Result<()> { + let empty_array = FixedSizeListBuilder::new(StringBuilder::new(), 1).finish(); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; 0]; + let hashes = create_hashes(&[Arc::new(empty_array)], &random_state, hashes_buff)?; + assert_eq!(hashes, &Vec::::new()); + Ok(()) + } + #[test] fn create_hashes_for_float_arrays() -> Result<()> { let f32_arr = Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); diff --git a/datafusion/common/src/instant.rs b/datafusion/common/src/instant.rs index 6401bc29c9426..42f21c061c0c2 100644 --- a/datafusion/common/src/instant.rs +++ b/datafusion/common/src/instant.rs @@ -18,9 +18,9 @@ //! WASM-compatible `Instant` wrapper. #[cfg(target_family = "wasm")] -/// DataFusion wrapper around [`std::time::Instant`]. Uses [`instant::Instant`] +/// DataFusion wrapper around [`std::time::Instant`]. Uses [`web_time::Instant`] /// under `wasm` feature gate. It provides the same API as [`std::time::Instant`]. -pub type Instant = instant::Instant; +pub type Instant = web_time::Instant; #[allow(clippy::disallowed_types)] #[cfg(not(target_family = "wasm"))] diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index fbdae1c50a83e..ac81d977b7296 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -28,28 +28,85 @@ use crate::{DataFusionError, Result}; /// Join type #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] pub enum JoinType { - /// Inner Join + /// Inner Join - Returns only rows where there is a matching value in both tables based on the join condition. + /// For example, if joining table A and B on A.id = B.id, only rows where A.id equals B.id will be included. + /// All columns from both tables are returned for the matching rows. Non-matching rows are excluded entirely. Inner, - /// Left Join + /// Left Join - Returns all rows from the left table and matching rows from the right table. + /// If no match, NULL values are returned for columns from the right table. Left, - /// Right Join + /// Right Join - Returns all rows from the right table and matching rows from the left table. + /// If no match, NULL values are returned for columns from the left table. Right, - /// Full Join + /// Full Join (also called Full Outer Join) - Returns all rows from both tables, matching rows where possible. + /// When a row from either table has no match in the other table, the missing columns are filled with NULL values. + /// For example, if table A has row X with no match in table B, the result will contain row X with NULL values for all of table B's columns. + /// This join type preserves all records from both tables, making it useful when you need to see all data regardless of matches. Full, - /// Left Semi Join + /// Left Semi Join - Returns rows from the left table that have matching rows in the right table. + /// Only columns from the left table are returned. LeftSemi, - /// Right Semi Join + /// Right Semi Join - Returns rows from the right table that have matching rows in the left table. + /// Only columns from the right table are returned. RightSemi, - /// Left Anti Join + /// Left Anti Join - Returns rows from the left table that do not have a matching row in the right table. LeftAnti, - /// Right Anti Join + /// Right Anti Join - Returns rows from the right table that do not have a matching row in the left table. RightAnti, + /// Left Mark join + /// + /// Returns one record for each record from the left input. The output contains an additional + /// column "mark" which is true if there is at least one match in the right input where the + /// join condition evaluates to true. Otherwise, the mark column is false. For more details see + /// [1]. This join type is used to decorrelate EXISTS subqueries used inside disjunctive + /// predicates. + /// + /// Note: This we currently do not implement the full null semantics for the mark join described + /// in [1] which will be needed if we and ANY subqueries. In our version the mark column will + /// only be true for had a match and false when no match was found, never null. + /// + /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf + LeftMark, } impl JoinType { pub fn is_outer(self) -> bool { self == JoinType::Left || self == JoinType::Right || self == JoinType::Full } + + /// Returns the `JoinType` if the (2) inputs were swapped + /// + /// Panics if [`Self::supports_swap`] returns false + pub fn swap(&self) -> JoinType { + match self { + JoinType::Inner => JoinType::Inner, + JoinType::Full => JoinType::Full, + JoinType::Left => JoinType::Right, + JoinType::Right => JoinType::Left, + JoinType::LeftSemi => JoinType::RightSemi, + JoinType::RightSemi => JoinType::LeftSemi, + JoinType::LeftAnti => JoinType::RightAnti, + JoinType::RightAnti => JoinType::LeftAnti, + JoinType::LeftMark => { + unreachable!("LeftMark join type does not support swapping") + } + } + } + + /// Does the join type support swapping inputs? + pub fn supports_swap(&self) -> bool { + matches!( + self, + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) + } } impl Display for JoinType { @@ -63,6 +120,7 @@ impl Display for JoinType { JoinType::RightSemi => "RightSemi", JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", + JoinType::LeftMark => "LeftMark", }; write!(f, "{join_type}") } @@ -82,6 +140,7 @@ impl FromStr for JoinType { "RIGHTSEMI" => Ok(JoinType::RightSemi), "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), + "LEFTMARK" => Ok(JoinType::LeftMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } @@ -97,10 +156,11 @@ pub enum JoinConstraint { } impl Display for JoinSide { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { JoinSide::Left => write!(f, "left"), JoinSide::Right => write!(f, "right"), + JoinSide::None => write!(f, "none"), } } } @@ -113,6 +173,9 @@ pub enum JoinSide { Left, /// Right side of the join Right, + /// Neither side of the join, used for Mark joins where the mark column does not belong to + /// either side of the join + None, } impl JoinSide { @@ -121,6 +184,7 @@ impl JoinSide { match self { JoinSide::Left => JoinSide::Right, JoinSide::Right => JoinSide::Left, + JoinSide::None => JoinSide::None, } } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index b8ba1ed4e8cb7..77e8cd60ede24 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -31,6 +32,7 @@ mod unnest; pub mod alias; pub mod cast; pub mod config; +pub mod cse; pub mod display; pub mod error; pub mod file_options; @@ -65,13 +67,14 @@ pub use functional_dependencies::{ get_target_functional_dependencies, Constraint, Constraints, Dependency, FunctionalDependence, FunctionalDependencies, }; +use hashbrown::hash_map::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; pub use stats::{ColumnStatistics, Statistics}; pub use table_reference::{ResolvedTableReference, TableReference}; -pub use unnest::UnnestOptions; +pub use unnest::{RecursionUnnestOption, UnnestOptions}; pub use utils::project_schema; // These are hidden from docs purely to avoid polluting the public view of what this crate exports. @@ -86,6 +89,10 @@ pub use error::{ _substrait_datafusion_err, }; +// The HashMap and HashSet implementations that should be used as the uniform defaults +pub type HashMap = hashbrown::HashMap; +pub type HashSet = hashbrown::HashSet; + /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index e23edb4e2adb7..c73c8a55f18c5 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -18,7 +18,6 @@ //! Interval parsing logic use std::fmt::Display; -use std::result; use std::str::FromStr; use sqlparser::parser::ParserError; @@ -41,7 +40,7 @@ pub enum CompressionTypeVariant { impl FromStr for CompressionTypeVariant { type Err = ParserError; - fn from_str(s: &str) -> result::Result { + fn from_str(s: &str) -> Result { let s = s.to_uppercase(); match s.as_str() { "GZIP" | "GZ" => Ok(Self::GZIP), diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index 87254a499fb11..60dde7861104a 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -23,7 +23,7 @@ use arrow_array::Array; use pyo3::exceptions::PyException; use pyo3::prelude::PyErr; use pyo3::types::{PyAnyMethods, PyList}; -use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; +use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python}; use crate::{DataFusionError, ScalarValue}; @@ -34,14 +34,14 @@ impl From for PyErr { } impl FromPyArrow for ScalarValue { - fn from_pyarrow_bound(value: &pyo3::Bound<'_, pyo3::PyAny>) -> PyResult { + fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); let typ = value.getattr("type")?; let val = value.call_method0("as_py")?; // construct pyarrow array from the python value and pyarrow type - let factory = py.import_bound("pyarrow")?.getattr("array")?; - let args = PyList::new_bound(py, [val]); + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, [val])?; let array = factory.call1((args, typ))?; // convert the pyarrow array to rust array using C data interface @@ -69,14 +69,25 @@ impl<'source> FromPyObject<'source> for ScalarValue { } } -impl IntoPy for ScalarValue { - fn into_py(self, py: Python) -> PyObject { - self.to_pyarrow(py).unwrap() +impl<'source> IntoPyObject<'source> for ScalarValue { + type Target = PyAny; + + type Output = Bound<'source, Self::Target>; + + type Error = PyErr; + + fn into_pyobject(self, py: Python<'source>) -> Result { + let array = self.to_array()?; + // convert to pyarrow array using C data interface + let pyarray = array.to_data().to_pyarrow(py)?; + let pyarray_bound = pyarray.bind(py); + pyarray_bound.call_method1("__getitem__", (0,)) } } #[cfg(test)] mod tests { + use pyo3::ffi::c_str; use pyo3::prepare_freethreaded_python; use pyo3::py_run; use pyo3::types::PyDict; @@ -86,10 +97,12 @@ mod tests { fn init_python() { prepare_freethreaded_python(); Python::with_gil(|py| { - if py.run_bound("import pyarrow", None, None).is_err() { - let locals = PyDict::new_bound(py); - py.run_bound( - "import sys; executable = sys.executable; python_path = sys.path", + if py.run(c_str!("import pyarrow"), None, None).is_err() { + let locals = PyDict::new(py); + py.run( + c_str!( + "import sys; executable = sys.executable; python_path = sys.path" + ), None, Some(&locals), ) @@ -135,17 +148,25 @@ mod tests { } #[test] - fn test_py_scalar() { + fn test_py_scalar() -> PyResult<()> { init_python(); - Python::with_gil(|py| { + Python::with_gil(|py| -> PyResult<()> { let scalar_float = ScalarValue::Float64(Some(12.34)); - let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap(); + let py_float = scalar_float + .into_pyobject(py)? + .call_method0("as_py") + .unwrap(); py_run!(py, py_float, "assert py_float == 12.34"); let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); - let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap(); + let py_string = scalar_string + .into_pyobject(py)? + .call_method0("as_py") + .unwrap(); py_run!(py, py_string, "assert py_string == 'Hello!'"); - }); + + Ok(()) + }) } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 530b2dd9bbf5d..87b52cc757278 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -28,6 +28,7 @@ use std::fmt; use std::hash::Hash; use std::hash::Hasher; use std::iter::repeat; +use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; @@ -38,9 +39,7 @@ use crate::cast::{ }; use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; -use crate::utils::{ - array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, -}; +use crate::utils::SingleRowListArrayBuilder; use arrow::compute::kernels::{self, numeric::*}; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; use arrow::{ @@ -58,6 +57,7 @@ use arrow::{ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use arrow_schema::{UnionFields, UnionMode}; +use crate::format::DEFAULT_CAST_OPTIONS; use half::f16; pub use struct_builder::ScalarStructBuilder; @@ -596,8 +596,8 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { let arr1 = first_array_for_list(arr1); let arr2 = first_array_for_list(arr2); - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + let lt_res = kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = kernels::cmp::eq(&arr1, &arr2).ok()?; for j in 0..lt_res.len() { if lt_res.is_valid(j) && lt_res.value(j) { @@ -624,8 +624,8 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option, m2: &Arc) -> Option { let arr1 = m1.entries().column(col_index); let arr2 = m2.entries().column(col_index); - let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?; + let lt_res = kernels::cmp::lt(arr1, arr2).ok()?; + let eq_res = kernels::cmp::eq(arr1, arr2).ok()?; for j in 0..lt_res.len() { if lt_res.is_valid(j) && lt_res.value(j) { @@ -690,8 +690,8 @@ hash_float_value!((f64, u64), (f32, u32)); // # Panics // // Panics if there is an error when creating hash values for rows -impl std::hash::Hash for ScalarValue { - fn hash(&self, state: &mut H) { +impl Hash for ScalarValue { + fn hash(&self, state: &mut H) { use ScalarValue::*; match self { Decimal128(v, p, s) => { @@ -767,7 +767,7 @@ impl std::hash::Hash for ScalarValue { } } -fn hash_nested_array(arr: ArrayRef, state: &mut H) { +fn hash_nested_array(arr: ArrayRef, state: &mut H) { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -801,7 +801,7 @@ fn dict_from_scalar( let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = std::iter::repeat(if value.is_null() { + let key_array: PrimitiveArray = repeat(if value.is_null() { None } else { Some(K::default_value()) @@ -978,6 +978,11 @@ impl ScalarValue { ScalarValue::from(val.into()) } + /// Returns a [`ScalarValue::Utf8View`] representing `val` + pub fn new_utf8view(val: impl Into) -> Self { + ScalarValue::Utf8View(Some(val.into())) + } + /// Returns a [`ScalarValue::IntervalYearMonth`] representing /// `years` years and `months` months pub fn new_interval_ym(years: i32, months: i32) -> Self { @@ -1145,6 +1150,12 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), + DataType::Decimal128(precision, scale) => { + ScalarValue::Decimal128(Some(0), *precision, *scale) + } + DataType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(Some(i256::ZERO), *precision, *scale) + } DataType::Timestamp(TimeUnit::Second, tz) => { ScalarValue::TimestampSecond(Some(0), tz.clone()) } @@ -1157,6 +1168,16 @@ impl ScalarValue { DataType::Timestamp(TimeUnit::Nanosecond, tz) => { ScalarValue::TimestampNanosecond(Some(0), tz.clone()) } + DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(Some(0)), + DataType::Time32(TimeUnit::Millisecond) => { + ScalarValue::Time32Millisecond(Some(0)) + } + DataType::Time64(TimeUnit::Microsecond) => { + ScalarValue::Time64Microsecond(Some(0)) + } + DataType::Time64(TimeUnit::Nanosecond) => { + ScalarValue::Time64Nanosecond(Some(0)) + } DataType::Interval(IntervalUnit::YearMonth) => { ScalarValue::IntervalYearMonth(Some(0)) } @@ -1176,6 +1197,8 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { ScalarValue::DurationNanosecond(Some(0)) } + DataType::Date32 => ScalarValue::Date32(Some(0)), + DataType::Date64 => ScalarValue::Date64(Some(0)), _ => { return _not_impl_err!( "Can't create a zero scalar from data_type \"{datatype:?}\"" @@ -1632,11 +1655,12 @@ impl ScalarValue { /// ``` /// use datafusion_common::ScalarValue; /// use arrow::array::{BooleanArray, Int32Array}; + /// use arrow::compute::kernels; /// /// let arr = Int32Array::from(vec![Some(1), None, Some(10)]); /// let five = ScalarValue::Int32(Some(5)); /// - /// let result = arrow::compute::kernels::cmp::lt( + /// let result = kernels::cmp::lt( /// &arr, /// &five.to_scalar().unwrap(), /// ).unwrap(); @@ -2054,7 +2078,7 @@ impl ScalarValue { scale: i8, size: usize, ) -> Result { - Ok(std::iter::repeat(value) + Ok(repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale)?) @@ -2095,7 +2119,11 @@ impl ScalarValue { } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_list_array(values, nullable)) + Arc::new( + SingleRowListArrayBuilder::new(values) + .with_nullable(nullable) + .build_list_array(), + ) } /// Same as [`ScalarValue::new_list`] but with nullable set to true. @@ -2151,7 +2179,11 @@ impl ScalarValue { } else { Self::iter_to_array(values).unwrap() }; - Arc::new(array_into_list_array(values, nullable)) + Arc::new( + SingleRowListArrayBuilder::new(values) + .with_nullable(nullable) + .build_list_array(), + ) } /// Converts `Vec` where each element has type corresponding to @@ -2188,7 +2220,7 @@ impl ScalarValue { } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_large_list_array(values)) + Arc::new(SingleRowListArrayBuilder::new(values).build_large_list_array()) } pub fn to_array_of_size_and_type( @@ -2207,7 +2239,7 @@ impl ScalarValue { /// /// Errors if `self` is /// - a decimal that fails be converted to a decimal array of size - /// - a `Fixedsizelist` that fails to be concatenated into an array of size + /// - a `FixedsizeList` that fails to be concatenated into an array of size /// - a `List` that fails to be concatenated into an array of size /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { @@ -2453,7 +2485,7 @@ impl ScalarValue { e, size ), - ScalarValue::Union(value, fields, _mode) => match value { + ScalarValue::Union(value, fields, mode) => match value { Some((v_id, value)) => { let mut new_fields = Vec::with_capacity(fields.len()); let mut child_arrays = Vec::::with_capacity(fields.len()); @@ -2462,7 +2494,12 @@ impl ScalarValue { value.to_array_of_size(size)? } else { let dt = field.data_type(); - new_null_array(dt, size) + match mode { + UnionMode::Sparse => new_null_array(dt, size), + // In a dense union, only the child with values needs to be + // allocated + UnionMode::Dense => new_null_array(dt, 0), + } }; let field = (**field).clone(); child_arrays.push(ar); @@ -2470,7 +2507,10 @@ impl ScalarValue { } let type_ids = repeat(*v_id).take(size); let type_ids = ScalarBuffer::::from_iter(type_ids); - let value_offsets: Option> = None; + let value_offsets = match mode { + UnionMode::Sparse => None, + UnionMode::Dense => Some(ScalarBuffer::from_iter(0..size as i32)), + }; let ar = UnionArray::try_new( fields.clone(), type_ids, @@ -2533,7 +2573,7 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr).take(size).collect::>(); + let arrays = repeat(arr).take(size).collect::>(); let ret = match !arrays.is_empty() { true => arrow::compute::concat(arrays.as_slice())?, false => arr.slice(0, 0), @@ -2578,7 +2618,7 @@ impl ScalarValue { /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; - /// use datafusion_common::utils::array_into_list_array_nullable; + /// use datafusion_common::utils::SingleRowListArrayBuilder; /// use std::sync::Arc; /// /// let list_arr = ListArray::from_iter_primitive::(vec![ @@ -2587,7 +2627,7 @@ impl ScalarValue { /// ]); /// /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ] - /// let list_arr = array_into_list_array_nullable(Arc::new(list_arr)); + /// let list_arr = SingleRowListArrayBuilder::new(Arc::new(list_arr)).build_list_array(); /// /// // Convert the array into Scalar Values for each row, we got 1D arrays in this example /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); @@ -2678,29 +2718,27 @@ impl ScalarValue { let list_array = array.as_list::(); let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. - let arr = - Arc::new(array_into_list_array(nested_array, field.is_nullable())); - - ScalarValue::List(arr) + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_list_scalar() } - DataType::LargeList(_) => { + DataType::LargeList(field) => { let list_array = as_large_list_array(array); let nested_array = list_array.value(index); // Produces a single element `LargeListArray` with the value at `index`. - let arr = Arc::new(array_into_large_list_array(nested_array)); - - ScalarValue::LargeList(arr) + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_large_list_scalar() } // TODO: There is no test for FixedSizeList now, add it later - DataType::FixedSizeList(_, _) => { + DataType::FixedSizeList(field, _) => { let list_array = as_fixed_size_list_array(array)?; let nested_array = list_array.value(index); - // Produces a single element `ListArray` with the value at `index`. + // Produces a single element `FixedSizeListArray` with the value at `index`. let list_size = nested_array.len(); - let arr = - Arc::new(array_into_fixed_size_list_array(nested_array, list_size)); - - ScalarValue::FixedSizeList(arr) + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_fixed_size_list_scalar(list_size) } DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, @@ -2831,22 +2869,74 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::from(value); - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), + ScalarValue::from(value).cast_to(target_type) + } + + /// Returns the Some(`&str`) representation of `ScalarValue` of logical string type + /// + /// Returns `None` if this `ScalarValue` is not a logical string type or the + /// `ScalarValue` represents the `NULL` value. + /// + /// Note you can use [`Option::flatten`] to check for non null logical + /// strings. + /// + /// For example, [`ScalarValue::Utf8`], [`ScalarValue::LargeUtf8`], and + /// [`ScalarValue::Dictionary`] with a logical string value and store + /// strings and can be accessed as `&str` using this method. + /// + /// # Example: logical strings + /// ``` + /// # use datafusion_common::ScalarValue; + /// /// non strings return None + /// let scalar = ScalarValue::from(42); + /// assert_eq!(scalar.try_as_str(), None); + /// // Non null logical string returns Some(Some(&str)) + /// let scalar = ScalarValue::from("hello"); + /// assert_eq!(scalar.try_as_str(), Some(Some("hello"))); + /// // Null logical string returns Some(None) + /// let scalar = ScalarValue::Utf8(None); + /// assert_eq!(scalar.try_as_str(), Some(None)); + /// ``` + /// + /// # Example: use [`Option::flatten`] to check for non-null logical strings + /// ``` + /// # use datafusion_common::ScalarValue; + /// // Non null logical string returns Some(Some(&str)) + /// let scalar = ScalarValue::from("hello"); + /// assert_eq!(scalar.try_as_str().flatten(), Some("hello")); + /// ``` + pub fn try_as_str(&self) -> Option> { + let v = match self { + ScalarValue::Utf8(v) => v, + ScalarValue::LargeUtf8(v) => v, + ScalarValue::Utf8View(v) => v, + ScalarValue::Dictionary(_, v) => return v.try_as_str(), + _ => return None, }; - let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; - ScalarValue::try_from_array(&cast_arr, 0) + Some(v.as_ref().map(|v| v.as_str())) } /// Try to cast this value to a ScalarValue of type `data_type` - pub fn cast_to(&self, data_type: &DataType) -> Result { - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), + pub fn cast_to(&self, target_type: &DataType) -> Result { + self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS) + } + + /// Try to cast this value to a ScalarValue of type `data_type` with [`CastOptions`] + pub fn cast_to_with_options( + &self, + target_type: &DataType, + cast_options: &CastOptions<'static>, + ) -> Result { + let scalar_array = match (self, target_type) { + ( + ScalarValue::Float64(Some(float_ts)), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64)) + .to_array()?, + _ => self.to_array()?, }; - let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?; + + let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -2902,7 +2992,7 @@ impl ScalarValue { /// preferred over this function if at all possible as they can be /// vectorized and are generally much faster. /// - /// This function has a few narrow usescases such as hash table key + /// This function has a few narrow use cases such as hash table key /// comparisons where comparing a single row at a time is necessary. /// /// # Errors @@ -3096,7 +3186,7 @@ impl ScalarValue { /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + match self { ScalarValue::Null | ScalarValue::Boolean(_) @@ -3150,12 +3240,12 @@ impl ScalarValue { ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() - .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .map(|(_id, sv)| sv.size() - size_of_val(sv)) .unwrap_or_default() // `fields` is boxed, so it is NOT already included in `self` - + std::mem::size_of_val(fields) - + (std::mem::size_of::() * fields.len()) - + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + + size_of_val(fields) + + (size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - size_of_val(field)).sum::() } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` @@ -3168,11 +3258,11 @@ impl ScalarValue { /// /// Includes the size of the [`Vec`] container itself. pub fn size_of_vec(vec: &Vec) -> usize { - std::mem::size_of_val(vec) - + (std::mem::size_of::() * vec.capacity()) + size_of_val(vec) + + (size_of::() * vec.capacity()) + vec .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -3180,11 +3270,11 @@ impl ScalarValue { /// /// Includes the size of the [`VecDeque`] container itself. pub fn size_of_vec_deque(vec_deque: &VecDeque) -> usize { - std::mem::size_of_val(vec_deque) - + (std::mem::size_of::() * vec_deque.capacity()) + size_of_val(vec_deque) + + (size_of::() * vec_deque.capacity()) + vec_deque .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -3192,11 +3282,11 @@ impl ScalarValue { /// /// Includes the size of the [`HashSet`] container itself. pub fn size_of_hashset(set: &HashSet) -> usize { - std::mem::size_of_val(set) - + (std::mem::size_of::() * set.capacity()) + size_of_val(set) + + (size_of::() * set.capacity()) + set .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } } @@ -3599,9 +3689,8 @@ impl fmt::Display for ScalarValue { columns .iter() .zip(fields.iter()) - .enumerate() - .map(|(index, (column, field))| { - if nulls.is_some_and(|b| b.is_null(index)) { + .map(|(column, field)| { + if nulls.is_some_and(|b| b.is_null(0)) { format!("{}:NULL", field.name()) } else if let DataType::Struct(_) = field.data_type() { let sv = ScalarValue::Struct(Arc::new( @@ -3892,12 +3981,12 @@ mod tests { }; use crate::assert_batches_eq; - use crate::utils::array_into_list_array_nullable; use arrow::buffer::OffsetBuffer; use arrow::compute::{is_null, kernels}; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; - use arrow_buffer::Buffer; + use arrow_array::types::Float64Type; + use arrow_buffer::{Buffer, NullBuffer}; use arrow_schema::Fields; use chrono::NaiveDate; use rand::Rng; @@ -4034,7 +4123,7 @@ mod tests { #[test] fn test_to_array_of_size_for_fsl() { let values = Int32Array::from_iter([Some(1), None, Some(2)]); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let arr = FixedSizeListArray::new(Arc::clone(&field), 3, Arc::new(values), None); let sv = ScalarValue::FixedSizeList(Arc::new(arr)); let actual_arr = sv @@ -4068,14 +4157,15 @@ mod tests { let result = ScalarValue::new_list_nullable(scalars.as_slice(), &DataType::Utf8); - let expected = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ - "rust", - "arrow", - "data-fusion", - ]))); + let expected = single_row_list_array(vec!["rust", "arrow", "data-fusion"]); assert_eq!(*result, expected); } + fn single_row_list_array(items: Vec<&str>) -> ListArray { + SingleRowListArrayBuilder::new(Arc::new(StringArray::from(items))) + .build_list_array() + } + fn build_list( values: Vec>>>, ) -> Vec { @@ -4090,8 +4180,7 @@ mod tests { ) } else if O::IS_LARGE { new_null_array( - &DataType::LargeList(Arc::new(Field::new( - "item", + &DataType::LargeList(Arc::new(Field::new_list_field( DataType::Int64, true, ))), @@ -4099,8 +4188,7 @@ mod tests { ) } else { new_null_array( - &DataType::List(Arc::new(Field::new( - "item", + &DataType::List(Arc::new(Field::new_list_field( DataType::Int64, true, ))), @@ -4119,7 +4207,7 @@ mod tests { #[test] fn test_iter_to_array_fixed_size_list() { - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let f1 = Arc::new(FixedSizeListArray::new( Arc::clone(&field), 3, @@ -4280,12 +4368,8 @@ mod tests { #[test] fn iter_to_array_string_test() { - let arr1 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ - "foo", "bar", "baz", - ]))); - let arr2 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ - "rust", "world", - ]))); + let arr1 = single_row_list_array(vec!["foo", "bar", "baz"]); + let arr2 = single_row_list_array(vec!["rust", "world"]); let scalars = vec![ ScalarValue::List(Arc::new(arr1)), @@ -4448,7 +4532,7 @@ mod tests { Ok(()) } - // Verifies that ScalarValue has the same behavior with compute kernal when it overflows. + // Verifies that ScalarValue has the same behavior with compute kernel when it overflows. fn check_scalar_add_overflow(left: ScalarValue, right: ScalarValue) where T: ArrowNumericType, @@ -4459,7 +4543,7 @@ mod tests { let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); - let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); + let arrow_result = add(arrow_left_array, arrow_right_array); assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); } @@ -4958,7 +5042,7 @@ mod tests { let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap(); let data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); assert_eq!(non_null_list_scalar.data_type(), data_type); assert_eq!(null_list_scalar.data_type(), data_type); @@ -4966,7 +5050,7 @@ mod tests { #[test] fn scalar_try_from_list_datatypes() { - let inner_field = Arc::new(Field::new("item", DataType::Int32, true)); + let inner_field = Arc::new(Field::new_list_field(DataType::Int32, true)); // Test for List let data_type = &DataType::List(Arc::clone(&inner_field)); @@ -5007,9 +5091,8 @@ mod tests { #[test] fn scalar_try_from_list_of_list() { - let data_type = DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let data_type = DataType::List(Arc::new(Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, ))); let data_type = &data_type; @@ -5017,9 +5100,11 @@ mod tests { let expected = ScalarValue::List( new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + &DataType::List(Arc::new(Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field( + DataType::Int32, + true, + ))), true, ))), 1, @@ -5035,13 +5120,12 @@ mod tests { #[test] fn scalar_try_from_not_equal_list_nested_list() { let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); let data_type = &list_data_type; let list_scalar: ScalarValue = data_type.try_into().unwrap(); - let nested_list_data_type = DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let nested_list_data_type = DataType::List(Arc::new(Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, ))); let data_type = &nested_list_data_type; @@ -5074,13 +5158,13 @@ mod tests { // thus the size of the enum appears to as well // The value may also change depending on rust version - assert_eq!(std::mem::size_of::(), 64); + assert_eq!(size_of::(), 64); } #[test] fn memory_size() { let sv = ScalarValue::Binary(Some(Vec::with_capacity(10))); - assert_eq!(sv.size(), std::mem::size_of::() + 10,); + assert_eq!(sv.size(), size_of::() + 10,); let sv_size = sv.size(); let mut v = Vec::with_capacity(10); @@ -5089,9 +5173,7 @@ mod tests { assert_eq!(v.capacity(), 10); assert_eq!( ScalarValue::size_of_vec(&v), - std::mem::size_of::>() - + (9 * std::mem::size_of::()) - + sv_size, + size_of::>() + (9 * size_of::()) + sv_size, ); let mut s = HashSet::with_capacity(0); @@ -5101,8 +5183,8 @@ mod tests { let s_capacity = s.capacity(); assert_eq!( ScalarValue::size_of_hashset(&s), - std::mem::size_of::>() - + ((s_capacity - 1) * std::mem::size_of::()) + size_of::>() + + ((s_capacity - 1) * size_of::()) + sv_size, ); } @@ -5566,6 +5648,194 @@ mod tests { assert_eq!(&array, &expected); } + #[test] + fn round_trip() { + // Each array type should be able to round tripped through a scalar + let cases: Vec = vec![ + // int + Arc::new(Int8Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Int16Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])), + Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])), + Arc::new(UInt16Array::from(vec![Some(1), None, Some(3)])), + Arc::new(UInt32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(UInt64Array::from(vec![Some(1), None, Some(3)])), + // bool + Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])), + // float + Arc::new(Float32Array::from(vec![Some(1.0), None, Some(3.0)])), + Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])), + // string array + Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])), + Arc::new(LargeStringArray::from(vec![Some("foo"), None, Some("bar")])), + Arc::new(StringViewArray::from(vec![Some("foo"), None, Some("bar")])), + // string dictionary + { + let mut builder = StringDictionaryBuilder::::new(); + builder.append("foo").unwrap(); + builder.append_null(); + builder.append("bar").unwrap(); + Arc::new(builder.finish()) + }, + // binary array + Arc::new(BinaryArray::from_iter(vec![ + Some(b"foo"), + None, + Some(b"bar"), + ])), + Arc::new(LargeBinaryArray::from_iter(vec![ + Some(b"foo"), + None, + Some(b"bar"), + ])), + Arc::new(BinaryViewArray::from_iter(vec![ + Some(b"foo"), + None, + Some(b"bar"), + ])), + // timestamp + Arc::new(TimestampSecondArray::from(vec![Some(1), None, Some(3)])), + Arc::new(TimestampMillisecondArray::from(vec![ + Some(1), + None, + Some(3), + ])), + Arc::new(TimestampMicrosecondArray::from(vec![ + Some(1), + None, + Some(3), + ])), + Arc::new(TimestampNanosecondArray::from(vec![Some(1), None, Some(3)])), + // timestamp with timezone + Arc::new( + TimestampSecondArray::from(vec![Some(1), None, Some(3)]) + .with_timezone_opt(Some("UTC")), + ), + Arc::new( + TimestampMillisecondArray::from(vec![Some(1), None, Some(3)]) + .with_timezone_opt(Some("UTC")), + ), + Arc::new( + TimestampMicrosecondArray::from(vec![Some(1), None, Some(3)]) + .with_timezone_opt(Some("UTC")), + ), + Arc::new( + TimestampNanosecondArray::from(vec![Some(1), None, Some(3)]) + .with_timezone_opt(Some("UTC")), + ), + // date + Arc::new(Date32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Date64Array::from(vec![Some(1), None, Some(3)])), + // time + Arc::new(Time32SecondArray::from(vec![Some(1), None, Some(3)])), + Arc::new(Time32MillisecondArray::from(vec![Some(1), None, Some(3)])), + Arc::new(Time64MicrosecondArray::from(vec![Some(1), None, Some(3)])), + Arc::new(Time64NanosecondArray::from(vec![Some(1), None, Some(3)])), + // null array + Arc::new(NullArray::new(3)), + // dense union + { + let mut builder = UnionBuilder::new_dense(); + builder.append::("a", 1).unwrap(); + builder.append::("b", 3.4).unwrap(); + Arc::new(builder.build().unwrap()) + }, + // sparse union + { + let mut builder = UnionBuilder::new_sparse(); + builder.append::("a", 1).unwrap(); + builder.append::("b", 3.4).unwrap(); + Arc::new(builder.build().unwrap()) + }, + // list array + { + let values_builder = StringBuilder::new(); + let mut builder = ListBuilder::new(values_builder); + // [A, B] + builder.values().append_value("A"); + builder.values().append_value("B"); + builder.append(true); + // [ ] (empty list) + builder.append(true); + // Null + builder.values().append_value("?"); // irrelevant + builder.append(false); + Arc::new(builder.finish()) + }, + // large list array + { + let values_builder = StringBuilder::new(); + let mut builder = LargeListBuilder::new(values_builder); + // [A, B] + builder.values().append_value("A"); + builder.values().append_value("B"); + builder.append(true); + // [ ] (empty list) + builder.append(true); + // Null + builder.append(false); + Arc::new(builder.finish()) + }, + // fixed size list array + { + let values_builder = Int32Builder::new(); + let mut builder = FixedSizeListBuilder::new(values_builder, 3); + + // [[0, 1, 2], null, [3, null, 5] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + Arc::new(builder.finish()) + }, + // map + { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + + let mut builder = MapBuilder::new(None, string_builder, int_builder); + // {"joe": 1} + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + // {} + builder.append(true).unwrap(); + // null + builder.append(false).unwrap(); + + Arc::new(builder.finish()) + }, + ]; + + for arr in cases { + round_trip_through_scalar(arr); + } + } + + /// for each row in `arr`: + /// 1. convert to a `ScalarValue` + /// 2. Convert `ScalarValue` back to an `ArrayRef` + /// 3. Compare the original array (sliced) and new array for equality + fn round_trip_through_scalar(arr: ArrayRef) { + for i in 0..arr.len() { + // convert Scalar --> Array + let scalar = ScalarValue::try_from_array(&arr, i).unwrap(); + let array = scalar.to_array_of_size(1).unwrap(); + assert_eq!(array.len(), 1); + assert_eq!(array.data_type(), arr.data_type()); + assert_eq!(array.as_ref(), arr.slice(i, 1).as_ref()); + } + } + #[test] fn test_scalar_union_sparse() { let field_a = Arc::new(Field::new("A", DataType::Int32, true)); @@ -5677,7 +5947,7 @@ mod tests { let field_a = Arc::new(Field::new("A", DataType::Utf8, false)); let field_primitive_list = Arc::new(Field::new( "primitive_list", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), false, )); @@ -5744,13 +6014,13 @@ mod tests { // Define list-of-structs scalars let nl0_array = ScalarValue::iter_to_array(vec![s0, s1.clone()]).unwrap(); - let nl0 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl0_array))); + let nl0 = SingleRowListArrayBuilder::new(nl0_array).build_list_scalar(); let nl1_array = ScalarValue::iter_to_array(vec![s2]).unwrap(); - let nl1 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl1_array))); + let nl1 = SingleRowListArrayBuilder::new(nl1_array).build_list_scalar(); let nl2_array = ScalarValue::iter_to_array(vec![s1]).unwrap(); - let nl2 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl2_array))); + let nl2 = SingleRowListArrayBuilder::new(nl2_array).build_list_scalar(); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); @@ -5878,9 +6148,8 @@ mod tests { fn build_2d_list(data: Vec>) -> ListArray { let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + Arc::new(Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), true, )), OffsetBuffer::::from_lengths([1]), @@ -5948,9 +6217,9 @@ mod tests { &DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); - let newscalar = ScalarValue::try_from_array(&array, 0).unwrap(); + let new_scalar = ScalarValue::try_from_array(&array, 0).unwrap(); assert_eq!( - newscalar.data_type(), + new_scalar.data_type(), DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); } @@ -5980,6 +6249,51 @@ mod tests { ScalarValue::from("larger than 12 bytes string"), DataType::Utf8View, ); + check_scalar_cast( + { + let element_field = + Arc::new(Field::new("element", DataType::Int32, true)); + + let mut builder = + ListBuilder::new(Int32Builder::new()).with_field(element_field); + builder.append_value([Some(1)]); + builder.append(true); + + ScalarValue::List(Arc::new(builder.finish())) + }, + DataType::List(Arc::new(Field::new("element", DataType::Int64, true))), + ); + check_scalar_cast( + { + let element_field = + Arc::new(Field::new("element", DataType::Int32, true)); + + let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 1) + .with_field(element_field); + builder.values().append_value(1); + builder.append(true); + + ScalarValue::FixedSizeList(Arc::new(builder.finish())) + }, + DataType::FixedSizeList( + Arc::new(Field::new("element", DataType::Int64, true)), + 1, + ), + ); + check_scalar_cast( + { + let element_field = + Arc::new(Field::new("element", DataType::Int32, true)); + + let mut builder = + LargeListBuilder::new(Int32Builder::new()).with_field(element_field); + builder.append_value([Some(1)]); + builder.append(true); + + ScalarValue::LargeList(Arc::new(builder.finish())) + }, + DataType::LargeList(Arc::new(Field::new("element", DataType::Int64, true))), + ); } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` @@ -6611,6 +6925,43 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_null_bug() { + let field_a = Field::new("a", DataType::Int32, true); + let field_b = Field::new("b", DataType::Int32, true); + let fields = Fields::from(vec![field_a, field_b]); + + let array_a = Arc::new(Int32Array::from_iter_values([1])); + let array_b = Arc::new(Int32Array::from_iter_values([2])); + let arrays: Vec = vec![array_a, array_b]; + + let mut not_nulls = BooleanBufferBuilder::new(1); + not_nulls.append(true); + let not_nulls = not_nulls.finish(); + let not_nulls = Some(NullBuffer::new(not_nulls)); + + let ar = StructArray::new(fields, arrays, not_nulls); + let s = ScalarValue::Struct(Arc::new(ar)); + + assert_eq!(s.to_string(), "{a:1,b:2}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:2})"#); + + let ScalarValue::Struct(arr) = s else { + panic!("Expected struct"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); + let expected = [ + "+--------------+", + "| s |", + "+--------------+", + "| {a: 1, b: 2} |", + "+--------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; @@ -6793,8 +7144,7 @@ mod tests { assert_eq!(1, arr.len()); assert_eq!( arr.data_type(), - &DataType::List(Arc::new(Field::new( - "item", + &DataType::List(Arc::new(Field::new_list_field( DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), true, ))) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index d8e62b3045f93..d2ce965c5c493 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -190,7 +190,7 @@ impl Precision { } } -impl Debug for Precision { +impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -200,7 +200,7 @@ impl Debug for Precision } } -impl Display for Precision { +impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -258,6 +258,48 @@ impl Statistics { self } + /// Project the statistics to the given column indices. + /// + /// For example, if we had statistics for columns `{"a", "b", "c"}`, + /// projecting to `vec![2, 1]` would return statistics for columns `{"c", + /// "b"}`. + pub fn project(mut self, projection: Option<&Vec>) -> Self { + let Some(projection) = projection else { + return self; + }; + + enum Slot { + /// The column is taken and put into the specified statistics location + Taken(usize), + /// The original columns is present + Present(ColumnStatistics), + } + + // Convert to Vec so we can avoid copying the statistics + let mut columns: Vec<_> = std::mem::take(&mut self.column_statistics) + .into_iter() + .map(Slot::Present) + .collect(); + + for idx in projection { + let next_idx = self.column_statistics.len(); + let slot = std::mem::replace( + columns.get_mut(*idx).expect("projection out of bounds"), + Slot::Taken(next_idx), + ); + match slot { + // The column was there, so just move it + Slot::Present(col) => self.column_statistics.push(col), + // The column was taken, so copy from the previous location + Slot::Taken(prev_idx) => self + .column_statistics + .push(self.column_statistics[prev_idx].clone()), + } + } + + self + } + /// Calculates the statistics after `fetch` and `skip` operations apply. /// Here, `self` denotes per-partition statistics. Use the `n_partitions` /// parameter to compute global statistics in a multi-partition setting. @@ -341,7 +383,7 @@ fn check_num_rows(value: Option, is_exact: bool) -> Precision { } impl Display for Statistics { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // string of column statistics let column_stats = self .column_statistics @@ -561,4 +603,50 @@ mod tests { let p2 = precision.clone(); assert_eq!(precision, p2); } + + #[test] + fn test_project_none() { + let projection = None; + let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); + assert_eq!(stats, make_stats(vec![10, 20, 30])); + } + + #[test] + fn test_project_empty() { + let projection = Some(vec![]); + let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); + assert_eq!(stats, make_stats(vec![])); + } + + #[test] + fn test_project_swap() { + let projection = Some(vec![2, 1]); + let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); + assert_eq!(stats, make_stats(vec![30, 20])); + } + + #[test] + fn test_project_repeated() { + let projection = Some(vec![1, 2, 1, 1, 0, 2]); + let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); + assert_eq!(stats, make_stats(vec![20, 30, 20, 20, 10, 30])); + } + + // Make a Statistics structure with the specified null counts for each column + fn make_stats(counts: impl IntoIterator) -> Statistics { + Statistics { + num_rows: Precision::Exact(42), + total_byte_size: Precision::Exact(500), + column_statistics: counts.into_iter().map(col_stats_i64).collect(), + } + } + + fn col_stats_i64(null_count: usize) -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Exact(null_count), + max_value: Precision::Exact(ScalarValue::Int64(Some(42))), + min_value: Precision::Exact(ScalarValue::Int64(Some(64))), + distinct_count: Precision::Exact(100), + } + } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 67f3da4f48deb..bb53a30dcb234 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -19,7 +19,7 @@ use crate::utils::{parse_identifiers_normalized, quote_identifier}; use std::sync::Arc; /// A fully resolved path to a table of the form "catalog.schema.table" -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ResolvedTableReference { /// The catalog (aka database) containing the table pub catalog: Arc, diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 36254192550c8..d3b8c84512583 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -279,8 +279,88 @@ pub fn get_data_dir( } } +#[macro_export] +macro_rules! create_array { + (Boolean, $values: expr) => { + std::sync::Arc::new(arrow::array::BooleanArray::from($values)) + }; + (Int8, $values: expr) => { + std::sync::Arc::new(arrow::array::Int8Array::from($values)) + }; + (Int16, $values: expr) => { + std::sync::Arc::new(arrow::array::Int16Array::from($values)) + }; + (Int32, $values: expr) => { + std::sync::Arc::new(arrow::array::Int32Array::from($values)) + }; + (Int64, $values: expr) => { + std::sync::Arc::new(arrow::array::Int64Array::from($values)) + }; + (UInt8, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt8Array::from($values)) + }; + (UInt16, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt16Array::from($values)) + }; + (UInt32, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt32Array::from($values)) + }; + (UInt64, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt64Array::from($values)) + }; + (Float16, $values: expr) => { + std::sync::Arc::new(arrow::array::Float16Array::from($values)) + }; + (Float32, $values: expr) => { + std::sync::Arc::new(arrow::array::Float32Array::from($values)) + }; + (Float64, $values: expr) => { + std::sync::Arc::new(arrow::array::Float64Array::from($values)) + }; + (Utf8, $values: expr) => { + std::sync::Arc::new(arrow::array::StringArray::from($values)) + }; +} + +/// Creates a record batch from literal slice of values, suitable for rapid +/// testing and development. +/// +/// Example: +/// ``` +/// use datafusion_common::{record_batch, create_array}; +/// let batch = record_batch!( +/// ("a", Int32, vec![1, 2, 3]), +/// ("b", Float64, vec![Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, vec!["alpha", "beta", "gamma"]) +/// ); +/// ``` +#[macro_export] +macro_rules! record_batch { + ($(($name: expr, $type: ident, $values: expr)),*) => { + { + let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + $( + arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + )* + ])); + + let batch = arrow_array::RecordBatch::try_new( + schema, + vec![$( + $crate::create_array!($type, $values), + )*] + ); + + batch + } + } +} + #[cfg(test)] mod tests { + use crate::cast::{as_float64_array, as_int32_array, as_string_array}; + use crate::error::Result; + use super::*; use std::env; @@ -333,4 +413,44 @@ mod tests { let res = parquet_test_data(); assert!(PathBuf::from(res).is_dir()); } + + #[test] + fn test_create_record_batch() -> Result<()> { + use arrow_array::Array; + + let batch = record_batch!( + ("a", Int32, vec![1, 2, 3, 4]), + ("b", Float64, vec![Some(4.0), None, Some(5.0), None]), + ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"]) + )?; + + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + let values: Vec<_> = as_int32_array(batch.column(0))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![1, 2, 3, 4]); + + let values: Vec<_> = as_float64_array(batch.column(1))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]); + + let nulls: Vec<_> = as_float64_array(batch.column(1))? + .nulls() + .unwrap() + .iter() + .collect(); + assert_eq!(nulls, vec![true, false, true, false]); + + let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect(); + assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]); + + Ok(()) + } } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index b4d3251fd2630..c70389b631773 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,9 +17,10 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees -use std::sync::Arc; - use crate::Result; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Arc; /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { @@ -123,6 +124,7 @@ pub trait TreeNode: Sized { /// TreeNodeVisitor::f_up(ChildNode2) /// TreeNodeVisitor::f_up(ParentNode) /// ``` + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>( &'n self, visitor: &mut V, @@ -172,6 +174,7 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ChildNode2) /// TreeNodeRewriter::f_up(ParentNode) /// ``` + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite>( self, rewriter: &mut R, @@ -194,6 +197,7 @@ pub trait TreeNode: Sized { &'n self, mut f: F, ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result>( node: &'n N, f: &mut F, @@ -228,6 +232,7 @@ pub trait TreeNode: Sized { self, mut f: F, ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn transform_down_impl Result>>( node: N, f: &mut F, @@ -238,15 +243,6 @@ pub trait TreeNode: Sized { transform_down_impl(self, &mut f) } - /// Same as [`Self::transform_down`] but with a mutable closure. - #[deprecated(since = "38.0.0", note = "Use `transform_down` instead")] - fn transform_down_mut Result>>( - self, - f: &mut F, - ) -> Result> { - self.transform_down(f) - } - /// Recursively rewrite the node using `f` in a bottom-up (post-order) /// fashion. /// @@ -260,6 +256,7 @@ pub trait TreeNode: Sized { self, mut f: F, ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn transform_up_impl Result>>( node: N, f: &mut F, @@ -271,15 +268,6 @@ pub trait TreeNode: Sized { transform_up_impl(self, &mut f) } - /// Same as [`Self::transform_up`] but with a mutable closure. - #[deprecated(since = "38.0.0", note = "Use `transform_up` instead")] - fn transform_up_mut Result>>( - self, - f: &mut F, - ) -> Result> { - self.transform_up(f) - } - /// Transforms the node using `f_down` while traversing the tree top-down /// (pre-order), and using `f_up` while traversing the tree bottom-up /// (post-order). @@ -383,6 +371,7 @@ pub trait TreeNode: Sized { mut f_down: FD, mut f_up: FU, ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn transform_down_up_impl< N: TreeNode, FD: FnMut(N) -> Result>, @@ -780,6 +769,297 @@ impl Transformed { } } +/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped. +/// The elements of the container are siblings so the continuation rules are similar to +/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`]. +pub trait TreeNodeContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result; + + /// Maps all elements of the container with `f`. + /// This method is usually called from [`TreeNode::map_children`] implementations as + /// a node is actually a container of the node's children. + fn map_elements Result>>( + self, + f: F, + ) -> Result>; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + Arc::unwrap_or_clone(self) + .map_elements(f)? + .map_data(|c| Ok(Arc::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + match self { + Some(t) => t.apply_elements(f), + None => Ok(TreeNodeRecursion::Continue), + } + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.map_or(Ok(Transformed::no(None)), |c| { + c.map_elements(f)?.map_data(|c| Ok(Some(c))) + }) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|c| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(c), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> + for HashMap +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self.values() { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|(k, c)| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + (k, result.data) + }) + } + TreeNodeRecursion::Stop => Ok((k, c)), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeContainer<'a, T> for (C0, C1) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1)))? + .transform_sibling(|(new_c0, c1)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1))) + }) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeContainer<'a, T> for (C0, C1, C2) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1, self.2)))? + .transform_sibling(|(new_c0, c1, c2)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1, c2))) + })? + .transform_sibling(|(new_c0, new_c1, c2)| { + c2.map_elements(&mut f)? + .map_data(|new_c2| Ok((new_c0, new_c1, new_c2))) + }) + } +} + +/// [`TreeNodeRefContainer`] contains references to elements that a function can be +/// applied on. The elements of the container are siblings so the continuation rules are +/// similar to [`TreeNodeRecursion::visit_sibling`]. +/// +/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference +/// elements (`T`) are not derived from the container's lifetime. +/// A typical usage of this container is in `Expr::apply_children` when we need to +/// construct a temporary container to be able to call `apply_ref_elements` on a +/// collection of tree node references. But in that case the container's temporary +/// lifetime is different to the lifetime of tree nodes that we put into it. +/// Please find an example use case in `Expr::apply_children` with the `Expr::Case` case. +/// +/// Most of the cases we don't need to create a temporary container with +/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`. +/// Please find an example use case in `Expr::apply_children` with the `Expr::GroupingSet` +/// case. +pub trait TreeNodeRefContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_ref_elements Result>( + &self, + f: F, + ) -> Result; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> { + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } +} + /// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator @@ -854,50 +1134,6 @@ impl TreeNodeIterator for I { } } -/// Transformation helper to process a heterogeneous sequence of tree node containing -/// expressions. -/// -/// This macro is very similar to [TreeNodeIterator::map_until_stop_and_collect] to -/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and -/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its -/// transformation (`F`). -/// -/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the -/// first element and further elements from the sequence of pairs. An element from a pair -/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on -/// the `Transformed.tnr` result of previous `F`s (`F0` initially). -/// -/// # Returns -/// Error if any of the transformations returns an error -/// -/// Ok(Transformed<(data0, ..., dataN)>) such that: -/// 1. `transformed` is true if any of the transformations had transformed true -/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and -/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F` -/// 3. `tnr` from `F0` or the last invocation of `F` -#[macro_export] -macro_rules! map_until_stop_and_collect { - ($F0:expr, $($EXPR:expr, $F:expr),*) => {{ - $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| { - let all_datas = ( - data0, - $( - if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump { - $F.map(|result| { - tnr = result.tnr; - transformed |= result.transformed; - result.data - })? - } else { - $EXPR - }, - )* - ); - Ok(Transformed::new(all_datas, transformed, tnr)) - }) - }} -} - /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -1027,26 +1263,37 @@ impl TreeNode for T { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::collections::HashMap; use std::fmt::Display; use crate::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use crate::Result; - #[derive(Debug, Eq, Hash, PartialEq)] - struct TestTreeNode { - children: Vec>, - data: T, + #[derive(Debug, Eq, Hash, PartialEq, Clone)] + pub struct TestTreeNode { + pub(crate) children: Vec>, + pub(crate) data: T, } impl TestTreeNode { - fn new(children: Vec>, data: T) -> Self { + pub(crate) fn new(children: Vec>, data: T) -> Self { Self { children, data } } + + pub(crate) fn new_leaf(data: T) -> Self { + Self { + children: vec![], + data, + } + } + + pub(crate) fn is_leaf(&self) -> bool { + self.children.is_empty() + } } impl TreeNode for TestTreeNode { @@ -1054,7 +1301,7 @@ mod tests { &'n self, f: F, ) -> Result { - self.children.iter().apply_until_stop(f) + self.children.apply_elements(f) } fn map_children Result>>( @@ -1063,8 +1310,7 @@ mod tests { ) -> Result> { Ok(self .children - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|new_children| Self { children: new_children, ..self @@ -1072,6 +1318,22 @@ mod tests { } } + impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } + } + // J // | // I @@ -1086,12 +1348,12 @@ mod tests { // | // A fn test_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1130,13 +1392,13 @@ mod tests { // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1146,12 +1408,12 @@ mod tests { // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1160,12 +1422,12 @@ mod tests { // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1202,12 +1464,12 @@ mod tests { } fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1236,12 +1498,12 @@ mod tests { } fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1250,12 +1512,12 @@ mod tests { } fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1289,12 +1551,12 @@ mod tests { } fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1303,12 +1565,12 @@ mod tests { } fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1372,12 +1634,12 @@ mod tests { } fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1385,12 +1647,12 @@ mod tests { } fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1406,12 +1668,12 @@ mod tests { } fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1419,12 +1681,12 @@ mod tests { } fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1451,12 +1713,12 @@ mod tests { } fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1464,12 +1726,12 @@ mod tests { } fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1499,13 +1761,13 @@ mod tests { } fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1513,12 +1775,12 @@ mod tests { } fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -2016,16 +2278,16 @@ mod tests { // A #[test] fn test_apply_and_visit_references() -> Result<()> { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_a_2 = TestTreeNode::new(vec![], "a".to_string()); - let node_b_2 = TestTreeNode::new(vec![], "b".to_string()); + let node_a_2 = TestTreeNode::new_leaf("a".to_string()); + let node_b_2 = TestTreeNode::new_leaf("b".to_string()); let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string()); let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string()); - let node_a_3 = TestTreeNode::new(vec![], "a".to_string()); + let node_a_3 = TestTreeNode::new_leaf("a".to_string()); let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string()); let node_f_ref = &tree; @@ -2086,4 +2348,18 @@ mod tests { Ok(()) } + + #[cfg(feature = "recursive_protection")] + #[test] + fn test_large_tree() { + let mut item = TestTreeNode::new_leaf("initial".to_string()); + for i in 0..3000 { + item = TestTreeNode::new(vec![item], format!("parent-{}", i)); + } + + let mut visitor = + TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); + + item.visit(&mut visitor).unwrap(); + } } diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index 51fca8f21ebe8..ec69db7903779 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -15,25 +15,35 @@ // specific language governing permissions and limitations // under the License. -use super::{LogicalType, NativeType}; +use crate::types::{LogicalTypeRef, NativeType}; +use std::sync::{Arc, LazyLock}; -#[derive(Debug)] -pub struct BuiltinType { - native: NativeType, -} - -impl LogicalType for BuiltinType { - fn native(&self) -> &NativeType { - &self.native - } +macro_rules! singleton { + ($name:ident, $getter:ident, $ty:ident) => { + static $name: LazyLock = + LazyLock::new(|| Arc::new(NativeType::$ty)); - fn name(&self) -> Option<&str> { - None - } + #[doc = "Getter for singleton instance of a logical type representing"] + #[doc = concat!("[`NativeType::", stringify!($ty), "`].")] + pub fn $getter() -> LogicalTypeRef { + Arc::clone(&$name) + } + }; } -impl From for BuiltinType { - fn from(native: NativeType) -> Self { - Self { native } - } -} +singleton!(LOGICAL_NULL, logical_null, Null); +singleton!(LOGICAL_BOOLEAN, logical_boolean, Boolean); +singleton!(LOGICAL_INT8, logical_int8, Int8); +singleton!(LOGICAL_INT16, logical_int16, Int16); +singleton!(LOGICAL_INT32, logical_int32, Int32); +singleton!(LOGICAL_INT64, logical_int64, Int64); +singleton!(LOGICAL_UINT8, logical_uint8, UInt8); +singleton!(LOGICAL_UINT16, logical_uint16, UInt16); +singleton!(LOGICAL_UINT32, logical_uint32, UInt32); +singleton!(LOGICAL_UINT64, logical_uint64, UInt64); +singleton!(LOGICAL_FLOAT16, logical_float16, Float16); +singleton!(LOGICAL_FLOAT32, logical_float32, Float32); +singleton!(LOGICAL_FLOAT64, logical_float64, Float64); +singleton!(LOGICAL_DATE, logical_date, Date); +singleton!(LOGICAL_BINARY, logical_binary, Binary); +singleton!(LOGICAL_STRING, logical_string, String); diff --git a/datafusion/common/src/types/field.rs b/datafusion/common/src/types/field.rs new file mode 100644 index 0000000000000..85c7c157272ae --- /dev/null +++ b/datafusion/common/src/types/field.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Field, Fields, UnionFields}; +use std::hash::{Hash, Hasher}; +use std::{ops::Deref, sync::Arc}; + +use super::{LogicalTypeRef, NativeType}; + +/// A record of a logical type, its name and its nullability. +#[derive(Debug, Clone, Eq, PartialOrd, Ord)] +pub struct LogicalField { + pub name: String, + pub logical_type: LogicalTypeRef, + pub nullable: bool, +} + +impl PartialEq for LogicalField { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.logical_type.eq(&other.logical_type) + && self.nullable == other.nullable + } +} + +impl Hash for LogicalField { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.logical_type.hash(state); + self.nullable.hash(state); + } +} + +impl From<&Field> for LogicalField { + fn from(value: &Field) -> Self { + Self { + name: value.name().clone(), + logical_type: Arc::new(NativeType::from(value.data_type().clone())), + nullable: value.is_nullable(), + } + } +} + +/// A reference counted [`LogicalField`]. +pub type LogicalFieldRef = Arc; + +/// A cheaply cloneable, owned collection of [`LogicalFieldRef`]. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct LogicalFields(Arc<[LogicalFieldRef]>); + +impl Deref for LogicalFields { + type Target = [LogicalFieldRef]; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl From<&Fields> for LogicalFields { + fn from(value: &Fields) -> Self { + value + .iter() + .map(|field| Arc::new(LogicalField::from(field.as_ref()))) + .collect() + } +} + +impl FromIterator for LogicalFields { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +/// A cheaply cloneable, owned collection of [`LogicalFieldRef`] and their +/// corresponding type ids. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct LogicalUnionFields(Arc<[(i8, LogicalFieldRef)]>); + +impl Deref for LogicalUnionFields { + type Target = [(i8, LogicalFieldRef)]; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl From<&UnionFields> for LogicalUnionFields { + fn from(value: &UnionFields) -> Self { + value + .iter() + .map(|(i, field)| (i, Arc::new(LogicalField::from(field.as_ref())))) + .collect() + } +} + +impl FromIterator<(i8, LogicalFieldRef)> for LogicalUnionFields { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index 0121e33d9d5e3..a65392cae3444 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use super::NativeType; +use crate::error::Result; +use arrow_schema::DataType; use core::fmt; use std::{cmp::Ordering, hash::Hash, sync::Arc}; -use super::NativeType; - /// Signature that uniquely identifies a type among other types. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum TypeSignature<'a> { @@ -63,7 +64,7 @@ pub type LogicalTypeRef = Arc; /// /// impl LogicalType for JSON { /// fn native(&self) -> &NativeType { -/// &NativeType::Utf8 +/// &NativeType::String /// } /// /// fn signature(&self) -> TypeSignature<'_> { @@ -75,12 +76,21 @@ pub type LogicalTypeRef = Arc; /// } /// ``` pub trait LogicalType: Sync + Send { + /// Get the native backing type of this logical type. fn native(&self) -> &NativeType; + /// Get the unique type signature for this logical type. Logical types with identical + /// signatures are considered equal. fn signature(&self) -> TypeSignature<'_>; + + /// Get the default physical type to cast `origin` to in order to obtain a physical type + /// that is logically compatible with this logical type. + fn default_cast_for(&self, origin: &DataType) -> Result { + self.native().default_cast_for(origin) + } } impl fmt::Debug for dyn LogicalType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("LogicalType") .field(&self.signature()) .field(&self.native()) @@ -88,9 +98,15 @@ impl fmt::Debug for dyn LogicalType { } } +impl std::fmt::Display for dyn LogicalType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + impl PartialEq for dyn LogicalType { fn eq(&self, other: &Self) -> bool { - self.native().eq(other.native()) && self.signature().eq(&other.signature()) + self.signature().eq(&other.signature()) } } diff --git a/datafusion/common/src/types/mod.rs b/datafusion/common/src/types/mod.rs index 4e1bcb75cb56e..2f9ce4ce02827 100644 --- a/datafusion/common/src/types/mod.rs +++ b/datafusion/common/src/types/mod.rs @@ -15,8 +15,12 @@ // specific language governing permissions and limitations // under the License. +mod builtin; +mod field; mod logical; mod native; +pub use builtin::*; +pub use field::*; pub use logical::*; pub use native::*; diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 1ebf4e19b01ed..c5f180a150352 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -15,61 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::{ops::Deref, sync::Arc}; - -use arrow_schema::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields}; - -use super::{LogicalType, TypeSignature}; - -/// A record of a native type, its name and its nullability. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct NativeField { - name: String, - native_type: NativeType, - nullable: bool, -} - -impl NativeField { - pub fn name(&self) -> &str { - &self.name - } - - pub fn native_type(&self) -> &NativeType { - &self.native_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } -} - -/// A reference counted [`NativeField`]. -pub type NativeFieldRef = Arc; - -/// A cheaply cloneable, owned collection of [`NativeFieldRef`]. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct NativeFields(Arc<[NativeFieldRef]>); - -impl Deref for NativeFields { - type Target = [NativeFieldRef]; - - fn deref(&self) -> &Self::Target { - self.0.as_ref() - } -} - -/// A cheaply cloneable, owned collection of [`NativeFieldRef`] and their -/// corresponding type ids. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct NativeUnionFields(Arc<[(i8, NativeFieldRef)]>); - -impl Deref for NativeUnionFields { - type Target = [(i8, NativeFieldRef)]; - - fn deref(&self) -> &Self::Target { - self.0.as_ref() - } -} +use super::{ + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, + TypeSignature, +}; +use crate::error::{Result, _internal_err}; +use arrow::compute::can_cast_types; +use arrow_schema::{ + DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, +}; +use std::{fmt::Display, sync::Arc}; /// Representation of a type that DataFusion can handle natively. It is a subset /// of the physical variants in Arrow's native [`DataType`]. @@ -194,15 +149,15 @@ pub enum NativeType { /// Enum parameter specifies the number of bytes per value. FixedSizeBinary(i32), /// A variable-length string in Unicode with UTF-8 encoding. - Utf8, + String, /// A list of some logical data type with variable length. - List(NativeFieldRef), + List(LogicalFieldRef), /// A list of some logical data type with fixed length. - FixedSizeList(NativeFieldRef, i32), + FixedSizeList(LogicalFieldRef, i32), /// A nested type that contains a number of sub-fields. - Struct(NativeFields), + Struct(LogicalFields), /// A nested type that can represent slots of differing types. - Union(NativeUnionFields), + Union(LogicalUnionFields), /// Decimal value with precision and scale /// /// * precision is the total number of digits @@ -222,11 +177,16 @@ pub enum NativeType { /// The key and value types are not constrained, but keys should be /// hashable and unique. /// - /// In a field with Map type, the field has a child Struct field, which then - /// has two children: key type and the second the value type. The names of the + /// In a field with Map type, key type and the second the value type. The names of the /// child fields may be respectively "entries", "key", and "value", but this is /// not enforced. - Map(NativeFieldRef), + Map(LogicalFieldRef), +} + +impl Display for NativeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NativeType::{self:?}") + } } impl LogicalType for NativeType { @@ -237,12 +197,171 @@ impl LogicalType for NativeType { fn signature(&self) -> TypeSignature<'_> { TypeSignature::Native(self) } + + fn default_cast_for(&self, origin: &DataType) -> Result { + use DataType::*; + + fn default_field_cast(to: &LogicalField, from: &Field) -> Result { + Ok(Arc::new(Field::new( + to.name.clone(), + to.logical_type.default_cast_for(from.data_type())?, + to.nullable, + ))) + } + + Ok(match (self, origin) { + (Self::Null, _) => Null, + (Self::Boolean, _) => Boolean, + (Self::Int8, _) => Int8, + (Self::Int16, _) => Int16, + (Self::Int32, _) => Int32, + (Self::Int64, _) => Int64, + (Self::UInt8, _) => UInt8, + (Self::UInt16, _) => UInt16, + (Self::UInt32, _) => UInt32, + (Self::UInt64, _) => UInt64, + (Self::Float16, _) => Float16, + (Self::Float32, _) => Float32, + (Self::Float64, _) => Float64, + (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), + (Self::Decimal(p, s), _) => Decimal256(*p, *s), + (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), + (Self::Date, _) => Date32, + (Self::Time(tu), _) => match tu { + TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), + TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu), + }, + (Self::Duration(tu), _) => Duration(*tu), + (Self::Interval(iu), _) => Interval(*iu), + (Self::Binary, LargeUtf8) => LargeBinary, + (Self::Binary, Utf8View) => BinaryView, + (Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => { + BinaryView + } + (Self::Binary, data_type) if can_cast_types(data_type, &LargeBinary) => { + LargeBinary + } + (Self::Binary, data_type) if can_cast_types(data_type, &Binary) => Binary, + (Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size), + (Self::String, LargeBinary) => LargeUtf8, + (Self::String, BinaryView) => Utf8View, + // We don't cast to another kind of string type if the origin one is already a string type + (Self::String, Utf8 | LargeUtf8 | Utf8View) => origin.to_owned(), + (Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View, + (Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => { + LargeUtf8 + } + (Self::String, data_type) if can_cast_types(data_type, &Utf8) => Utf8, + (Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => { + List(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), LargeList(from_field)) => { + LargeList(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), ListView(from_field)) => { + ListView(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), LargeListView(from_field)) => { + LargeListView(default_field_cast(to_field, from_field)?) + } + // List array where each element is a len 1 list of the origin type + (Self::List(field), _) => List(Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(origin)?, + field.nullable, + ))), + ( + Self::FixedSizeList(to_field, to_size), + FixedSizeList(from_field, from_size), + ) if from_size == to_size => { + FixedSizeList(default_field_cast(to_field, from_field)?, *to_size) + } + ( + Self::FixedSizeList(to_field, size), + List(from_field) + | LargeList(from_field) + | ListView(from_field) + | LargeListView(from_field), + ) => FixedSizeList(default_field_cast(to_field, from_field)?, *size), + // FixedSizeList array where each element is a len 1 list of the origin type + (Self::FixedSizeList(field, size), _) => FixedSizeList( + Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(origin)?, + field.nullable, + )), + *size, + ), + // From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196 + (Self::Struct(to_fields), Struct(from_fields)) + if from_fields.len() == to_fields.len() => + { + Struct( + from_fields + .iter() + .zip(to_fields.iter()) + .map(|(from, to)| default_field_cast(to, from)) + .collect::>()?, + ) + } + (Self::Struct(to_fields), Null) => Struct( + to_fields + .iter() + .map(|field| { + Ok(Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(&Null)?, + field.nullable, + ))) + }) + .collect::>()?, + ), + (Self::Map(to_field), Map(from_field, sorted)) => { + Map(default_field_cast(to_field, from_field)?, *sorted) + } + (Self::Map(field), Null) => Map( + Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(&Null)?, + field.nullable, + )), + false, + ), + (Self::Union(to_fields), Union(from_fields, mode)) + if from_fields.len() == to_fields.len() => + { + Union( + from_fields + .iter() + .zip(to_fields.iter()) + .map(|((_, from), (i, to))| { + Ok((*i, default_field_cast(to, from)?)) + }) + .collect::>()?, + *mode, + ) + } + _ => { + return _internal_err!( + "Unavailable default cast for native type {:?} from physical type {:?}", + self, + origin + ) + } + }) + } } // The following From, From, ... implementations are temporary // mapping solutions to provide backwards compatibility while transitioning from // the purely physical system to a logical / physical system. +impl From<&DataType> for NativeType { + fn from(value: &DataType) -> Self { + value.clone().into() + } +} + impl From for NativeType { fn from(value: DataType) -> Self { use NativeType::*; @@ -267,7 +386,7 @@ impl From for NativeType { DataType::Interval(iu) => Interval(iu), DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary, DataType::FixedSizeBinary(size) => FixedSizeBinary(size), - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Utf8, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => String, DataType::List(field) | DataType::ListView(field) | DataType::LargeList(field) @@ -275,54 +394,70 @@ impl From for NativeType { DataType::FixedSizeList(field, size) => { FixedSizeList(Arc::new(field.as_ref().into()), size) } - DataType::Struct(fields) => Struct(NativeFields::from(&fields)), + DataType::Struct(fields) => Struct(LogicalFields::from(&fields)), DataType::Union(union_fields, _) => { - Union(NativeUnionFields::from(&union_fields)) + Union(LogicalUnionFields::from(&union_fields)) } - DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s), DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())), + DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), DataType::RunEndEncoded(_, field) => field.data_type().clone().into(), } } } -impl From<&Field> for NativeField { - fn from(value: &Field) -> Self { - Self { - name: value.name().clone(), - native_type: value.data_type().clone().into(), - nullable: value.is_nullable(), - } +impl NativeType { + #[inline] + pub fn is_numeric(&self) -> bool { + use NativeType::*; + matches!( + self, + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float16 + | Float32 + | Float64 + | Decimal(_, _) + ) } -} -impl From<&Fields> for NativeFields { - fn from(value: &Fields) -> Self { - value - .iter() - .map(|field| Arc::new(NativeField::from(field.as_ref()))) - .collect() + #[inline] + pub fn is_integer(&self) -> bool { + use NativeType::*; + matches!( + self, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 + ) } -} -impl FromIterator for NativeFields { - fn from_iter>(iter: T) -> Self { - Self(iter.into_iter().collect()) + #[inline] + pub fn is_timestamp(&self) -> bool { + matches!(self, NativeType::Timestamp(_, _)) } -} -impl From<&UnionFields> for NativeUnionFields { - fn from(value: &UnionFields) -> Self { - value - .iter() - .map(|(i, field)| (i, Arc::new(NativeField::from(field.as_ref())))) - .collect() + #[inline] + pub fn is_date(&self) -> bool { + matches!(self, NativeType::Date) + } + + #[inline] + pub fn is_time(&self) -> bool { + matches!(self, NativeType::Time(_)) + } + + #[inline] + pub fn is_interval(&self) -> bool { + matches!(self, NativeType::Interval(_)) } -} -impl FromIterator<(i8, NativeFieldRef)> for NativeUnionFields { - fn from_iter>(iter: T) -> Self { - Self(iter.into_iter().collect()) + #[inline] + pub fn is_duration(&self) -> bool { + matches!(self, NativeType::Duration(_)) } } diff --git a/datafusion/common/src/unnest.rs b/datafusion/common/src/unnest.rs index fd92267f9b4c3..db48edd061605 100644 --- a/datafusion/common/src/unnest.rs +++ b/datafusion/common/src/unnest.rs @@ -17,6 +17,8 @@ //! [`UnnestOptions`] for unnesting structured types +use crate::Column; + /// Options for unnesting a column that contains a list type, /// replicating values in the other, non nested rows. /// @@ -60,10 +62,27 @@ /// └─────────┘ └─────┘ └─────────┘ └─────┘ /// c1 c2 c1 c2 /// ``` +/// +/// `recursions` instruct how a column should be unnested (e.g unnesting a column multiple +/// time, with depth = 1 and depth = 2). Any unnested column not being mentioned inside this +/// options is inferred to be unnested with depth = 1 #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] pub struct UnnestOptions { /// Should nulls in the input be preserved? Defaults to true pub preserve_nulls: bool, + /// If specific columns need to be unnested multiple times (e.g at different depth), + /// declare them here. Any unnested columns not being mentioned inside this option + /// will be unnested with depth = 1 + pub recursions: Vec, +} + +/// Instruction on how to unnest a column (mostly with a list type) +/// such as how to name the output, and how many level it should be unnested +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct RecursionUnnestOption { + pub input_column: Column, + pub output_column: Column, + pub depth: usize, } impl Default for UnnestOptions { @@ -71,6 +90,7 @@ impl Default for UnnestOptions { Self { // default to true to maintain backwards compatible behavior preserve_nulls: true, + recursions: vec![], } } } @@ -87,4 +107,10 @@ impl UnnestOptions { self.preserve_nulls = preserve_nulls; self } + + /// Set the recursions for the unnest operation + pub fn with_recursions(mut self, recursion: RecursionUnnestOption) -> Self { + self.recursions.push(recursion); + self + } } diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index 2c34b61bd0930..ab73996fcd8b7 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! This module provides a function to estimate the memory size of a HashTable prior to alloaction +//! This module provides a function to estimate the memory size of a HashTable prior to allocation use crate::{DataFusionError, Result}; +use std::mem::size_of; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -78,7 +79,7 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result // For the majority of cases hashbrown overestimates the bucket quantity // to keep ~1/8 of them empty. We take this factor into account by // multiplying the number of elements with a fixed ratio of 8/7 (~1.14). - // This formula leads to overallocation for small tables (< 8 elements) + // This formula leads to over-allocation for small tables (< 8 elements) // but should be fine overall. num_elements .checked_mul(8) @@ -87,7 +88,7 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result // + size of entry * number of buckets // + 1 byte for each bucket // + fixed size of collection (HashSet/HashTable) - std::mem::size_of::() + size_of::() .checked_mul(estimated_buckets)? .checked_add(estimated_buckets)? .checked_add(fixed_size) @@ -101,14 +102,14 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result #[cfg(test)] mod tests { - use std::collections::HashSet; + use std::{collections::HashSet, mem::size_of}; use super::estimate_memory_size; #[test] fn test_estimate_memory() { // size (bytes): 48 - let fixed_size = std::mem::size_of::>(); + let fixed_size = size_of::>(); // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two() let num_elements = 8; @@ -126,7 +127,7 @@ mod tests { #[test] fn test_estimate_memory_overflow() { let num_elements = usize::MAX; - let fixed_size = std::mem::size_of::>(); + let fixed_size = size_of::>(); let estimated = estimate_memory_size::(num_elements, fixed_size); assert!(estimated.is_err()); diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 5bf0f08b092a4..29d33fec14abe 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -23,17 +23,14 @@ pub mod proxy; pub mod string_utils; use crate::error::{_internal_datafusion_err, _internal_err}; -use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use arrow::array::{ArrayRef, PrimitiveArray}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow::array::ArrayRef; use arrow::buffer::OffsetBuffer; -use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{Field, SchemaRef, UInt32Type}; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::{Field, SchemaRef}; use arrow_array::cast::AsArray; use arrow_array::{ Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, - RecordBatchOptions, }; use arrow_schema::DataType; use sqlparser::ast::Ident; @@ -42,8 +39,10 @@ use sqlparser::parser::Parser; use std::borrow::{Borrow, Cow}; use std::cmp::{min, Ordering}; use std::collections::HashSet; +use std::num::NonZero; use std::ops::Range; use std::sync::Arc; +use std::thread::available_parallelism; /// Applies an optional projection to a [`SchemaRef`], returning the /// projected schema @@ -93,20 +92,6 @@ pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result, -) -> Result { - let new_columns = take_arrays(record_batch.columns(), indices)?; - RecordBatch::try_new_with_options( - record_batch.schema(), - new_columns, - &RecordBatchOptions::new().with_row_count(Some(indices.len())), - ) - .map_err(|e| arrow_datafusion_err!(e)) -} - /// This function compares two tuples depending on the given sort options. pub fn compare_rows( x: &[ScalarValue], @@ -290,24 +275,6 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } -/// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. -/// -/// TODO: use implementation in arrow-rs when available: -/// -pub fn take_arrays(arrays: &[ArrayRef], indices: &dyn Array) -> Result> { - arrays - .iter() - .map(|array| { - compute::take( - array.as_ref(), - indices, - None, // None: no index check - ) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect() -} - pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -354,52 +321,201 @@ pub fn longest_consecutive_prefix>( count } -/// Array Utils +/// Creates single element [`ListArray`], [`LargeListArray`] and +/// [`FixedSizeListArray`] from other arrays +/// +/// For example this builder can convert `[1, 2, 3]` into `[[1, 2, 3]]` +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{Array, ListArray}; +/// # use arrow_array::types::Int64Type; +/// # use datafusion_common::utils::SingleRowListArrayBuilder; +/// // Array is [1, 2, 3] +/// let arr = ListArray::from_iter_primitive::(vec![ +/// Some(vec![Some(1), Some(2), Some(3)]), +/// ]); +/// // Wrap as a list array: [[1, 2, 3]] +/// let list_arr = SingleRowListArrayBuilder::new(Arc::new(arr)).build_list_array(); +/// assert_eq!(list_arr.len(), 1); +/// ``` +#[derive(Debug, Clone)] +pub struct SingleRowListArrayBuilder { + /// array to be wrapped + arr: ArrayRef, + /// Should the resulting array be nullable? Defaults to `true`. + nullable: bool, + /// Specify the field name for the resulting array. Defaults to value used in + /// [`Field::new_list_field`] + field_name: Option, +} + +impl SingleRowListArrayBuilder { + /// Create a new instance of [`SingleRowListArrayBuilder`] + pub fn new(arr: ArrayRef) -> Self { + Self { + arr, + nullable: true, + field_name: None, + } + } + + /// Set the nullable flag + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } + + /// sets the field name for the resulting array + pub fn with_field_name(mut self, field_name: Option) -> Self { + self.field_name = field_name; + self + } + + /// Copies field name and nullable from the specified field + pub fn with_field(self, field: &Field) -> Self { + self.with_field_name(Some(field.name().to_owned())) + .with_nullable(field.is_nullable()) + } + + /// Build a single element [`ListArray`] + pub fn build_list_array(self) -> ListArray { + let (field, arr) = self.into_field_and_arr(); + let offsets = OffsetBuffer::from_lengths([arr.len()]); + ListArray::new(field, offsets, arr, None) + } + + /// Build a single element [`ListArray`] and wrap as [`ScalarValue::List`] + pub fn build_list_scalar(self) -> ScalarValue { + ScalarValue::List(Arc::new(self.build_list_array())) + } + + /// Build a single element [`LargeListArray`] + pub fn build_large_list_array(self) -> LargeListArray { + let (field, arr) = self.into_field_and_arr(); + let offsets = OffsetBuffer::from_lengths([arr.len()]); + LargeListArray::new(field, offsets, arr, None) + } + + /// Build a single element [`LargeListArray`] and wrap as [`ScalarValue::LargeList`] + pub fn build_large_list_scalar(self) -> ScalarValue { + ScalarValue::LargeList(Arc::new(self.build_large_list_array())) + } + + /// Build a single element [`FixedSizeListArray`] + pub fn build_fixed_size_list_array(self, list_size: usize) -> FixedSizeListArray { + let (field, arr) = self.into_field_and_arr(); + FixedSizeListArray::new(field, list_size as i32, arr, None) + } + + /// Build a single element [`FixedSizeListArray`] and wrap as [`ScalarValue::FixedSizeList`] + pub fn build_fixed_size_list_scalar(self, list_size: usize) -> ScalarValue { + ScalarValue::FixedSizeList(Arc::new(self.build_fixed_size_list_array(list_size))) + } + + /// Helper function: convert this builder into a tuple of field and array + fn into_field_and_arr(self) -> (Arc, ArrayRef) { + let Self { + arr, + nullable, + field_name, + } = self; + let data_type = arr.data_type().to_owned(); + let field = match field_name { + Some(name) => Field::new(name, data_type, nullable), + None => Field::new_list_field(data_type, nullable), + }; + (Arc::new(field), arr) + } +} /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` /// The field in the list array is nullable. +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { - array_into_list_array(arr, true) + SingleRowListArrayBuilder::new(arr) + .with_nullable(true) + .build_list_array() } -/// Array Utils - /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { - let offsets = OffsetBuffer::from_lengths([arr.len()]); - ListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), - offsets, - arr, - None, - ) + SingleRowListArrayBuilder::new(arr) + .with_nullable(nullable) + .build_list_array() +} + +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] +pub fn array_into_list_array_with_field_name( + arr: ArrayRef, + nullable: bool, + field_name: &str, +) -> ListArray { + SingleRowListArrayBuilder::new(arr) + .with_nullable(nullable) + .with_field_name(Some(field_name.to_string())) + .build_list_array() } /// Wrap an array into a single element `LargeListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { - let offsets = OffsetBuffer::from_lengths([arr.len()]); - LargeListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), - offsets, - arr, - None, - ) + SingleRowListArrayBuilder::new(arr).build_large_list_array() } +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] +pub fn array_into_large_list_array_with_field_name( + arr: ArrayRef, + field_name: &str, +) -> LargeListArray { + SingleRowListArrayBuilder::new(arr) + .with_field_name(Some(field_name.to_string())) + .build_large_list_array() +} + +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] pub fn array_into_fixed_size_list_array( arr: ArrayRef, list_size: usize, ) -> FixedSizeListArray { - let list_size = list_size as i32; - FixedSizeListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), - list_size, - arr, - None, - ) + SingleRowListArrayBuilder::new(arr).build_fixed_size_list_array(list_size) +} + +#[deprecated( + since = "44.0.0", + note = "please use `SingleRowListArrayBuilder` instead" +)] +pub fn array_into_fixed_size_list_array_with_field_name( + arr: ArrayRef, + list_size: usize, + field_name: &str, +) -> FixedSizeListArray { + SingleRowListArrayBuilder::new(arr) + .with_field_name(Some(field_name.to_string())) + .build_fixed_size_list_array(list_size) } /// Wrap arrays into a single element `ListArray`. @@ -563,7 +679,7 @@ pub mod datafusion_strsim { struct StringWrapper<'a>(&'a str); - impl<'a, 'b> IntoIterator for &'a StringWrapper<'b> { + impl<'b> IntoIterator for &StringWrapper<'b> { type Item = char; type IntoIter = Chars<'b>; @@ -759,12 +875,22 @@ pub fn combine_limit( (combined_skip, combined_fetch) } +/// Returns the estimated number of threads available for parallel execution. +/// +/// This is a wrapper around `std::thread::available_parallelism`, providing a default value +/// of `1` if the system's parallelism cannot be determined. +pub fn get_available_parallelism() -> usize { + available_parallelism() + .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) + .get() +} + #[cfg(test)] mod tests { + use super::*; use crate::ScalarValue::Null; use arrow::array::Float64Array; - - use super::*; + use sqlparser::tokenizer::Span; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { @@ -992,6 +1118,7 @@ mod tests { let expected_parsed = vec![Ident { value: identifier.to_string(), quote_style, + span: Span::empty(), }]; assert_eq!( @@ -1003,40 +1130,6 @@ mod tests { Ok(()) } - #[test] - fn test_take_arrays() -> Result<()> { - let arrays: Vec = vec![ - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), - Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), - Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), - ]; - - let row_indices_vec: Vec> = vec![ - // Get rows 0 and 1 - vec![0, 1], - // Get rows 0 and 1 - vec![0, 2], - // Get rows 1 and 3 - vec![1, 3], - // Get rows 2 and 4 - vec![2, 4], - ]; - for row_indices in row_indices_vec { - let indices: PrimitiveArray = - PrimitiveArray::from_iter_values(row_indices.iter().cloned()); - let chunk = take_arrays(&arrays, &indices)?; - for (arr_orig, arr_chunk) in arrays.iter().zip(&chunk) { - for (idx, orig_idx) in row_indices.iter().enumerate() { - let res1 = ScalarValue::try_from_array(arr_orig, *orig_idx as usize)?; - let res2 = ScalarValue::try_from_array(arr_chunk, idx)?; - assert_eq!(res1, res2); - } - } - } - Ok(()) - } - #[test] fn test_get_at_indices() -> Result<()> { let in_vec = vec![1, 2, 3, 4, 5, 6, 7]; diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index d68b5e354384a..d940677a5fb3b 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -17,7 +17,11 @@ //! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations -use hashbrown::raw::{Bucket, RawTable}; +use hashbrown::{ + hash_table::HashTable, + raw::{Bucket, RawTable}, +}; +use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. pub trait VecAllocExt { @@ -88,12 +92,12 @@ impl VecAllocExt for Vec { type T = T; fn push_accounted(&mut self, x: Self::T, accounting: &mut usize) { - let prev_capacty = self.capacity(); + let prev_capacity = self.capacity(); self.push(x); let new_capacity = self.capacity(); - if new_capacity > prev_capacty { + if new_capacity > prev_capacity { // capacity changed, so we allocated more - let bump_size = (new_capacity - prev_capacty) * std::mem::size_of::(); + let bump_size = (new_capacity - prev_capacity) * size_of::(); // Note multiplication should never overflow because `push` would // have panic'd first, but the checked_add could potentially // overflow since accounting could be tracking additional values, and @@ -102,7 +106,7 @@ impl VecAllocExt for Vec { } } fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } @@ -157,7 +161,7 @@ impl RawTableAllocExt for RawTable { // need to request more memory let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * std::mem::size_of::(); + let bump_size = bump_elements * size_of::(); *accounting = (*accounting).checked_add(bump_size).expect("overflow"); self.reserve(bump_elements, hasher); @@ -172,3 +176,71 @@ impl RawTableAllocExt for RawTable { } } } + +/// Extension trait for hash browns [`HashTable`] to account for allocations. +pub trait HashTableAllocExt { + /// Item type. + type T; + + /// Insert new element into table and increase + /// `accounting` by any newly allocated bytes. + /// + /// Returns the bucket where the element was inserted. + /// Note that allocation counts capacity, not size. + /// + /// # Example: + /// ``` + /// # use datafusion_common::utils::proxy::HashTableAllocExt; + /// # use hashbrown::hash_table::HashTable; + /// let mut table = HashTable::new(); + /// let mut allocated = 0; + /// let hash_fn = |x: &u32| (*x as u64) % 1000; + /// // pretend 0x3117 is the hash value for 1 + /// table.insert_accounted(1, hash_fn, &mut allocated); + /// assert_eq!(allocated, 64); + /// + /// // insert more values + /// for i in 0..100 { table.insert_accounted(i, hash_fn, &mut allocated); } + /// assert_eq!(allocated, 400); + /// ``` + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ); +} + +impl HashTableAllocExt for HashTable +where + T: Eq, +{ + type T = T; + + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ) { + let hash = hasher(&x); + + // NOTE: `find_entry` does NOT grow! + match self.find_entry(hash, |y| y == &x) { + Ok(_occupied) => {} + Err(_absent) => { + if self.len() == self.capacity() { + // need to request more memory + let bump_elements = self.capacity().max(16); + let bump_size = bump_elements * size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + + self.reserve(bump_elements, &hasher); + } + + // still need to insert the element since first try failed + self.entry(hash, |y| y == &x, hasher).insert(x); + } + } + } +} diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 01ba90ee5de87..e341816b2b8a9 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -19,7 +19,7 @@ name = "datafusion" description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" keywords = ["arrow", "query", "sql"] -include = ["benches/*.rs", "src/**/*.rs", "Cargo.toml"] +include = ["benches/*.rs", "src/**/*.rs", "Cargo.toml", "LICENSE.txt", "NOTICE.txt"] readme = "../../README.md" version = { workspace = true } edition = { workspace = true } @@ -27,10 +27,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -# Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with -# "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" -# https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.78" +rust-version = { workspace = true } [lints] workspace = true @@ -59,6 +56,7 @@ default = [ "unicode_expressions", "compression", "parquet", + "recursive_protection", ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) @@ -67,10 +65,15 @@ math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ - "datafusion-physical-expr/regex_expressions", - "datafusion-optimizer/regex_expressions", "datafusion-functions/regex_expressions", ] +recursive_protection = [ + "datafusion-common/recursive_protection", + "datafusion-expr/recursive_protection", + "datafusion-optimizer/recursive_protection", + "datafusion-physical-optimizer/recursive_protection", + "datafusion-sql/recursive_protection", +] serde = ["arrow-schema/serde"] string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ @@ -79,8 +82,7 @@ unicode_expressions = [ ] [dependencies] -ahash = { workspace = true } -apache-avro = { version = "0.16", optional = true } +apache-avro = { version = "0.17", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-ipc = { workspace = true } @@ -90,12 +92,11 @@ async-compression = { version = "0.4.0", features = [ "gzip", "xz", "zstd", - "futures-io", "tokio", ], optional = true } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.4.3", optional = true } +bzip2 = { version = "0.5.0", optional = true } chrono = { workspace = true } dashmap = { workspace = true } datafusion-catalog = { workspace = true } @@ -106,6 +107,7 @@ datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true, optional = true } +datafusion-functions-table = { workspace = true } datafusion-functions-window = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } @@ -116,19 +118,14 @@ datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } futures = { workspace = true } glob = "0.3.0" -half = { workspace = true } -hashbrown = { workspace = true } -indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } num-traits = { version = "0.2", optional = true } -num_cpus = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } -paste = "1.0.15" -pin-project-lite = "^0.2.7" rand = { workspace = true } +regex = { workspace = true } sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } @@ -141,27 +138,20 @@ zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] arrow-buffer = { workspace = true } async-trait = { workspace = true } -bigdecimal = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } -csv = "1.1.6" ctor = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } -half = { workspace = true, default-features = true } paste = "^1.0" -postgres-protocol = "0.6.4" -postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.4.3" regex = { workspace = true } rstest = { workspace = true } -rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } serde_json = { workspace = true } test-utils = { path = "../../test-utils" } -thiserror = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } -tokio-postgres = "0.7.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } @@ -170,6 +160,10 @@ nix = { version = "0.29.0", features = ["fs"] } harness = false name = "aggregate_query_sql" +[[bench]] +harness = false +name = "csv_load" + [[bench]] harness = false name = "distinct_query_sql" diff --git a/datafusion/core/LICENSE.txt b/datafusion/core/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/core/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/core/NOTICE.txt b/datafusion/core/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/core/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/core/benches/csv_load.rs b/datafusion/core/benches/csv_load.rs new file mode 100644 index 0000000000000..2d42121ec9b25 --- /dev/null +++ b/datafusion/core/benches/csv_load.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use datafusion::error::Result; +use datafusion::execution::context::SessionContext; +use datafusion::prelude::CsvReadOptions; +use datafusion::test_util::csv::TestCsvFile; +use parking_lot::Mutex; +use std::sync::Arc; +use std::time::Duration; +use test_utils::AccessLogGenerator; +use tokio::runtime::Runtime; + +fn load_csv(ctx: Arc>, path: &str, options: CsvReadOptions) { + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().read_csv(path, options)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context() -> Result>> { + let ctx = SessionContext::new(); + Ok(Arc::new(Mutex::new(ctx))) +} + +fn generate_test_file() -> TestCsvFile { + let write_location = std::env::current_dir() + .unwrap() + .join("benches") + .join("data"); + + // Make sure the write directory exists. + std::fs::create_dir_all(&write_location).unwrap(); + let file_path = write_location.join("logs.csv"); + + let generator = AccessLogGenerator::new().with_include_nulls(true); + let num_batches = 2; + TestCsvFile::try_new(file_path.clone(), generator.take(num_batches as usize)) + .expect("Failed to create test file.") +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context().unwrap(); + let test_file = generate_test_file(); + + let mut group = c.benchmark_group("load csv testing"); + group.measurement_time(Duration::from_secs(20)); + + group.bench_function("default csv read options", |b| { + b.iter(|| { + load_csv( + ctx.clone(), + test_file.path().to_str().unwrap(), + CsvReadOptions::default(), + ) + }) + }); + + group.bench_function("null regex override", |b| { + b.iter(|| { + load_csv( + ctx.clone(), + test_file.path().to_str().unwrap(), + CsvReadOptions::default().null_regex(Some("^NULL$|^$".to_string())), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index bc4298786002e..f82a126c56520 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -249,7 +249,7 @@ fn criterion_benchmark(c: &mut Criterion) { } // Temporary file must outlive the benchmarks, it is deleted when dropped - std::mem::drop(temp_file); + drop(temp_file); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index 3ad71be1f4478..7d87a37b3b9c8 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -36,8 +36,9 @@ use datafusion::physical_plan::{ memory::MemoryExec, }; use datafusion::prelude::SessionContext; +use datafusion_physical_expr_common::sort_expr::LexOrdering; -// Initialise the operator using the provided record batches and the sort key +// Initialize the operator using the provided record batches and the sort key // as inputs. All record batches must have the same schema. fn sort_preserving_merge_operator( session_ctx: Arc, @@ -52,7 +53,7 @@ fn sort_preserving_merge_operator( expr: col(name, &schema).unwrap(), options: Default::default(), }) - .collect::>(); + .collect::(); let exec = MemoryExec::try_new( &batches.into_iter().map(|rb| vec![rb]).collect::>(), diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 99a74b61b3e0a..14e80ce364e33 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -89,6 +89,7 @@ use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; /// Benchmarks for SortPreservingMerge stream use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::StreamExt; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -257,7 +258,7 @@ impl BenchCase { } /// Make sort exprs for each column in `schema` -fn make_sort_exprs(schema: &Schema) -> Vec { +fn make_sort_exprs(schema: &Schema) -> LexOrdering { schema .fields() .iter() diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 00f6d5916751b..44320e7a287a1 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,22 +15,33 @@ // specific language governing permissions and limitations // under the License. +extern crate arrow; #[macro_use] extern crate criterion; -extern crate arrow; extern crate datafusion; mod data_utils; + use crate::criterion::Criterion; use arrow::datatypes::{DataType, Field, Fields, Schema}; +use criterion::Bencher; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; +use datafusion_common::ScalarValue; +use itertools::Itertools; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::sync::Arc; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; use test_utils::TableDef; use tokio::runtime::Runtime; +const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; +const BENCHMARKS_PATH_2: &str = "./benchmarks/"; +const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; + /// Create a logical plan from the specified sql fn logical_plan(ctx: &SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); @@ -60,7 +71,9 @@ fn create_schema(column_prefix: &str, num_columns: usize) -> Schema { fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc { let schema = Arc::new(create_schema(column_prefix, num_columns)); - MemTable::try_new(schema, vec![]).map(Arc::new).unwrap() + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() } fn create_context() -> SessionContext { @@ -89,7 +102,60 @@ fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { ctx } +fn register_clickbench_hits_table() -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + + // use an external table for clickbench benchmarks + let path = + if PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() { + format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}") + } else { + format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}") + }; + + let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + + rt.block_on(ctx.sql(&sql)).unwrap(); + + let count = + rt.block_on(async { ctx.table("hits").await.unwrap().count().await.unwrap() }); + assert!(count > 0); + ctx +} + +/// Target of this benchmark: control that placeholders replacing does not get slower, +/// if the query does not contain placeholders at all. +fn benchmark_with_param_values_many_columns(ctx: &SessionContext, b: &mut Bencher) { + const COLUMNS_NUM: usize = 200; + let mut aggregates = String::new(); + for i in 0..COLUMNS_NUM { + if i > 0 { + aggregates.push_str(", "); + } + aggregates.push_str(format!("MAX(a{})", i).as_str()); + } + // SELECT max(attr0), ..., max(attrN) FROM t1. + let query = format!("SELECT {} FROM t1", aggregates); + let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); + let rt = Runtime::new().unwrap(); + let plan = + rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); + b.iter(|| { + let plan = plan.clone(); + criterion::black_box(plan.with_param_values(vec![ScalarValue::from(1)]).unwrap()); + }); +} + fn criterion_benchmark(c: &mut Criterion) { + // verify that we can load the clickbench data prior to running the benchmark + if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() + && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() + { + panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + } + let ctx = create_context(); // Test simplest @@ -144,6 +210,85 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("physical_select_aggregates_from_200", |b| { + let mut aggregates = String::new(); + for i in 0..200 { + if i > 0 { + aggregates.push_str(", "); + } + aggregates.push_str(format!("MAX(a{})", i).as_str()); + } + let query = format!("SELECT {} FROM t1", aggregates); + b.iter(|| { + physical_plan(&ctx, &query); + }); + }); + + // Benchmark for Physical Planning Joins + c.bench_function("physical_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 = b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_theta_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 < b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_many_self_joins", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT ta.a9, tb.a10, tc.a11, td.a12, te.a13, tf.a14 \ + FROM t1 AS ta, t1 AS tb, t1 AS tc, t1 AS td, t1 AS te, t1 AS tf \ + WHERE ta.a9 = tb.a10 AND tb.a10 = tc.a11 AND tc.a11 = td.a12 AND \ + td.a12 = te.a13 AND te.a13 = tf.a14", + ); + }); + }); + + c.bench_function("physical_unnest_to_join", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 \ + FROM t1 WHERE a7 = (SELECT b8 FROM t2)", + ); + }); + }); + + c.bench_function("physical_intersection", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 FROM t1 \ + INTERSECT SELECT t2.b8 FROM t2", + ); + }); + }); + // these two queries should be equivalent + c.bench_function("physical_join_distinct", |b| { + b.iter(|| { + logical_plan( + &ctx, + "SELECT DISTINCT t1.a7 \ + FROM t1, t2 WHERE t1.a7 = t2.b8", + ); + }); + }); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); @@ -154,9 +299,15 @@ fn criterion_benchmark(c: &mut Criterion) { "q16", "q17", "q18", "q19", "q20", "q21", "q22", ]; + let benchmarks_path = if PathBuf::from(BENCHMARKS_PATH_1).exists() { + BENCHMARKS_PATH_1 + } else { + BENCHMARKS_PATH_2 + }; + for q in tpch_queries { let sql = - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap(); + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { b.iter(|| physical_plan(&tpch_ctx, &sql)) }); @@ -165,7 +316,7 @@ fn criterion_benchmark(c: &mut Criterion) { let all_tpch_sql_queries = tpch_queries .iter() .map(|q| { - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap() + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap() }) .collect::>(); @@ -177,26 +328,25 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpch_all", |b| { - b.iter(|| { - for sql in &all_tpch_sql_queries { - logical_plan(&tpch_ctx, sql) - } - }) - }); + // c.bench_function("logical_plan_tpch_all", |b| { + // b.iter(|| { + // for sql in &all_tpch_sql_queries { + // logical_plan(&tpch_ctx, sql) + // } + // }) + // }); // --- TPC-DS --- let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas()); - - // 10, 35: Physical plan does not support logical expression Exists() - // 45: Physical plan does not support logical expression () - // 41: Optimizing disjunctions not supported - let ignored = [10, 35, 41, 45]; + let tests_path = if PathBuf::from("./tests/").exists() { + "./tests/" + } else { + "datafusion/core/tests/" + }; let raw_tpcds_sql_queries = (1..100) - .filter(|q| !ignored.contains(q)) - .map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap()) + .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); // some queries have multiple statements @@ -213,13 +363,60 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpcds_all", |b| { + // c.bench_function("logical_plan_tpcds_all", |b| { + // b.iter(|| { + // for sql in &all_tpcds_sql_queries { + // logical_plan(&tpcds_ctx, sql) + // } + // }) + // }); + + // -- clickbench -- + + let queries_file = + File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); + let extended_file = + File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); + + let clickbench_queries: Vec = BufReader::new(queries_file) + .lines() + .chain(BufReader::new(extended_file).lines()) + .map(|l| l.expect("Could not parse line")) + .collect_vec(); + + let clickbench_ctx = register_clickbench_hits_table(); + + // for (i, sql) in clickbench_queries.iter().enumerate() { + // c.bench_function(&format!("logical_plan_clickbench_q{}", i + 1), |b| { + // b.iter(|| logical_plan(&clickbench_ctx, sql)) + // }); + // } + + for (i, sql) in clickbench_queries.iter().enumerate() { + c.bench_function(&format!("physical_plan_clickbench_q{}", i + 1), |b| { + b.iter(|| physical_plan(&clickbench_ctx, sql)) + }); + } + + // c.bench_function("logical_plan_clickbench_all", |b| { + // b.iter(|| { + // for sql in &clickbench_queries { + // logical_plan(&clickbench_ctx, sql) + // } + // }) + // }); + + c.bench_function("physical_plan_clickbench_all", |b| { b.iter(|| { - for sql in &all_tpcds_sql_queries { - logical_plan(&tpcds_ctx, sql) + for sql in &clickbench_queries { + physical_plan(&clickbench_ctx, sql) } }) }); + + c.bench_function("with_param_values_many_columns", |b| { + benchmark_with_param_values_many_columns(&ctx, b); + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 53cfe94ecab3e..8b453d5e96987 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::execution::SessionStateDefaults; +use datafusion_common::{not_impl_err, HashSet, Result}; use datafusion_expr::{ aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, DocSection, Documentation, ScalarUDF, WindowUDF, @@ -24,7 +25,12 @@ use itertools::Itertools; use std::env::args; use std::fmt::Write as _; -fn main() { +/// Print documentation for all functions of a given type to stdout +/// +/// Usage: `cargo run --bin print_functions_docs -- ` +/// +/// Called from `dev/update_function_docs.sh` +fn main() -> Result<()> { let args: Vec = args().collect(); if args.len() != 2 { @@ -42,12 +48,13 @@ fn main() { _ => { panic!("Unknown function type: {}", function_type) } - }; + }?; println!("{docs}"); + Ok(()) } -fn print_aggregate_docs() -> String { +fn print_aggregate_docs() -> Result { let mut providers: Vec> = vec![]; for f in SessionStateDefaults::default_aggregate_functions() { @@ -57,7 +64,7 @@ fn print_aggregate_docs() -> String { print_docs(providers, aggregate_doc_sections::doc_sections()) } -fn print_scalar_docs() -> String { +fn print_scalar_docs() -> Result { let mut providers: Vec> = vec![]; for f in SessionStateDefaults::default_scalar_functions() { @@ -67,7 +74,7 @@ fn print_scalar_docs() -> String { print_docs(providers, scalar_doc_sections::doc_sections()) } -fn print_window_docs() -> String { +fn print_window_docs() -> Result { let mut providers: Vec> = vec![]; for f in SessionStateDefaults::default_window_functions() { @@ -77,15 +84,42 @@ fn print_window_docs() -> String { print_docs(providers, window_doc_sections::doc_sections()) } +// Temporary method useful to semi automate +// the migration of UDF documentation generation from code based +// to attribute based +// To be removed +#[allow(dead_code)] +fn save_doc_code_text(documentation: &Documentation, name: &str) { + let attr_text = documentation.to_doc_attribute(); + + let file_path = format!("{}.txt", name); + if std::path::Path::new(&file_path).exists() { + std::fs::remove_file(&file_path).unwrap(); + } + + // Open the file in append mode, create it if it doesn't exist + let mut file = std::fs::OpenOptions::new() + .append(true) // Open in append mode + .create(true) // Create the file if it doesn't exist + .open(file_path) + .unwrap(); + + use std::io::Write; + file.write_all(attr_text.as_bytes()).unwrap(); +} + fn print_docs( providers: Vec>, doc_sections: Vec, -) -> String { +) -> Result { let mut docs = "".to_string(); + // Ensure that all providers have documentation + let mut providers_with_no_docs = HashSet::new(); + // doc sections only includes sections that have 'include' == true for doc_section in doc_sections { - // make sure there is a function that is in this doc section + // make sure there is at least one function that is in this doc section if !&providers.iter().any(|f| { if let Some(documentation) = f.get_documentation() { documentation.doc_section == doc_section @@ -96,19 +130,21 @@ fn print_docs( continue; } + // filter out functions that are not in this doc section let providers: Vec<&Box> = providers .iter() .filter(|&f| { if let Some(documentation) = f.get_documentation() { documentation.doc_section == doc_section } else { + providers_with_no_docs.insert(f.get_name()); false } }) .collect::>(); // write out section header - let _ = writeln!(docs, "## {} ", doc_section.label); + let _ = writeln!(docs, "\n## {} \n", doc_section.label); if let Some(description) = doc_section.description { let _ = writeln!(docs, "{description}"); @@ -146,6 +182,9 @@ fn print_docs( unreachable!() }; + // Temporary for doc gen migration, see `save_doc_code_text` comments + // save_doc_code_text(documentation, &name); + // first, the name, description and syntax example let _ = write!( docs, @@ -182,6 +221,13 @@ fn print_docs( ); } + if let Some(alt_syntax) = &documentation.alternative_syntax { + let _ = writeln!(docs, "#### Alternative Syntax\n"); + for syntax in alt_syntax { + let _ = writeln!(docs, "```sql\n{}\n```", syntax); + } + } + // next, aliases if !f.get_aliases().is_empty() { let _ = writeln!(docs, "#### Aliases"); @@ -202,9 +248,20 @@ fn print_docs( } } - docs + // If there are any functions that do not have documentation, print them out + // eventually make this an error: https://github.com/apache/datafusion/issues/12872 + if !providers_with_no_docs.is_empty() { + eprintln!("INFO: The following functions do not have documentation:"); + for f in &providers_with_no_docs { + eprintln!(" - {f}"); + } + not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + } else { + Ok(docs) + } } +/// Trait for accessing name / aliases / documentation for differnet functions trait DocProvider { fn get_name(&self) -> String; fn get_aliases(&self) -> Vec; diff --git a/datafusion/core/src/catalog_common/information_schema.rs b/datafusion/core/src/catalog_common/information_schema.rs index 180994b1cbe89..ce3092acfdf1f 100644 --- a/datafusion/core/src/catalog_common/information_schema.rs +++ b/datafusion/core/src/catalog_common/information_schema.rs @@ -19,26 +19,29 @@ //! //! [Information Schema]: https://en.wikipedia.org/wiki/Information_schema -use arrow::{ - array::{StringBuilder, UInt64Builder}, - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; -use async_trait::async_trait; -use datafusion_common::DataFusionError; -use std::fmt::Debug; -use std::{any::Any, sync::Arc}; - use crate::catalog::{CatalogProviderList, SchemaProvider, TableProvider}; use crate::datasource::streaming::StreamingTable; use crate::execution::context::TaskContext; -use crate::logical_expr::TableType; +use crate::logical_expr::{TableType, Volatility}; use crate::physical_plan::stream::RecordBatchStreamAdapter; use crate::physical_plan::SendableRecordBatchStream; use crate::{ config::{ConfigEntry, ConfigOptions}, physical_plan::streaming::PartitionStream, }; +use arrow::{ + array::{StringBuilder, UInt64Builder}, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, +}; +use arrow_array::builder::{BooleanBuilder, UInt8Builder}; +use async_trait::async_trait; +use datafusion_common::error::Result; +use datafusion_common::DataFusionError; +use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::{any::Any, sync::Arc}; pub(crate) const INFORMATION_SCHEMA: &str = "information_schema"; pub(crate) const TABLES: &str = "tables"; @@ -46,10 +49,19 @@ pub(crate) const VIEWS: &str = "views"; pub(crate) const COLUMNS: &str = "columns"; pub(crate) const DF_SETTINGS: &str = "df_settings"; pub(crate) const SCHEMATA: &str = "schemata"; +pub(crate) const ROUTINES: &str = "routines"; +pub(crate) const PARAMETERS: &str = "parameters"; /// All information schema tables -pub const INFORMATION_SCHEMA_TABLES: &[&str] = - &[TABLES, VIEWS, COLUMNS, DF_SETTINGS, SCHEMATA]; +pub const INFORMATION_SCHEMA_TABLES: &[&str] = &[ + TABLES, + VIEWS, + COLUMNS, + DF_SETTINGS, + SCHEMATA, + ROUTINES, + PARAMETERS, +]; /// Implements the `information_schema` virtual schema and tables /// @@ -208,6 +220,256 @@ impl InformationSchemaConfig { builder.add_setting(entry); } } + + fn make_routines( + &self, + udfs: &HashMap>, + udafs: &HashMap>, + udwfs: &HashMap>, + config_options: &ConfigOptions, + builder: &mut InformationSchemaRoutinesBuilder, + ) -> Result<()> { + let catalog_name = &config_options.catalog.default_catalog; + let schema_name = &config_options.catalog.default_schema; + + for (name, udf) in udfs { + let return_types = get_udf_args_and_return_types(udf)? + .into_iter() + .map(|(_, return_type)| return_type) + .collect::>(); + for return_type in return_types { + builder.add_routine( + catalog_name, + schema_name, + name, + "FUNCTION", + Self::is_deterministic(udf.signature()), + return_type, + "SCALAR", + udf.documentation().map(|d| d.description.to_string()), + udf.documentation().map(|d| d.syntax_example.to_string()), + ) + } + } + + for (name, udaf) in udafs { + let return_types = get_udaf_args_and_return_types(udaf)? + .into_iter() + .map(|(_, return_type)| return_type) + .collect::>(); + for return_type in return_types { + builder.add_routine( + catalog_name, + schema_name, + name, + "FUNCTION", + Self::is_deterministic(udaf.signature()), + return_type, + "AGGREGATE", + udaf.documentation().map(|d| d.description.to_string()), + udaf.documentation().map(|d| d.syntax_example.to_string()), + ) + } + } + + for (name, udwf) in udwfs { + let return_types = get_udwf_args_and_return_types(udwf)? + .into_iter() + .map(|(_, return_type)| return_type) + .collect::>(); + for return_type in return_types { + builder.add_routine( + catalog_name, + schema_name, + name, + "FUNCTION", + Self::is_deterministic(udwf.signature()), + return_type, + "WINDOW", + udwf.documentation().map(|d| d.description.to_string()), + udwf.documentation().map(|d| d.syntax_example.to_string()), + ) + } + } + Ok(()) + } + + fn is_deterministic(signature: &Signature) -> bool { + signature.volatility == Volatility::Immutable + } + fn make_parameters( + &self, + udfs: &HashMap>, + udafs: &HashMap>, + udwfs: &HashMap>, + config_options: &ConfigOptions, + builder: &mut InformationSchemaParametersBuilder, + ) -> Result<()> { + let catalog_name = &config_options.catalog.default_catalog; + let schema_name = &config_options.catalog.default_schema; + let mut add_parameters = |func_name: &str, + args: Option<&Vec<(String, String)>>, + arg_types: Vec, + return_type: Option, + is_variadic: bool, + rid: u8| { + for (position, type_name) in arg_types.iter().enumerate() { + let param_name = + args.and_then(|a| a.get(position).map(|arg| arg.0.as_str())); + builder.add_parameter( + catalog_name, + schema_name, + func_name, + position as u64 + 1, + "IN", + param_name, + type_name, + None::<&str>, + is_variadic, + rid, + ); + } + if let Some(return_type) = return_type { + builder.add_parameter( + catalog_name, + schema_name, + func_name, + 1, + "OUT", + None::<&str>, + return_type.as_str(), + None::<&str>, + false, + rid, + ); + } + }; + + for (func_name, udf) in udfs { + let args = udf.documentation().and_then(|d| d.arguments.clone()); + let combinations = get_udf_args_and_return_types(udf)?; + for (rid, (arg_types, return_type)) in combinations.into_iter().enumerate() { + add_parameters( + func_name, + args.as_ref(), + arg_types, + return_type, + Self::is_variadic(udf.signature()), + rid as u8, + ); + } + } + + for (func_name, udaf) in udafs { + let args = udaf.documentation().and_then(|d| d.arguments.clone()); + let combinations = get_udaf_args_and_return_types(udaf)?; + for (rid, (arg_types, return_type)) in combinations.into_iter().enumerate() { + add_parameters( + func_name, + args.as_ref(), + arg_types, + return_type, + Self::is_variadic(udaf.signature()), + rid as u8, + ); + } + } + + for (func_name, udwf) in udwfs { + let args = udwf.documentation().and_then(|d| d.arguments.clone()); + let combinations = get_udwf_args_and_return_types(udwf)?; + for (rid, (arg_types, return_type)) in combinations.into_iter().enumerate() { + add_parameters( + func_name, + args.as_ref(), + arg_types, + return_type, + Self::is_variadic(udwf.signature()), + rid as u8, + ); + } + } + + Ok(()) + } + + fn is_variadic(signature: &Signature) -> bool { + matches!( + signature.type_signature, + TypeSignature::Variadic(_) | TypeSignature::VariadicAny + ) + } +} + +/// get the arguments and return types of a UDF +/// returns a tuple of (arg_types, return_type) +fn get_udf_args_and_return_types( + udf: &Arc, +) -> Result, Option)>> { + let signature = udf.signature(); + let arg_types = signature.type_signature.get_possible_types(); + if arg_types.is_empty() { + Ok(vec![(vec![], None)]) + } else { + Ok(arg_types + .into_iter() + .map(|arg_types| { + // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let return_type = udf.return_type(&arg_types).ok().map(|t| t.to_string()); + let arg_types = arg_types + .into_iter() + .map(|t| t.to_string()) + .collect::>(); + (arg_types, return_type) + }) + .collect::>()) + } +} + +fn get_udaf_args_and_return_types( + udaf: &Arc, +) -> Result, Option)>> { + let signature = udaf.signature(); + let arg_types = signature.type_signature.get_possible_types(); + if arg_types.is_empty() { + Ok(vec![(vec![], None)]) + } else { + Ok(arg_types + .into_iter() + .map(|arg_types| { + // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let return_type = + udaf.return_type(&arg_types).ok().map(|t| t.to_string()); + let arg_types = arg_types + .into_iter() + .map(|t| t.to_string()) + .collect::>(); + (arg_types, return_type) + }) + .collect::>()) + } +} + +fn get_udwf_args_and_return_types( + udwf: &Arc, +) -> Result, Option)>> { + let signature = udwf.signature(); + let arg_types = signature.type_signature.get_possible_types(); + if arg_types.is_empty() { + Ok(vec![(vec![], None)]) + } else { + Ok(arg_types + .into_iter() + .map(|arg_types| { + // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_types = arg_types + .into_iter() + .map(|t| t.to_string()) + .collect::>(); + (arg_types, None) + }) + .collect::>()) + } } #[async_trait] @@ -234,11 +496,13 @@ impl SchemaProvider for InformationSchemaProvider { VIEWS => Arc::new(InformationSchemaViews::new(config)), DF_SETTINGS => Arc::new(InformationSchemaDfSettings::new(config)), SCHEMATA => Arc::new(InformationSchemata::new(config)), + ROUTINES => Arc::new(InformationSchemaRoutines::new(config)), + PARAMETERS => Arc::new(InformationSchemaParameters::new(config)), _ => return Ok(None), }; Ok(Some(Arc::new( - StreamingTable::try_new(table.schema().clone(), vec![table]).unwrap(), + StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) } @@ -271,7 +535,7 @@ impl InformationSchemaTables { schema_names: StringBuilder::new(), table_names: StringBuilder::new(), table_types: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -285,7 +549,7 @@ impl PartitionStream for InformationSchemaTables { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_tables(&mut builder).await?; @@ -327,7 +591,7 @@ impl InformationSchemaTablesBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_names.finish()), Arc::new(self.schema_names.finish()), @@ -363,7 +627,7 @@ impl InformationSchemaViews { schema_names: StringBuilder::new(), table_names: StringBuilder::new(), definitions: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -377,7 +641,7 @@ impl PartitionStream for InformationSchemaViews { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_views(&mut builder).await?; @@ -415,7 +679,7 @@ impl InformationSchemaViewBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_names.finish()), Arc::new(self.schema_names.finish()), @@ -478,7 +742,7 @@ impl InformationSchemaColumns { numeric_scales: UInt64Builder::with_capacity(default_capacity), datetime_precisions: UInt64Builder::with_capacity(default_capacity), interval_types: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -492,7 +756,7 @@ impl PartitionStream for InformationSchemaColumns { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_columns(&mut builder).await?; @@ -621,7 +885,7 @@ impl InformationSchemaColumnsBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_names.finish()), Arc::new(self.schema_names.finish()), @@ -666,7 +930,7 @@ impl InformationSchemata { fn builder(&self) -> InformationSchemataBuilder { InformationSchemataBuilder { - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), catalog_name: StringBuilder::new(), schema_name: StringBuilder::new(), schema_owner: StringBuilder::new(), @@ -712,7 +976,7 @@ impl InformationSchemataBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.catalog_name.finish()), Arc::new(self.schema_name.finish()), @@ -736,7 +1000,7 @@ impl PartitionStream for InformationSchemata { let mut builder = self.builder(); let config = self.config.clone(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { config.make_schemata(&mut builder).await; @@ -768,7 +1032,7 @@ impl InformationSchemaDfSettings { names: StringBuilder::new(), values: StringBuilder::new(), descriptions: StringBuilder::new(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), } } } @@ -782,7 +1046,7 @@ impl PartitionStream for InformationSchemaDfSettings { let config = self.config.clone(); let mut builder = self.builder(); Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { // create a mem table with the names of tables @@ -809,7 +1073,7 @@ impl InformationSchemaDfSettingsBuilder { fn finish(&mut self) -> RecordBatch { RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(self.names.finish()), Arc::new(self.values.finish()), @@ -819,3 +1083,270 @@ impl InformationSchemaDfSettingsBuilder { .unwrap() } } + +#[derive(Debug)] +struct InformationSchemaRoutines { + schema: SchemaRef, + config: InformationSchemaConfig, +} + +impl InformationSchemaRoutines { + fn new(config: InformationSchemaConfig) -> Self { + let schema = Arc::new(Schema::new(vec![ + Field::new("specific_catalog", DataType::Utf8, false), + Field::new("specific_schema", DataType::Utf8, false), + Field::new("specific_name", DataType::Utf8, false), + Field::new("routine_catalog", DataType::Utf8, false), + Field::new("routine_schema", DataType::Utf8, false), + Field::new("routine_name", DataType::Utf8, false), + Field::new("routine_type", DataType::Utf8, false), + Field::new("is_deterministic", DataType::Boolean, true), + Field::new("data_type", DataType::Utf8, true), + Field::new("function_type", DataType::Utf8, true), + Field::new("description", DataType::Utf8, true), + Field::new("syntax_example", DataType::Utf8, true), + ])); + + Self { schema, config } + } + + fn builder(&self) -> InformationSchemaRoutinesBuilder { + InformationSchemaRoutinesBuilder { + schema: Arc::clone(&self.schema), + specific_catalog: StringBuilder::new(), + specific_schema: StringBuilder::new(), + specific_name: StringBuilder::new(), + routine_catalog: StringBuilder::new(), + routine_schema: StringBuilder::new(), + routine_name: StringBuilder::new(), + routine_type: StringBuilder::new(), + is_deterministic: BooleanBuilder::new(), + data_type: StringBuilder::new(), + function_type: StringBuilder::new(), + description: StringBuilder::new(), + syntax_example: StringBuilder::new(), + } + } +} + +struct InformationSchemaRoutinesBuilder { + schema: SchemaRef, + specific_catalog: StringBuilder, + specific_schema: StringBuilder, + specific_name: StringBuilder, + routine_catalog: StringBuilder, + routine_schema: StringBuilder, + routine_name: StringBuilder, + routine_type: StringBuilder, + is_deterministic: BooleanBuilder, + data_type: StringBuilder, + function_type: StringBuilder, + description: StringBuilder, + syntax_example: StringBuilder, +} + +impl InformationSchemaRoutinesBuilder { + #[allow(clippy::too_many_arguments)] + fn add_routine( + &mut self, + catalog_name: impl AsRef, + schema_name: impl AsRef, + routine_name: impl AsRef, + routine_type: impl AsRef, + is_deterministic: bool, + data_type: Option>, + function_type: impl AsRef, + description: Option>, + syntax_example: Option>, + ) { + self.specific_catalog.append_value(catalog_name.as_ref()); + self.specific_schema.append_value(schema_name.as_ref()); + self.specific_name.append_value(routine_name.as_ref()); + self.routine_catalog.append_value(catalog_name.as_ref()); + self.routine_schema.append_value(schema_name.as_ref()); + self.routine_name.append_value(routine_name.as_ref()); + self.routine_type.append_value(routine_type.as_ref()); + self.is_deterministic.append_value(is_deterministic); + self.data_type.append_option(data_type.as_ref()); + self.function_type.append_value(function_type.as_ref()); + self.description.append_option(description); + self.syntax_example.append_option(syntax_example); + } + + fn finish(&mut self) -> RecordBatch { + RecordBatch::try_new( + Arc::clone(&self.schema), + vec![ + Arc::new(self.specific_catalog.finish()), + Arc::new(self.specific_schema.finish()), + Arc::new(self.specific_name.finish()), + Arc::new(self.routine_catalog.finish()), + Arc::new(self.routine_schema.finish()), + Arc::new(self.routine_name.finish()), + Arc::new(self.routine_type.finish()), + Arc::new(self.is_deterministic.finish()), + Arc::new(self.data_type.finish()), + Arc::new(self.function_type.finish()), + Arc::new(self.description.finish()), + Arc::new(self.syntax_example.finish()), + ], + ) + .unwrap() + } +} + +impl PartitionStream for InformationSchemaRoutines { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, ctx: Arc) -> SendableRecordBatchStream { + let config = self.config.clone(); + let mut builder = self.builder(); + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + futures::stream::once(async move { + config.make_routines( + ctx.scalar_functions(), + ctx.aggregate_functions(), + ctx.window_functions(), + ctx.session_config().options(), + &mut builder, + )?; + Ok(builder.finish()) + }), + )) + } +} + +#[derive(Debug)] +struct InformationSchemaParameters { + schema: SchemaRef, + config: InformationSchemaConfig, +} + +impl InformationSchemaParameters { + fn new(config: InformationSchemaConfig) -> Self { + let schema = Arc::new(Schema::new(vec![ + Field::new("specific_catalog", DataType::Utf8, false), + Field::new("specific_schema", DataType::Utf8, false), + Field::new("specific_name", DataType::Utf8, false), + Field::new("ordinal_position", DataType::UInt64, false), + Field::new("parameter_mode", DataType::Utf8, false), + Field::new("parameter_name", DataType::Utf8, true), + Field::new("data_type", DataType::Utf8, false), + Field::new("parameter_default", DataType::Utf8, true), + Field::new("is_variadic", DataType::Boolean, false), + // `rid` (short for `routine id`) is used to differentiate parameters from different signatures + // (It serves as the group-by key when generating the `SHOW FUNCTIONS` query). + // For example, the following signatures have different `rid` values: + // - `datetrunc(Utf8, Timestamp(Microsecond, Some("+TZ"))) -> Timestamp(Microsecond, Some("+TZ"))` + // - `datetrunc(Utf8View, Timestamp(Nanosecond, None)) -> Timestamp(Nanosecond, None)` + Field::new("rid", DataType::UInt8, false), + ])); + + Self { schema, config } + } + + fn builder(&self) -> InformationSchemaParametersBuilder { + InformationSchemaParametersBuilder { + schema: Arc::clone(&self.schema), + specific_catalog: StringBuilder::new(), + specific_schema: StringBuilder::new(), + specific_name: StringBuilder::new(), + ordinal_position: UInt64Builder::new(), + parameter_mode: StringBuilder::new(), + parameter_name: StringBuilder::new(), + data_type: StringBuilder::new(), + parameter_default: StringBuilder::new(), + is_variadic: BooleanBuilder::new(), + rid: UInt8Builder::new(), + } + } +} + +struct InformationSchemaParametersBuilder { + schema: SchemaRef, + specific_catalog: StringBuilder, + specific_schema: StringBuilder, + specific_name: StringBuilder, + ordinal_position: UInt64Builder, + parameter_mode: StringBuilder, + parameter_name: StringBuilder, + data_type: StringBuilder, + parameter_default: StringBuilder, + is_variadic: BooleanBuilder, + rid: UInt8Builder, +} + +impl InformationSchemaParametersBuilder { + #[allow(clippy::too_many_arguments)] + fn add_parameter( + &mut self, + specific_catalog: impl AsRef, + specific_schema: impl AsRef, + specific_name: impl AsRef, + ordinal_position: u64, + parameter_mode: impl AsRef, + parameter_name: Option>, + data_type: impl AsRef, + parameter_default: Option>, + is_variadic: bool, + rid: u8, + ) { + self.specific_catalog + .append_value(specific_catalog.as_ref()); + self.specific_schema.append_value(specific_schema.as_ref()); + self.specific_name.append_value(specific_name.as_ref()); + self.ordinal_position.append_value(ordinal_position); + self.parameter_mode.append_value(parameter_mode.as_ref()); + self.parameter_name.append_option(parameter_name.as_ref()); + self.data_type.append_value(data_type.as_ref()); + self.parameter_default.append_option(parameter_default); + self.is_variadic.append_value(is_variadic); + self.rid.append_value(rid); + } + + fn finish(&mut self) -> RecordBatch { + RecordBatch::try_new( + Arc::clone(&self.schema), + vec![ + Arc::new(self.specific_catalog.finish()), + Arc::new(self.specific_schema.finish()), + Arc::new(self.specific_name.finish()), + Arc::new(self.ordinal_position.finish()), + Arc::new(self.parameter_mode.finish()), + Arc::new(self.parameter_name.finish()), + Arc::new(self.data_type.finish()), + Arc::new(self.parameter_default.finish()), + Arc::new(self.is_variadic.finish()), + Arc::new(self.rid.finish()), + ], + ) + .unwrap() + } +} + +impl PartitionStream for InformationSchemaParameters { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, ctx: Arc) -> SendableRecordBatchStream { + let config = self.config.clone(); + let mut builder = self.builder(); + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + futures::stream::once(async move { + config.make_parameters( + ctx.scalar_functions(), + ctx.aggregate_functions(), + ctx.window_functions(), + ctx.session_config().options(), + &mut builder, + )?; + Ok(builder.finish()) + }), + )) + } +} diff --git a/datafusion/core/src/catalog_common/listing_schema.rs b/datafusion/core/src/catalog_common/listing_schema.rs index e45c8a8d4aeb0..dc55a07ef82d4 100644 --- a/datafusion/core/src/catalog_common/listing_schema.rs +++ b/datafusion/core/src/catalog_common/listing_schema.rs @@ -18,14 +18,16 @@ //! [`ListingSchemaProvider`]: [`SchemaProvider`] that scans ObjectStores for tables automatically use std::any::Any; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::path::Path; use std::sync::{Arc, Mutex}; use crate::catalog::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::execution::context::SessionState; -use datafusion_common::{Constraints, DFSchema, DataFusionError, TableReference}; +use datafusion_common::{ + Constraints, DFSchema, DataFusionError, HashMap, TableReference, +}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -136,6 +138,7 @@ impl ListingSchemaProvider { file_type: self.format.clone(), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, @@ -145,7 +148,8 @@ impl ListingSchemaProvider { }, ) .await?; - let _ = self.register_table(table_name.to_string(), provider.clone())?; + let _ = + self.register_table(table_name.to_string(), Arc::clone(&provider))?; } } Ok(()) @@ -187,7 +191,7 @@ impl SchemaProvider for ListingSchemaProvider { self.tables .lock() .expect("Can't lock tables") - .insert(name, table.clone()); + .insert(name, Arc::clone(&table)); Ok(Some(table)) } diff --git a/datafusion/core/src/catalog_common/memory.rs b/datafusion/core/src/catalog_common/memory.rs index f25146616891f..6cdefc31f18c3 100644 --- a/datafusion/core/src/catalog_common/memory.rs +++ b/datafusion/core/src/catalog_common/memory.rs @@ -67,7 +67,7 @@ impl CatalogProviderList for MemoryCatalogProviderList { } fn catalog(&self, name: &str) -> Option> { - self.catalogs.get(name).map(|c| c.value().clone()) + self.catalogs.get(name).map(|c| Arc::clone(c.value())) } } @@ -102,7 +102,7 @@ impl CatalogProvider for MemoryCatalogProvider { } fn schema(&self, name: &str) -> Option> { - self.schemas.get(name).map(|s| s.value().clone()) + self.schemas.get(name).map(|s| Arc::clone(s.value())) } fn register_schema( @@ -175,7 +175,7 @@ impl SchemaProvider for MemorySchemaProvider { &self, name: &str, ) -> datafusion_common::Result>, DataFusionError> { - Ok(self.tables.get(name).map(|table| table.value().clone())) + Ok(self.tables.get(name).map(|table| Arc::clone(table.value()))) } fn register_table( diff --git a/datafusion/core/src/catalog_common/mod.rs b/datafusion/core/src/catalog_common/mod.rs index 85207845a0054..68c78dda48999 100644 --- a/datafusion/core/src/catalog_common/mod.rs +++ b/datafusion/core/src/catalog_common/mod.rs @@ -36,10 +36,6 @@ pub use datafusion_sql::{ResolvedTableReference, TableReference}; use std::collections::BTreeSet; use std::ops::ControlFlow; -/// See [`CatalogProviderList`] -#[deprecated(since = "35.0.0", note = "use [`CatalogProviderList`] instead")] -pub trait CatalogList: CatalogProviderList {} - /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. /// This can be used to determine which tables need to be in the catalog for a query to be planned. /// diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 72cffd2e7c39f..0a2a5368eb647 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -20,11 +20,6 @@ #[cfg(feature = "parquet")] mod parquet; -use std::any::Any; -use std::borrow::Cow; -use std::collections::HashMap; -use std::sync::Arc; - use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::pretty; use crate::datasource::file_format::csv::CsvFormatFactory; @@ -43,6 +38,10 @@ use crate::physical_plan::{ ExecutionPlan, SendableRecordBatchStream, }; use crate::prelude::SessionContext; +use std::any::Any; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; @@ -50,7 +49,8 @@ use arrow::datatypes::{DataType, Field}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, + exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, ParamValues, + SchemaError, UnnestOptions, }; use datafusion_expr::dml::InsertOp; use datafusion_expr::{case, is_null, lit, SortExpr}; @@ -76,6 +76,9 @@ pub struct DataFrameWriteOptions { /// Sets which columns should be used for hive-style partitioned writes by name. /// Can be set to empty vec![] for non-partitioned writes. partition_by: Vec, + /// Sets which columns should be used for sorting the output by name. + /// Can be set to empty vec![] for non-sorted writes. + sort_by: Vec, } impl DataFrameWriteOptions { @@ -85,6 +88,7 @@ impl DataFrameWriteOptions { insert_op: InsertOp::Append, single_file_output: false, partition_by: vec![], + sort_by: vec![], } } @@ -105,6 +109,12 @@ impl DataFrameWriteOptions { self.partition_by = partition_by; self } + + /// Sets the sort_by columns for output sorting + pub fn with_sort_by(mut self, sort_by: Vec) -> Self { + self.sort_by = sort_by; + self + } } impl Default for DataFrameWriteOptions { @@ -373,32 +383,9 @@ impl DataFrame { self.select(expr) } - /// Expand each list element of a column to multiple rows. - #[deprecated(since = "37.0.0", note = "use unnest_columns instead")] - pub fn unnest_column(self, column: &str) -> Result { - self.unnest_columns(&[column]) - } - - /// Expand each list element of a column to multiple rows, with - /// behavior controlled by [`UnnestOptions`]. - /// - /// Please see the documentation on [`UnnestOptions`] for more - /// details about the meaning of unnest. - #[deprecated(since = "37.0.0", note = "use unnest_columns_with_options instead")] - pub fn unnest_column_with_options( - self, - column: &str, - options: UnnestOptions, - ) -> Result { - self.unnest_columns_with_options(&[column], options) - } - /// Expand multiple list/struct columns into a set of rows and new columns. /// - /// See also: - /// - /// 1. [`UnnestOptions`] documentation for the behavior of `unnest` - /// 2. [`Self::unnest_column_with_options`] + /// See also: [`UnnestOptions`] documentation for the behavior of `unnest` /// /// # Example /// ``` @@ -892,16 +879,16 @@ impl DataFrame { for result in describe_record_batch.iter() { let array_ref = match result { Ok(df) => { - let batchs = df.clone().collect().await; - match batchs { - Ok(batchs) - if batchs.len() == 1 - && batchs[0] + let batches = df.clone().collect().await; + match batches { + Ok(batches) + if batches.len() == 1 + && batches[0] .column_by_name(field.name()) .is_some() => { let column = - batchs[0].column_by_name(field.name()).unwrap(); + batches[0].column_by_name(field.name()).unwrap(); if column.data_type().is_null() { Arc::new(StringArray::from(vec!["null"])) @@ -924,9 +911,7 @@ impl DataFrame { { Arc::new(StringArray::from(vec!["null"])) } - Err(other_err) => { - panic!("{other_err}") - } + Err(e) => return exec_err!("{}", e), }; array_datas.push(array_ref); } @@ -1541,8 +1526,17 @@ impl DataFrame { write_options: DataFrameWriteOptions, ) -> Result, DataFusionError> { let arrow_schema = Schema::from(self.schema()); + + let plan = if write_options.sort_by.is_empty() { + self.plan + } else { + LogicalPlanBuilder::from(self.plan) + .sort(write_options.sort_by)? + .build()? + }; + let plan = LogicalPlanBuilder::insert_into( - self.plan, + plan, table_name.to_owned(), &arrow_schema, write_options.insert_op, @@ -1587,10 +1581,10 @@ impl DataFrame { writer_options: Option, ) -> Result, DataFusionError> { if options.insert_op != InsertOp::Append { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "{} is not implemented for DataFrame::write_csv.", options.insert_op - ))); + ); } let format = if let Some(csv_opts) = writer_options { @@ -1601,8 +1595,16 @@ impl DataFrame { let file_type = format_as_file_type(format); + let plan = if options.sort_by.is_empty() { + self.plan + } else { + LogicalPlanBuilder::from(self.plan) + .sort(options.sort_by)? + .build()? + }; + let plan = LogicalPlanBuilder::copy_to( - self.plan, + plan, path.into(), file_type, HashMap::new(), @@ -1648,10 +1650,10 @@ impl DataFrame { writer_options: Option, ) -> Result, DataFusionError> { if options.insert_op != InsertOp::Append { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "{} is not implemented for DataFrame::write_json.", options.insert_op - ))); + ); } let format = if let Some(json_opts) = writer_options { @@ -1662,8 +1664,16 @@ impl DataFrame { let file_type = format_as_file_type(format); + let plan = if options.sort_by.is_empty() { + self.plan + } else { + LogicalPlanBuilder::from(self.plan) + .sort(options.sort_by)? + .build()? + }; + let plan = LogicalPlanBuilder::copy_to( - self.plan, + plan, path.into(), file_type, Default::default(), @@ -1694,7 +1704,7 @@ impl DataFrame { /// # } /// ``` pub fn with_column(self, name: &str, expr: Expr) -> Result { - let window_func_exprs = find_window_exprs(&[expr.clone()]); + let window_func_exprs = find_window_exprs(std::slice::from_ref(&expr)); let (window_fn_str, plan) = if window_func_exprs.is_empty() { (None, self.plan) @@ -1893,6 +1903,17 @@ impl DataFrame { let mem_table = MemTable::try_new(schema, partitions)?; context.read_table(Arc::new(mem_table)) } + + /// Apply an alias to the DataFrame. + /// + /// This method replaces the qualifiers of output columns with the given alias. + pub fn alias(self, alias: &str) -> Result { + let plan = LogicalPlanBuilder::from(self.plan).alias(alias)?.build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + }) + } } #[derive(Debug)] @@ -1964,20 +1985,22 @@ mod tests { use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use arrow::array::{self, Int32Array}; + use crate::prelude::{CsvReadOptions, NdJsonReadOptions, ParquetReadOptions}; + use arrow::array::Int32Array; use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_functions_window::expr_fn::row_number; + use datafusion_functions_window::nth_value::first_value_udwf; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; use sqlparser::ast::NullTreatment; + use tempfile::TempDir; // Get string representation of the plan async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { @@ -2002,8 +2025,8 @@ mod tests { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), - Arc::new(array::StringArray::from(vec!["a"])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), ], ) .unwrap(); @@ -2199,10 +2222,8 @@ mod tests { async fn select_with_window_exprs() -> Result<()> { // build plan using Table API let t = test_table().await?; - let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue, - ), + let first_row = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF(first_value_udwf()), vec![col("aggregate_test_100.c1")], )) .partition_by(vec![col("aggregate_test_100.c2")]) @@ -2406,6 +2427,30 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_assert_no_empty_batches() -> Result<()> { + // build plan using DataFrame API + let df = test_table().await?; + let group_expr = vec![col("c1")]; + let aggr_expr = vec![ + min(col("c12")), + max(col("c12")), + avg(col("c12")), + sum(col("c12")), + count(col("c12")), + count_distinct(col("c12")), + median(col("c12")), + ]; + + let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; + // Empty batches should not be produced + for batch in df { + assert!(batch.num_rows() > 0); + } + + Ok(()) + } + #[tokio::test] async fn test_aggregate_with_pk() -> Result<()> { // create the dataframe @@ -2623,6 +2668,54 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_union() -> Result<()> { + let df = test_table().await?; + + let df1 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![min(col("c2"))])? + // SELECT `c1` , min(c2) as `result` + .select(vec![col("c1"), min(col("c2")).alias("result")])?; + let df2 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![max(col("c3"))])? + // SELECT `c1` , max(c3) as `result` + .select(vec![col("c1"), max(col("c3")).alias("result")])?; + + let df_union = df1.union(df2)?; + let df = df_union + // GROUP BY `c1` + .aggregate( + vec![col("c1")], + vec![sum(col("result")).alias("sum_result")], + )? + // SELECT `c1`, sum(result) as `sum_result` + .select(vec![(col("c1")), col("sum_result")])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+----+------------+", + "| c1 | sum_result |", + "+----+------------+", + "| a | 84 |", + "| b | 69 |", + "| c | 124 |", + "| d | 126 |", + "| e | 121 |", + "+----+------------+" + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_aggregate_subexpr() -> Result<()> { let df = test_table().await?; @@ -2987,9 +3080,7 @@ mod tests { JoinType::Inner, Some(Expr::from(ScalarValue::Null)), )?; - let expected_plan = "CrossJoin:\ - \n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\ - \n TableScan: b projection=[c1]"; + let expected_plan = "EmptyRelation"; assert_eq!(expected_plan, format!("{}", join.into_optimized_plan()?)); // JOIN ON expression must be boolean type @@ -3235,7 +3326,7 @@ mod tests { &df_results ); - // check that col with the same name ovwewritten + // check that col with the same name overwritten let df_results_overwrite = df .clone() .with_column("c1", col("c2") + col("c3"))? @@ -3258,7 +3349,7 @@ mod tests { &df_results_overwrite ); - // check that col with the same name ovwewritten using same name as reference + // check that col with the same name overwritten using same name as reference let df_results_overwrite_self = df .clone() .with_column("c2", col("c2") + lit(1))? @@ -3547,11 +3638,10 @@ mod tests { #[tokio::test] async fn with_column_renamed_case_sensitive() -> Result<()> { - let config = - SessionConfig::from_string_hash_map(&std::collections::HashMap::from([( - "datafusion.sql_parser.enable_ident_normalization".to_owned(), - "false".to_owned(), - )]))?; + let config = SessionConfig::from_string_hash_map(&HashMap::from([( + "datafusion.sql_parser.enable_ident_normalization".to_owned(), + "false".to_owned(), + )]))?; let ctx = SessionContext::new_with_config(config); let name = "aggregate_test_100"; register_aggregate_csv(&ctx, name).await?; @@ -3623,7 +3713,7 @@ mod tests { #[tokio::test] async fn row_writer_resize_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "column_1", DataType::Utf8, false, @@ -3632,7 +3722,7 @@ mod tests { let data = RecordBatch::try_new( schema, vec![ - Arc::new(arrow::array::StringArray::from(vec![ + Arc::new(StringArray::from(vec![ Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), ])) @@ -3842,6 +3932,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftAnti, JoinType::RightAnti, + JoinType::LeftMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -3859,7 +3950,10 @@ mod tests { let join_schema = physical_plan.schema(); match join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => { let left_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c1", &join_schema)?), Arc::new(Column::new_with_schema("c2", &join_schema)?), @@ -4010,4 +4104,237 @@ mod tests { Ok(()) } + + // Test issue: https://github.com/apache/datafusion/issues/13873 + #[tokio::test] + async fn write_parquet_with_order() -> Result<()> { + let tmp_dir = TempDir::new()?; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let ctx = SessionContext::new(); + let write_df = ctx.read_batch(RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 5, 7, 3, 2])), + Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])), + ], + )?)?; + + let test_path = tmp_dir.path().join("test.parquet"); + + write_df + .clone() + .write_parquet( + test_path.to_str().unwrap(), + DataFrameWriteOptions::new() + .with_sort_by(vec![col("a").sort(true, true)]), + None, + ) + .await?; + + let ctx = SessionContext::new(); + ctx.register_parquet( + "data", + test_path.to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + + let df = ctx.sql("SELECT * FROM data").await?; + let results = df.collect().await?; + + let df_explain = ctx.sql("explain SELECT a FROM data").await?; + let explain_result = df_explain.collect().await?; + + println!("explain_result {:?}", explain_result); + + assert_batches_eq!( + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "| 2 | 6 |", + "| 3 | 5 |", + "| 5 | 3 |", + "| 7 | 4 |", + "+---+---+", + ], + &results + ); + Ok(()) + } + + // Test issue: https://github.com/apache/datafusion/issues/13873 + #[tokio::test] + async fn write_csv_with_order() -> Result<()> { + let tmp_dir = TempDir::new()?; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let ctx = SessionContext::new(); + let write_df = ctx.read_batch(RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 5, 7, 3, 2])), + Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])), + ], + )?)?; + + let test_path = tmp_dir.path().join("test.csv"); + + write_df + .clone() + .write_csv( + test_path.to_str().unwrap(), + DataFrameWriteOptions::new() + .with_sort_by(vec![col("a").sort(true, true)]), + None, + ) + .await?; + + let ctx = SessionContext::new(); + ctx.register_csv( + "data", + test_path.to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + let df = ctx.sql("SELECT * FROM data").await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "| 2 | 6 |", + "| 3 | 5 |", + "| 5 | 3 |", + "| 7 | 4 |", + "+---+---+", + ], + &results + ); + Ok(()) + } + + // Test issue: https://github.com/apache/datafusion/issues/13873 + #[tokio::test] + async fn write_json_with_order() -> Result<()> { + let tmp_dir = TempDir::new()?; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let ctx = SessionContext::new(); + let write_df = ctx.read_batch(RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 5, 7, 3, 2])), + Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])), + ], + )?)?; + + let test_path = tmp_dir.path().join("test.json"); + + write_df + .clone() + .write_json( + test_path.to_str().unwrap(), + DataFrameWriteOptions::new() + .with_sort_by(vec![col("a").sort(true, true)]), + None, + ) + .await?; + + let ctx = SessionContext::new(); + ctx.register_json( + "data", + test_path.to_str().unwrap(), + NdJsonReadOptions::default().schema(&schema), + ) + .await?; + + let df = ctx.sql("SELECT * FROM data").await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "| 2 | 6 |", + "| 3 | 5 |", + "| 5 | 3 |", + "| 7 | 4 |", + "+---+---+", + ], + &results + ); + Ok(()) + } + + // Test issue: https://github.com/apache/datafusion/issues/13873 + #[tokio::test] + async fn write_table_with_order() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); + let location = tmp_dir.path().join("test_table/"); + + let mut write_df = ctx + .sql("values ('z'), ('x'), ('a'), ('b'), ('c')") + .await + .unwrap(); + + // Ensure the column names and types match the target table + write_df = write_df + .with_column_renamed("column1", "tablecol1") + .unwrap(); + let sql_str = + "create external table data(tablecol1 varchar) stored as parquet location '" + .to_owned() + + location.to_str().unwrap() + + "'"; + + ctx.sql(sql_str.as_str()).await?.collect().await?; + + // This is equivalent to INSERT INTO test. + write_df + .clone() + .write_table( + "data", + DataFrameWriteOptions::new() + .with_sort_by(vec![col("tablecol1").sort(true, true)]), + ) + .await?; + + let df = ctx.sql("SELECT * FROM data").await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+-----------+", + "| tablecol1 |", + "+-----------+", + "| a |", + "| b |", + "| c |", + "| x |", + "| z |", + "+-----------+", + ], + &results + ); + Ok(()) + } } diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index f90b35fde6baf..1dd4d68fca6b3 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -26,6 +26,7 @@ use super::{ }; use datafusion_common::config::TableParquetOptions; +use datafusion_common::not_impl_err; use datafusion_expr::dml::InsertOp; impl DataFrame { @@ -59,10 +60,10 @@ impl DataFrame { writer_options: Option, ) -> Result, DataFusionError> { if options.insert_op != InsertOp::Append { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "{} is not implemented for DataFrame::write_parquet.", options.insert_op - ))); + ); } let format = if let Some(parquet_opts) = writer_options { @@ -73,8 +74,16 @@ impl DataFrame { let file_type = format_as_file_type(format); + let plan = if options.sort_by.is_empty() { + self.plan + } else { + LogicalPlanBuilder::from(self.plan) + .sort(options.sort_by)? + .build()? + }; + let plan = LogicalPlanBuilder::copy_to( - self.plan, + plan, path.into(), file_type, Default::default(), diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 98b6702bc3834..8f0e3792ffec7 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -60,7 +60,7 @@ pub struct AvroArrowArrayReader<'a, R: Read> { schema_lookup: BTreeMap, } -impl<'a, R: Read> AvroArrowArrayReader<'a, R> { +impl AvroArrowArrayReader<'_, R> { pub fn try_new( reader: R, schema: SchemaRef, @@ -138,7 +138,11 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { } AvroSchema::Array(schema) => { let sub_parent_field_name = format!("{}.element", parent_field_name); - Self::child_schema_lookup(&sub_parent_field_name, schema, schema_lookup)?; + Self::child_schema_lookup( + &sub_parent_field_name, + &schema.items, + schema_lookup, + )?; } _ => (), } @@ -206,7 +210,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn build_primitive_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef where T: ArrowNumericType + Resolver, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { Arc::new( rows.iter() @@ -354,7 +358,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let builder = builder .as_any_mut() .downcast_mut::>() - .ok_or_else(||ArrowError::SchemaError( + .ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -369,7 +373,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { builder.append(true); } DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::SchemaError( + let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -402,7 +406,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { col_name: &str, ) -> ArrowResult where - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, T: ArrowPrimitiveType + ArrowDictionaryKeyType, { let mut builder: StringDictionaryBuilder = @@ -453,12 +457,10 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt64 => { self.build_dictionary_array::(rows, col_name) } - _ => Err(ArrowError::SchemaError( - "unsupported dictionary key type".to_string(), - )), + _ => Err(SchemaError("unsupported dictionary key type".to_string())), } } else { - Err(ArrowError::SchemaError( + Err(SchemaError( "dictionary types other than UTF-8 not yet supported".to_string(), )) } @@ -532,7 +534,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt32 => self.read_primitive_list_values::(rows), DataType::UInt64 => self.read_primitive_list_values::(rows), DataType::Float16 => { - return Err(ArrowError::SchemaError("Float16 not supported".to_string())) + return Err(SchemaError("Float16 not supported".to_string())) } DataType::Float32 => self.read_primitive_list_values::(rows), DataType::Float64 => self.read_primitive_list_values::(rows), @@ -541,7 +543,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { - return Err(ArrowError::SchemaError( + return Err(SchemaError( "Temporal types are not yet supported, see ARROW-4803".to_string(), )) } @@ -623,7 +625,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .unwrap() } datatype => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "Nested list of {datatype:?} not supported" ))); } @@ -737,7 +739,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time64" ))) } @@ -751,7 +753,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time32" ))) } @@ -854,7 +856,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { make_array(data) } _ => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "type {:?} not supported", field.data_type() ))) @@ -870,7 +872,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn read_primitive_list_values(&self, rows: &[&Value]) -> ArrayData where T: ArrowPrimitiveType + ArrowNumericType, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { let values = rows .iter() @@ -970,7 +972,7 @@ fn resolve_u8(v: &Value) -> AvroResult { other => Err(AvroError::GetU8(other.into())), }?; if let Value::Int(n) = int { - if n >= 0 && n <= std::convert::From::from(u8::MAX) { + if n >= 0 && n <= From::from(u8::MAX) { return Ok(n as u8); } } @@ -1048,7 +1050,7 @@ fn maybe_resolve_union(value: &Value) -> &Value { impl Resolver for N where N: ArrowNumericType, - N::Native: num_traits::cast::NumCast, + N::Native: NumCast, { fn resolve(value: &Value) -> Option { let value = maybe_resolve_union(value); diff --git a/datafusion/core/src/datasource/avro_to_arrow/mod.rs b/datafusion/core/src/datasource/avro_to_arrow/mod.rs index c59078c89dd00..71184a78c96f5 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/mod.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/mod.rs @@ -39,7 +39,7 @@ use std::io::Read; pub fn read_avro_schema_from_reader(reader: &mut R) -> Result { let avro_reader = apache_avro::Reader::new(reader)?; let schema = avro_reader.writer_schema(); - schema::to_arrow_schema(schema) + to_arrow_schema(schema) } #[cfg(not(feature = "avro"))] diff --git a/datafusion/core/src/datasource/avro_to_arrow/reader.rs b/datafusion/core/src/datasource/avro_to_arrow/reader.rs index 5dc53c5c86c87..dbc24da463667 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/reader.rs @@ -128,7 +128,7 @@ pub struct Reader<'a, R: Read> { batch_size: usize, } -impl<'a, R: Read> Reader<'a, R> { +impl Reader<'_, R> { /// Create a new Avro Reader from any value that implements the `Read` trait. /// /// If reading a `File`, you can customise the Reader, such as to enable schema @@ -142,7 +142,7 @@ impl<'a, R: Read> Reader<'a, R> { Ok(Self { array_reader: AvroArrowArrayReader::try_new( reader, - schema.clone(), + Arc::clone(&schema), projection, )?, schema, @@ -153,11 +153,11 @@ impl<'a, R: Read> Reader<'a, R> { /// Returns the schema of the reader, useful for getting the schema without reading /// record batches pub fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } -impl<'a, R: Read> Iterator for Reader<'a, R> { +impl Iterator for Reader<'_, R> { type Item = ArrowResult; /// Returns the next batch of results (defined by `self.batch_size`), or `None` if there diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index 039a6aacc07eb..991f648e58bd2 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -73,11 +73,15 @@ fn schema_to_field_with_props( AvroSchema::Bytes => DataType::Binary, AvroSchema::String => DataType::Utf8, AvroSchema::Array(item_schema) => DataType::List(Arc::new( - schema_to_field_with_props(item_schema, Some("element"), false, None)?, + schema_to_field_with_props(&item_schema.items, Some("element"), false, None)?, )), AvroSchema::Map(value_schema) => { - let value_field = - schema_to_field_with_props(value_schema, Some("value"), false, None)?; + let value_field = schema_to_field_with_props( + &value_schema.types, + Some("value"), + false, + None, + )?; DataType::Dictionary( Box::new(DataType::Utf8), Box::new(value_field.data_type().clone()), @@ -144,14 +148,17 @@ fn schema_to_field_with_props( AvroSchema::Decimal(DecimalSchema { precision, scale, .. }) => DataType::Decimal128(*precision as u8, *scale as i8), + AvroSchema::BigDecimal => DataType::LargeBinary, AvroSchema::Uuid => DataType::FixedSizeBinary(16), AvroSchema::Date => DataType::Date32, AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond), AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond), AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None), AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None), + AvroSchema::TimestampNanos => DataType::Timestamp(TimeUnit::Nanosecond, None), AvroSchema::LocalTimestampMillis => todo!(), AvroSchema::LocalTimestampMicros => todo!(), + AvroSchema::LocalTimestampNanos => todo!(), AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond), }; @@ -371,6 +378,7 @@ mod test { aliases: Some(vec![alias("foofixed"), alias("barfixed")]), size: 1, doc: None, + default: None, attributes: Default::default(), }); let props = external_props(&fixed_schema); diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs index 23f57b12ae08c..b63755f644a84 100644 --- a/datafusion/core/src/datasource/cte_worktable.rs +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -39,8 +39,6 @@ use crate::datasource::{TableProvider, TableType}; #[derive(Debug)] pub struct CteWorkTable { /// The name of the CTE work table - // WIP, see https://github.com/apache/datafusion/issues/462 - #[allow(dead_code)] name: String, /// This schema must be shared across both the static and recursive terms of a recursive query table_schema: SchemaRef, @@ -56,6 +54,16 @@ impl CteWorkTable { table_schema, } } + + /// The user-provided name of the CTE + pub fn name(&self) -> &str { + &self.name + } + + /// The schema of the recursive term of the query + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.table_schema) + } } #[async_trait] @@ -69,7 +77,7 @@ impl TableProvider for CteWorkTable { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } fn table_type(&self) -> TableType { @@ -86,7 +94,7 @@ impl TableProvider for CteWorkTable { // TODO: pushdown filters and limits Ok(Arc::new(WorkTableExec::new( self.name.clone(), - self.table_schema.clone(), + Arc::clone(&self.table_schema), ))) } diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index b4a5a76fc9ff6..91c1e0ac97fcf 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -24,7 +24,7 @@ use crate::datasource::TableProvider; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, Constraints}; -use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource}; +use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType}; /// DataFusion default table source, wrapping TableProvider. /// @@ -61,8 +61,13 @@ impl TableSource for DefaultTableSource { self.table_provider.constraints() } + /// Get the type of this table for metadata/catalog purposes. + fn table_type(&self) -> TableType { + self.table_provider.table_type() + } + /// Tests whether the table provider can make use of any or all filter expressions - /// to optimise data retrieval. + /// to optimize data retrieval. fn supports_filters_pushdown( &self, filter: &[&Expr], @@ -96,7 +101,45 @@ pub fn source_as_provider( .as_any() .downcast_ref::() { - Some(source) => Ok(source.table_provider.clone()), + Some(source) => Ok(Arc::clone(&source.table_provider)), _ => internal_err!("TableSource was not DefaultTableSource"), } } + +#[test] +fn preserves_table_type() { + use async_trait::async_trait; + use datafusion_common::DataFusionError; + + #[derive(Debug)] + struct TestTempTable; + + #[async_trait] + impl TableProvider for TestTempTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_type(&self) -> TableType { + TableType::Temporary + } + + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + async fn scan( + &self, + _: &dyn datafusion_catalog::Session, + _: Option<&Vec>, + _: &[Expr], + _: Option, + ) -> Result, DataFusionError> + { + unimplemented!() + } + } + + let table_source = DefaultTableSource::new(Arc::new(TestTempTable)); + assert_eq!(table_source.table_type(), TableType::Temporary); +} diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index bc5b82bd8c5b5..abda7fa9ec4b6 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -61,7 +61,7 @@ impl TableProvider for EmptyTable { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index c10ebbd6c9eab..1d9827ae0ab57 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -26,12 +26,13 @@ use std::fmt::{self, Debug}; use std::sync::Arc; use super::file_compression_type::FileCompressionType; -use super::write::demux::start_demuxer_task; +use super::write::demux::DemuxedStreamReceiver; use super::write::{create_writer, SharedBuffer}; use super::FileFormatFactory; +use crate::datasource::file_format::write::get_writer_schema; use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::{ - ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, + ArrowExec, FileGroupDisplay, FileScanConfig, FileSink, FileSinkConfig, }; use crate::error::Result; use crate::execution::context::SessionState; @@ -46,11 +47,11 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, }; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; -use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::Bytes; @@ -186,19 +187,13 @@ impl FileFormat for ArrowFormat { return not_impl_err!("Overwrites are not implemented yet for Arrow format"); } - let sink_schema = conf.output_schema().clone(); let sink = Arc::new(ArrowFileSink::new(conf)); - Ok(Arc::new(DataSinkExec::new( - input, - sink, - sink_schema, - order_requirements, - )) as _) + Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } } -/// Implements [`DataSink`] for writing to arrow_ipc files +/// Implements [`FileSink`] for writing to arrow_ipc files struct ArrowFileSink { config: FileSinkConfig, } @@ -207,85 +202,21 @@ impl ArrowFileSink { fn new(config: FileSinkConfig) -> Self { Self { config } } - - /// Converts table schema to writer schema, which may differ in the case - /// of hive style partitioning where some columns are removed from the - /// underlying files. - fn get_writer_schema(&self) -> Arc { - if !self.config.table_partition_cols.is_empty() { - let schema = self.config.output_schema(); - let partition_names: Vec<_> = self - .config - .table_partition_cols - .iter() - .map(|(s, _)| s) - .collect(); - Arc::new(Schema::new( - schema - .fields() - .iter() - .filter(|f| !partition_names.contains(&f.name())) - .map(|f| (**f).clone()) - .collect::>(), - )) - } else { - self.config.output_schema().clone() - } - } -} - -impl Debug for ArrowFileSink { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ArrowFileSink").finish() - } -} - -impl DisplayAs for ArrowFileSink { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "ArrowFileSink(file_groups=",)?; - FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; - write!(f, ")") - } - } - } } #[async_trait] -impl DataSink for ArrowFileSink { - fn as_any(&self) -> &dyn Any { - self +impl FileSink for ArrowFileSink { + fn config(&self) -> &FileSinkConfig { + &self.config } - fn metrics(&self) -> Option { - None - } - - async fn write_all( + async fn spawn_writer_tasks_and_join( &self, - data: SendableRecordBatchStream, - context: &Arc, + _context: &Arc, + demux_task: SpawnedTask>, + mut file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, ) -> Result { - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - - let part_col = if !self.config.table_partition_cols.is_empty() { - Some(self.config.table_partition_cols.clone()) - } else { - None - }; - - let (demux_task, mut file_stream_rx) = start_demuxer_task( - data, - context, - part_col, - self.config.table_paths[0].clone(), - "arrow".into(), - self.config.keep_partition_by_columns, - ); - let mut file_write_tasks: JoinSet> = JoinSet::new(); @@ -296,13 +227,13 @@ impl DataSink for ArrowFileSink { let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( shared_buffer.clone(), - &self.get_writer_schema(), + &get_writer_schema(&self.config), ipc_options.clone(), )?; let mut object_store_writer = create_writer( FileCompressionType::UNCOMPRESSED, &path, - object_store.clone(), + Arc::clone(&object_store), ) .await?; file_write_tasks.spawn(async move { @@ -351,6 +282,43 @@ impl DataSink for ArrowFileSink { } } +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> &SchemaRef { + self.config.output_schema() + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + FileSink::write_all(self, data, context).await + } +} + const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; @@ -476,7 +444,7 @@ mod tests { .infer_schema( &state, &(store.clone() as Arc), - &[object_meta.clone()], + std::slice::from_ref(&object_meta), ) .await?; let actual_fields = inferred_schema @@ -515,7 +483,7 @@ mod tests { .infer_schema( &state, &(store.clone() as Arc), - &[object_meta.clone()], + std::slice::from_ref(&object_meta), ) .await; diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 5190bdbe153a2..f854b9506a647 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -25,8 +25,8 @@ use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; +use datafusion_common::internal_err; use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::DataFusionError; use datafusion_common::GetExt; use datafusion_common::DEFAULT_AVRO_EXTENSION; use datafusion_physical_expr::PhysicalExpr; @@ -105,9 +105,7 @@ impl FileFormat for AvroFormat { let ext = self.get_ext(); match file_compression_type.get_variant() { CompressionTypeVariant::UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "Avro FileFormat does not support compression.".into(), - )), + _ => internal_err!("Avro FileFormat does not support compression."), } } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index f235c3b628a0d..edf757e539a9e 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -22,12 +22,16 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::sync::Arc; -use super::write::orchestration::stateless_multipart_put; -use super::{FileFormat, FileFormatFactory}; +use super::write::orchestration::spawn_writer_tasks_and_join; +use super::{ + Decoder, DecoderDeserializer, FileFormat, FileFormatFactory, + DEFAULT_SCHEMA_INFER_MAX_RECORD, +}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::write::demux::DemuxedStreamReceiver; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ - CsvExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, + CsvExec, FileGroupDisplay, FileScanConfig, FileSink, FileSinkConfig, }; use crate::error::Result; use crate::execution::context::SessionState; @@ -38,17 +42,17 @@ use crate::physical_plan::{ use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; -use arrow::datatypes::SchemaRef; -use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow_schema::ArrowError; use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::{ exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, }; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{Buf, Bytes}; @@ -56,6 +60,7 @@ use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; +use regex::Regex; #[derive(Default)] /// Factory struct used to create [CsvFormatFactory] @@ -78,7 +83,7 @@ impl CsvFormatFactory { } } -impl fmt::Debug for CsvFormatFactory { +impl Debug for CsvFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("CsvFormatFactory") .field("options", &self.options) @@ -137,11 +142,11 @@ impl CsvFormat { /// Return a newline delimited stream from the specified file on /// Stream, decompressing if necessary /// Each returned `Bytes` has a whole number of newline delimited rows - async fn read_to_delimited_chunks( + async fn read_to_delimited_chunks<'a>( &self, store: &Arc, object: &ObjectMeta, - ) -> BoxStream<'static, Result> { + ) -> BoxStream<'a, Result> { // stream to only read as many rows as needed into memory let stream = store .get(&object.location) @@ -165,10 +170,10 @@ impl CsvFormat { stream.boxed() } - async fn read_to_delimited_chunks_from_stream( + async fn read_to_delimited_chunks_from_stream<'a>( &self, - stream: BoxStream<'static, Result>, - ) -> BoxStream<'static, Result> { + stream: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { let file_compression_type: FileCompressionType = self.options.compression.into(); let decoder = file_compression_type.convert_stream(stream); let steam = match decoder { @@ -204,7 +209,7 @@ impl CsvFormat { /// Set a limit in terms of records to scan to infer the schema /// - default to `DEFAULT_SCHEMA_INFER_MAX_RECORD` pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self { - self.options.schema_infer_max_rec = max_rec; + self.options.schema_infer_max_rec = Some(max_rec); self } @@ -215,6 +220,13 @@ impl CsvFormat { self } + /// Set the regex to use for null values in the CSV reader. + /// - default to treat empty values as null. + pub fn with_null_regex(mut self, null_regex: Option) -> Self { + self.options.null_regex = null_regex; + self + } + /// Returns `Some(true)` if the first line is a header, `Some(false)` if /// it is not, and `None` if it is not specified. pub fn has_header(&self) -> Option { @@ -293,6 +305,45 @@ impl CsvFormat { } } +#[derive(Debug)] +pub(crate) struct CsvDecoder { + inner: arrow::csv::reader::Decoder, +} + +impl CsvDecoder { + pub(crate) fn new(decoder: arrow::csv::reader::Decoder) -> Self { + Self { inner: decoder } + } +} + +impl Decoder for CsvDecoder { + fn decode(&mut self, buf: &[u8]) -> Result { + self.inner.decode(buf) + } + + fn flush(&mut self) -> Result, ArrowError> { + self.inner.flush() + } + + fn can_flush_early(&self) -> bool { + self.inner.capacity() == 0 + } +} + +impl Debug for CsvSerializer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CsvSerializer") + .field("header", &self.header) + .finish() + } +} + +impl From for DecoderDeserializer { + fn from(decoder: arrow::csv::reader::Decoder) -> Self { + DecoderDeserializer::new(CsvDecoder::new(decoder)) + } +} + #[async_trait] impl FileFormat for CsvFormat { fn as_any(&self) -> &dyn Any { @@ -319,13 +370,22 @@ impl FileFormat for CsvFormat { ) -> Result { let mut schemas = vec![]; - let mut records_to_read = self.options.schema_infer_max_rec; + let mut records_to_read = self + .options + .schema_infer_max_rec + .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD); for object in objects { let stream = self.read_to_delimited_chunks(store, object).await; let (schema, records_read) = self .infer_schema_from_stream(state, records_to_read, stream) - .await?; + .await + .map_err(|err| { + DataFusionError::Context( + format!("Error when processing CSV file {}", &object.location), + Box::new(err), + ) + })?; records_to_read -= records_read; schemas.push(schema); if records_to_read == 0 { @@ -408,15 +468,9 @@ impl FileFormat for CsvFormat { let writer_options = CsvWriterOptions::try_from(&options)?; - let sink_schema = conf.output_schema().clone(); let sink = Arc::new(CsvSink::new(conf, writer_options)); - Ok(Arc::new(DataSinkExec::new( - input, - sink, - sink_schema, - order_requirements, - )) as _) + Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } } @@ -433,11 +487,13 @@ impl CsvFormat { let mut total_records_read = 0; let mut column_names = vec![]; let mut column_type_possibilities = vec![]; - let mut first_chunk = true; + let mut record_number = -1; pin_mut!(stream); while let Some(chunk) = stream.next().await.transpose()? { + record_number += 1; + let first_chunk = record_number == 0; let mut format = arrow::csv::reader::Format::default() .with_header( first_chunk @@ -446,7 +502,18 @@ impl CsvFormat { .has_header .unwrap_or(state.config_options().catalog.has_header), ) - .with_delimiter(self.options.delimiter); + .with_delimiter(self.options.delimiter) + .with_quote(self.options.quote); + + if let Some(null_regex) = &self.options.null_regex { + let regex = Regex::new(null_regex.as_str()) + .expect("Unable to parse CSV null regex."); + format = format.with_null_regex(regex); + } + + if let Some(escape) = self.options.escape { + format = format.with_escape(escape); + } if let Some(comment) = self.options.comment { format = format.with_comment(comment); @@ -471,14 +538,14 @@ impl CsvFormat { (field.name().clone(), possibilities) }) .unzip(); - first_chunk = false; } else { if fields.len() != column_type_possibilities.len() { return exec_err!( "Encountered unequal lengths between records on CSV file whilst inferring schema. \ - Expected {} records, found {} records", + Expected {} fields, found {} fields at record {}", column_type_possibilities.len(), - fields.len() + fields.len(), + record_number + 1 ); } @@ -613,43 +680,42 @@ impl CsvSink { } } - /// Retrieve the inner [`FileSinkConfig`]. - pub fn config(&self) -> &FileSinkConfig { + /// Retrieve the writer options + pub fn writer_options(&self) -> &CsvWriterOptions { + &self.writer_options + } +} + +#[async_trait] +impl FileSink for CsvSink { + fn config(&self) -> &FileSinkConfig { &self.config } - async fn multipartput_all( + async fn spawn_writer_tasks_and_join( &self, - data: SendableRecordBatchStream, context: &Arc, + demux_task: SpawnedTask>, + file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, ) -> Result { - let builder = &self.writer_options.writer_options; - - let builder_clone = builder.clone(); - let options_clone = self.writer_options.clone(); - let get_serializer = move || { - Arc::new( - CsvSerializer::new() - .with_builder(builder_clone.clone()) - .with_header(options_clone.writer_options.header()), - ) as _ - }; - - stateless_multipart_put( - data, + let builder = self.writer_options.writer_options.clone(); + let header = builder.header(); + let serializer = Arc::new( + CsvSerializer::new() + .with_builder(builder) + .with_header(header), + ) as _; + spawn_writer_tasks_and_join( context, - "csv".into(), - Box::new(get_serializer), - &self.config, + serializer, self.writer_options.compression.into(), + object_store, + demux_task, + file_stream_rx, ) .await } - - /// Retrieve the writer options - pub fn writer_options(&self) -> &CsvWriterOptions { - &self.writer_options - } } #[async_trait] @@ -658,8 +724,8 @@ impl DataSink for CsvSink { self } - fn metrics(&self) -> Option { - None + fn schema(&self) -> &SchemaRef { + self.config.output_schema() } async fn write_all( @@ -667,8 +733,7 @@ impl DataSink for CsvSink { data: SendableRecordBatchStream, context: &Arc, ) -> Result { - let total_count = self.multipartput_all(data, context).await?; - Ok(total_count) + FileSink::write_all(self, data, context).await } } @@ -676,23 +741,27 @@ impl DataSink for CsvSink { mod tests { use super::super::test_util::scan_format; use super::*; - use crate::arrow::util::pretty; use crate::assert_batches_eq; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::test_util::VariableStream; + use crate::datasource::file_format::{ + BatchDeserializer, DecoderDeserializer, DeserializerOutput, + }; use crate::datasource::listing::ListingOptions; + use crate::execution::session_state::SessionStateBuilder; use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; use arrow::compute::concat_batches; + use arrow::csv::ReaderBuilder; + use arrow::util::pretty::pretty_format_batches; + use arrow_array::{BooleanArray, Float64Array, Int32Array, StringArray}; use datafusion_common::cast::as_string_array; use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::{col, lit}; - use crate::execution::session_state::SessionStateBuilder; use chrono::DateTime; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -751,8 +820,67 @@ mod tests { let state = session_ctx.state(); let projection = None; - let exec = - get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?; + let root = "./tests/data/csv"; + let format = CsvFormat::default().with_has_header(true); + let exec = scan_format( + &state, + &format, + root, + "aggregate_test_100_with_nulls.csv", + projection, + None, + ) + .await?; + + let x: Vec = exec + .schema() + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + assert_eq!( + vec![ + "c1: Utf8", + "c2: Int64", + "c3: Int64", + "c4: Int64", + "c5: Int64", + "c6: Int64", + "c7: Int64", + "c8: Int64", + "c9: Int64", + "c10: Utf8", + "c11: Float64", + "c12: Float64", + "c13: Utf8", + "c14: Null", + "c15: Utf8" + ], + x + ); + + Ok(()) + } + + #[tokio::test] + async fn infer_schema_with_null_regex() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let projection = None; + let root = "./tests/data/csv"; + let format = CsvFormat::default() + .with_has_header(true) + .with_null_regex(Some("^NULL$|^$".to_string())); + let exec = scan_format( + &state, + &format, + root, + "aggregate_test_100_with_nulls.csv", + projection, + None, + ) + .await?; let x: Vec = exec .schema() @@ -774,7 +902,9 @@ mod tests { "c10: Utf8", "c11: Float64", "c12: Float64", - "c13: Utf8" + "c13: Utf8", + "c14: Null", + "c15: Null" ], x ); @@ -859,6 +989,55 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_infer_schema_escape_chars() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let variable_object_store = Arc::new(VariableStream::new( + Bytes::from( + r#"c1,c2,c3,c4 +0.3,"Here, is a comma\"",third,3 +0.31,"double quotes are ok, "" quote",third again,9 +0.314,abc,xyz,27"#, + ), + 1, + )); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let num_rows_to_read = 3; + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(num_rows_to_read) + .with_quote(b'"') + .with_escape(Some(b'\\')); + + let inferred_schema = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await?; + + let actual_fields: Vec<_> = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + + assert_eq!( + vec!["c1: Float64", "c2: Utf8", "c3: Utf8", "c4: Int64",], + actual_fields + ); + Ok(()) + } + #[rstest( file_compression_type, case(FileCompressionType::UNCOMPRESSED), @@ -872,18 +1051,19 @@ mod tests { async fn query_compress_data( file_compression_type: FileCompressionType, ) -> Result<()> { - let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); let mut cfg = SessionConfig::new(); cfg.options_mut().catalog.has_header = true; let session_state = SessionStateBuilder::new() .with_config(cfg) - .with_runtime_env(runtime) .with_default_features() .build(); let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap(); let path = Path::from("csv/aggregate_test_100.csv"); let csv = CsvFormat::default().with_has_header(true); - let records_to_read = csv.options().schema_infer_max_rec; + let records_to_read = csv + .options() + .schema_infer_max_rec + .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD); let store = Arc::new(integration) as Arc; let original_stream = store.get(&path).await?; @@ -968,7 +1148,7 @@ mod tests { limit: Option, has_header: bool, ) -> Result> { - let root = format!("{}/csv", crate::test_util::arrow_test_data()); + let root = format!("{}/csv", arrow_test_data()); let format = CsvFormat::default().with_has_header(has_header); scan_format(state, &format, &root, file_name, projection, limit).await } @@ -1029,7 +1209,7 @@ mod tests { ) -> Result { let df = ctx.sql(&format!("EXPLAIN {sql}")).await?; let result = df.collect().await?; - let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + let plan = format!("{}", &pretty_format_batches(&result)?); let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap(); @@ -1147,18 +1327,13 @@ mod tests { Ok(()) } - /// Read a single empty csv file in parallel + /// Read a single empty csv file /// /// empty_0_byte.csv: /// (file is empty) - #[rstest(n_partitions, case(1), case(2), case(3), case(4))] #[tokio::test] - async fn test_csv_parallel_empty_file(n_partitions: usize) -> Result<()> { - let config = SessionConfig::new() - .with_repartition_file_scans(true) - .with_repartition_file_min_size(0) - .with_target_partitions(n_partitions); - let ctx = SessionContext::new_with_config(config); + async fn test_csv_empty_file() -> Result<()> { + let ctx = SessionContext::new(); ctx.register_csv( "empty", "tests/data/empty_0_byte.csv", @@ -1166,32 +1341,24 @@ mod tests { ) .await?; - // Require a predicate to enable repartition for the optimizer let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - let actual_partitions = count_query_csv_partitions(&ctx, query).await?; #[rustfmt::skip] let expected = ["++", "++"]; assert_batches_eq!(expected, &query_result); - assert_eq!(1, actual_partitions); // Won't get partitioned if all files are empty Ok(()) } - /// Read a single empty csv file with header in parallel + /// Read a single empty csv file with header /// /// empty.csv: /// c1,c2,c3 - #[rstest(n_partitions, case(1), case(2), case(3))] #[tokio::test] - async fn test_csv_parallel_empty_with_header(n_partitions: usize) -> Result<()> { - let config = SessionConfig::new() - .with_repartition_file_scans(true) - .with_repartition_file_min_size(0) - .with_target_partitions(n_partitions); - let ctx = SessionContext::new_with_config(config); + async fn test_csv_empty_with_header() -> Result<()> { + let ctx = SessionContext::new(); ctx.register_csv( "empty", "tests/data/empty.csv", @@ -1199,21 +1366,18 @@ mod tests { ) .await?; - // Require a predicate to enable repartition for the optimizer let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - let actual_partitions = count_query_csv_partitions(&ctx, query).await?; #[rustfmt::skip] let expected = ["++", "++"]; assert_batches_eq!(expected, &query_result); - assert_eq!(n_partitions, actual_partitions); Ok(()) } - /// Read multiple empty csv files in parallel + /// Read multiple empty csv files /// /// all_empty /// ├── empty0.csv @@ -1222,13 +1386,13 @@ mod tests { /// /// empty0.csv/empty1.csv/empty2.csv: /// (file is empty) - #[rstest(n_partitions, case(1), case(2), case(3), case(4))] #[tokio::test] - async fn test_csv_parallel_multiple_empty_files(n_partitions: usize) -> Result<()> { + async fn test_csv_multiple_empty_files() -> Result<()> { + // Testing that partitioning doesn't break with empty files let config = SessionConfig::new() .with_repartition_file_scans(true) .with_repartition_file_min_size(0) - .with_target_partitions(n_partitions); + .with_target_partitions(4); let ctx = SessionContext::new_with_config(config); let file_format = Arc::new(CsvFormat::default().with_has_header(false)); let listing_options = ListingOptions::new(file_format.clone()) @@ -1246,13 +1410,11 @@ mod tests { // Require a predicate to enable repartition for the optimizer let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - let actual_partitions = count_query_csv_partitions(&ctx, query).await?; #[rustfmt::skip] let expected = ["++", "++"]; assert_batches_eq!(expected, &query_result); - assert_eq!(1, actual_partitions); // Won't get partitioned if all files are empty Ok(()) } @@ -1396,4 +1558,180 @@ mod tests { Ok(()) } + + #[rstest] + fn test_csv_deserializer_with_finish( + #[values(1, 5, 17)] batch_size: usize, + #[values(0, 5, 93)] line_count: usize, + ) -> Result<()> { + let schema = csv_schema(); + let generator = CsvBatchGenerator::new(batch_size, line_count); + let mut deserializer = csv_deserializer(batch_size, &schema); + + for data in generator { + deserializer.digest(data); + } + deserializer.finish(); + + let batch_count = line_count.div_ceil(batch_size); + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + for _ in 0..batch_count { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])?; + } + assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); + + let expected = csv_expected_batch(schema, line_count)?; + + assert_eq!( + expected.clone(), + all_batches.clone(), + "Expected:\n{}\nActual:\n{}", + pretty_format_batches(&[expected])?, + pretty_format_batches(&[all_batches])?, + ); + + Ok(()) + } + + #[rstest] + fn test_csv_deserializer_without_finish( + #[values(1, 5, 17)] batch_size: usize, + #[values(0, 5, 93)] line_count: usize, + ) -> Result<()> { + let schema = csv_schema(); + let generator = CsvBatchGenerator::new(batch_size, line_count); + let mut deserializer = csv_deserializer(batch_size, &schema); + + for data in generator { + deserializer.digest(data); + } + + let batch_count = line_count / batch_size; + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + for _ in 0..batch_count { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])?; + } + assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); + + let expected = csv_expected_batch(schema, batch_count * batch_size)?; + + assert_eq!( + expected.clone(), + all_batches.clone(), + "Expected:\n{}\nActual:\n{}", + pretty_format_batches(&[expected])?, + pretty_format_batches(&[all_batches])?, + ); + + Ok(()) + } + + struct CsvBatchGenerator { + batch_size: usize, + line_count: usize, + offset: usize, + } + + impl CsvBatchGenerator { + fn new(batch_size: usize, line_count: usize) -> Self { + Self { + batch_size, + line_count, + offset: 0, + } + } + } + + impl Iterator for CsvBatchGenerator { + type Item = Bytes; + + fn next(&mut self) -> Option { + // Return `batch_size` rows per batch: + let mut buffer = Vec::new(); + for _ in 0..self.batch_size { + if self.offset >= self.line_count { + break; + } + buffer.extend_from_slice(&csv_line(self.offset)); + self.offset += 1; + } + + (!buffer.is_empty()).then(|| buffer.into()) + } + } + + fn csv_expected_batch( + schema: SchemaRef, + line_count: usize, + ) -> Result { + let mut c1 = Vec::with_capacity(line_count); + let mut c2 = Vec::with_capacity(line_count); + let mut c3 = Vec::with_capacity(line_count); + let mut c4 = Vec::with_capacity(line_count); + + for i in 0..line_count { + let (int_value, float_value, bool_value, char_value) = csv_values(i); + c1.push(int_value); + c2.push(float_value); + c3.push(bool_value); + c4.push(char_value); + } + + let expected = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(c1)), + Arc::new(Float64Array::from(c2)), + Arc::new(BooleanArray::from(c3)), + Arc::new(StringArray::from(c4)), + ], + )?; + Ok(expected) + } + + fn csv_line(line_number: usize) -> Bytes { + let (int_value, float_value, bool_value, char_value) = csv_values(line_number); + format!( + "{},{},{},{}\n", + int_value, float_value, bool_value, char_value + ) + .into() + } + + fn csv_values(line_number: usize) -> (i32, f64, bool, String) { + let int_value = line_number as i32; + let float_value = line_number as f64; + let bool_value = line_number % 2 == 0; + let char_value = format!("{}-string", line_number); + (int_value, float_value, bool_value, char_value) + } + + fn csv_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Boolean, true), + Field::new("c4", DataType::Utf8, true), + ])) + } + + fn csv_deserializer( + batch_size: usize, + schema: &Arc, + ) -> impl BatchDeserializer { + let decoder = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .build_decoder(); + DecoderDeserializer::new(CsvDecoder::new(decoder)) + } } diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index a054094822d01..6612de077988d 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -123,10 +123,10 @@ impl FileCompressionType { } /// Given a `Stream`, create a `Stream` which data are compressed with `FileCompressionType`. - pub fn convert_to_compress_stream( + pub fn convert_to_compress_stream<'a>( &self, - s: BoxStream<'static, Result>, - ) -> Result>> { + s: BoxStream<'a, Result>, + ) -> Result>> { Ok(match self.variant { #[cfg(feature = "compression")] GZIP => ReaderStream::new(AsyncGzEncoder::new(StreamReader::new(s))) @@ -180,10 +180,10 @@ impl FileCompressionType { } /// Given a `Stream`, create a `Stream` which data are decompressed with `FileCompressionType`. - pub fn convert_stream( + pub fn convert_stream<'a>( &self, - s: BoxStream<'static, Result>, - ) -> Result>> { + s: BoxStream<'a, Result>, + ) -> Result>> { Ok(match self.variant { #[cfg(feature = "compression")] GZIP => { diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index c9ed0c0d28059..5bffb7e582c1d 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -24,11 +24,15 @@ use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; -use super::write::orchestration::stateless_multipart_put; -use super::{FileFormat, FileFormatFactory, FileScanConfig}; +use super::write::orchestration::spawn_writer_tasks_and_join; +use super::{ + Decoder, DecoderDeserializer, FileFormat, FileFormatFactory, FileScanConfig, + DEFAULT_SCHEMA_INFER_MAX_RECORD, +}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::write::demux::DemuxedStreamReceiver; use crate::datasource::file_format::write::BatchSerializer; -use crate::datasource::physical_plan::FileGroupDisplay; +use crate::datasource::physical_plan::{FileGroupDisplay, FileSink}; use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; use crate::error::Result; use crate::execution::context::SessionState; @@ -42,13 +46,14 @@ use arrow::datatypes::SchemaRef; use arrow::json; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; +use arrow_schema::ArrowError; use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; use async_trait::async_trait; @@ -118,7 +123,7 @@ impl GetExt for JsonFormatFactory { } } -impl fmt::Debug for JsonFormatFactory { +impl Debug for JsonFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("JsonFormatFactory") .field("options", &self.options) @@ -147,7 +152,7 @@ impl JsonFormat { /// Set a limit in terms of records to scan to infer the schema /// - defaults to `DEFAULT_SCHEMA_INFER_MAX_RECORD` pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self { - self.options.schema_infer_max_rec = max_rec; + self.options.schema_infer_max_rec = Some(max_rec); self } @@ -187,7 +192,10 @@ impl FileFormat for JsonFormat { objects: &[ObjectMeta], ) -> Result { let mut schemas = Vec::new(); - let mut records_to_read = self.options.schema_infer_max_rec; + let mut records_to_read = self + .options + .schema_infer_max_rec + .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD); let file_compression_type = FileCompressionType::from(self.options.compression); for object in objects { let mut take_while = || { @@ -259,15 +267,9 @@ impl FileFormat for JsonFormat { let writer_options = JsonWriterOptions::try_from(&self.options)?; - let sink_schema = conf.output_schema().clone(); let sink = Arc::new(JsonSink::new(conf, writer_options)); - Ok(Arc::new(DataSinkExec::new( - input, - sink, - sink_schema, - order_requirements, - )) as _) + Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } } @@ -331,32 +333,36 @@ impl JsonSink { } } - /// Retrieve the inner [`FileSinkConfig`]. - pub fn config(&self) -> &FileSinkConfig { + /// Retrieve the writer options + pub fn writer_options(&self) -> &JsonWriterOptions { + &self.writer_options + } +} + +#[async_trait] +impl FileSink for JsonSink { + fn config(&self) -> &FileSinkConfig { &self.config } - async fn multipartput_all( + async fn spawn_writer_tasks_and_join( &self, - data: SendableRecordBatchStream, context: &Arc, + demux_task: SpawnedTask>, + file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, ) -> Result { - let get_serializer = move || Arc::new(JsonSerializer::new()) as _; - - stateless_multipart_put( - data, + let serializer = Arc::new(JsonSerializer::new()) as _; + spawn_writer_tasks_and_join( context, - "json".into(), - Box::new(get_serializer), - &self.config, + serializer, self.writer_options.compression.into(), + object_store, + demux_task, + file_stream_rx, ) .await } - /// Retrieve the writer options - pub fn writer_options(&self) -> &JsonWriterOptions { - &self.writer_options - } } #[async_trait] @@ -365,8 +371,8 @@ impl DataSink for JsonSink { self } - fn metrics(&self) -> Option { - None + fn schema(&self) -> &SchemaRef { + self.config.output_schema() } async fn write_all( @@ -374,8 +380,38 @@ impl DataSink for JsonSink { data: SendableRecordBatchStream, context: &Arc, ) -> Result { - let total_count = self.multipartput_all(data, context).await?; - Ok(total_count) + FileSink::write_all(self, data, context).await + } +} + +#[derive(Debug)] +pub(crate) struct JsonDecoder { + inner: json::reader::Decoder, +} + +impl JsonDecoder { + pub(crate) fn new(decoder: json::reader::Decoder) -> Self { + Self { inner: decoder } + } +} + +impl Decoder for JsonDecoder { + fn decode(&mut self, buf: &[u8]) -> Result { + self.inner.decode(buf) + } + + fn flush(&mut self) -> Result, ArrowError> { + self.inner.flush() + } + + fn can_flush_early(&self) -> bool { + false + } +} + +impl From for DecoderDeserializer { + fn from(decoder: json::reader::Decoder) -> Self { + DecoderDeserializer::new(JsonDecoder::new(decoder)) } } @@ -383,12 +419,18 @@ impl DataSink for JsonSink { mod tests { use super::super::test_util::scan_format; use super::*; + use crate::datasource::file_format::{ + BatchDeserializer, DecoderDeserializer, DeserializerOutput, + }; use crate::execution::options::NdJsonReadOptions; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; + use arrow::compute::concat_batches; + use arrow::json::ReaderBuilder; use arrow::util::pretty; + use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::stats::Precision; use datafusion_common::{assert_batches_eq, internal_err}; @@ -575,13 +617,11 @@ mod tests { Ok(()) } - #[rstest(n_partitions, case(1), case(2), case(3), case(4))] #[tokio::test] - async fn it_can_read_empty_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + async fn it_can_read_empty_ndjson() -> Result<()> { let config = SessionConfig::new() .with_repartition_file_scans(true) - .with_repartition_file_min_size(0) - .with_target_partitions(n_partitions); + .with_repartition_file_min_size(0); let ctx = SessionContext::new_with_config(config); @@ -594,7 +634,6 @@ mod tests { let query = "SELECT * FROM json_parallel_empty WHERE random() > 0.5;"; let result = ctx.sql(query).await?.collect().await?; - let actual_partitions = count_num_partitions(&ctx, query).await?; #[rustfmt::skip] let expected = [ @@ -603,8 +642,100 @@ mod tests { ]; assert_batches_eq!(expected, &result); - assert_eq!(1, actual_partitions); Ok(()) } + + #[test] + fn test_json_deserializer_finish() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::Int64, true), + Field::new("c5", DataType::Int64, true), + ])); + let mut deserializer = json_deserializer(1, &schema)?; + + deserializer.digest(r#"{ "c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5 }"#.into()); + deserializer.digest(r#"{ "c1": 6, "c2": 7, "c3": 8, "c4": 9, "c5": 10 }"#.into()); + deserializer + .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); + deserializer.finish(); + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + for _ in 0..3 { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])? + } + assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); + + let expected = [ + "+----+----+----+----+----+", + "| c1 | c2 | c3 | c4 | c5 |", + "+----+----+----+----+----+", + "| 1 | 2 | 3 | 4 | 5 |", + "| 6 | 7 | 8 | 9 | 10 |", + "| 11 | 12 | 13 | 14 | 15 |", + "+----+----+----+----+----+", + ]; + + assert_batches_eq!(expected, &[all_batches]); + + Ok(()) + } + + #[test] + fn test_json_deserializer_no_finish() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::Int64, true), + Field::new("c5", DataType::Int64, true), + ])); + let mut deserializer = json_deserializer(1, &schema)?; + + deserializer.digest(r#"{ "c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5 }"#.into()); + deserializer.digest(r#"{ "c1": 6, "c2": 7, "c3": 8, "c4": 9, "c5": 10 }"#.into()); + deserializer + .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + // We get RequiresMoreData after 2 batches because of how json::Decoder works + for _ in 0..2 { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])? + } + assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); + + let expected = [ + "+----+----+----+----+----+", + "| c1 | c2 | c3 | c4 | c5 |", + "+----+----+----+----+----+", + "| 1 | 2 | 3 | 4 | 5 |", + "| 6 | 7 | 8 | 9 | 10 |", + "+----+----+----+----+----+", + ]; + + assert_batches_eq!(expected, &[all_batches]); + + Ok(()) + } + + fn json_deserializer( + batch_size: usize, + schema: &Arc, + ) -> Result> { + let decoder = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .build_decoder()?; + Ok(DecoderDeserializer::new(JsonDecoder::new(decoder))) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index e16986c660adf..f47e2107ade6f 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -32,9 +32,10 @@ pub mod parquet; pub mod write; use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Display}; +use std::collections::{HashMap, VecDeque}; +use std::fmt::{self, Debug, Display}; use std::sync::Arc; +use std::task::Poll; use crate::arrow::datatypes::SchemaRef; use crate::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -42,17 +43,20 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt}; use datafusion_expr::Expr; use datafusion_physical_expr::PhysicalExpr; use async_trait::async_trait; +use bytes::{Buf, Bytes}; use datafusion_physical_expr_common::sort_expr::LexRequirement; use file_compression_type::FileCompressionType; +use futures::stream::BoxStream; +use futures::{ready, Stream, StreamExt}; use object_store::{ObjectMeta, ObjectStore}; -use std::fmt::Debug; /// Factory for creating [`FileFormat`] instances based on session and command level options /// @@ -79,7 +83,7 @@ pub trait FileFormatFactory: Sync + Send + GetExt + Debug { /// /// [`TableProvider`]: crate::catalog::TableProvider #[async_trait] -pub trait FileFormat: Send + Sync + fmt::Debug { +pub trait FileFormat: Send + Sync + Debug { /// Returns the table provider as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -168,6 +172,165 @@ pub enum FilePushdownSupport { Supported, } +/// Possible outputs of a [`BatchDeserializer`]. +#[derive(Debug, PartialEq)] +pub enum DeserializerOutput { + /// A successfully deserialized [`RecordBatch`]. + RecordBatch(RecordBatch), + /// The deserializer requires more data to make progress. + RequiresMoreData, + /// The input data has been exhausted. + InputExhausted, +} + +/// Trait defining a scheme for deserializing byte streams into structured data. +/// Implementors of this trait are responsible for converting raw bytes into +/// `RecordBatch` objects. +pub trait BatchDeserializer: Send + Debug { + /// Feeds a message for deserialization, updating the internal state of + /// this `BatchDeserializer`. Note that one can call this function multiple + /// times before calling `next`, which will queue multiple messages for + /// deserialization. Returns the number of bytes consumed. + fn digest(&mut self, message: T) -> usize; + + /// Attempts to deserialize any pending messages and returns a + /// `DeserializerOutput` to indicate progress. + fn next(&mut self) -> Result; + + /// Informs the deserializer that no more messages will be provided for + /// deserialization. + fn finish(&mut self); +} + +/// A general interface for decoders such as [`arrow::json::reader::Decoder`] and +/// [`arrow::csv::reader::Decoder`]. Defines an interface similar to +/// [`Decoder::decode`] and [`Decoder::flush`] methods, but also includes +/// a method to check if the decoder can flush early. Intended to be used in +/// conjunction with [`DecoderDeserializer`]. +/// +/// [`arrow::json::reader::Decoder`]: ::arrow::json::reader::Decoder +/// [`arrow::csv::reader::Decoder`]: ::arrow::csv::reader::Decoder +/// [`Decoder::decode`]: ::arrow::json::reader::Decoder::decode +/// [`Decoder::flush`]: ::arrow::json::reader::Decoder::flush +pub(crate) trait Decoder: Send + Debug { + /// See [`arrow::json::reader::Decoder::decode`]. + /// + /// [`arrow::json::reader::Decoder::decode`]: ::arrow::json::reader::Decoder::decode + fn decode(&mut self, buf: &[u8]) -> Result; + + /// See [`arrow::json::reader::Decoder::flush`]. + /// + /// [`arrow::json::reader::Decoder::flush`]: ::arrow::json::reader::Decoder::flush + fn flush(&mut self) -> Result, ArrowError>; + + /// Whether the decoder can flush early in its current state. + fn can_flush_early(&self) -> bool; +} + +impl Debug for DecoderDeserializer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Deserializer") + .field("buffered_queue", &self.buffered_queue) + .field("finalized", &self.finalized) + .finish() + } +} + +impl BatchDeserializer for DecoderDeserializer { + fn digest(&mut self, message: Bytes) -> usize { + if message.is_empty() { + return 0; + } + + let consumed = message.len(); + self.buffered_queue.push_back(message); + consumed + } + + fn next(&mut self) -> Result { + while let Some(buffered) = self.buffered_queue.front_mut() { + let decoded = self.decoder.decode(buffered)?; + buffered.advance(decoded); + + if buffered.is_empty() { + self.buffered_queue.pop_front(); + } + + // Flush when the stream ends or batch size is reached + // Certain implementations can flush early + if decoded == 0 || self.decoder.can_flush_early() { + return match self.decoder.flush() { + Ok(Some(batch)) => Ok(DeserializerOutput::RecordBatch(batch)), + Ok(None) => continue, + Err(e) => Err(e), + }; + } + } + if self.finalized { + Ok(DeserializerOutput::InputExhausted) + } else { + Ok(DeserializerOutput::RequiresMoreData) + } + } + + fn finish(&mut self) { + self.finalized = true; + // Ensure the decoder is flushed: + self.buffered_queue.push_back(Bytes::new()); + } +} + +/// A generic, decoder-based deserialization scheme for processing encoded data. +/// +/// This struct is responsible for converting a stream of bytes, which represent +/// encoded data, into a stream of `RecordBatch` objects, following the specified +/// schema and formatting options. It also handles any buffering necessary to satisfy +/// the `Decoder` interface. +pub(crate) struct DecoderDeserializer { + /// The underlying decoder used for deserialization + pub(crate) decoder: T, + /// The buffer used to store the remaining bytes to be decoded + pub(crate) buffered_queue: VecDeque, + /// Whether the input stream has been fully consumed + pub(crate) finalized: bool, +} + +impl DecoderDeserializer { + /// Creates a new `DecoderDeserializer` with the provided decoder. + pub(crate) fn new(decoder: T) -> Self { + DecoderDeserializer { + decoder, + buffered_queue: VecDeque::new(), + finalized: false, + } + } +} + +/// Deserializes a stream of bytes into a stream of [`RecordBatch`] objects using the +/// provided deserializer. +/// +/// Returns a boxed stream of `Result`. The stream yields [`RecordBatch`] +/// objects as they are produced by the deserializer, or an [`ArrowError`] if an error +/// occurs while polling the input or deserializing. +pub(crate) fn deserialize_stream<'a>( + mut input: impl Stream> + Unpin + Send + 'a, + mut deserializer: impl BatchDeserializer + 'a, +) -> BoxStream<'a, Result> { + futures::stream::poll_fn(move |cx| loop { + match ready!(input.poll_next_unpin(cx)).transpose()? { + Some(b) => _ = deserializer.digest(b), + None => deserializer.finish(), + }; + + return match deserializer.next()? { + DeserializerOutput::RecordBatch(rb) => Poll::Ready(Some(Ok(rb))), + DeserializerOutput::InputExhausted => Poll::Ready(None), + DeserializerOutput::RequiresMoreData => continue, + }; + }) + .boxed() +} + /// A container of [FileFormatFactory] which also implements [FileType]. /// This enables converting a dyn FileFormat to a dyn FileType. /// The former trait is a superset of the latter trait, which includes execution time @@ -224,32 +387,38 @@ pub fn format_as_file_type( /// downcasted to a [DefaultFileType]. pub fn file_type_to_format( file_type: &Arc, -) -> datafusion_common::Result> { +) -> Result> { match file_type .as_ref() .as_any() .downcast_ref::() { - Some(source) => Ok(source.file_format_factory.clone()), + Some(source) => Ok(Arc::clone(&source.file_format_factory)), _ => internal_err!("FileType was not DefaultFileType"), } } +/// Create a new field with the specified data type, copying the other +/// properties from the input field +fn field_with_new_type(field: &FieldRef, new_type: DataType) -> FieldRef { + Arc::new(field.as_ref().clone().with_data_type(new_type)) +} + /// Transform a schema to use view types for Utf8 and Binary +/// +/// See [parquet::ParquetFormat::force_view_types] for details pub fn transform_schema_to_view(schema: &Schema) -> Schema { let transformed_fields: Vec> = schema .fields .iter() .map(|field| match field.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => Arc::new( - Field::new(field.name(), DataType::Utf8View, field.is_nullable()) - .with_metadata(field.metadata().to_owned()), - ), - DataType::Binary | DataType::LargeBinary => Arc::new( - Field::new(field.name(), DataType::BinaryView, field.is_nullable()) - .with_metadata(field.metadata().to_owned()), - ), - _ => field.clone(), + DataType::Utf8 | DataType::LargeUtf8 => { + field_with_new_type(field, DataType::Utf8View) + } + DataType::Binary | DataType::LargeBinary => { + field_with_new_type(field, DataType::BinaryView) + } + _ => Arc::clone(field), }) .collect(); Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) @@ -274,6 +443,7 @@ pub(crate) fn coerce_file_schema_to_view_type( (f.name(), dt) }) .collect(); + if !transform { return None; } @@ -283,15 +453,14 @@ pub(crate) fn coerce_file_schema_to_view_type( .iter() .map( |field| match (table_fields.get(field.name()), field.data_type()) { - (Some(DataType::Utf8View), DataType::Utf8) - | (Some(DataType::Utf8View), DataType::LargeUtf8) => Arc::new( - Field::new(field.name(), DataType::Utf8View, field.is_nullable()), - ), - (Some(DataType::BinaryView), DataType::Binary) - | (Some(DataType::BinaryView), DataType::LargeBinary) => Arc::new( - Field::new(field.name(), DataType::BinaryView, field.is_nullable()), - ), - _ => field.clone(), + (Some(DataType::Utf8View), DataType::Utf8 | DataType::LargeUtf8) => { + field_with_new_type(field, DataType::Utf8View) + } + ( + Some(DataType::BinaryView), + DataType::Binary | DataType::LargeBinary, + ) => field_with_new_type(field, DataType::BinaryView), + _ => Arc::clone(field), }, ) .collect(); @@ -302,6 +471,78 @@ pub(crate) fn coerce_file_schema_to_view_type( )) } +/// Transform a schema so that any binary types are strings +pub fn transform_binary_to_string(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Binary => field_with_new_type(field, DataType::Utf8), + DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), + DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), + _ => Arc::clone(field), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} + +/// If the table schema uses a string type, coerce the file schema to use a string type. +/// +/// See [parquet::ParquetFormat::binary_as_string] for details +pub(crate) fn coerce_file_schema_to_string_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| (f.name(), f.data_type())) + .collect(); + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + // table schema uses string type, coerce the file schema to use string type + ( + Some(DataType::Utf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8) + } + // table schema uses large string type, coerce the file schema to use large string type + ( + Some(DataType::LargeUtf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::LargeUtf8) + } + // table schema uses string view type, coerce the file schema to use view type + ( + Some(DataType::Utf8View), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8View) + } + _ => Arc::clone(field), + }, + ) + .collect(); + + if !transform { + None + } else { + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) + } +} + #[cfg(test)] pub(crate) mod test_util { use std::ops::Range; @@ -332,7 +573,9 @@ pub(crate) mod test_util { let store = Arc::new(LocalFileSystem::new()) as _; let meta = local_unpartitioned_file(format!("{store_root}/{file_name}")); - let file_schema = format.infer_schema(state, &store, &[meta.clone()]).await?; + let file_schema = format + .infer_schema(state, &store, std::slice::from_ref(&meta)) + .await?; let statistics = format .infer_stats(state, &store, file_schema.clone(), &meta) @@ -344,6 +587,7 @@ pub(crate) mod test_util { range: None, statistics: None, extensions: None, + metadata_size_hint: None, }]]; let exec = format @@ -369,8 +613,8 @@ pub(crate) mod test_util { iterations_detected: Arc>, } - impl std::fmt::Display for VariableStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl Display for VariableStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "VariableStream") } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 1e0e28ef88cb2..2f32479ed2b00 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -87,9 +87,11 @@ pub struct CsvReadOptions<'a> { pub file_compression_type: FileCompressionType, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Optional regex to match null values + pub null_regex: Option, } -impl<'a> Default for CsvReadOptions<'a> { +impl Default for CsvReadOptions<'_> { fn default() -> Self { Self::new() } @@ -112,6 +114,7 @@ impl<'a> CsvReadOptions<'a> { file_compression_type: FileCompressionType::UNCOMPRESSED, file_sort_order: vec![], comment: None, + null_regex: None, } } @@ -212,6 +215,12 @@ impl<'a> CsvReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Configure the null parsing regex. + pub fn null_regex(mut self, null_regex: Option) -> Self { + self.null_regex = null_regex; + self + } } /// Options that control the reading of Parquet files. @@ -243,7 +252,7 @@ pub struct ParquetReadOptions<'a> { pub file_sort_order: Vec>, } -impl<'a> Default for ParquetReadOptions<'a> { +impl Default for ParquetReadOptions<'_> { fn default() -> Self { Self { file_extension: DEFAULT_PARQUET_EXTENSION, @@ -262,6 +271,12 @@ impl<'a> ParquetReadOptions<'a> { Default::default() } + /// Specify file_extension + pub fn file_extension(mut self, file_extension: &'a str) -> Self { + self.file_extension = file_extension; + self + } + /// Specify parquet_pruning pub fn parquet_pruning(mut self, parquet_pruning: bool) -> Self { self.parquet_pruning = Some(parquet_pruning); @@ -317,7 +332,7 @@ pub struct ArrowReadOptions<'a> { pub table_partition_cols: Vec<(String, DataType)>, } -impl<'a> Default for ArrowReadOptions<'a> { +impl Default for ArrowReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -362,7 +377,7 @@ pub struct AvroReadOptions<'a> { pub table_partition_cols: Vec<(String, DataType)>, } -impl<'a> Default for AvroReadOptions<'a> { +impl Default for AvroReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -414,7 +429,7 @@ pub struct NdJsonReadOptions<'a> { pub file_sort_order: Vec>, } -impl<'a> Default for NdJsonReadOptions<'a> { +impl Default for NdJsonReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -528,7 +543,8 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_terminator(self.terminator) .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) - .with_file_compression_type(self.file_compression_type.to_owned()); + .with_file_compression_type(self.file_compression_type.to_owned()) + .with_null_regex(self.null_regex.clone()); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 8647b5df90be1..4c7169764a769 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -23,16 +23,21 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use super::write::demux::start_demuxer_task; +use super::write::demux::DemuxedStreamReceiver; use super::write::{create_writer, SharedBuffer}; use super::{ - coerce_file_schema_to_view_type, transform_schema_to_view, FileFormat, - FileFormatFactory, FilePushdownSupport, FileScanConfig, + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, + transform_binary_to_string, transform_schema_to_view, FileFormat, FileFormatFactory, + FilePushdownSupport, FileScanConfig, }; use crate::arrow::array::RecordBatch; use crate::arrow::datatypes::{Fields, Schema, SchemaRef}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::{FileGroupDisplay, FileSinkConfig}; +use crate::datasource::file_format::write::get_writer_schema; +use crate::datasource::physical_plan::parquet::{ + can_expr_be_pushed_down_with_schemas, ParquetExecBuilder, +}; +use crate::datasource::physical_plan::{FileGroupDisplay, FileSink, FileSinkConfig}; use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use crate::error::Result; use crate::execution::context::SessionState; @@ -44,11 +49,11 @@ use crate::physical_plan::{ use arrow::compute::sum; use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; -use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; +use datafusion_common::HashMap; use datafusion_common::{ - internal_datafusion_err, not_impl_err, DataFusionError, GetExt, + internal_datafusion_err, internal_err, not_impl_err, DataFusionError, GetExt, DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; @@ -58,40 +63,32 @@ use datafusion_expr::dml::InsertOp; use datafusion_expr::Expr; use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_expr_common::sort_expr::LexRequirement; use async_trait::async_trait; use bytes::Bytes; -use hashbrown::HashMap; +use futures::future::BoxFuture; +use futures::{FutureExt, StreamExt, TryStreamExt}; use log::debug; use object_store::buffered::BufWriter; +use object_store::path::Path; +use object_store::{ObjectMeta, ObjectStore}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::arrow_writer::{ compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, - ArrowLeafColumn, -}; -use parquet::arrow::{ - arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, + ArrowLeafColumn, ArrowWriterOptions, }; +use parquet::arrow::async_reader::MetadataFetch; +use parquet::arrow::{parquet_to_arrow_schema, ArrowSchemaConverter, AsyncArrowWriter}; +use parquet::errors::ParquetError; use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData}; -use parquet::file::properties::WriterProperties; +use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; use parquet::file::writer::SerializedFileWriter; use parquet::format::FileMetaData; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::task::JoinSet; -use crate::datasource::physical_plan::parquet::{ - can_expr_be_pushed_down_with_schemas, ParquetExecBuilder, -}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; -use futures::future::BoxFuture; -use futures::{FutureExt, StreamExt, TryStreamExt}; -use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; -use parquet::arrow::arrow_reader::statistics::StatisticsConverter; -use parquet::arrow::async_reader::MetadataFetch; -use parquet::errors::ParquetError; - /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. const INITIAL_BUFFER_BYTES: usize = 1048576; @@ -164,7 +161,7 @@ impl GetExt for ParquetFormatFactory { } } -impl fmt::Debug for ParquetFormatFactory { +impl Debug for ParquetFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ParquetFormatFactory") .field("ParquetFormatFactory", &self.options) @@ -253,13 +250,29 @@ impl ParquetFormat { self.options.global.schema_force_view_types } - /// If true, will use view types (StringView and BinaryView). - /// - /// Refer to [`Self::force_view_types`]. + /// If true, will use view types. See [`Self::force_view_types`] for details pub fn with_force_view_types(mut self, use_views: bool) -> Self { self.options.global.schema_force_view_types = use_views; self } + + /// Return `true` if binary types will be read as strings. + /// + /// If this returns true, DataFusion will instruct the parquet reader + /// to read binary columns such as `Binary` or `BinaryView` as the + /// corresponding string type such as `Utf8` or `LargeUtf8`. + /// The parquet reader has special optimizations for `Utf8` and `LargeUtf8` + /// validation, and such queries are significantly faster than reading + /// binary columns and then casting to string columns. + pub fn binary_as_string(&self) -> bool { + self.options.global.binary_as_string + } + + /// If true, will read binary types as strings. See [`Self::binary_as_string`] for details + pub fn with_binary_as_string(mut self, binary_as_string: bool) -> Self { + self.options.global.binary_as_string = binary_as_string; + self + } } /// Clears all metadata (Schema level and field level) on an iterator @@ -306,9 +319,7 @@ impl FileFormat for ParquetFormat { let ext = self.get_ext(); match file_compression_type.get_variant() { CompressionTypeVariant::UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "Parquet FileFormat does not support compression.".into(), - )), + _ => internal_err!("Parquet FileFormat does not support compression."), } } @@ -350,6 +361,12 @@ impl FileFormat for ParquetFormat { Schema::try_merge(schemas) }?; + let schema = if self.binary_as_string() { + transform_binary_to_string(&schema) + } else { + schema + }; + let schema = if self.force_view_types() { transform_schema_to_view(&schema) } else { @@ -411,15 +428,9 @@ impl FileFormat for ParquetFormat { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } - let sink_schema = conf.output_schema().clone(); let sink = Arc::new(ParquetSink::new(conf, self.options.clone())); - Ok(Arc::new(DataSinkExec::new( - input, - sink, - sink_schema, - order_requirements, - )) as _) + Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } fn supports_filters_pushdown( @@ -456,7 +467,7 @@ impl<'a> ObjectStoreFetch<'a> { } } -impl<'a> MetadataFetch for ObjectStoreFetch<'a> { +impl MetadataFetch for ObjectStoreFetch<'_> { fn fetch( &mut self, range: Range, @@ -509,7 +520,7 @@ async fn fetch_schema( /// Read and parse the statistics of the Parquet file at location `path` /// -/// See [`statistics_from_parquet_meta`] for more details +/// See [`statistics_from_parquet_meta_calc`] for more details async fn fetch_statistics( store: &dyn ObjectStore, table_schema: SchemaRef, @@ -552,6 +563,10 @@ pub fn statistics_from_parquet_meta_calc( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; + if let Some(merged) = coerce_file_schema_to_string_type(&table_schema, &file_schema) { + file_schema = merged; + } + if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &file_schema) { file_schema = merged; } @@ -686,42 +701,32 @@ impl ParquetSink { } } - /// Retrieve the inner [`FileSinkConfig`]. - pub fn config(&self) -> &FileSinkConfig { - &self.config - } - /// Retrieve the file metadata for the written files, keyed to the path /// which may be partitioned (in the case of hive style partitioning). pub fn written(&self) -> HashMap { self.written.lock().clone() } - /// Converts table schema to writer schema, which may differ in the case - /// of hive style partitioning where some columns are removed from the - /// underlying files. - fn get_writer_schema(&self) -> Arc { - if !self.config.table_partition_cols.is_empty() - && !self.config.keep_partition_by_columns - { - let schema = self.config.output_schema(); - let partition_names: Vec<_> = self - .config - .table_partition_cols - .iter() - .map(|(s, _)| s) - .collect(); - Arc::new(Schema::new( - schema - .fields() - .iter() - .filter(|f| !partition_names.contains(&f.name())) - .map(|f| (**f).clone()) - .collect::>(), - )) + /// Create writer properties based upon configuration settings, + /// including partitioning and the inclusion of arrow schema metadata. + fn create_writer_props(&self) -> Result { + let schema = if self.parquet_options.global.allow_single_file_parallelism { + // If parallelizing writes, we may be also be doing hive style partitioning + // into multiple files which impacts the schema per file. + // Refer to `get_writer_schema()` + &get_writer_schema(&self.config) } else { - self.config.output_schema().clone() + self.config.output_schema() + }; + + // TODO: avoid this clone in follow up PR, where the writer properties & schema + // are calculated once on `ParquetSink::new` + let mut parquet_opts = self.parquet_options.clone(); + if !self.parquet_options.global.skip_arrow_metadata { + parquet_opts.arrow_schema(schema); } + + Ok(WriterPropertiesBuilder::try_from(&parquet_opts)?.build()) } /// Creates an AsyncArrowWriter which serializes a parquet file to an ObjectStore @@ -733,10 +738,14 @@ impl ParquetSink { parquet_props: WriterProperties, ) -> Result> { let buf_writer = BufWriter::new(object_store, location.clone()); - let writer = AsyncArrowWriter::try_new( + let options = ArrowWriterOptions::new() + .with_properties(parquet_props) + .with_skip_arrow_metadata(self.parquet_options.global.skip_arrow_metadata); + + let writer = AsyncArrowWriter::try_new_with_options( buf_writer, - self.get_writer_schema(), - Some(parquet_props), + get_writer_schema(&self.config), + options, )?; Ok(writer) } @@ -748,36 +757,27 @@ impl ParquetSink { } #[async_trait] -impl DataSink for ParquetSink { - fn as_any(&self) -> &dyn Any { - self - } - - fn metrics(&self) -> Option { - None +impl FileSink for ParquetSink { + fn config(&self) -> &FileSinkConfig { + &self.config } - async fn write_all( + async fn spawn_writer_tasks_and_join( &self, - data: SendableRecordBatchStream, context: &Arc, + demux_task: SpawnedTask>, + mut file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, ) -> Result { - let parquet_props = ParquetWriterOptions::try_from(&self.parquet_options)?; - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - let parquet_opts = &self.parquet_options; let allow_single_file_parallelism = parquet_opts.global.allow_single_file_parallelism; - let part_col = if !self.config.table_partition_cols.is_empty() { - Some(self.config.table_partition_cols.clone()) - } else { - None - }; + let mut file_write_tasks: JoinSet< + std::result::Result<(Path, FileMetaData), DataFusionError>, + > = JoinSet::new(); + let parquet_props = self.create_writer_props()?; let parallel_options = ParallelParquetWriterOptions { max_parallel_row_groups: parquet_opts .global @@ -787,26 +787,13 @@ impl DataSink for ParquetSink { .maximum_buffered_record_batches_per_stream, }; - let (demux_task, mut file_stream_rx) = start_demuxer_task( - data, - context, - part_col, - self.config.table_paths[0].clone(), - "parquet".into(), - self.config.keep_partition_by_columns, - ); - - let mut file_write_tasks: JoinSet< - std::result::Result<(Path, FileMetaData), DataFusionError>, - > = JoinSet::new(); - while let Some((path, mut rx)) = file_stream_rx.recv().await { if !allow_single_file_parallelism { let mut writer = self .create_async_arrow_writer( &path, - object_store.clone(), - parquet_props.writer_options().clone(), + Arc::clone(&object_store), + parquet_props.clone(), ) .await?; let mut reservation = @@ -829,10 +816,10 @@ impl DataSink for ParquetSink { // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, &path, - object_store.clone(), + Arc::clone(&object_store), ) .await?; - let schema = self.get_writer_schema(); + let schema = get_writer_schema(&self.config); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); let pool = Arc::clone(context.memory_pool()); @@ -841,7 +828,7 @@ impl DataSink for ParquetSink { writer, rx, schema, - props.writer_options(), + &props, parallel_options_clone, pool, ) @@ -882,6 +869,25 @@ impl DataSink for ParquetSink { } } +#[async_trait] +impl DataSink for ParquetSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> &SchemaRef { + self.config.output_schema() + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + FileSink::write_all(self, data, context).await + } +} + /// Consumes a stream of [ArrowLeafColumn] via a channel and serializes them using an [ArrowColumnWriter] /// Once the channel is exhausted, returns the ArrowColumnWriter. async fn column_serializer_task( @@ -908,7 +914,7 @@ fn spawn_column_parallel_row_group_writer( max_buffer_size: usize, pool: &Arc, ) -> Result<(Vec, Vec)> { - let schema_desc = arrow_to_parquet_schema(&schema)?; + let schema_desc = ArrowSchemaConverter::new().convert(&schema)?; let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; let num_columns = col_writers.len(); @@ -1016,8 +1022,8 @@ fn spawn_parquet_parallel_serialization_task( let max_row_group_rows = writer_props.max_row_group_size(); let (mut column_writer_handles, mut col_array_channels) = spawn_column_parallel_row_group_writer( - schema.clone(), - writer_props.clone(), + Arc::clone(&schema), + Arc::clone(&writer_props), max_buffer_rb, &pool, )?; @@ -1029,15 +1035,23 @@ fn spawn_parquet_parallel_serialization_task( // function. loop { if current_rg_rows + rb.num_rows() < max_row_group_rows { - send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) - .await?; + send_arrays_to_col_writers( + &col_array_channels, + &rb, + Arc::clone(&schema), + ) + .await?; current_rg_rows += rb.num_rows(); break; } else { let rows_left = max_row_group_rows - current_rg_rows; let a = rb.slice(0, rows_left); - send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) - .await?; + send_arrays_to_col_writers( + &col_array_channels, + &a, + Arc::clone(&schema), + ) + .await?; // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup // on a separate task, so that we can immediately start on the next RG before waiting @@ -1060,8 +1074,8 @@ fn spawn_parquet_parallel_serialization_task( (column_writer_handles, col_array_channels) = spawn_column_parallel_row_group_writer( - schema.clone(), - writer_props.clone(), + Arc::clone(&schema), + Arc::clone(&writer_props), max_buffer_rb, &pool, )?; @@ -1103,7 +1117,7 @@ async fn concatenate_parallel_row_groups( let mut file_reservation = MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); - let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; + let schema_desc = ArrowSchemaConverter::new().convert(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( merged_buff.clone(), schema_desc.root_schema_ptr(), @@ -1164,15 +1178,15 @@ async fn output_single_parquet_file_parallelized( let launch_serialization_task = spawn_parquet_parallel_serialization_task( data, serialize_tx, - output_schema.clone(), - arc_props.clone(), + Arc::clone(&output_schema), + Arc::clone(&arc_props), parallel_options, Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( serialize_rx, - output_schema.clone(), - arc_props.clone(), + Arc::clone(&output_schema), + Arc::clone(&arc_props), object_store_writer, pool, ) @@ -1278,7 +1292,7 @@ mod tests { use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::physical_plan::metrics::MetricValue; - use crate::prelude::{SessionConfig, SessionContext}; + use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::{Array, ArrayRef, StringArray}; use arrow_array::types::Int32Type; use arrow_array::{DictionaryArray, Int32Array, Int64Array}; @@ -1289,8 +1303,8 @@ mod tests { as_float64_array, as_int32_array, as_timestamp_nanosecond_array, }; use datafusion_common::config::ParquetOptions; - use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Utf8; + use datafusion_common::{assert_batches_eq, ScalarValue}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -1411,7 +1425,7 @@ mod tests { } impl Display for RequestCountingObjectStore { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "RequestCounting({})", self.inner) } } @@ -1679,7 +1693,7 @@ mod tests { let null_utf8 = if force_views { ScalarValue::Utf8View(None) } else { - ScalarValue::Utf8(None) + Utf8(None) }; // Fetch statistics for first file @@ -1692,7 +1706,7 @@ mod tests { let expected_type = if force_views { ScalarValue::Utf8View } else { - ScalarValue::Utf8 + Utf8 }; assert_eq!( c1_stats.max_value, @@ -2217,6 +2231,59 @@ mod tests { scan_format(state, &*format, &testdata, file_name, projection, limit).await } + /// Test that 0-byte files don't break while reading + #[tokio::test] + async fn test_read_empty_parquet() -> Result<()> { + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}/empty.parquet", tmp_dir.path().to_string_lossy()); + File::create(&path).await?; + + let ctx = SessionContext::new(); + + let df = ctx + .read_parquet(&path, ParquetReadOptions::default()) + .await + .expect("read_parquet should succeed"); + + let result = df.collect().await?; + #[rustfmt::skip] + let expected = ["++", + "++"]; + assert_batches_eq!(expected, &result); + + Ok(()) + } + + /// Test that 0-byte files don't break while reading + #[tokio::test] + async fn test_read_partitioned_empty_parquet() -> Result<()> { + let tmp_dir = tempfile::TempDir::new().unwrap(); + let partition_dir = tmp_dir.path().join("col1=a"); + std::fs::create_dir(&partition_dir).unwrap(); + File::create(partition_dir.join("empty.parquet")) + .await + .unwrap(); + + let ctx = SessionContext::new(); + + let df = ctx + .read_parquet( + tmp_dir.path().to_str().unwrap(), + ParquetReadOptions::new() + .table_partition_cols(vec![("col1".to_string(), DataType::Utf8)]), + ) + .await + .expect("read_parquet should succeed"); + + let result = df.collect().await?; + #[rustfmt::skip] + let expected = ["++", + "++"]; + assert_batches_eq!(expected, &result); + + Ok(()) + } + fn build_ctx(store_url: &url::Url) -> Arc { let tmp_dir = tempfile::TempDir::new().unwrap(); let local = Arc::new( @@ -2246,6 +2313,216 @@ mod tests { #[tokio::test] async fn parquet_sink_write() -> Result<()> { + let parquet_sink = create_written_parquet_sink("file:///").await?; + + // assert written to proper path + let (path, file_metadata) = get_written(parquet_sink)?; + let path_parts = path.parts().collect::>(); + assert_eq!(path_parts.len(), 1, "should not have path prefix"); + + // check the file metadata + let expected_kv_meta = vec![ + // default is to include arrow schema + KeyValue { + key: "ARROW:schema".to_string(), + value: Some(ENCODED_ARROW_SCHEMA.to_string()), + }, + KeyValue { + key: "my-data".to_string(), + value: Some("stuff".to_string()), + }, + KeyValue { + key: "my-data-bool-key".to_string(), + value: None, + }, + ]; + assert_file_metadata(file_metadata, &expected_kv_meta); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_parallel_write() -> Result<()> { + let opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 2, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + + let parquet_sink = + create_written_parquet_sink_using_config("file:///", opts).await?; + + // assert written to proper path + let (path, file_metadata) = get_written(parquet_sink)?; + let path_parts = path.parts().collect::>(); + assert_eq!(path_parts.len(), 1, "should not have path prefix"); + + // check the file metadata + let expected_kv_meta = vec![ + // default is to include arrow schema + KeyValue { + key: "ARROW:schema".to_string(), + value: Some(ENCODED_ARROW_SCHEMA.to_string()), + }, + KeyValue { + key: "my-data".to_string(), + value: Some("stuff".to_string()), + }, + KeyValue { + key: "my-data-bool-key".to_string(), + value: None, + }, + ]; + assert_file_metadata(file_metadata, &expected_kv_meta); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_insert_schema_into_metadata() -> Result<()> { + // expected kv metadata without schema + let expected_without = vec![ + KeyValue { + key: "my-data".to_string(), + value: Some("stuff".to_string()), + }, + KeyValue { + key: "my-data-bool-key".to_string(), + value: None, + }, + ]; + // expected kv metadata with schema + let expected_with = [ + vec![KeyValue { + key: "ARROW:schema".to_string(), + value: Some(ENCODED_ARROW_SCHEMA.to_string()), + }], + expected_without.clone(), + ] + .concat(); + + // single threaded write, skip insert + let opts = ParquetOptions { + allow_single_file_parallelism: false, + skip_arrow_metadata: true, + ..Default::default() + }; + let parquet_sink = + create_written_parquet_sink_using_config("file:///", opts).await?; + let (_, file_metadata) = get_written(parquet_sink)?; + assert_file_metadata(file_metadata, &expected_without); + + // single threaded write, do not skip insert + let opts = ParquetOptions { + allow_single_file_parallelism: false, + skip_arrow_metadata: false, + ..Default::default() + }; + let parquet_sink = + create_written_parquet_sink_using_config("file:///", opts).await?; + let (_, file_metadata) = get_written(parquet_sink)?; + assert_file_metadata(file_metadata, &expected_with); + + // multithreaded write, skip insert + let opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 2, + maximum_buffered_record_batches_per_stream: 2, + skip_arrow_metadata: true, + ..Default::default() + }; + let parquet_sink = + create_written_parquet_sink_using_config("file:///", opts).await?; + let (_, file_metadata) = get_written(parquet_sink)?; + assert_file_metadata(file_metadata, &expected_without); + + // multithreaded write, do not skip insert + let opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 2, + maximum_buffered_record_batches_per_stream: 2, + skip_arrow_metadata: false, + ..Default::default() + }; + let parquet_sink = + create_written_parquet_sink_using_config("file:///", opts).await?; + let (_, file_metadata) = get_written(parquet_sink)?; + assert_file_metadata(file_metadata, &expected_with); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_extension() -> Result<()> { + let filename = "test_file.custom_ext"; + let file_path = format!("file:///path/to/{}", filename); + let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?; + + // assert written to proper path + let (path, _) = get_written(parquet_sink)?; + let path_parts = path.parts().collect::>(); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); + assert_eq!(path_parts.last().unwrap().as_ref(), filename); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_directory_name() -> Result<()> { + let file_path = "file:///path/to"; + let parquet_sink = create_written_parquet_sink(file_path).await?; + + // assert written to proper path + let (path, _) = get_written(parquet_sink)?; + let path_parts = path.parts().collect::>(); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); + assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_folder_ending() -> Result<()> { + let file_path = "file:///path/to/"; + let parquet_sink = create_written_parquet_sink(file_path).await?; + + // assert written to proper path + let (path, _) = get_written(parquet_sink)?; + let path_parts = path.parts().collect::>(); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); + assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); + + Ok(()) + } + + async fn create_written_parquet_sink(table_path: &str) -> Result> { + create_written_parquet_sink_using_config(table_path, ParquetOptions::default()) + .await + } + + static ENCODED_ARROW_SCHEMA: &str = "/////5QAAAAQAAAAAAAKAAwACgAJAAQACgAAABAAAAAAAQQACAAIAAAABAAIAAAABAAAAAIAAAA8AAAABAAAANz///8UAAAADAAAAAAAAAUMAAAAAAAAAMz///8BAAAAYgAAABAAFAAQAAAADwAEAAAACAAQAAAAGAAAAAwAAAAAAAAFEAAAAAAAAAAEAAQABAAAAAEAAABhAAAA"; + + async fn create_written_parquet_sink_using_config( + table_path: &str, + global: ParquetOptions, + ) -> Result> { + // schema should match the ENCODED_ARROW_SCHEMA bove let field_a = Field::new("a", DataType::Utf8, false); let field_b = Field::new("b", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -2254,11 +2531,12 @@ mod tests { let file_sink_config = FileSinkConfig { object_store_url: object_store_url.clone(), file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], - table_paths: vec![ListingTableUrl::parse("file:///")?], + table_paths: vec![ListingTableUrl::parse(table_path)?], output_schema: schema.clone(), table_partition_cols: vec![], insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, + file_extension: "parquet".into(), }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -2267,6 +2545,7 @@ mod tests { ("my-data".to_string(), Some("stuff".to_string())), ("my-data-bool-key".to_string(), None), ]), + global, ..Default::default() }, )); @@ -2277,18 +2556,20 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); // write stream - parquet_sink - .write_all( - Box::pin(RecordBatchStreamAdapter::new( - schema, - futures::stream::iter(vec![Ok(batch)]), - )), - &build_ctx(object_store_url.as_ref()), - ) - .await - .unwrap(); + FileSink::write_all( + parquet_sink.as_ref(), + Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(vec![Ok(batch)]), + )), + &build_ctx(object_store_url.as_ref()), + ) + .await?; - // assert written + Ok(parquet_sink) + } + + fn get_written(parquet_sink: Arc) -> Result<(Path, FileMetaData)> { let mut written = parquet_sink.written(); let written = written.drain(); assert_eq!( @@ -2298,19 +2579,17 @@ mod tests { written.len() ); - // check the file metadata - let ( - path, - FileMetaData { - num_rows, - schema, - key_value_metadata, - .. - }, - ) = written.take(1).next().unwrap(); - let path_parts = path.parts().collect::>(); - assert_eq!(path_parts.len(), 1, "should not have path prefix"); + let (path, file_metadata) = written.take(1).next().unwrap(); + Ok((path, file_metadata)) + } + fn assert_file_metadata(file_metadata: FileMetaData, expected_kv: &Vec) { + let FileMetaData { + num_rows, + schema, + key_value_metadata, + .. + } = file_metadata; assert_eq!(num_rows, 2, "file metadata to have 2 rows"); assert!( schema.iter().any(|col_schema| col_schema.name == "a"), @@ -2323,19 +2602,7 @@ mod tests { let mut key_value_metadata = key_value_metadata.unwrap(); key_value_metadata.sort_by(|a, b| a.key.cmp(&b.key)); - let expected_metadata = vec![ - KeyValue { - key: "my-data".to_string(), - value: Some("stuff".to_string()), - }, - KeyValue { - key: "my-data-bool-key".to_string(), - value: None, - }, - ]; - assert_eq!(key_value_metadata, expected_metadata); - - Ok(()) + assert_eq!(&key_value_metadata, expected_kv); } #[tokio::test] @@ -2354,6 +2621,7 @@ mod tests { table_partition_cols: vec![("a".to_string(), DataType::Utf8)], // add partitioning insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, + file_extension: "parquet".into(), }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -2366,16 +2634,15 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); // write stream - parquet_sink - .write_all( - Box::pin(RecordBatchStreamAdapter::new( - schema, - futures::stream::iter(vec![Ok(batch)]), - )), - &build_ctx(object_store_url.as_ref()), - ) - .await - .unwrap(); + FileSink::write_all( + parquet_sink.as_ref(), + Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(vec![Ok(batch)]), + )), + &build_ctx(object_store_url.as_ref()), + ) + .await?; // assert written let mut written = parquet_sink.written(); @@ -2437,6 +2704,7 @@ mod tests { table_partition_cols: vec![], insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, + file_extension: "parquet".into(), }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -2464,7 +2732,8 @@ mod tests { "no bytes are reserved yet" ); - let mut write_task = parquet_sink.write_all( + let mut write_task = FileSink::write_all( + parquet_sink.as_ref(), Box::pin(RecordBatchStreamAdapter::new( schema, bounded_stream(batch, 1000), diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 427b28db40301..48db2c0802559 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -20,11 +20,10 @@ use std::borrow::Cow; use std::collections::HashMap; - use std::sync::Arc; use crate::datasource::listing::ListingTableUrl; - +use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; use crate::physical_plan::SendableRecordBatchStream; @@ -32,37 +31,51 @@ use arrow_array::builder::UInt64Builder; use arrow_array::cast::AsArray; use arrow_array::{downcast_dictionary_array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Schema}; -use chrono::NaiveDate; use datafusion_common::cast::{ as_boolean_array, as_date32_array, as_date64_array, as_int32_array, as_int64_array, as_string_array, as_string_view_array, }; -use datafusion_common::{exec_datafusion_err, DataFusionError}; +use datafusion_common::{exec_datafusion_err, not_impl_err, DataFusionError}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; +use chrono::NaiveDate; use futures::StreamExt; use object_store::path::Path; - use rand::distributions::DistString; - use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; type RecordBatchReceiver = Receiver; -type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; +pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; /// Splits a single [SendableRecordBatchStream] into a dynamically determined -/// number of partitions at execution time. The partitions are determined by -/// factors known only at execution time, such as total number of rows and -/// partition column values. The demuxer task communicates to the caller -/// by sending channels over a channel. The inner channels send RecordBatches -/// which should be contained within the same output file. The outer channel -/// is used to send a dynamic number of inner channels, representing a dynamic -/// number of total output files. The caller is also responsible to monitor -/// the demux task for errors and abort accordingly. The single_file_output parameter -/// overrides all other settings to force only a single file to be written. -/// partition_by parameter will additionally split the input based on the unique -/// values of a specific column ``` +/// number of partitions at execution time. +/// +/// The partitions are determined by factors known only at execution time, such +/// as total number of rows and partition column values. The demuxer task +/// communicates to the caller by sending channels over a channel. The inner +/// channels send RecordBatches which should be contained within the same output +/// file. The outer channel is used to send a dynamic number of inner channels, +/// representing a dynamic number of total output files. +/// +/// The caller is also responsible to monitor the demux task for errors and +/// abort accordingly. +/// +/// A path with an extension will force only a single file to +/// be written with the extension from the path. Otherwise the default extension +/// will be used and the output will be split into multiple files. +/// +/// Examples of `base_output_path` +/// * `tmp/dataset/` -> is a folder since it ends in `/` +/// * `tmp/dataset` -> is still a folder since it does not end in `/` but has no valid file extension +/// * `tmp/file.parquet` -> is a file since it does not end in `/` and has a valid file extension `.parquet` +/// * `tmp/file.parquet/` -> is a folder since it ends in `/` +/// +/// The `partition_by` parameter will additionally split the input based on the +/// unique values of a specific column, see +/// +/// +/// ```text /// ┌───────────┐ ┌────────────┐ ┌─────────────┐ /// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ /// │ └───────────┘ └────────────┘ └─────────────┘ @@ -74,45 +87,47 @@ type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; /// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ /// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ /// └───────────┘ └────────────┘ └─────────────┘ +/// ``` pub(crate) fn start_demuxer_task( - input: SendableRecordBatchStream, + config: &FileSinkConfig, + data: SendableRecordBatchStream, context: &Arc, - partition_by: Option>, - base_output_path: ListingTableUrl, - file_extension: String, - keep_partition_by_columns: bool, ) -> (SpawnedTask>, DemuxedStreamReceiver) { let (tx, rx) = mpsc::unbounded_channel(); - let context = context.clone(); - let single_file_output = !base_output_path.is_collection(); - let task = match partition_by { - Some(parts) => { - // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot - // bound this channel without risking a deadlock. - SpawnedTask::spawn(async move { - hive_style_partitions_demuxer( - tx, - input, - context, - parts, - base_output_path, - file_extension, - keep_partition_by_columns, - ) - .await - }) - } - None => SpawnedTask::spawn(async move { + let context = Arc::clone(context); + let file_extension = config.file_extension.clone(); + let base_output_path = config.table_paths[0].clone(); + let task = if config.table_partition_cols.is_empty() { + let single_file_output = !base_output_path.is_collection() + && base_output_path.file_extension().is_some(); + SpawnedTask::spawn(async move { row_count_demuxer( tx, - input, + data, context, base_output_path, file_extension, single_file_output, ) .await - }), + }) + } else { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + let partition_by = config.table_partition_cols.clone(); + let keep_partition_by_columns = config.keep_partition_by_columns; + SpawnedTask::spawn(async move { + hive_style_partitions_demuxer( + tx, + data, + context, + partition_by, + base_output_path, + file_extension, + keep_partition_by_columns, + ) + .await + }) }; (task, rx) @@ -280,9 +295,8 @@ async fn hive_style_partitions_demuxer( Some(part_tx) => part_tx, None => { // Create channel for previously unseen distinct partition key and notify consumer of new file - let (part_tx, part_rx) = tokio::sync::mpsc::channel::( - max_buffered_recordbatches, - ); + let (part_tx, part_rx) = + mpsc::channel::(max_buffered_recordbatches); let file_path = compute_hive_style_file_path( &part_key, &partition_by, @@ -421,10 +435,10 @@ fn compute_partition_keys_by_row<'a>( ) } _ => { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "it is not yet supported to write to hive partitions with datatype {}", dtype - ))) + ) } } @@ -461,7 +475,7 @@ fn remove_partition_by_columns( .zip(parted_batch.schema().fields()) .filter_map(|(a, f)| { if !partition_names.contains(&f.name()) { - Some((a.clone(), (**f).clone())) + Some((Arc::clone(a), (**f).clone())) } else { None } diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 42115fc7b93fb..c064999c1e5be 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -22,10 +22,11 @@ use std::io::Write; use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; use arrow_array::RecordBatch; - +use arrow_schema::Schema; use bytes::Bytes; use object_store::buffered::BufWriter; use object_store::path::Path; @@ -86,3 +87,25 @@ pub(crate) async fn create_writer( let buf_writer = BufWriter::new(object_store, location.clone()); file_compression_type.convert_async_writer(buf_writer) } + +/// Converts table schema to writer schema, which may differ in the case +/// of hive style partitioning where some columns are removed from the +/// underlying files. +pub(crate) fn get_writer_schema(config: &FileSinkConfig) -> Arc { + if !config.table_partition_cols.is_empty() && !config.keep_partition_by_columns { + let schema = config.output_schema(); + let partition_names: Vec<_> = + config.table_partition_cols.iter().map(|(s, _)| s).collect(); + Arc::new(Schema::new_with_metadata( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + schema.metadata().clone(), + )) + } else { + Arc::clone(config.output_schema()) + } +} diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 6f27e6f3889f2..7a271def7dc64 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -21,12 +21,10 @@ use std::sync::Arc; -use super::demux::start_demuxer_task; +use super::demux::DemuxedStreamReceiver; use super::{create_writer, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; -use crate::physical_plan::SendableRecordBatchStream; use arrow_array::RecordBatch; use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; @@ -35,6 +33,7 @@ use datafusion_execution::TaskContext; use bytes::Bytes; use futures::join; +use object_store::ObjectStore; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinSet; @@ -53,11 +52,11 @@ pub(crate) enum SerializedRecordBatchResult { }, Failure { /// As explained in [`serialize_rb_stream_to_object_store`]: - /// - If an IO error occured that involved the ObjectStore writer, then the writer will not be returned to the caller + /// - If an IO error occurred that involved the ObjectStore writer, then the writer will not be returned to the caller /// - Otherwise, the writer is returned to the caller writer: Option, - /// the actual error that occured + /// the actual error that occurred err: DataFusionError, }, } @@ -94,7 +93,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( // subsequent batches, so we track that here. let mut initial = true; while let Some(batch) = data_rx.recv().await { - let serializer_clone = serializer.clone(); + let serializer_clone = Arc::clone(&serializer); let task = SpawnedTask::spawn(async move { let num_rows = batch.num_rows(); let bytes = serializer_clone.serialize(batch, initial)?; @@ -238,34 +237,14 @@ pub(crate) async fn stateless_serialize_and_write_files( /// Orchestrates multipart put of a dynamic number of output files from a single input stream /// for any statelessly serialized file type. That is, any file type for which each [RecordBatch] /// can be serialized independently of all other [RecordBatch]s. -pub(crate) async fn stateless_multipart_put( - data: SendableRecordBatchStream, +pub(crate) async fn spawn_writer_tasks_and_join( context: &Arc, - file_extension: String, - get_serializer: Box Arc + Send>, - config: &FileSinkConfig, + serializer: Arc, compression: FileCompressionType, + object_store: Arc, + demux_task: SpawnedTask>, + mut file_stream_rx: DemuxedStreamReceiver, ) -> Result { - let object_store = context - .runtime_env() - .object_store(&config.object_store_url)?; - - let base_output_path = &config.table_paths[0]; - let part_cols = if !config.table_partition_cols.is_empty() { - Some(config.table_partition_cols.clone()) - } else { - None - }; - - let (demux_task, mut file_stream_rx) = start_demuxer_task( - data, - context, - part_cols, - base_output_path.clone(), - file_extension, - config.keep_partition_by_columns, - ); - let rb_buffer_size = &context .session_config() .options() @@ -278,17 +257,18 @@ pub(crate) async fn stateless_multipart_put( stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { - let serializer = get_serializer(); - let writer = create_writer(compression, &location, object_store.clone()).await?; + let writer = + create_writer(compression, &location, Arc::clone(&object_store)).await?; - tx_file_bundle - .send((rb_stream, serializer, writer)) + if tx_file_bundle + .send((rb_stream, Arc::clone(&serializer), writer)) .await - .map_err(|_| { - internal_datafusion_err!( - "Writer receive file bundle channel closed unexpectedly!" - ) - })?; + .is_err() + { + internal_datafusion_err!( + "Writer receive file bundle channel closed unexpectedly!" + ); + } } // Signal to the write coordinator that no more files are coming @@ -301,9 +281,8 @@ pub(crate) async fn stateless_multipart_put( r1.map_err(DataFusionError::ExecutionJoin)??; r2.map_err(DataFusionError::ExecutionJoin)??; - let total_count = rx_row_cnt.await.map_err(|_| { + // Return total row count: + rx_row_cnt.await.map_err(|_| { internal_datafusion_err!("Did not receive row count from write coordinator") - })?; - - Ok(total_count) + }) } diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs deleted file mode 100644 index 37ce59f8207b2..0000000000000 --- a/datafusion/core/src/datasource/function.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! A table that uses a function to generate data - -use super::TableProvider; - -use datafusion_common::Result; -use datafusion_expr::Expr; - -use std::fmt::Debug; -use std::sync::Arc; - -/// A trait for table function implementations -pub trait TableFunctionImpl: Debug + Sync + Send { - /// Create a table provider - fn call(&self, args: &[Expr]) -> Result>; -} - -/// A table that uses a function to generate data -#[derive(Debug)] -pub struct TableFunction { - /// Name of the table function - name: String, - /// Function implementation - fun: Arc, -} - -impl TableFunction { - /// Create a new table function - pub fn new(name: String, fun: Arc) -> Self { - Self { name, fun } - } - - /// Get the name of the table function - pub fn name(&self) -> &str { - &self.name - } - - /// Get the implementation of the table function - pub fn function(&self) -> &Arc { - &self.fun - } - - /// Get the function implementation and generate a table - pub fn create_table_provider(&self, args: &[Expr]) -> Result> { - self.fun.call(args) - } -} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index e18fb8fc7ba39..47710726a25ef 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -17,14 +17,14 @@ //! Helper functions for the table implementation -use std::collections::HashMap; use std::mem; use std::sync::Arc; use super::ListingTableUrl; use super::PartitionedFile; use crate::execution::context::SessionState; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::internal_err; +use datafusion_common::{HashMap, Result, ScalarValue}; use datafusion_expr::{BinaryExpr, Operator}; use arrow::{ @@ -135,7 +135,7 @@ pub fn split_files( partitioned_files.sort_by(|a, b| a.path().cmp(b.path())); // effectively this is div with rounding up instead of truncating - let chunk_size = (partitioned_files.len() + n - 1) / n; + let chunk_size = partitioned_files.len().div_ceil(n); let mut chunks = Vec::with_capacity(n); let mut current_chunk = Vec::with_capacity(chunk_size); for file in partitioned_files.drain(..) { @@ -171,7 +171,13 @@ impl Partition { trace!("Listing partition {}", self.path); let prefix = Some(&self.path).filter(|p| !p.as_ref().is_empty()); let result = store.list_with_delimiter(prefix).await?; - self.files = Some(result.objects); + self.files = Some( + result + .objects + .into_iter() + .filter(|object_meta| object_meta.size > 0) + .collect(), + ); Ok((self, result.common_prefixes)) } } @@ -285,25 +291,20 @@ async fn prune_partitions( let props = ExecutionProps::new(); // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Option { - let expr = create_physical_expr(filter, &df_schema, &props).ok()?; - expr.evaluate(&batch) - .ok()? - .into_array(partitions.len()) - .ok() + let do_filter = |filter| -> Result { + let expr = create_physical_expr(filter, &df_schema, &props)?; + expr.evaluate(&batch)?.into_array(partitions.len()) }; - //.Compute the conjunction of the filters, ignoring errors + //.Compute the conjunction of the filters let mask = filters .iter() - .fold(None, |acc, filter| match (acc, do_filter(filter)) { - (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), - (None, Some(r)) => Some(r.as_boolean().clone()), - (r, None) => r, - }); + .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) + .reduce(|a, b| Ok(and(&a?, &b?)?)); let mask = match mask { - Some(mask) => mask, + Some(Ok(mask)) => mask, + Some(Err(err)) => return Err(err), None => return Ok(partitions), }; @@ -401,8 +402,8 @@ fn evaluate_partition_prefix<'a>( /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. -/// `filters` might contain expressions that can be resolved only at the -/// file level (e.g. Parquet row group pruning). +/// `filters` should only contain expressions that can be evaluated +/// using only the partition columns. pub async fn pruned_partition_list<'a>( ctx: &'a SessionState, store: &'a dyn ObjectStore, @@ -413,10 +414,17 @@ pub async fn pruned_partition_list<'a>( ) -> Result>> { // if no partition col => simply list all the files if partition_cols.is_empty() { + if !filters.is_empty() { + return internal_err!( + "Got partition filters for unpartitioned table {}", + table_path + ); + } return Ok(Box::pin( table_path .list_all_files(ctx, store, file_extension) .await? + .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)) .map_ok(|object_meta| object_meta.into()), )); } @@ -467,6 +475,7 @@ pub async fn pruned_partition_list<'a>( range: None, statistics: None, extensions: None, + metadata_size_hint: None, }) })); @@ -564,6 +573,7 @@ mod tests { async fn test_pruned_partition_list_empty() { let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/notparquetfile", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), ("tablepath/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); @@ -588,6 +598,7 @@ mod tests { let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/file.parquet", 100), ("tablepath/mypartition=val2/file.parquet", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), ("tablepath/mypartition=val1/other=val3/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); @@ -631,13 +642,11 @@ mod tests { ]); let filter1 = Expr::eq(col("part1"), lit("p1v2")); let filter2 = Expr::eq(col("part2"), lit("p2v1")); - // filter3 cannot be resolved at partition pruning - let filter3 = Expr::eq(col("part2"), col("other")); let pruned = pruned_partition_list( &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2, filter3], + &[filter1, filter2], ".parquet", &[ (String::from("part1"), DataType::Utf8), @@ -671,6 +680,107 @@ mod tests { ); } + /// Describe a partition as a (path, depth, files) tuple for easier assertions + fn describe_partition(partition: &Partition) -> (&str, usize, Vec<&str>) { + ( + partition.path.as_ref(), + partition.depth, + partition + .files + .as_ref() + .map(|f| f.iter().map(|f| f.location.filename().unwrap()).collect()) + .unwrap_or_default(), + ) + } + + #[tokio::test] + async fn test_list_partition() { + let (store, _) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), + ]); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 0, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec![]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 1, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 2, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ( + "tablepath/part1=p1v2/part2=p2v1", + 2, + vec!["file1.parquet", "file2.parquet"] + ), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), + ] + ); + } + #[test] fn test_parse_partitions_for_path() { assert_eq!( diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index c5a441aacf1d2..f11653ce1e52a 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -81,6 +81,8 @@ pub struct PartitionedFile { pub statistics: Option, /// An optional field for user defined per object metadata pub extensions: Option>, + /// The estimated size of the parquet metadata, in bytes + pub metadata_size_hint: Option, } impl PartitionedFile { @@ -98,6 +100,7 @@ impl PartitionedFile { range: None, statistics: None, extensions: None, + metadata_size_hint: None, } } @@ -115,10 +118,19 @@ impl PartitionedFile { range: Some(FileRange { start, end }), statistics: None, extensions: None, + metadata_size_hint: None, } .with_range(start, end) } + /// Provide a hint to the size of the file metadata. If a hint is provided + /// the reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. + /// Without an appropriate hint, two read may be required to fetch the metadata. + pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { + self.metadata_size_hint = Some(metadata_size_hint); + self + } + /// Return a file reference from the given path pub fn from_path(path: String) -> Result { let size = std::fs::metadata(path.clone())?.len(); @@ -156,6 +168,7 @@ impl From for PartitionedFile { range: None, statistics: None, extensions: None, + metadata_size_hint: None, } } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 0da146c558a0a..a3d7a26d98114 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -114,19 +114,26 @@ impl ListingTableConfig { } } - fn infer_file_extension(path: &str) -> Result { + ///Returns a tupe of (file_extension, optional compression_extension) + /// + /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` + /// For example a path ending with blah.test.csv returns `("csv", None)` + fn infer_file_extension_and_compression_type( + path: &str, + ) -> Result<(String, Option)> { let mut exts = path.rsplit('.'); - let mut splitted = exts.next().unwrap_or(""); + let splitted = exts.next().unwrap_or(""); let file_compression_type = FileCompressionType::from_str(splitted) .unwrap_or(FileCompressionType::UNCOMPRESSED); if file_compression_type.is_compressed() { - splitted = exts.next().unwrap_or(""); + let splitted2 = exts.next().unwrap_or(""); + Ok((splitted2.to_string(), Some(splitted.to_string()))) + } else { + Ok((splitted.to_string(), None)) } - - Ok(splitted.to_string()) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -147,18 +154,33 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let file_extension = - ListingTableConfig::infer_file_extension(file.location.as_ref())?; + let (file_extension, maybe_compression_type) = + ListingTableConfig::infer_file_extension_and_compression_type( + file.location.as_ref(), + )?; + + let mut format_options = HashMap::new(); + if let Some(ref compression_type) = maybe_compression_type { + format_options + .insert("format.compression".to_string(), compression_type.clone()); + } let file_format = state .get_file_format_factory(&file_extension) .ok_or(config_datafusion_err!( "No file_format found with extension {file_extension}" ))? - .create(state, &HashMap::new())?; + .create(state, &format_options)?; + + let listing_file_extension = + if let Some(compression_type) = maybe_compression_type { + format!("{}.{}", &file_extension, &compression_type) + } else { + file_extension + }; let listing_options = ListingOptions::new(file_format) - .with_file_extension(file_extension) + .with_file_extension(listing_file_extension) .with_target_partitions(state.config().target_partitions()); Ok(Self { @@ -470,6 +492,8 @@ impl ListingOptions { let files: Vec<_> = table_path .list_all_files(state, store.as_ref(), &self.file_extension) .await? + // Empty files cannot affect schema but may throw when trying to read for it + .try_filter(|object_meta| future::ready(object_meta.size > 0)) .try_collect() .await?; @@ -719,10 +743,16 @@ impl ListingTable { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } + let table_schema = Arc::new( + builder + .finish() + .with_metadata(file_schema.metadata().clone()), + ); + let table = Self { table_paths: config.table_paths, file_schema, - table_schema: Arc::new(builder.finish()), + table_schema, options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), @@ -782,6 +812,16 @@ impl ListingTable { } } +// Expressions can be used for parttion pruning if they can be evaluated using +// only the partiton columns and there are partition columns. +fn can_be_evaluted_for_partition_pruning( + partition_column_names: &[&str], + expr: &Expr, +) -> bool { + !partition_column_names.is_empty() + && expr_applicable_for_cols(partition_column_names, expr) +} + #[async_trait] impl TableProvider for ListingTable { fn as_any(&self) -> &dyn Any { @@ -807,10 +847,33 @@ impl TableProvider for ListingTable { filters: &[Expr], limit: Option, ) -> Result> { + // extract types of partition columns + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .collect::>>()?; + + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) + }); // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? let session_state = state.as_any().downcast_ref::().unwrap(); + + // We should not limit the number of partitioned files to scan if there are filters and limit + // at the same time. This is because the limit should be applied after the filters are applied. + let statistic_file_limit = if filters.is_empty() { limit } else { None }; + let (mut partitioned_file_lists, statistics) = self - .list_files_for_scan(session_state, filters, limit) + .list_files_for_scan(session_state, &partition_filters, statistic_file_limit) .await?; // if no files need to be read, return an `EmptyExec` @@ -846,40 +909,18 @@ impl TableProvider for ListingTable { None => {} // no ordering required }; - // extract types of partition columns - let table_partition_cols = self - .options - .table_partition_cols - .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) - .collect::>>()?; - - // If the filters can be resolved using only partition cols, there is no need to - // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated - let table_partition_col_names = table_partition_cols - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - let filters = filters - .iter() - .filter(|filter| { - !expr_applicable_for_cols(&table_partition_col_names, filter) - }) - .cloned() - .collect::>(); - - let filters = conjunction(filters.to_vec()) - .map(|expr| -> Result<_> { - // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. + let filters = match conjunction(filters.to_vec()) { + Some(expr) => { let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; let filters = create_physical_expr( &expr, &table_df_schema, state.execution_props(), )?; - Ok(Some(filters)) - }) - .unwrap_or(Ok(None))?; + Some(filters) + } + None => None, + }; let Some(object_store_url) = self.table_paths.first().map(ListingTableUrl::object_store) @@ -894,6 +935,7 @@ impl TableProvider for ListingTable { session_state, FileScanConfig::new(object_store_url, Arc::clone(&self.file_schema)) .with_file_groups(partitioned_file_lists) + .with_constraints(self.constraints.clone()) .with_statistics(statistics) .with_projection(projection.cloned()) .with_limit(limit) @@ -908,18 +950,17 @@ impl TableProvider for ListingTable { &self, filters: &[&Expr], ) -> Result> { + let partition_column_names = self + .options + .table_partition_cols + .iter() + .map(|col| col.0.as_str()) + .collect::>(); filters .iter() .map(|filter| { - if expr_applicable_for_cols( - &self - .options - .table_partition_cols - .iter() - .map(|col| col.0.as_str()) - .collect::>(), - filter, - ) { + if can_be_evaluted_for_partition_pruning(&partition_column_names, filter) + { // if filter can be handled by partition pruning, it is exact return Ok(TableProviderFilterPushDown::Exact); } @@ -1010,21 +1051,22 @@ impl TableProvider for ListingTable { table_partition_cols: self.options.table_partition_cols.clone(), insert_op, keep_partition_by_columns, + file_extension: self.options().format.get_ext(), }; let order_requirements = if !self.options().file_sort_order.is_empty() { // Multiple sort orders in outer vec are equivalent, so we pass only the first one - let ordering = self - .try_create_output_ordering()? - .first() - .ok_or(DataFusionError::Internal( - "Expected ListingTable to have a sort order, but none found!".into(), - ))? - .clone(); + let orderings = self.try_create_output_ordering()?; + let Some(ordering) = orderings.first() else { + return internal_err!( + "Expected ListingTable to have a sort order, but none found!" + ); + }; // Converts Vec> into type required by execution plan to specify its required input ordering Some(LexRequirement::new( ordering .into_iter() + .cloned() .map(PhysicalSortRequirement::from) .collect::>(), )) @@ -1108,8 +1150,8 @@ impl ListingTable { /// This method first checks if the statistics for the given file are already cached. /// If they are, it returns the cached statistics. /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics<'a>( - &'a self, + async fn do_collect_statistics( + &self, ctx: &SessionState, store: &Arc, part_file: &PartitionedFile, @@ -1126,14 +1168,14 @@ impl ListingTable { .infer_stats( ctx, store, - self.file_schema.clone(), + Arc::clone(&self.file_schema), &part_file.object_meta, ) .await?; let statistics = Arc::new(statistics); self.collected_statistics.put_with_extra( &part_file.object_meta.location, - statistics.clone(), + Arc::clone(&statistics), &part_file.object_meta, ); Ok(statistics) @@ -1272,13 +1314,16 @@ mod tests { // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![vec![PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]]) + Ok(vec![LexOrdering::new( + vec![PhysicalSortExpr { + expr: physical_col("string_col", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + ) + ]) ), // ok with two columns, different options ( @@ -1286,15 +1331,17 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![vec![ - PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) - .asc() - .nulls_last(), - - PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) - .desc() - .nulls_first() - ]]) + Ok(vec![LexOrdering::new( + vec![ + PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) + .asc() + .nulls_last(), + PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) + .desc() + .nulls_first() + ], + ) + ]) ), ]; @@ -2171,4 +2218,23 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_infer_options_compressed_csv() -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let filename = format!("{}/csv/aggregate_test_100.csv.gz", testdata); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let ctx = SessionContext::new(); + + let config = ListingTableConfig::new(table_path); + let config_with_opts = config.infer_options(&ctx.state()).await?; + let config_with_schema = config_with_opts.infer_schema(&ctx.state()).await?; + + let schema = config_with_schema.file_schema.unwrap(); + + assert_eq!(schema.fields.len(), 13); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 1701707fdb726..6fb536ca2f05b 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -170,7 +170,7 @@ impl ListingTableUrl { if ignore_subdirectory { segments .next() - .map_or(false, |file_name| glob.matches(file_name)) + .is_some_and(|file_name| glob.matches(file_name)) } else { let stripped = segments.join(DELIMITER); glob.matches(&stripped) @@ -190,6 +190,30 @@ impl ListingTableUrl { self.url.path().ends_with(DELIMITER) } + /// Returns the file extension of the last path segment if it exists + /// + /// Examples: + /// ```rust + /// use datafusion::datasource::listing::ListingTableUrl; + /// let url = ListingTableUrl::parse("file:///foo/bar.csv").unwrap(); + /// assert_eq!(url.file_extension(), Some("csv")); + /// let url = ListingTableUrl::parse("file:///foo/bar").unwrap(); + /// assert_eq!(url.file_extension(), None); + /// let url = ListingTableUrl::parse("file:///foo/bar.").unwrap(); + /// assert_eq!(url.file_extension(), None); + /// ``` + pub fn file_extension(&self) -> Option<&str> { + if let Some(segments) = self.url.path_segments() { + if let Some(last_segment) = segments.last() { + if last_segment.contains(".") && !last_segment.ends_with(".") { + return last_segment.split('.').last(); + } + } + } + + None + } + /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning /// an iterator of the remaining path segments pub(crate) fn strip_prefix<'a, 'b: 'a>( @@ -493,4 +517,54 @@ mod tests { "path not ends with / - fragment ends with / - not collection", ); } + + #[test] + fn test_file_extension() { + fn test(input: &str, expected: Option<&str>, message: &str) { + let url = ListingTableUrl::parse(input).unwrap(); + assert_eq!(url.file_extension(), expected, "{message}"); + } + + test("https://a.b.c/path/", None, "path ends with / - not a file"); + test( + "https://a.b.c/path/?a=b", + None, + "path ends with / - with query args - not a file", + ); + test( + "https://a.b.c/path?a=b/", + None, + "path not ends with / - query ends with / but no file extension", + ); + test( + "https://a.b.c/path/#a=b", + None, + "path ends with / - with fragment - not a file", + ); + test( + "https://a.b.c/path#a=b/", + None, + "path not ends with / - fragment ends with / but no file extension", + ); + test( + "file///some/path/", + None, + "file path ends with / - not a file", + ); + test( + "file///some/path/file", + None, + "file path does not end with - no extension", + ); + test( + "file///some/path/file.", + None, + "file path ends with . - no value after .", + ); + test( + "file///some/path/file.ext", + Some("ext"), + "file path ends with .ext - extension is ext", + ); + } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index fed63ec12b496..636d1623c5e91 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -91,7 +91,7 @@ impl TableProviderFactory for ListingTableFactory { .field_with_name(col) .map_err(|e| arrow_datafusion_err!(e)) }) - .collect::>>()? + .collect::>>()? .into_iter() .map(|f| (f.name().to_owned(), f.data_type().to_owned())) .collect(); @@ -127,7 +127,7 @@ impl TableProviderFactory for ListingTableFactory { // See: https://github.com/apache/datafusion/issues/7317 None => { let schema = options.infer_schema(session_state, &table_path).await?; - let df_schema = schema.clone().to_dfschema()?; + let df_schema = Arc::clone(&schema).to_dfschema()?; let column_refs: HashSet<_> = cmd .order_exprs .iter() @@ -197,6 +197,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, @@ -236,6 +237,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, @@ -252,7 +254,7 @@ mod tests { let format = listing_table.options().format.clone(); let csv_format = format.as_any().downcast_ref::().unwrap(); let csv_options = csv_format.options().clone(); - assert_eq!(csv_options.schema_infer_max_rec, 1000); + assert_eq!(csv_options.schema_infer_max_rec, Some(1000)); let listing_options = listing_table.options(); assert_eq!(".tbl", listing_options.file_extension); } diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 24a4938e7b2bf..31239ed332aea 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -37,14 +37,13 @@ use crate::physical_planner::create_physical_sort_exprs; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_catalog::Session; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; -use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_expr::SortExpr; use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_expr::SortExpr; use futures::StreamExt; use log::debug; use parking_lot::Mutex; @@ -132,6 +131,7 @@ impl MemTable { state: &SessionState, ) -> Result { let schema = t.schema(); + let constraints = t.constraints(); let exec = t.scan(state, None, &[], None).await?; let partition_count = exec.output_partitioning().partition_count(); @@ -139,7 +139,7 @@ impl MemTable { for part_idx in 0..partition_count { let task = state.task_ctx(); - let exec = exec.clone(); + let exec = Arc::clone(&exec); join_set.spawn(async move { let stream = exec.execute(part_idx, task)?; common::collect(stream).await @@ -162,7 +162,10 @@ impl MemTable { } } - let exec = MemoryExec::try_new(&data, schema.clone(), None)?; + let mut exec = MemoryExec::try_new(&data, Arc::clone(&schema), None)?; + if let Some(cons) = constraints { + exec = exec.with_constraints(cons.clone()); + } if let Some(num_partitions) = output_partitions { let exec = RepartitionExec::try_new( @@ -183,9 +186,9 @@ impl MemTable { output_partitions.push(batches); } - return MemTable::try_new(schema.clone(), output_partitions); + return MemTable::try_new(Arc::clone(&schema), output_partitions); } - MemTable::try_new(schema.clone(), data) + MemTable::try_new(Arc::clone(&schema), data) } } @@ -196,7 +199,7 @@ impl TableProvider for MemTable { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn constraints(&self) -> Option<&Constraints> { @@ -241,7 +244,7 @@ impl TableProvider for MemTable { ) }) .collect::>>()?; - exec = exec.with_sort_information(file_sort_order); + exec = exec.try_with_sort_information(file_sort_order)?; } Ok(Arc::new(exec)) @@ -293,13 +296,8 @@ impl TableProvider for MemTable { if insert_op != InsertOp::Append { return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } - let sink = Arc::new(MemSink::new(self.batches.clone())); - Ok(Arc::new(DataSinkExec::new( - input, - sink, - self.schema.clone(), - None, - ))) + let sink = MemSink::try_new(self.batches.clone(), Arc::clone(&self.schema))?; + Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None))) } fn get_column_default(&self, column: &str) -> Option<&Expr> { @@ -311,6 +309,7 @@ impl TableProvider for MemTable { struct MemSink { /// Target locations for writing data batches: Vec, + schema: SchemaRef, } impl Debug for MemSink { @@ -333,8 +332,14 @@ impl DisplayAs for MemSink { } impl MemSink { - fn new(batches: Vec) -> Self { - Self { batches } + /// Creates a new [`MemSink`]. + /// + /// The caller is responsible for ensuring that there is at least one partition to insert into. + fn try_new(batches: Vec, schema: SchemaRef) -> Result { + if batches.is_empty() { + return plan_err!("Cannot insert into MemTable with zero partitions"); + } + Ok(Self { batches, schema }) } } @@ -344,8 +349,8 @@ impl DataSink for MemSink { self } - fn metrics(&self) -> Option { - None + fn schema(&self) -> &SchemaRef { + &self.schema } async fn write_all( @@ -779,4 +784,27 @@ mod tests { assert_eq!(resulting_data_in_table[0].len(), 2); Ok(()) } + + // Test inserting a batch into a MemTable without any partitions + #[tokio::test] + async fn test_insert_into_zero_partition() -> Result<()> { + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + // Run the experiment and expect an error + let experiment_result = experiment(schema, vec![], vec![vec![batch.clone()]]) + .await + .unwrap_err(); + // Ensure that there is a descriptive error message + assert_eq!( + "Error during planning: Cannot insert into MemTable with zero partitions", + experiment_result.strip_backtrace() + ); + Ok(()) + } } diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 0ed53418fe32c..7d3fe9ddd7515 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -25,7 +25,6 @@ pub mod default_table_source; pub mod dynamic_file; pub mod empty; pub mod file_format; -pub mod function; pub mod listing; pub mod listing_table_factory; pub mod memory; @@ -62,7 +61,7 @@ fn create_ordering( for exprs in sort_order { // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; + let mut sort_exprs = LexOrdering::default(); for sort in exprs { match &sort.expr { Expr::Column(col) => match expressions::col(&col.name, schema) { diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 39625a55ca15e..54344d55bbd11 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -35,10 +35,11 @@ use arrow::buffer::Buffer; use arrow_ipc::reader::FileDecoder; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::Statistics; +use datafusion_common::{Constraints, Statistics}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use datafusion_physical_plan::{ExecutionMode, PlanProperties}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::PlanProperties; use futures::StreamExt; use itertools::Itertools; @@ -46,7 +47,6 @@ use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; /// Execution plan for scanning Arrow data source #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct ArrowExec { base_config: FileScanConfig, projected_statistics: Statistics, @@ -60,11 +60,16 @@ pub struct ArrowExec { impl ArrowExec { /// Create a new Arrow reader execution plan provided base configurations pub fn new(base_config: FileScanConfig) -> Self { - let (projected_schema, projected_statistics, projected_output_ordering) = - base_config.project(); + let ( + projected_schema, + projected_constraints, + projected_statistics, + projected_output_ordering, + ) = base_config.project(); let cache = Self::compute_properties( - projected_schema.clone(), + Arc::clone(&projected_schema), &projected_output_ordering, + projected_constraints, &base_config, ); Self { @@ -88,17 +93,20 @@ impl ArrowExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( schema: SchemaRef, - projected_output_ordering: &[LexOrdering], + output_ordering: &[LexOrdering], + constraints: Constraints, file_scan_config: &FileScanConfig, ) -> PlanProperties { // Equivalence Properties let eq_properties = - EquivalenceProperties::new_with_orderings(schema, projected_output_ordering); + EquivalenceProperties::new_with_orderings(schema, output_ordering) + .with_constraints(constraints); PlanProperties::new( eq_properties, Self::output_partitioning_helper(file_scan_config), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + EmissionType::Incremental, + Boundedness::Bounded, ) } @@ -207,7 +215,7 @@ impl ExecutionPlan for ArrowExec { Some(Arc::new(Self { base_config: new_config, projected_statistics: self.projected_statistics.clone(), - projected_schema: self.projected_schema.clone(), + projected_schema: Arc::clone(&self.projected_schema), projected_output_ordering: self.projected_output_ordering.clone(), metrics: self.metrics.clone(), cache: self.cache.clone(), @@ -222,7 +230,7 @@ pub struct ArrowOpener { impl FileOpener for ArrowOpener { fn open(&self, file_meta: FileMeta) -> Result { - let object_store = self.object_store.clone(); + let object_store = Arc::clone(&self.object_store); let projection = self.projection.clone(); Ok(Box::pin(async move { let range = file_meta.range.clone(); diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index ce72c4087424e..87d8964bed6af 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -24,17 +24,18 @@ use super::FileScanConfig; use crate::error::Result; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, }; use arrow::datatypes::SchemaRef; +use datafusion_common::Constraints; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; /// Execution plan for scanning Avro data source #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct AvroExec { base_config: FileScanConfig, projected_statistics: Statistics, @@ -48,11 +49,16 @@ pub struct AvroExec { impl AvroExec { /// Create a new Avro reader execution plan provided base configurations pub fn new(base_config: FileScanConfig) -> Self { - let (projected_schema, projected_statistics, projected_output_ordering) = - base_config.project(); + let ( + projected_schema, + projected_constraints, + projected_statistics, + projected_output_ordering, + ) = base_config.project(); let cache = Self::compute_properties( - projected_schema.clone(), + Arc::clone(&projected_schema), &projected_output_ordering, + projected_constraints, &base_config, ); Self { @@ -73,16 +79,19 @@ impl AvroExec { fn compute_properties( schema: SchemaRef, orderings: &[LexOrdering], + constraints: Constraints, file_scan_config: &FileScanConfig, ) -> PlanProperties { // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings); + let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) + .with_constraints(constraints); let n_partitions = file_scan_config.file_groups.len(); PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(n_partitions), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + EmissionType::Incremental, + Boundedness::Bounded, ) } } @@ -175,7 +184,7 @@ impl ExecutionPlan for AvroExec { Some(Arc::new(Self { base_config: new_config, projected_statistics: self.projected_statistics.clone(), - projected_schema: self.projected_schema.clone(), + projected_schema: Arc::clone(&self.projected_schema), projected_output_ordering: self.projected_output_ordering.clone(), metrics: self.metrics.clone(), cache: self.cache.clone(), @@ -205,7 +214,7 @@ mod private { fn open(&self, reader: R) -> Result> { AvroReader::try_new( reader, - self.schema.clone(), + Arc::clone(&self.schema), self.batch_size, self.projection.clone(), ) @@ -218,7 +227,7 @@ mod private { impl FileOpener for AvroOpener { fn open(&self, file_meta: FileMeta) -> Result { - let config = self.config.clone(); + let config = Arc::clone(&self.config); Ok(Box::pin(async move { let r = config.object_store.get(file_meta.location()).await?; match r.payload { @@ -285,7 +294,7 @@ mod tests { let meta = local_unpartitioned_file(filename); let file_schema = AvroFormat {} - .infer_schema(&state, &store, &[meta.clone()]) + .infer_schema(&state, &store, std::slice::from_ref(&meta)) .await?; let avro_exec = AvroExec::new( @@ -350,7 +359,7 @@ mod tests { let object_store_url = ObjectStoreUrl::local_filesystem(); let meta = local_unpartitioned_file(filename); let actual_schema = AvroFormat {} - .infer_schema(&state, &object_store, &[meta.clone()]) + .infer_schema(&state, &object_store, std::slice::from_ref(&meta)) .await?; let mut builder = SchemaBuilder::from(actual_schema.fields()); @@ -423,7 +432,7 @@ mod tests { let object_store_url = ObjectStoreUrl::local_filesystem(); let meta = local_unpartitioned_file(filename); let file_schema = AvroFormat {} - .infer_schema(&state, &object_store, &[meta.clone()]) + .infer_schema(&state, &object_store, std::slice::from_ref(&meta)) .await?; let mut partitioned_file = PartitionedFile::from(meta); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 6cd1864deb1d4..dd5736806eebe 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -24,6 +24,7 @@ use std::task::Poll; use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer}; use crate::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, @@ -32,18 +33,19 @@ use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, - Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PlanProperties, SendableRecordBatchStream, Statistics, }; use arrow::csv; use arrow::datatypes::SchemaRef; use datafusion_common::config::ConfigOptions; +use datafusion_common::Constraints; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use bytes::{Buf, Bytes}; -use futures::{ready, StreamExt, TryStreamExt}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; @@ -208,11 +210,16 @@ impl CsvExecBuilder { newlines_in_values, } = self; - let (projected_schema, projected_statistics, projected_output_ordering) = - base_config.project(); + let ( + projected_schema, + projected_constraints, + projected_statistics, + projected_output_ordering, + ) = base_config.project(); let cache = CsvExec::compute_properties( projected_schema, &projected_output_ordering, + projected_constraints, &base_config, ); @@ -319,15 +326,18 @@ impl CsvExec { fn compute_properties( schema: SchemaRef, orderings: &[LexOrdering], + constraints: Constraints, file_scan_config: &FileScanConfig, ) -> PlanProperties { // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings); + let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) + .with_constraints(constraints); PlanProperties::new( eq_properties, Self::output_partitioning_helper(file_scan_config), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + EmissionType::Incremental, + Boundedness::Bounded, ) } @@ -521,7 +531,7 @@ impl CsvConfig { } fn builder(&self) -> csv::ReaderBuilder { - let mut builder = csv::ReaderBuilder::new(self.file_schema.clone()) + let mut builder = csv::ReaderBuilder::new(Arc::clone(&self.file_schema)) .with_delimiter(self.delimiter) .with_batch_size(self.batch_size) .with_header(self.has_header) @@ -611,12 +621,14 @@ impl FileOpener for CsvOpener { ); } - let store = self.config.object_store.clone(); + let store = Arc::clone(&self.config.object_store); + let terminator = self.config.terminator; Ok(Box::pin(async move { // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = + calculate_range(&file_meta, &store, terminator).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, @@ -651,36 +663,14 @@ impl FileOpener for CsvOpener { Ok(futures::stream::iter(config.open(decoder)?).boxed()) } GetResultPayload::Stream(s) => { - let mut decoder = config.builder().build_decoder(); + let decoder = config.builder().build_decoder(); let s = s.map_err(DataFusionError::from); - let mut input = - file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffered = Bytes::new(); - - let s = futures::stream::poll_fn(move |cx| { - loop { - if buffered.is_empty() { - match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => buffered = b, - Some(Err(e)) => { - return Poll::Ready(Some(Err(e.into()))) - } - None => {} - }; - } - let decoded = match decoder.decode(buffered.as_ref()) { - // Note: the decoder needs to be called with an empty - // array to delimt the final record - Ok(0) => break, - Ok(decoded) => decoded, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - buffered.advance(decoded); - } - - Poll::Ready(decoder.flush().transpose()) - }); - Ok(s.boxed()) + let input = file_compression_type.convert_stream(s.boxed())?.fuse(); + + Ok(deserialize_stream( + input, + DecoderDeserializer::from(decoder), + )) } } })) @@ -698,12 +688,12 @@ pub async fn plan_to_csv( let store = task_ctx.runtime_env().object_store(&object_store_url)?; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let storeref = store.clone(); - let plan: Arc = plan.clone(); + let storeref = Arc::clone(&store); + let plan: Arc = Arc::clone(&plan); let filename = format!("{}/part-{i}.csv", parsed.prefix()); let file = object_store::path::Path::parse(filename)?; - let mut stream = plan.execute(i, task_ctx.clone())?; + let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); @@ -753,6 +743,7 @@ mod tests { use crate::{scalar::ScalarValue, test_util::aggr_test_schema}; use arrow::datatypes::*; + use bytes::Bytes; use datafusion_common::test_util::arrow_test_data; use datafusion_common::config::CsvOptions; @@ -1216,7 +1207,7 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\n1,2\n3,4"); + let data = Bytes::from("a,b\n1,2\n3,4"); let path = object_store::path::Path::from("a.csv"); store.put(&path, data.into()).await.unwrap(); @@ -1247,7 +1238,7 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\r1,2\r3,4"); + let data = Bytes::from("a,b\r1,2\r3,4"); let path = object_store::path::Path::from("a.csv"); store.put(&path, data.into()).await.unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs index 28f975ae193d5..f681dfe219b51 100644 --- a/datafusion/core/src/datasource/physical_plan/file_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -217,8 +217,7 @@ impl FileGroupPartitioner { return None; } - let target_partition_size = - (total_size as usize + (target_partitions) - 1) / (target_partitions); + let target_partition_size = (total_size as usize).div_ceil(target_partitions); let current_partition_index: usize = 0; let current_partition_size: usize = 0; @@ -782,7 +781,7 @@ mod test { assert_partitioned_files(expected, actual); } - /// Asserts that the two groups of `ParititonedFile` are the same + /// Asserts that the two groups of [`PartitionedFile`] are the same /// (PartitionedFile doesn't implement PartialEq) fn assert_partitioned_files( expected: Option>>, diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 2c438e8b0e78b..5a38886bb16f9 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,7 +19,8 @@ //! file sources. use std::{ - borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, + sync::Arc, vec, }; use super::{get_projected_output_ordering, statistics::MinMaxStatistics}; @@ -32,8 +33,10 @@ use arrow::datatypes::{ArrowNativeType, UInt16Type}; use arrow_array::{ArrayRef, DictionaryArray, RecordBatch, RecordBatchOptions}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::stats::Precision; -use datafusion_common::{exec_err, ColumnStatistics, DataFusionError, Statistics}; -use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_common::{ + exec_err, ColumnStatistics, Constraints, DataFusionError, Statistics, +}; +use datafusion_physical_expr::LexOrdering; use log::warn; @@ -113,6 +116,8 @@ pub struct FileScanConfig { /// concurrently, however files *within* a partition will be read /// sequentially, one after the next. pub file_groups: Vec>, + /// Table constraints + pub constraints: Constraints, /// Estimated overall statistics of the files, taking `filters` into account. /// Defaults to [`Statistics::new_unknown`]. pub statistics: Statistics, @@ -145,6 +150,7 @@ impl FileScanConfig { object_store_url, file_schema, file_groups: vec![], + constraints: Constraints::empty(), statistics, projection: None, limit: None, @@ -153,6 +159,12 @@ impl FileScanConfig { } } + /// Set the table constraints of the files + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + /// Set the statistics of the files pub fn with_statistics(mut self, statistics: Statistics) -> Self { self.statistics = statistics; @@ -209,30 +221,31 @@ impl FileScanConfig { self } - /// Project the schema and the statistics on the given column indices - pub fn project(&self) -> (SchemaRef, Statistics, Vec) { + /// Project the schema, constraints, and the statistics on the given column indices + pub fn project(&self) -> (SchemaRef, Constraints, Statistics, Vec) { if self.projection.is_none() && self.table_partition_cols.is_empty() { return ( Arc::clone(&self.file_schema), + self.constraints.clone(), self.statistics.clone(), self.output_ordering.clone(), ); } - let proj_iter: Box> = match &self.projection { - Some(proj) => Box::new(proj.iter().copied()), - None => Box::new( - 0..(self.file_schema.fields().len() + self.table_partition_cols.len()), - ), + let proj_indices = if let Some(proj) = &self.projection { + proj + } else { + let len = self.file_schema.fields().len() + self.table_partition_cols.len(); + &(0..len).collect::>() }; let mut table_fields = vec![]; let mut table_cols_stats = vec![]; - for idx in proj_iter { - if idx < self.file_schema.fields().len() { - let field = self.file_schema.field(idx); + for idx in proj_indices { + if *idx < self.file_schema.fields().len() { + let field = self.file_schema.field(*idx); table_fields.push(field.clone()); - table_cols_stats.push(self.statistics.column_statistics[idx].clone()) + table_cols_stats.push(self.statistics.column_statistics[*idx].clone()) } else { let partition_idx = idx - self.file_schema.fields().len(); table_fields.push(self.table_partition_cols[partition_idx].to_owned()); @@ -248,14 +261,25 @@ impl FileScanConfig { column_statistics: table_cols_stats, }; - let projected_schema = Arc::new( - Schema::new(table_fields).with_metadata(self.file_schema.metadata().clone()), - ); + let projected_schema = Arc::new(Schema::new_with_metadata( + table_fields, + self.file_schema.metadata().clone(), + )); + + let projected_constraints = self + .constraints + .project(proj_indices) + .unwrap_or_else(Constraints::empty); let projected_output_ordering = get_projected_output_ordering(self, &projected_schema); - (projected_schema, table_stats, projected_output_ordering) + ( + projected_schema, + projected_constraints, + table_stats, + projected_output_ordering, + ) } #[cfg_attr(not(feature = "avro"), allow(unused))] // Only used by avro @@ -281,7 +305,12 @@ impl FileScanConfig { fields.map_or_else( || Arc::clone(&self.file_schema), - |f| Arc::new(Schema::new(f).with_metadata(self.file_schema.metadata.clone())), + |f| { + Arc::new(Schema::new_with_metadata( + f, + self.file_schema.metadata.clone(), + )) + }, ) } @@ -300,7 +329,7 @@ impl FileScanConfig { pub fn split_groups_by_statistics( table_schema: &SchemaRef, file_groups: &[Vec], - sort_order: &[PhysicalSortExpr], + sort_order: &LexOrdering, ) -> Result>> { let flattened_files = file_groups.iter().flatten().collect::>(); // First Fit: @@ -491,7 +520,7 @@ impl ZeroBufferGenerator where T: ArrowNativeType, { - const SIZE: usize = std::mem::size_of::(); + const SIZE: usize = size_of::(); fn get_buffer(&mut self, n_vals: usize) -> Buffer { match &mut self.cache { @@ -628,7 +657,7 @@ mod tests { )]), ); - let (proj_schema, proj_statistics, _) = conf.project(); + let (proj_schema, _, proj_statistics, _) = conf.project(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( proj_schema.field(file_schema.fields().len()).name(), @@ -668,7 +697,7 @@ mod tests { ); // verify the proj_schema includes the last column and exactly the same the field it is defined - let (proj_schema, _proj_statistics, _) = conf.project(); + let (proj_schema, _, _, _) = conf.project(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( *proj_schema.field(file_schema.fields().len()), @@ -701,7 +730,7 @@ mod tests { )]), ); - let (proj_schema, proj_statistics, _) = conf.project(); + let (proj_schema, _, proj_statistics, _) = conf.project(); assert_eq!( columns(&proj_schema), vec!["date".to_owned(), "c1".to_owned()] @@ -1105,17 +1134,18 @@ mod tests { )))) .collect::>(), )); - let sort_order = case - .sort - .into_iter() - .map(|expr| { - crate::physical_planner::create_physical_sort_expr( - &expr, - &DFSchema::try_from(table_schema.as_ref().clone())?, - &ExecutionProps::default(), - ) - }) - .collect::>>()?; + let sort_order = LexOrdering::from( + case.sort + .into_iter() + .map(|expr| { + crate::physical_planner::create_physical_sort_expr( + &expr, + &DFSchema::try_from(table_schema.as_ref().clone())?, + &ExecutionProps::default(), + ) + }) + .collect::>>()?, + ); let partitioned_files = case.files.into_iter().map(From::from).collect::>(); @@ -1156,7 +1186,7 @@ mod tests { }) .collect::>() }) - .map_err(|e| e.to_string().leak() as &'static str); + .map_err(|e| e.strip_backtrace().leak() as &'static str); assert_eq!(results_by_name, case.expected_result, "{}", case.name); } @@ -1200,6 +1230,7 @@ mod tests { .collect::>(), }), extensions: None, + metadata_size_hint: None, } } } diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 6f354b31ae878..18cda4524ab25 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -24,6 +24,7 @@ use std::collections::VecDeque; use std::mem; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; use crate::datasource::listing::PartitionedFile; @@ -252,7 +253,7 @@ impl FileStream { ) -> Result { let (projected_schema, ..) = config.project(); let pc_projector = PartitionColumnProjector::new( - projected_schema.clone(), + Arc::clone(&projected_schema), &config .table_partition_cols .iter() @@ -295,6 +296,7 @@ impl FileStream { object_meta: part_file.object_meta, range: part_file.range, extensions: part_file.extensions, + metadata_size_hint: part_file.metadata_size_hint, }; Some( @@ -509,7 +511,7 @@ impl Stream for FileStream { impl RecordBatchStream for FileStream { fn schema(&self) -> SchemaRef { - self.projected_schema.clone() + Arc::clone(&self.projected_schema) } } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index cf8f129a50369..7ac062e549c42 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -24,6 +24,7 @@ use std::task::Poll; use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer}; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, @@ -32,17 +33,18 @@ use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, - Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PlanProperties, SendableRecordBatchStream, Statistics, }; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; +use datafusion_common::Constraints; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use bytes::{Buf, Bytes}; -use futures::{ready, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; @@ -65,11 +67,16 @@ impl NdJsonExec { base_config: FileScanConfig, file_compression_type: FileCompressionType, ) -> Self { - let (projected_schema, projected_statistics, projected_output_ordering) = - base_config.project(); + let ( + projected_schema, + projected_constraints, + projected_statistics, + projected_output_ordering, + ) = base_config.project(); let cache = Self::compute_properties( projected_schema, &projected_output_ordering, + projected_constraints, &base_config, ); Self { @@ -86,6 +93,11 @@ impl NdJsonExec { &self.base_config } + /// Ref to file compression type + pub fn file_compression_type(&self) -> &FileCompressionType { + &self.file_compression_type + } + fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) } @@ -94,15 +106,18 @@ impl NdJsonExec { fn compute_properties( schema: SchemaRef, orderings: &[LexOrdering], + constraints: Constraints, file_scan_config: &FileScanConfig, ) -> PlanProperties { // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings); + let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) + .with_constraints(constraints); PlanProperties::new( eq_properties, Self::output_partitioning_helper(file_scan_config), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + EmissionType::Incremental, + Boundedness::Bounded, ) } @@ -262,13 +277,13 @@ impl FileOpener for JsonOpener { /// /// See [`CsvOpener`](super::CsvOpener) for an example. fn open(&self, file_meta: FileMeta) -> Result { - let store = self.object_store.clone(); - let schema = self.projected_schema.clone(); + let store = Arc::clone(&self.object_store); + let schema = Arc::clone(&self.projected_schema); let batch_size = self.batch_size; let file_compression_type = self.file_compression_type.to_owned(); Ok(Box::pin(async move { - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = calculate_range(&file_meta, &store, None).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, @@ -307,37 +322,15 @@ impl FileOpener for JsonOpener { GetResultPayload::Stream(s) => { let s = s.map_err(DataFusionError::from); - let mut decoder = ReaderBuilder::new(schema) + let decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; - let mut input = - file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffer = Bytes::new(); - - let s = futures::stream::poll_fn(move |cx| { - loop { - if buffer.is_empty() { - match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => buffer = b, - Some(Err(e)) => { - return Poll::Ready(Some(Err(e.into()))) - } - None => {} - }; - } - - let decoded = match decoder.decode(buffer.as_ref()) { - Ok(0) => break, - Ok(decoded) => decoded, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - - buffer.advance(decoded); - } + let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - Poll::Ready(decoder.flush().transpose()) - }); - Ok(s.boxed()) + Ok(deserialize_stream( + input, + DecoderDeserializer::from(decoder), + )) } } })) @@ -355,12 +348,12 @@ pub async fn plan_to_json( let store = task_ctx.runtime_env().object_store(&object_store_url)?; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let storeref = store.clone(); - let plan: Arc = plan.clone(); + let storeref = Arc::clone(&store); + let plan: Arc = Arc::clone(&plan); let filename = format!("{}/part-{i}.json", parsed.prefix()); let file = object_store::path::Path::parse(filename)?; - let mut stream = plan.execute(i, task_ctx.clone())?; + let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut buf_writer = BufWriter::new(storeref, file.clone()); @@ -447,7 +440,7 @@ mod tests { .object_meta; let schema = JsonFormat::default() .with_file_compression_type(file_compression_type.to_owned()) - .infer_schema(state, &store, &[meta.clone()]) + .infer_schema(state, &store, std::slice::from_ref(&meta)) .await .unwrap(); @@ -885,7 +878,7 @@ mod tests { )] #[cfg(feature = "compression")] #[tokio::test] - async fn test_json_with_repartitioing( + async fn test_json_with_repartitioning( file_compression_type: FileCompressionType, ) -> Result<()> { let config = SessionConfig::new() diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 6e8752ccfbf4b..5bb7da8376a2e 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -51,7 +51,8 @@ use std::{ vec, }; -use super::listing::ListingTableUrl; +use super::{file_format::write::demux::start_demuxer_task, listing::ListingTableUrl}; +use crate::datasource::file_format::write::demux::DemuxedStreamReceiver; use crate::error::Result; use crate::physical_plan::{DisplayAs, DisplayFormatType}; use crate::{ @@ -63,13 +64,73 @@ use crate::{ }; use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::insert::DataSink; +use async_trait::async_trait; use futures::StreamExt; use log::debug; use object_store::{path::Path, GetOptions, GetRange, ObjectMeta, ObjectStore}; +/// General behaviors for files that do `DataSink` operations +#[async_trait] +pub trait FileSink: DataSink { + /// Retrieves the file sink configuration. + fn config(&self) -> &FileSinkConfig; + + /// Spawns writer tasks and joins them to perform file writing operations. + /// Is a critical part of `FileSink` trait, since it's the very last step for `write_all`. + /// + /// This function handles the process of writing data to files by: + /// 1. Spawning tasks for writing data to individual files. + /// 2. Coordinating the tasks using a demuxer to distribute data among files. + /// 3. Collecting results using `tokio::join`, ensuring that all tasks complete successfully. + /// + /// # Parameters + /// - `context`: The execution context (`TaskContext`) that provides resources + /// like memory management and runtime environment. + /// - `demux_task`: A spawned task that handles demuxing, responsible for splitting + /// an input [`SendableRecordBatchStream`] into dynamically determined partitions. + /// See `start_demuxer_task()` + /// - `file_stream_rx`: A receiver that yields streams of record batches and their + /// corresponding file paths for writing. See `start_demuxer_task()` + /// - `object_store`: A handle to the object store where the files are written. + /// + /// # Returns + /// - `Result`: Returns the total number of rows written across all files. + async fn spawn_writer_tasks_and_join( + &self, + context: &Arc, + demux_task: SpawnedTask>, + file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, + ) -> Result; + + /// File sink implementation of the [`DataSink::write_all`] method. + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let config = self.config(); + let object_store = context + .runtime_env() + .object_store(&config.object_store_url)?; + let (demux_task, file_stream_rx) = start_demuxer_task(config, data, context); + self.spawn_writer_tasks_and_join( + context, + demux_task, + file_stream_rx, + object_store, + ) + .await + } +} + /// The base configurations to provide when creating a physical plan for /// writing to any given file format. pub struct FileSinkConfig { @@ -89,6 +150,8 @@ pub struct FileSinkConfig { pub insert_op: InsertOp, /// Controls whether partition columns are kept for the file pub keep_partition_by_columns: bool, + /// File extension without a dot(.) + pub file_extension: String, } impl FileSinkConfig { @@ -110,7 +173,7 @@ impl Debug for FileScanConfig { impl DisplayAs for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { - let (schema, _, orderings) = self.project(); + let (schema, _, _, orderings) = self.project(); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -125,6 +188,10 @@ impl DisplayAs for FileScanConfig { display_orderings(f, &orderings)?; + if !self.constraints.is_empty() { + write!(f, ", {}", self.constraints)?; + } + Ok(()) } } @@ -138,7 +205,7 @@ impl DisplayAs for FileScanConfig { #[derive(Debug)] struct FileGroupsDisplay<'a>(&'a [Vec]); -impl<'a> DisplayAs for FileGroupsDisplay<'a> { +impl DisplayAs for FileGroupsDisplay<'_> { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { let n_groups = self.0.len(); let groups = if n_groups == 1 { "group" } else { "groups" }; @@ -170,7 +237,7 @@ impl<'a> DisplayAs for FileGroupsDisplay<'a> { #[derive(Debug)] pub(crate) struct FileGroupDisplay<'a>(pub &'a [PartitionedFile]); -impl<'a> DisplayAs for FileGroupDisplay<'a> { +impl DisplayAs for FileGroupDisplay<'_> { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { write!(f, "[")?; match t { @@ -247,6 +314,8 @@ pub struct FileMeta { pub range: Option, /// An optional field for user defined per object metadata pub extensions: Option>, + /// Size hint for the metadata of this file + pub metadata_size_hint: Option, } impl FileMeta { @@ -262,6 +331,7 @@ impl From for FileMeta { object_meta, range: None, extensions: None, + metadata_size_hint: None, } } } @@ -328,11 +398,11 @@ impl From for FileMeta { fn get_projected_output_ordering( base_config: &FileScanConfig, projected_schema: &SchemaRef, -) -> Vec> { +) -> Vec { let mut all_orderings = vec![]; for output_ordering in &base_config.output_ordering { - let mut new_ordering = vec![]; - for PhysicalSortExpr { expr, options } in output_ordering { + let mut new_ordering = LexOrdering::default(); + for PhysicalSortExpr { expr, options } in output_ordering.iter() { if let Some(col) = expr.as_any().downcast_ref::() { let name = col.name(); if let Some((idx, _)) = projected_schema.column_with_name(name) { @@ -422,9 +492,11 @@ enum RangeCalculation { async fn calculate_range( file_meta: &FileMeta, store: &Arc, + terminator: Option, ) -> Result { let location = file_meta.location(); let file_size = file_meta.object_meta.size; + let newline = terminator.unwrap_or(b'\n'); match file_meta.range { None => Ok(RangeCalculation::Range(None)), @@ -432,13 +504,13 @@ async fn calculate_range( let (start, end) = (start as usize, end as usize); let start_delta = if start != 0 { - find_first_newline(store, location, start - 1, file_size).await? + find_first_newline(store, location, start - 1, file_size, newline).await? } else { 0 }; let end_delta = if end != file_size { - find_first_newline(store, location, end - 1, file_size).await? + find_first_newline(store, location, end - 1, file_size, newline).await? } else { 0 }; @@ -458,7 +530,7 @@ async fn calculate_range( /// within an object, such as a file, in an object store. /// /// This function scans the contents of the object starting from the specified `start` position -/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// up to the `end` position, looking for the first occurrence of a newline character. /// It returns the position of the first newline relative to the start of the range. /// /// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. @@ -470,6 +542,7 @@ async fn find_first_newline( location: &Path, start: usize, end: usize, + newline: u8, ) -> Result { let options = GetOptions { range: Some(GetRange::Bounded(start..end)), @@ -482,7 +555,7 @@ async fn find_first_newline( let mut index = 0; while let Some(chunk) = result_stream.next().await.transpose()? { - if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + if let Some(position) = chunk.iter().position(|&byte| byte == newline) { return Ok(index + position); } @@ -763,7 +836,7 @@ mod tests { /// create a PartitionedFile for testing fn partitioned_file(path: &str) -> PartitionedFile { let object_meta = ObjectMeta { - location: object_store::path::Path::parse(path).unwrap(), + location: Path::parse(path).unwrap(), last_modified: Utc::now(), size: 42, e_tag: None, @@ -776,6 +849,7 @@ mod tests { range: None, statistics: None, extensions: None, + metadata_size_hint: None, } } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs index ea3030664b7b3..0d77a99699bd7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -345,7 +345,7 @@ mod test { use parquet::basic::LogicalType; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::{SchemaDescPtr, SchemaDescriptor}; - use std::sync::{Arc, OnceLock}; + use std::sync::{Arc, LazyLock}; #[test] fn test_only_scans() { @@ -358,7 +358,7 @@ mod test { let row_group_indexes = access_plan.row_group_indexes(); let row_selection = access_plan - .into_overall_row_selection(row_group_metadata()) + .into_overall_row_selection(&ROW_GROUP_METADATA) .unwrap(); // scan all row groups, no selection @@ -377,7 +377,7 @@ mod test { let row_group_indexes = access_plan.row_group_indexes(); let row_selection = access_plan - .into_overall_row_selection(row_group_metadata()) + .into_overall_row_selection(&ROW_GROUP_METADATA) .unwrap(); // skip all row groups, no selection @@ -403,7 +403,7 @@ mod test { let row_group_indexes = access_plan.row_group_indexes(); let row_selection = access_plan - .into_overall_row_selection(row_group_metadata()) + .into_overall_row_selection(&ROW_GROUP_METADATA) .unwrap(); assert_eq!(row_group_indexes, vec![0, 1]); @@ -442,7 +442,7 @@ mod test { let row_group_indexes = access_plan.row_group_indexes(); let row_selection = access_plan - .into_overall_row_selection(row_group_metadata()) + .into_overall_row_selection(&ROW_GROUP_METADATA) .unwrap(); assert_eq!(row_group_indexes, vec![1, 2, 3]); @@ -478,7 +478,7 @@ mod test { let row_group_indexes = access_plan.row_group_indexes(); let err = access_plan - .into_overall_row_selection(row_group_metadata()) + .into_overall_row_selection(&ROW_GROUP_METADATA) .unwrap_err() .to_string(); assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); @@ -504,39 +504,35 @@ mod test { let row_group_indexes = access_plan.row_group_indexes(); let err = access_plan - .into_overall_row_selection(row_group_metadata()) + .into_overall_row_selection(&ROW_GROUP_METADATA) .unwrap_err() .to_string(); assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); assert_contains!(err, "Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 22 rows"); } - static ROW_GROUP_METADATA: OnceLock> = OnceLock::new(); - /// [`RowGroupMetaData`] that returns 4 row groups with 10, 20, 30, 40 rows /// respectively - fn row_group_metadata() -> &'static [RowGroupMetaData] { - ROW_GROUP_METADATA.get_or_init(|| { - let schema_descr = get_test_schema_descr(); - let row_counts = [10, 20, 30, 40]; - - row_counts - .into_iter() - .map(|num_rows| { - let column = ColumnChunkMetaData::builder(schema_descr.column(0)) - .set_num_values(num_rows) - .build() - .unwrap(); - - RowGroupMetaData::builder(schema_descr.clone()) - .set_num_rows(num_rows) - .set_column_metadata(vec![column]) - .build() - .unwrap() - }) - .collect() - }) - } + static ROW_GROUP_METADATA: LazyLock> = LazyLock::new(|| { + let schema_descr = get_test_schema_descr(); + let row_counts = [10, 20, 30, 40]; + + row_counts + .into_iter() + .map(|num_rows| { + let column = ColumnChunkMetaData::builder(schema_descr.column(0)) + .set_num_values(num_rows) + .build() + .unwrap(); + + RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(num_rows) + .set_column_metadata(vec![column]) + .build() + .unwrap() + }) + .collect() + }); /// Single column schema with a single column named "a" of type `BYTE_ARRAY`/`String` fn get_test_schema_descr() -> SchemaDescPtr { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 6afb66cc7c02e..085f44191b8a8 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -34,13 +34,15 @@ use crate::{ physical_optimizer::pruning::PruningPredicate, physical_plan::{ metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, + DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, }, }; use arrow::datatypes::SchemaRef; +use datafusion_common::Constraints; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalExpr}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use itertools::Itertools; use log::debug; @@ -166,6 +168,33 @@ pub use writer::plan_to_parquet; /// [`RowFilter`]: parquet::arrow::arrow_reader::RowFilter /// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md /// +/// # Example: rewriting `ParquetExec` +/// +/// You can modify a `ParquetExec` using [`ParquetExecBuilder`], for example +/// to change files or add a predicate. +/// +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # fn parquet_exec() -> ParquetExec { unimplemented!() } +/// // Split a single ParquetExec into multiple ParquetExecs, one for each file +/// let exec = parquet_exec(); +/// let existing_file_groups = &exec.base_config().file_groups; +/// let new_execs = existing_file_groups +/// .iter() +/// .map(|file_group| { +/// // create a new exec by copying the existing exec into a builder +/// let new_exec = exec.clone() +/// .into_builder() +/// .with_file_groups(vec![file_group.clone()]) +/// .build(); +/// new_exec +/// }) +/// .collect::>(); +/// ``` +/// /// # Implementing External Indexes /// /// It is possible to restrict the row groups and selections within those row @@ -257,6 +286,12 @@ pub struct ParquetExec { schema_adapter_factory: Option>, } +impl From for ParquetExecBuilder { + fn from(exec: ParquetExec) -> Self { + exec.into_builder() + } +} + /// [`ParquetExecBuilder`], builder for [`ParquetExec`]. /// /// See example on [`ParquetExec`]. @@ -291,9 +326,15 @@ impl ParquetExecBuilder { } } + /// Update the list of files groups to read + pub fn with_file_groups(mut self, file_groups: Vec>) -> Self { + self.file_scan_config.file_groups = file_groups; + self + } + /// Set the filter predicate when reading. /// - /// See the "Predicate Pushdown" section of the [`ParquetExec`] documenation + /// See the "Predicate Pushdown" section of the [`ParquetExec`] documentation /// for more details. pub fn with_predicate(mut self, predicate: Arc) -> Self { self.predicate = Some(predicate); @@ -386,7 +427,7 @@ impl ParquetExecBuilder { let pruning_predicate = predicate .clone() .and_then(|predicate_expr| { - match PruningPredicate::try_new(predicate_expr, file_schema.clone()) { + match PruningPredicate::try_new(predicate_expr, Arc::clone(file_schema)) { Ok(pruning_predicate) => Some(Arc::new(pruning_predicate)), Err(e) => { debug!("Could not create pruning predicate for: {e}"); @@ -400,16 +441,21 @@ impl ParquetExecBuilder { let page_pruning_predicate = predicate .as_ref() .map(|predicate_expr| { - PagePruningAccessPlanFilter::new(predicate_expr, file_schema.clone()) + PagePruningAccessPlanFilter::new(predicate_expr, Arc::clone(file_schema)) }) .map(Arc::new); - let (projected_schema, projected_statistics, projected_output_ordering) = - base_config.project(); + let ( + projected_schema, + projected_constraints, + projected_statistics, + projected_output_ordering, + ) = base_config.project(); let cache = ParquetExec::compute_properties( projected_schema, &projected_output_ordering, + projected_constraints, &base_config, ); ParquetExec { @@ -459,6 +505,34 @@ impl ParquetExec { ParquetExecBuilder::new(file_scan_config) } + /// Convert this `ParquetExec` into a builder for modification + pub fn into_builder(self) -> ParquetExecBuilder { + // list out fields so it is clear what is being dropped + // (note the fields which are dropped are re-created as part of calling + // `build` on the builder) + let Self { + base_config, + projected_statistics: _, + metrics: _, + predicate, + pruning_predicate: _, + page_pruning_predicate: _, + metadata_size_hint, + parquet_file_reader_factory, + cache: _, + table_parquet_options, + schema_adapter_factory, + } = self; + ParquetExecBuilder { + file_scan_config: base_config, + predicate, + metadata_size_hint, + table_parquet_options, + parquet_file_reader_factory, + schema_adapter_factory, + } + } + /// [`FileScanConfig`] that controls this scan (such as which files to read) pub fn base_config(&self) -> &FileScanConfig { &self.base_config @@ -479,9 +553,15 @@ impl ParquetExec { self.pruning_predicate.as_ref() } + /// return the optional file reader factory + pub fn parquet_file_reader_factory( + &self, + ) -> Option<&Arc> { + self.parquet_file_reader_factory.as_ref() + } + /// Optional user defined parquet file reader factory. /// - /// See documentation on [`ParquetExecBuilder::with_parquet_file_reader_factory`] pub fn with_parquet_file_reader_factory( mut self, parquet_file_reader_factory: Arc, @@ -490,6 +570,11 @@ impl ParquetExec { self } + /// return the optional schema adapter factory + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + /// Optional schema adapter factory. /// /// See documentation on [`ParquetExecBuilder::with_schema_adapter_factory`] @@ -532,7 +617,7 @@ impl ParquetExec { } /// If enabled, the reader will read the page index - /// This is used to optimise filter pushdown + /// This is used to optimize filter pushdown /// via `RowSelector` and `RowFilter` by /// eliminating unnecessary IO and decoding pub fn with_enable_page_index(mut self, enable_page_index: bool) -> Self { @@ -574,19 +659,26 @@ impl ParquetExec { fn compute_properties( schema: SchemaRef, orderings: &[LexOrdering], + constraints: Constraints, file_config: &FileScanConfig, ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings); - PlanProperties::new( - eq_properties, + EquivalenceProperties::new_with_orderings(schema, orderings) + .with_constraints(constraints), Self::output_partitioning_helper(file_config), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + EmissionType::Incremental, + Boundedness::Bounded, ) } - fn with_file_groups(mut self, file_groups: Vec>) -> Self { + /// Updates the file groups to read and recalculates the output partitioning + /// + /// Note this function does not update statistics or other properties + /// that depend on the file groups. + fn with_file_groups_and_update_partitioning( + mut self, + file_groups: Vec>, + ) -> Self { self.base_config.file_groups = file_groups; // Changing file groups may invalidate output partitioning. Update it also let output_partitioning = Self::output_partitioning_helper(&self.base_config); @@ -679,7 +771,8 @@ impl ExecutionPlan for ParquetExec { let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { - new_plan = new_plan.with_file_groups(repartitioned_file_groups); + new_plan = new_plan + .with_file_groups_and_update_partitioning(repartitioned_file_groups); } Ok(Some(Arc::new(new_plan))) } @@ -721,7 +814,7 @@ impl ExecutionPlan for ParquetExec { predicate: self.predicate.clone(), pruning_predicate: self.pruning_predicate.clone(), page_pruning_predicate: self.page_pruning_predicate.clone(), - table_schema: self.base_config.file_schema.clone(), + table_schema: Arc::clone(&self.base_config.file_schema), metadata_size_hint: self.metadata_size_hint, metrics: self.metrics.clone(), parquet_file_reader_factory, @@ -797,6 +890,7 @@ mod tests { // See also `parquet_exec` integration test use std::fs::{self, File}; use std::io::Write; + use std::sync::Mutex; use super::*; use crate::dataframe::DataFrameWriteOptions; @@ -822,6 +916,7 @@ mod tests { use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Fields}; + use bytes::{BufMut, BytesMut}; use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::planner::logical2physical; @@ -831,7 +926,7 @@ mod tests { use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; - use object_store::ObjectMeta; + use object_store::{ObjectMeta, ObjectStore}; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; use tempfile::TempDir; @@ -1570,6 +1665,7 @@ mod tests { range: Some(FileRange { start, end }), statistics: None, extensions: None, + metadata_size_hint: None, } } @@ -1613,7 +1709,7 @@ mod tests { let store = Arc::new(LocalFileSystem::new()) as _; let file_schema = ParquetFormat::default() - .infer_schema(&state, &store, &[meta.clone()]) + .infer_schema(&state, &store, std::slice::from_ref(&meta)) .await?; let group_empty = vec![vec![file_range(&meta, 0, 2)]]; @@ -1645,7 +1741,7 @@ mod tests { let meta = local_unpartitioned_file(filename); let schema = ParquetFormat::default() - .infer_schema(&state, &store, &[meta.clone()]) + .infer_schema(&state, &store, std::slice::from_ref(&meta)) .await .unwrap(); @@ -1662,6 +1758,7 @@ mod tests { range: None, statistics: None, extensions: None, + metadata_size_hint: None, }; let expected_schema = Schema::new(vec![ @@ -1749,6 +1846,7 @@ mod tests { range: None, statistics: None, extensions: None, + metadata_size_hint: None, }; let file_schema = Arc::new(Schema::empty()); @@ -1911,7 +2009,7 @@ mod tests { assert_contains!( &display, - "pruning_predicate=CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != bar OR bar != c1_max@1 END" + "pruning_predicate=c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != bar OR bar != c1_max@1)" ); assert_contains!(&display, r#"predicate=c1@0 != bar"#); @@ -2141,7 +2239,7 @@ mod tests { // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; - std::fs::create_dir(&out_dir).unwrap(); + fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; let schema: Schema = df.schema().into(); // Register a listing table - this will use all files in the directory as data sources @@ -2292,4 +2390,134 @@ mod tests { writer.flush().unwrap(); writer.close().unwrap(); } + + /// Write out a batch to a parquet file and return the total size of the file + async fn write_batch( + path: &str, + store: Arc, + batch: RecordBatch, + ) -> usize { + let mut writer = + ArrowWriter::try_new(BytesMut::new().writer(), batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.flush().unwrap(); + let bytes = writer.into_inner().unwrap().into_inner().freeze(); + let total_size = bytes.len(); + let path = Path::from(path); + let payload = object_store::PutPayload::from_bytes(bytes); + store + .put_opts(&path, payload, object_store::PutOptions::default()) + .await + .unwrap(); + total_size + } + + /// A ParquetFileReaderFactory that tracks the metadata_size_hint passed to it + #[derive(Debug, Clone)] + struct TrackingParquetFileReaderFactory { + inner: Arc, + metadata_size_hint_calls: Arc>>>, + } + + impl TrackingParquetFileReaderFactory { + fn new(store: Arc) -> Self { + Self { + inner: Arc::new(DefaultParquetFileReaderFactory::new(store)) as _, + metadata_size_hint_calls: Arc::new(Mutex::new(vec![])), + } + } + } + + impl ParquetFileReaderFactory for TrackingParquetFileReaderFactory { + fn create_reader( + &self, + partition_index: usize, + file_meta: crate::datasource::physical_plan::FileMeta, + metadata_size_hint: Option, + metrics: &ExecutionPlanMetricsSet, + ) -> Result> + { + self.metadata_size_hint_calls + .lock() + .unwrap() + .push(metadata_size_hint); + self.inner.create_reader( + partition_index, + file_meta, + metadata_size_hint, + metrics, + ) + } + } + + /// Test passing `metadata_size_hint` to either a single file or the whole exec + #[tokio::test] + async fn test_metadata_size_hint() { + let store = + Arc::new(object_store::memory::InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://test").unwrap(); + + let ctx = SessionContext::new(); + ctx.register_object_store(store_url.as_ref(), store.clone()); + + // write some data out, it doesn't matter what it is + let c1: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + let batch = create_batch(vec![("c1", c1)]); + let schema = batch.schema(); + let name_1 = "test1.parquet"; + let name_2 = "test2.parquet"; + let total_size_1 = write_batch(name_1, store.clone(), batch.clone()).await; + let total_size_2 = write_batch(name_2, store.clone(), batch.clone()).await; + + let reader_factory = + Arc::new(TrackingParquetFileReaderFactory::new(store.clone())); + + let size_hint_calls = reader_factory.metadata_size_hint_calls.clone(); + + let exec = ParquetExec::builder( + FileScanConfig::new(store_url, schema) + .with_file( + PartitionedFile { + object_meta: ObjectMeta { + location: Path::from(name_1), + last_modified: Utc::now(), + size: total_size_1, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + } + .with_metadata_size_hint(123), + ) + .with_file(PartitionedFile { + object_meta: ObjectMeta { + location: Path::from(name_2), + last_modified: Utc::now(), + size: total_size_2, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + }), + ) + .with_parquet_file_reader_factory(reader_factory) + .with_metadata_size_hint(456) + .build(); + + let exec = Arc::new(exec); + let res = collect(exec, ctx.task_ctx()).await.unwrap(); + assert_eq!(res.len(), 2); + + let calls = size_hint_calls.lock().unwrap().clone(); + assert_eq!(calls.len(), 2); + assert_eq!(calls, vec![Some(123), Some(456)]); + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index a818a88502842..883f296f3b95d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -17,7 +17,9 @@ //! [`ParquetOpener`] for opening Parquet files -use crate::datasource::file_format::coerce_file_schema_to_view_type; +use crate::datasource::file_format::{ + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, +}; use crate::datasource::physical_plan::parquet::page_filter::PagePruningAccessPlanFilter; use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; use crate::datasource::physical_plan::parquet::{ @@ -80,18 +82,20 @@ pub(super) struct ParquetOpener { } impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> datafusion_common::Result { + fn open(&self, file_meta: FileMeta) -> Result { let file_range = file_meta.range.clone(); let extensions = file_meta.extensions.clone(); let file_name = file_meta.location().to_string(); let file_metrics = ParquetFileMetrics::new(self.partition_index, &file_name, &self.metrics); + let metadata_size_hint = file_meta.metadata_size_hint.or(self.metadata_size_hint); + let mut reader: Box = self.parquet_file_reader_factory.create_reader( self.partition_index, file_meta, - self.metadata_size_hint, + metadata_size_hint, &self.metrics, )?; @@ -101,11 +105,11 @@ impl FileOpener for ParquetOpener { SchemaRef::from(self.table_schema.project(&self.projection)?); let schema_adapter = self .schema_adapter_factory - .create(projected_schema, self.table_schema.clone()); + .create(projected_schema, Arc::clone(&self.table_schema)); let predicate = self.predicate.clone(); let pruning_predicate = self.pruning_predicate.clone(); let page_pruning_predicate = self.page_pruning_predicate.clone(); - let table_schema = self.table_schema.clone(); + let table_schema = Arc::clone(&self.table_schema); let reorder_predicates = self.reorder_filters; let pushdown_filters = self.pushdown_filters; let enable_page_index = should_enable_page_index( @@ -121,7 +125,14 @@ impl FileOpener for ParquetOpener { let mut metadata_timer = file_metrics.metadata_load_time.timer(); let metadata = ArrowReaderMetadata::load_async(&mut reader, options.clone()).await?; - let mut schema = metadata.schema().clone(); + let mut schema = Arc::clone(metadata.schema()); + + if let Some(merged) = + coerce_file_schema_to_string_type(&table_schema, &schema) + { + schema = Arc::new(merged); + } + // read with view types if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &schema) { @@ -130,16 +141,16 @@ impl FileOpener for ParquetOpener { let options = ArrowReaderOptions::new() .with_page_index(enable_page_index) - .with_schema(schema.clone()); + .with_schema(Arc::clone(&schema)); let metadata = - ArrowReaderMetadata::try_new(metadata.metadata().clone(), options)?; + ArrowReaderMetadata::try_new(Arc::clone(metadata.metadata()), options)?; metadata_timer.stop(); let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata(reader, metadata); - let file_schema = builder.schema().clone(); + let file_schema = Arc::clone(builder.schema()); let (schema_mapping, adapted_projections) = schema_adapter.map_schema(&file_schema)?; @@ -177,7 +188,7 @@ impl FileOpener for ParquetOpener { // Determine which row groups to actually read. The idea is to skip // as many row groups as possible based on the metadata and query - let file_metadata = builder.metadata().clone(); + let file_metadata = Arc::clone(builder.metadata()); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let rg_metadata = file_metadata.row_groups(); // track which row groups to actually read diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index ced07de974f68..4d0a8451a0d4e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -118,14 +118,16 @@ impl PagePruningAccessPlanFilter { let predicates = split_conjunction(expr) .into_iter() .filter_map(|predicate| { - let pp = - match PruningPredicate::try_new(predicate.clone(), schema.clone()) { - Ok(pp) => pp, - Err(e) => { - debug!("Ignoring error creating page pruning predicate: {e}"); - return None; - } - }; + let pp = match PruningPredicate::try_new( + Arc::clone(predicate), + Arc::clone(&schema), + ) { + Ok(pp) => pp, + Err(e) => { + debug!("Ignoring error creating page pruning predicate: {e}"); + return None; + } + }; if pp.always_true() { debug!("Ignoring always true page pruning predicate: {predicate}"); @@ -447,7 +449,7 @@ impl<'a> PagesPruningStatistics<'a> { Some(vec) } } -impl<'a> PruningStatistics for PagesPruningStatistics<'a> { +impl PruningStatistics for PagesPruningStatistics<'_> { fn min_values(&self, _column: &datafusion_common::Column) -> Option { match self.converter.data_page_mins( self.column_index, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index b161d55f2e6eb..9c2599a809cbe 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -336,7 +336,7 @@ impl<'schema> PushdownChecker<'schema> { } } -impl<'schema> TreeNodeRewriter for PushdownChecker<'schema> { +impl TreeNodeRewriter for PushdownChecker<'_> { type Node = Arc; fn f_down( @@ -422,7 +422,7 @@ fn would_column_prevent_pushdown( checker.prevents_pushdown() } -/// Recurses through expr as a trea, finds all `column`s, and checks if any of them would prevent +/// Recurses through expr as a tree, finds all `column`s, and checks if any of them would prevent /// this expression from being predicate pushed down. If any of them would, this returns false. /// Otherwise, true. pub fn can_expr_be_pushed_down_with_schemas( @@ -541,7 +541,7 @@ pub fn build_row_filter( let mut candidates: Vec = predicates .into_iter() .map(|expr| { - FilterCandidateBuilder::new(expr.clone(), file_schema, table_schema) + FilterCandidateBuilder::new(Arc::clone(expr), file_schema, table_schema) .build(metadata) }) .collect::, _>>()? @@ -692,7 +692,7 @@ mod test { let mut parquet_reader = parquet_reader_builder.build().expect("building reader"); - // Parquet file is small, we only need 1 recordbatch + // Parquet file is small, we only need 1 record batch let first_rb = parquet_reader .next() .expect("expected record batch") diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 1e5a0195449d8..1e6b8d59005fa 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -228,6 +228,80 @@ struct BloomFilterStatistics { column_sbbf: HashMap, } +impl BloomFilterStatistics { + /// Helper function for checking if [`Sbbf`] filter contains [`ScalarValue`]. + /// + /// In case the type of scalar is not supported, returns `true`, assuming that the + /// value may be present. + fn check_scalar(sbbf: &Sbbf, value: &ScalarValue, parquet_type: &Type) -> bool { + match value { + ScalarValue::Utf8(Some(v)) + | ScalarValue::Utf8View(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) + | ScalarValue::BinaryView(Some(v)) + | ScalarValue::LargeBinary(Some(v)) => sbbf.check(v), + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::UInt64(Some(v)) => sbbf.check(v), + ScalarValue::UInt32(Some(v)) => sbbf.check(v), + ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { + Type::INT32 => { + //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 + // All physical type are little-endian + if *p > 9 { + //DECIMAL can be used to annotate the following types: + // + // int32: for 1 <= precision <= 9 + // int64: for 1 <= precision <= 18 + return true; + } + let b = (*v as i32).to_le_bytes(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Int32 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::INT64 => { + if *p > 18 { + return true; + } + let b = (*v as i64).to_le_bytes(); + let decimal = Decimal::Int64 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::FIXED_LEN_BYTE_ARRAY => { + // keep with from_bytes_to_i128 + let b = v.to_be_bytes().to_vec(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Bytes { + value: b.into(), + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + _ => true, + }, + ScalarValue::Dictionary(_, inner) => { + BloomFilterStatistics::check_scalar(sbbf, inner, parquet_type) + } + _ => true, + } + } +} + impl PruningStatistics for BloomFilterStatistics { fn min_values(&self, _column: &Column) -> Option { None @@ -268,70 +342,7 @@ impl PruningStatistics for BloomFilterStatistics { let known_not_present = values .iter() - .map(|value| { - match value { - ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { - sbbf.check(&v.as_str()) - } - ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { - sbbf.check(v) - } - ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), - ScalarValue::Boolean(Some(v)) => sbbf.check(v), - ScalarValue::Float64(Some(v)) => sbbf.check(v), - ScalarValue::Float32(Some(v)) => sbbf.check(v), - ScalarValue::Int64(Some(v)) => sbbf.check(v), - ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::UInt64(Some(v)) => sbbf.check(v), - ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { - Type::INT32 => { - //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 - // All physical type are little-endian - if *p > 9 { - //DECIMAL can be used to annotate the following types: - // - // int32: for 1 <= precision <= 9 - // int64: for 1 <= precision <= 18 - return true; - } - let b = (*v as i32).to_le_bytes(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Int32 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::INT64 => { - if *p > 18 { - return true; - } - let b = (*v as i64).to_le_bytes(); - let decimal = Decimal::Int64 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::FIXED_LEN_BYTE_ARRAY => { - // keep with from_bytes_to_i128 - let b = v.to_be_bytes().to_vec(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Bytes { - value: b.into(), - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - _ => true, - }, - _ => true, - } - }) + .map(|value| BloomFilterStatistics::check_scalar(sbbf, value, parquet_type)) // The row group doesn't contain any of the values if // all the checks are false .all(|v| !v); @@ -374,7 +385,7 @@ impl<'a> RowGroupPruningStatistics<'a> { } } -impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { +impl PruningStatistics for RowGroupPruningStatistics<'_> { fn min_values(&self, column: &Column) -> Option { self.statistics_converter(column) .and_then(|c| Ok(c.row_group_mins(self.metadata_iter())?)) @@ -431,8 +442,8 @@ mod tests { use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::planner::logical2physical; - use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; + use parquet::arrow::ArrowSchemaConverter; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::file::metadata::ColumnChunkMetaData; @@ -719,7 +730,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); - let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); + let schema_descr = ArrowSchemaConverter::new().convert(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -748,7 +759,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); - let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); + let schema_descr = ArrowSchemaConverter::new().convert(&schema).unwrap(); let expr = col("c1") .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); @@ -779,11 +790,8 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 2), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 2), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -849,11 +857,8 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 0), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 0), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -863,7 +868,7 @@ mod tests { .with_scale(0) .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); - let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( + let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast( lit(ScalarValue::Decimal128(Some(500), 5, 2)), Decimal128(11, 2), )); @@ -947,7 +952,7 @@ mod tests { // INT64: c1 < 5, the c1 is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) @@ -1005,7 +1010,7 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) @@ -1018,7 +1023,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -1083,7 +1088,7 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) @@ -1096,7 +1101,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/writer.rs b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs index 0c0c54691068c..00926dc2330b1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/writer.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs @@ -40,14 +40,14 @@ pub async fn plan_to_parquet( let store = task_ctx.runtime_env().object_store(&object_store_url)?; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let plan: Arc = plan.clone(); + let plan: Arc = Arc::clone(&plan); let filename = format!("{}/part-{i}.parquet", parsed.prefix()); let file = Path::parse(filename)?; let propclone = writer_properties.clone(); - let storeref = store.clone(); + let storeref = Arc::clone(&store); let buf_writer = BufWriter::new(storeref, file.clone()); - let mut stream = plan.execute(i, task_ctx.clone())?; + let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut writer = AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; diff --git a/datafusion/core/src/datasource/physical_plan/statistics.rs b/datafusion/core/src/datasource/physical_plan/statistics.rs index e1c61ec1a7129..b4a8f377d2565 100644 --- a/datafusion/core/src/datasource/physical_plan/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/statistics.rs @@ -26,16 +26,17 @@ use std::sync::Arc; +use crate::datasource::listing::PartitionedFile; + use arrow::{ compute::SortColumn, row::{Row, Rows}, }; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; - -use crate::datasource::listing::PartitionedFile; +use datafusion_physical_expr_common::sort_expr::LexOrdering; /// A normalized representation of file min/max statistics that allows for efficient sorting & comparison. /// The min/max values are ordered by [`Self::sort_order`]. @@ -43,13 +44,13 @@ use crate::datasource::listing::PartitionedFile; pub(crate) struct MinMaxStatistics { min_by_sort_order: Rows, max_by_sort_order: Rows, - sort_order: Vec, + sort_order: LexOrdering, } impl MinMaxStatistics { /// Sort order used to sort the statistics #[allow(unused)] - pub fn sort_order(&self) -> &[PhysicalSortExpr] { + pub fn sort_order(&self) -> &LexOrdering { &self.sort_order } @@ -65,8 +66,8 @@ impl MinMaxStatistics { } pub fn new_from_files<'a>( - projected_sort_order: &[PhysicalSortExpr], // Sort order with respect to projected schema - projected_schema: &SchemaRef, // Projected schema + projected_sort_order: &LexOrdering, // Sort order with respect to projected schema + projected_schema: &SchemaRef, // Projected schema projection: Option<&[usize]>, // Indices of projection in full table schema (None = all columns) files: impl IntoIterator, ) -> Result { @@ -118,15 +119,17 @@ impl MinMaxStatistics { projected_schema .project(&(sort_columns.iter().map(|c| c.index()).collect::>()))?, ); - let min_max_sort_order = sort_columns - .iter() - .zip(projected_sort_order.iter()) - .enumerate() - .map(|(i, (col, sort))| PhysicalSortExpr { - expr: Arc::new(Column::new(col.name(), i)), - options: sort.options, - }) - .collect::>(); + let min_max_sort_order = LexOrdering::from( + sort_columns + .iter() + .zip(projected_sort_order.iter()) + .enumerate() + .map(|(i, (col, sort))| PhysicalSortExpr { + expr: Arc::new(Column::new(col.name(), i)), + options: sort.options, + }) + .collect::>(), + ); let (min_values, max_values): (Vec<_>, Vec<_>) = sort_columns .iter() @@ -166,7 +169,7 @@ impl MinMaxStatistics { } pub fn new( - sort_order: &[PhysicalSortExpr], + sort_order: &LexOrdering, schema: &SchemaRef, min_values: RecordBatch, max_values: RecordBatch, @@ -229,9 +232,7 @@ impl MinMaxStatistics { // check that sort columns are non-nullable if field.is_nullable() { - return Err(DataFusionError::Plan( - "cannot sort by nullable column".to_string(), - )); + return plan_err!("cannot sort by nullable column"); } Ok(SortColumn { @@ -256,7 +257,7 @@ impl MinMaxStatistics { Ok(Self { min_by_sort_order: min.map_err(|e| e.context("build min rows"))?, max_by_sort_order: max.map_err(|e| e.context("build max rows"))?, - sort_order: sort_order.to_vec(), + sort_order: sort_order.clone(), }) } @@ -277,14 +278,10 @@ impl MinMaxStatistics { } fn sort_columns_from_physical_sort_exprs( - sort_order: &[PhysicalSortExpr], -) -> Option> { + sort_order: &LexOrdering, +) -> Option> { sort_order .iter() - .map(|expr| { - expr.expr - .as_any() - .downcast_ref::() - }) + .map(|expr| expr.expr.as_any().downcast_ref::()) .collect::>>() } diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index fdf3381758a48..b27cf9c5f8338 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -32,11 +32,19 @@ use std::sync::Arc; /// /// This interface provides a way to implement custom schema adaptation logic /// for ParquetExec (for example, to fill missing columns with default value -/// other than null) +/// other than null). +/// +/// Most users should use [`DefaultSchemaAdapterFactory`]. See that struct for +/// more details and examples. pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { - /// Provides `SchemaAdapter`. - // The design of this function is mostly modeled for the needs of DefaultSchemaAdapterFactory, - // read its implementation docs for the reasoning + /// Create a [`SchemaAdapter`] + /// + /// Arguments: + /// + /// * `projected_table_schema`: The schema for the table, projected to + /// include only the fields being output (projected) by the this mapping. + /// + /// * `table_schema`: The entire table schema for the table fn create( &self, projected_table_schema: SchemaRef, @@ -44,53 +52,57 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { ) -> Box; } -/// Adapt file-level [`RecordBatch`]es to a table schema, which may have a schema -/// obtained from merging multiple file-level schemas. -/// -/// This is useful for enabling schema evolution in partitioned datasets. -/// -/// This has to be done in two stages. +/// Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table +/// schema, which may have a schema obtained from merging multiple file-level +/// schemas. /// -/// 1. Before reading the file, we have to map projected column indexes from the -/// table schema to the file schema. +/// This is useful for implementing schema evolution in partitioned datasets. /// -/// 2. After reading a record batch map the read columns back to the expected -/// columns indexes and insert null-valued columns wherever the file schema was -/// missing a column present in the table schema. +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. pub trait SchemaAdapter: Send + Sync { /// Map a column index in the table schema to a column index in a particular /// file schema /// + /// This is used while reading a file to push down projections by mapping + /// projected column indexes from the table schema to the file schema + /// /// Panics if index is not in range for the table schema fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option; - /// Creates a `SchemaMapping` that can be used to cast or map the columns - /// from the file schema to the table schema. + /// Creates a mapping for casting columns from the file schema to the table + /// schema. /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. + /// This is used after reading a record batch. The returned [`SchemaMapper`]: /// - /// Returns a [`SchemaMapper`] that can be applied to the output batch - /// along with an ordered list of columns to project from the file + /// 1. Maps columns to the expected columns indexes + /// 2. Handles missing values (e.g. fills nulls or a default value) for + /// columns in the in the table schema not in the file schema + /// 2. Handles different types: if the column in the file schema has a + /// different type than `table_schema`, the mapper will resolve this + /// difference (e.g. by casting to the appropriate type) + /// + /// Returns: + /// * a [`SchemaMapper`] + /// * an ordered list of columns to project from the file fn map_schema( &self, file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)>; } -/// Maps, by casting or reordering columns from the file schema to the table -/// schema. +/// Maps, columns from a specific file schema to the table schema. +/// +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. pub trait SchemaMapper: Debug + Send + Sync { - /// Adapts a `RecordBatch` to match the `table_schema` using the stored - /// mapping and conversions. + /// Adapts a `RecordBatch` to match the `table_schema` fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; /// Adapts a [`RecordBatch`] that does not have all the columns from the /// file schema. /// - /// This method is used when applying a filter to a subset of the columns as - /// part of `DataFusionArrowPredicate` when `filter_pushdown` is enabled. + /// This method is used, for example, when applying a filter to a subset of + /// the columns as part of `DataFusionArrowPredicate` when `filter_pushdown` + /// is enabled. /// /// This method is slower than `map_batch` as it looks up columns by name. fn map_partial_batch( @@ -99,11 +111,106 @@ pub trait SchemaMapper: Debug + Send + Sync { ) -> datafusion_common::Result; } -/// Implementation of [`SchemaAdapterFactory`] that maps columns by name -/// and casts columns to the expected type. +/// Default [`SchemaAdapterFactory`] for mapping schemas. +/// +/// This can be used to adapt file-level record batches to a table schema and +/// implement schema evolution. +/// +/// Given an input file schema and a table schema, this factory returns +/// [`SchemaAdapter`] that return [`SchemaMapper`]s that: +/// +/// 1. Reorder columns +/// 2. Cast columns to the correct type +/// 3. Fill missing columns with nulls +/// +/// # Errors: +/// +/// * If a column in the table schema is non-nullable but is not present in the +/// file schema (i.e. it is missing), the returned mapper tries to fill it with +/// nulls resulting in a schema error. +/// +/// # Illustration of Schema Mapping +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌───────┐ ┌───────┐ │ ┌───────┐ ┌───────┐ ┌───────┐ │ +/// ││ 1.0 │ │ "foo" │ ││ NULL │ │ "foo" │ │ "1.0" │ +/// ├───────┤ ├───────┤ │ Schema mapping ├───────┤ ├───────┤ ├───────┤ │ +/// ││ 2.0 │ │ "bar" │ ││ NULL │ │ "bar" │ │ "2.0" │ +/// └───────┘ └───────┘ │────────────────▶ └───────┘ └───────┘ └───────┘ │ +/// │ │ +/// column "c" column "b"│ column "a" column "b" column "c"│ +/// │ Float64 Utf8 │ Int32 Utf8 Utf8 +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Input Record Batch Output Record Batch +/// +/// Schema { Schema { +/// "c": Float64, "a": Int32, +/// "b": Utf8, "b": Utf8, +/// } "c": Utf8, +/// } +/// ``` +/// +/// # Example of using the `DefaultSchemaAdapterFactory` to map [`RecordBatch`]s +/// +/// Note `SchemaMapping` also supports mapping partial batches, which is used as +/// part of predicate pushdown. +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapterFactory}; +/// # use datafusion_common::record_batch; +/// // Table has fields "a", "b" and "c" +/// let table_schema = Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Utf8, true), +/// Field::new("c", DataType::Utf8, true), +/// ]); +/// +/// // create an adapter to map the table schema to the file schema +/// let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); +/// +/// // The file schema has fields "c" and "b" but "b" is stored as an 'Float64' +/// // instead of 'Utf8' +/// let file_schema = Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Float64, true), +/// ]); +/// +/// // Get a mapping from the file schema to the table schema +/// let (mapper, _indices) = adapter.map_schema(&file_schema).unwrap(); +/// +/// let file_batch = record_batch!( +/// ("c", Utf8, vec!["foo", "bar"]), +/// ("b", Float64, vec![1.0, 2.0]) +/// ).unwrap(); +/// +/// let mapped_batch = mapper.map_batch(file_batch).unwrap(); +/// +/// // the mapped batch has the correct schema and the "b" column has been cast to Utf8 +/// let expected_batch = record_batch!( +/// ("a", Int32, vec![None, None]), // missing column filled with nulls +/// ("b", Utf8, vec!["1.0", "2.0"]), // b was cast to string and order was changed +/// ("c", Utf8, vec!["foo", "bar"]) +/// ).unwrap(); +/// assert_eq!(mapped_batch, expected_batch); +/// ``` #[derive(Clone, Debug, Default)] pub struct DefaultSchemaAdapterFactory; +impl DefaultSchemaAdapterFactory { + /// Create a new factory for mapping batches from a file schema to a table + /// schema. + /// + /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with + /// the same schema for both the projected table schema and the table + /// schema. + pub fn from_schema(table_schema: SchemaRef) -> Box { + Self.create(Arc::clone(&table_schema), table_schema) + } +} + impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { fn create( &self, @@ -117,8 +224,8 @@ impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { } } -/// This SchemaAdapter requires both the table schema and the projected table schema because of the -/// needs of the [`SchemaMapping`] it creates. Read its documentation for more details +/// This SchemaAdapter requires both the table schema and the projected table +/// schema. See [`SchemaMapping`] for more details #[derive(Clone, Debug)] pub(crate) struct DefaultSchemaAdapter { /// The schema for the table, projected to include only the fields being output (projected) by the @@ -142,11 +249,12 @@ impl SchemaAdapter for DefaultSchemaAdapter { Some(file_schema.fields.find(field.name())?.0) } - /// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema. + /// Creates a `SchemaMapping` for casting or mapping the columns from the + /// file schema to the table schema. /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. + /// If the provided `file_schema` contains columns of a different type to + /// the expected `table_schema`, the method will attempt to cast the array + /// data from the file schema to the table schema where possible. /// /// Returns a [`SchemaMapping`] that can be applied to the output batch /// along with an ordered list of columns to project from the file @@ -180,45 +288,54 @@ impl SchemaAdapter for DefaultSchemaAdapter { Ok(( Arc::new(SchemaMapping { - projected_table_schema: self.projected_table_schema.clone(), + projected_table_schema: Arc::clone(&self.projected_table_schema), field_mappings, - table_schema: self.table_schema.clone(), + table_schema: Arc::clone(&self.table_schema), }), projection, )) } } -/// The SchemaMapping struct holds a mapping from the file schema to the table schema -/// and any necessary type conversions that need to be applied. +/// The SchemaMapping struct holds a mapping from the file schema to the table +/// schema and any necessary type conversions. +/// +/// Note, because `map_batch` and `map_partial_batch` functions have different +/// needs, this struct holds two schemas: +/// +/// 1. The projected **table** schema +/// 2. The full table schema /// -/// This needs both the projected table schema and full table schema because its different -/// functions have different needs. The [`map_batch`] function is only used by the ParquetOpener to -/// produce a RecordBatch which has the projected schema, since that's the schema which is supposed -/// to come out of the execution of this query. [`map_partial_batch`], however, is used to create a -/// RecordBatch with a schema that can be used for Parquet pushdown, meaning that it may contain -/// fields which are not in the projected schema (as the fields that parquet pushdown filters -/// operate can be completely distinct from the fields that are projected (output) out of the -/// ParquetExec). +/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which +/// has the projected schema, since that's the schema which is supposed to come +/// out of the execution of this query. Thus `map_batch` uses +/// `projected_table_schema` as it can only operate on the projected fields. /// -/// [`map_partial_batch`] uses `table_schema` to create the resulting RecordBatch (as it could be -/// operating on any fields in the schema), while [`map_batch`] uses `projected_table_schema` (as -/// it can only operate on the projected fields). +/// [`map_partial_batch`] is used to create a RecordBatch with a schema that +/// can be used for Parquet predicate pushdown, meaning that it may contain +/// fields which are not in the projected schema (as the fields that parquet +/// pushdown filters operate can be completely distinct from the fields that are +/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses +/// `table_schema` to create the resulting RecordBatch (as it could be operating +/// on any fields in the schema). /// /// [`map_batch`]: Self::map_batch /// [`map_partial_batch`]: Self::map_partial_batch #[derive(Debug)] pub struct SchemaMapping { - /// The schema of the table. This is the expected schema after conversion and it should match - /// the schema of the query result. + /// The schema of the table. This is the expected schema after conversion + /// and it should match the schema of the query result. projected_table_schema: SchemaRef, - /// Mapping from field index in `projected_table_schema` to index in projected file_schema. - /// They are Options instead of just plain `usize`s because the table could have fields that - /// don't exist in the file. + /// Mapping from field index in `projected_table_schema` to index in + /// projected file_schema. + /// + /// They are Options instead of just plain `usize`s because the table could + /// have fields that don't exist in the file. field_mappings: Vec>, - /// The entire table schema, as opposed to the projected_table_schema (which only contains the - /// columns that we are projecting out of this query). This contains all fields in the table, - /// regardless of if they will be projected out or not. + /// The entire table schema, as opposed to the projected_table_schema (which + /// only contains the columns that we are projecting out of this query). + /// This contains all fields in the table, regardless of if they will be + /// projected out or not. table_schema: SchemaRef, } @@ -255,7 +372,7 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = self.projected_table_schema.clone(); + let schema = Arc::clone(&self.projected_table_schema); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } @@ -304,7 +421,8 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = Arc::new(Schema::new(fields)); + let schema = + Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone())); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } @@ -330,8 +448,9 @@ mod tests { use crate::datasource::listing::PartitionedFile; use crate::datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; + use datafusion_common::record_batch; #[cfg(feature = "parquet")] use parquet::arrow::ArrowWriter; use tempfile::TempDir; @@ -359,7 +478,7 @@ mod tests { writer.close().unwrap(); let location = Path::parse(path.to_str().unwrap()).unwrap(); - let metadata = std::fs::metadata(path.as_path()).expect("Local file metadata"); + let metadata = fs::metadata(path.as_path()).expect("Local file metadata"); let meta = ObjectMeta { location, last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), @@ -374,6 +493,7 @@ mod tests { range: None, statistics: None, extensions: None, + metadata_size_hint: None, }; let f1 = Field::new("id", DataType::Int32, true); @@ -404,6 +524,58 @@ mod tests { assert_batches_sorted_eq!(expected, &read); } + #[test] + fn default_schema_adapter() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + // file has a subset of the table schema fields and different type + let file_schema = Schema::new(vec![ + Field::new("c", DataType::Float64, true), // not in table schema + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![1]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // the mapped batch has the correct schema and the "b" column has been cast to Utf8 + let expected_batch = record_batch!( + ("a", Int32, vec![None, None]), // missing column filled with nulls + ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed + ) + .unwrap(); + assert_eq!(mapped_batch, expected_batch); + } + + #[test] + fn default_schema_adapter_non_nullable_columns() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), // "a"" is declared non nullable + Field::new("b", DataType::Utf8, true), + ]); + let file_schema = Schema::new(vec![ + // since file doesn't have "a" it will be filled with nulls + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![0]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + // Mapping fails because it tries to fill in a non-nullable column with nulls + let err = mapper.map_batch(file_batch).unwrap_err().to_string(); + assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + } + #[derive(Debug)] struct TestSchemaAdapterFactory; diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 34023fbbb6207..56cbb126d02c0 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -36,7 +36,6 @@ use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; -use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; @@ -62,7 +61,7 @@ impl TableProviderFactory for StreamTableFactory { let header = if let Ok(opt) = cmd .options .get("format.has_header") - .map(|has_header| bool::from_str(has_header)) + .map(|has_header| bool::from_str(has_header.to_lowercase().as_str())) .transpose() { opt.unwrap_or(false) @@ -101,7 +100,7 @@ impl FromStr for StreamEncoding { match s.to_ascii_lowercase().as_str() { "csv" => Ok(Self::Csv), "json" => Ok(Self::Json), - _ => plan_err!("Unrecognised StreamEncoding {}", s), + _ => plan_err!("Unrecognized StreamEncoding {}", s), } } } @@ -187,7 +186,7 @@ impl StreamProvider for FileStreamProvider { fn reader(&self) -> Result> { let file = File::open(&self.location)?; - let schema = self.schema.clone(); + let schema = Arc::clone(&self.schema); match &self.encoding { StreamEncoding::Csv => { let reader = arrow::csv::ReaderBuilder::new(schema) @@ -311,7 +310,7 @@ impl TableProvider for StreamTable { } fn schema(&self) -> SchemaRef { - self.0.source.schema().clone() + Arc::clone(self.0.source.schema()) } fn constraints(&self) -> Option<&Constraints> { @@ -338,8 +337,8 @@ impl TableProvider for StreamTable { }; Ok(Arc::new(StreamingTableExec::try_new( - self.0.source.schema().clone(), - vec![Arc::new(StreamRead(self.0.clone())) as _], + Arc::clone(self.0.source.schema()), + vec![Arc::new(StreamRead(Arc::clone(&self.0))) as _], projection, projected_schema, true, @@ -365,8 +364,7 @@ impl TableProvider for StreamTable { Ok(Arc::new(DataSinkExec::new( input, - Arc::new(StreamWrite(self.0.clone())), - self.0.source.schema().clone(), + Arc::new(StreamWrite(Arc::clone(&self.0))), ordering, ))) } @@ -381,8 +379,8 @@ impl PartitionStream for StreamRead { } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - let config = self.0.clone(); - let schema = self.0.source.schema().clone(); + let config = Arc::clone(&self.0); + let schema = Arc::clone(self.0.source.schema()); let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); let tx = builder.tx(); builder.spawn_blocking(move || { @@ -413,8 +411,8 @@ impl DataSink for StreamWrite { self } - fn metrics(&self) -> Option { - None + fn schema(&self) -> &SchemaRef { + self.0.source.schema() } async fn write_all( @@ -422,7 +420,7 @@ impl DataSink for StreamWrite { mut data: SendableRecordBatchStream, _context: &Arc, ) -> Result { - let config = self.0.clone(); + let config = Arc::clone(&self.0); let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); // Note: FIFO Files support poll so this could use AsyncFd let write_task = SpawnedTask::spawn_blocking(move || { diff --git a/datafusion/core/src/datasource/streaming.rs b/datafusion/core/src/datasource/streaming.rs index 0a14cfefcdf23..1da3c3da9c897 100644 --- a/datafusion/core/src/datasource/streaming.rs +++ b/datafusion/core/src/datasource/streaming.rs @@ -76,7 +76,7 @@ impl TableProvider for StreamingTable { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { @@ -91,7 +91,7 @@ impl TableProvider for StreamingTable { limit: Option, ) -> Result> { Ok(Arc::new(StreamingTableExec::try_new( - self.schema.clone(), + Arc::clone(&self.schema), self.partitions.clone(), projection, None, diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 1ffe54e4b06c1..33a3f4da68430 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -31,6 +31,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::Column; use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; use datafusion_optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion_optimizer::analyzer::type_coercion::TypeCoercion; use datafusion_optimizer::Analyzer; use crate::datasource::{TableProvider, TableType}; @@ -67,11 +68,11 @@ impl ViewTable { fn apply_required_rule(logical_plan: LogicalPlan) -> Result { let options = ConfigOptions::default(); - Analyzer::with_rules(vec![Arc::new(ExpandWildcardRule::new())]).execute_and_check( - logical_plan, - &options, - |_, _| {}, - ) + Analyzer::with_rules(vec![ + Arc::new(ExpandWildcardRule::new()), + Arc::new(TypeCoercion::new()), + ]) + .execute_and_check(logical_plan, &options, |_, _| {}) } /// Get definition ref diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index b0951d9ec44cd..6d5d151c0be7b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -30,9 +30,8 @@ use crate::{ catalog_common::memory::MemorySchemaProvider, catalog_common::MemoryCatalogProvider, dataframe::DataFrame, - datasource::{ - function::{TableFunction, TableFunctionImpl}, - listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, + datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }, datasource::{provider_as_source, MemTable, ViewTable}, error::{DataFusionError, Result}, @@ -42,7 +41,8 @@ use crate::{ logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE, + DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable, + TableType, UNNAMED_TABLE, }, physical_expr::PhysicalExpr, physical_plan::ExecutionPlan, @@ -54,9 +54,9 @@ use arrow::record_batch::RecordBatch; use arrow_schema::Schema; use datafusion_common::{ config::{ConfigExtension, TableOptions}, - exec_err, not_impl_err, plan_datafusion_err, plan_err, + exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, SchemaReference, TableReference, + DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -73,7 +73,9 @@ use crate::datasource::dynamic_file::DynamicListTableFactory; use crate::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use datafusion_catalog::{DynamicFileCatalog, SessionStore, UrlTableFactory}; +use datafusion_catalog::{ + DynamicFileCatalog, SessionStore, TableFunction, TableFunctionImpl, UrlTableFactory, +}; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; @@ -389,8 +391,11 @@ impl SessionContext { current_catalog_list, Arc::clone(&factory) as Arc, )); + + let session_id = self.session_id.clone(); let ctx: SessionContext = self .into_state_builder() + .with_session_id(session_id) .with_catalog_list(catalog_list) .build() .into(); @@ -507,7 +512,7 @@ impl SessionContext { /// Return the [RuntimeEnv] used to run queries with this `SessionContext` pub fn runtime_env(&self) -> Arc { - self.state.read().runtime_env().clone() + Arc::clone(self.state.read().runtime_env()) } /// Returns an id that uniquely identifies this `SessionContext`. @@ -687,7 +692,39 @@ impl SessionContext { LogicalPlan::Statement(Statement::SetVariable(stmt)) => { self.set_variable(stmt).await } - + LogicalPlan::Statement(Statement::Prepare(Prepare { + name, + input, + data_types, + })) => { + // The number of parameters must match the specified data types length. + if !data_types.is_empty() { + let param_names = input.get_parameter_names()?; + if param_names.len() != data_types.len() { + return plan_err!( + "Prepare specifies {} data types but query has {} parameters", + data_types.len(), + param_names.len() + ); + } + } + // Store the unoptimized plan into the session state. Although storing the + // optimized plan or the physical plan would be more efficient, doing so is + // not currently feasible. This is because `now()` would be optimized to a + // constant value, causing each EXECUTE to yield the same result, which is + // incorrect behavior. + self.state.write().store_prepared(name, data_types, input)?; + self.return_empty_dataframe() + } + LogicalPlan::Statement(Statement::Execute(execute)) => { + self.execute_prepared(execute) + } + LogicalPlan::Statement(Statement::Deallocate(deallocate)) => { + self.state + .write() + .remove_prepared(deallocate.name.as_str())?; + self.return_empty_dataframe() + } plan => Ok(DataFrame::new(self.state(), plan)), } } @@ -738,6 +775,11 @@ impl SessionContext { cmd: &CreateExternalTable, ) -> Result { let exist = self.table_exist(cmd.name.clone())?; + + if cmd.temporary { + return not_impl_err!("Temporary tables not supported"); + } + if exist { match cmd.if_not_exists { true => return self.return_empty_dataframe(), @@ -761,10 +803,16 @@ impl SessionContext { or_replace, constraints, column_defaults, + temporary, } = cmd; let input = Arc::unwrap_or_clone(input); let input = self.state().optimize(&input)?; + + if temporary { + return not_impl_err!("Temporary tables not supported"); + } + let table = self.table(name.clone()).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -813,10 +861,15 @@ impl SessionContext { input, or_replace, definition, + temporary, } = cmd; let view = self.table(name.clone()).await; + if temporary { + return not_impl_err!("Temporary views not supported"); + } + match (or_replace, view) { (true, Ok(_)) => { self.deregister_table(name.clone())?; @@ -994,7 +1047,7 @@ impl SessionContext { Ok(table) } - async fn find_and_deregister<'a>( + async fn find_and_deregister( &self, table_ref: impl Into, table_type: TableType, @@ -1072,6 +1125,49 @@ impl SessionContext { } } + fn execute_prepared(&self, execute: Execute) -> Result { + let Execute { + name, parameters, .. + } = execute; + let prepared = self.state.read().get_prepared(&name).ok_or_else(|| { + exec_datafusion_err!("Prepared statement '{}' does not exist", name) + })?; + + // Only allow literals as parameters for now. + let mut params: Vec = parameters + .into_iter() + .map(|e| match e { + Expr::Literal(scalar) => Ok(scalar.into_value()), + _ => not_impl_err!("Unsupported parameter type: {}", e), + }) + .collect::>()?; + + // If the prepared statement provides data types, cast the params to those types. + if !prepared.data_types.is_empty() { + if params.len() != prepared.data_types.len() { + return exec_err!( + "Prepared statement '{}' expects {} parameters, but {} provided", + name, + prepared.data_types.len(), + params.len() + ); + } + params = params + .into_iter() + .zip(prepared.data_types.iter()) + .map(|(e, dt)| e.cast_to(dt)) + .collect::>()?; + } + + let params = ParamValues::List(params); + let plan = prepared + .plan + .as_ref() + .clone() + .replace_params_with_values(¶ms)?; + Ok(DataFrame::new(self.state(), plan)) + } + /// Registers a variable provider within this context. pub fn register_variable( &self, @@ -1336,8 +1432,8 @@ impl SessionContext { /// Registers a [`TableProvider`] as a table that can be /// referenced from SQL statements executed against this context. /// - /// Returns the [`TableProvider`] previously registered for this - /// reference, if any + /// If a table of the same name was already registered, returns "Table + /// already exists" error. pub fn register_table( &self, table_ref: impl Into, @@ -1385,10 +1481,7 @@ impl SessionContext { /// provided reference. /// /// [`register_table`]: SessionContext::register_table - pub async fn table<'a>( - &self, - table_ref: impl Into, - ) -> Result { + pub async fn table(&self, table_ref: impl Into) -> Result { let table_ref: TableReference = table_ref.into(); let provider = self.table_provider(table_ref.clone()).await?; let plan = LogicalPlanBuilder::scan( @@ -1415,7 +1508,7 @@ impl SessionContext { } /// Return a [`TableProvider`] for the specified table. - pub async fn table_provider<'a>( + pub async fn table_provider( &self, table_ref: impl Into, ) -> Result> { @@ -1453,7 +1546,7 @@ impl SessionContext { /// Get reference to [`SessionState`] pub fn state_ref(&self) -> Arc> { - self.state.clone() + Arc::clone(&self.state) } /// Get weak reference to [`SessionState`] @@ -1672,7 +1765,7 @@ impl<'a> BadPlanVisitor<'a> { } } -impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { +impl<'n> TreeNodeVisitor<'n> for BadPlanVisitor<'_> { type Node = LogicalPlan; fn f_down(&mut self, node: &'n Self::Node) -> Result { @@ -1696,15 +1789,14 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { #[cfg(test)] mod tests { - use std::env; - use std::path::PathBuf; - use super::{super::options::CsvReadOptions, *}; use crate::assert_batches_eq; use crate::execution::memory_pool::MemoryConsumer; - use crate::execution::runtime_env::RuntimeEnvBuilder; use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; + use arrow_schema::{DataType, TimeUnit}; + use std::env; + use std::path::PathBuf; use datafusion_common_runtime::SpawnedTask; @@ -1712,6 +1804,8 @@ mod tests { use crate::execution::session_state::SessionStateBuilder; use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; + use datafusion_expr::planner::TypePlanner; + use sqlparser::ast; use tempfile::TempDir; #[tokio::test] @@ -1809,7 +1903,7 @@ mod tests { #[tokio::test] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded - // environment. Usecase is for concurrent planing. + // environment. Use case is for concurrent planing. let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?); @@ -1837,14 +1931,12 @@ mod tests { let path = path.join("tests/tpch-csv"); let url = format!("file://{}", path.display()); - let runtime = RuntimeEnvBuilder::new().build_arc()?; let cfg = SessionConfig::new() .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); let session_state = SessionStateBuilder::new() .with_config(cfg) - .with_runtime_env(runtime) .with_default_features() .build(); let ctx = SessionContext::new_with_state(session_state); @@ -2108,6 +2200,39 @@ mod tests { Ok(()) } + #[tokio::test] + async fn custom_type_planner() -> Result<()> { + let state = SessionStateBuilder::new() + .with_default_features() + .with_type_planner(Arc::new(MyTypePlanner {})) + .build(); + let ctx = SessionContext::new_with_state(state); + let result = ctx + .sql("SELECT DATETIME '2021-01-01 00:00:00'") + .await? + .collect() + .await?; + let expected = [ + "+-----------------------------+", + "| Utf8(\"2021-01-01 00:00:00\") |", + "+-----------------------------+", + "| 2021-01-01T00:00:00 |", + "+-----------------------------+", + ]; + assert_batches_eq!(expected, &result); + Ok(()) + } + #[test] + fn preserve_session_context_id() -> Result<()> { + let ctx = SessionContext::new(); + // it does make sense to preserve session id in this case + // as `enable_url_table()` can be seen as additional configuration + // option on ctx. + // some systems like datafusion ballista relies on stable session_id + assert_eq!(ctx.session_id(), ctx.enable_url_table().session_id()); + Ok(()) + } + struct MyPhysicalPlanner {} #[async_trait] @@ -2123,9 +2248,9 @@ mod tests { fn create_physical_expr( &self, _expr: &Expr, - _input_dfschema: &crate::common::DFSchema, + _input_dfschema: &DFSchema, _session_state: &SessionState, - ) -> Result> { + ) -> Result> { unimplemented!() } } @@ -2168,4 +2293,25 @@ mod tests { Ok(ctx) } + + #[derive(Debug)] + struct MyTypePlanner {} + + impl TypePlanner for MyTypePlanner { + fn plan_type(&self, sql_type: &ast::DataType) -> Result> { + match sql_type { + ast::DataType::Datetime(precision) => { + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(Some(DataType::Timestamp(precision, None))) + } + _ => Ok(None), + } + } + } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 3f23c150be839..be87c7cac1d2a 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -281,10 +281,10 @@ mod tests { ) .await; let binding = DataFilePaths::to_urls(&path2).unwrap(); - let expexted_path = binding[0].as_str(); + let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path) + format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expected_path) ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. @@ -316,7 +316,7 @@ mod tests { let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); assert_eq!(total_rows, 0); - // Read the datafram from doule dot folder; + // Read the dataframe from double dot folder; let read_df = ctx .read_parquet( &path5, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 4953eecd66e39..c5874deb6ed50 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -24,7 +24,6 @@ use crate::catalog_common::information_schema::{ use crate::catalog_common::MemoryCatalogProviderList; use crate::datasource::cte_worktable::CteWorkTable; use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; -use crate::datasource::function::{TableFunction, TableFunctionImpl}; use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; @@ -33,14 +32,14 @@ use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use datafusion_catalog::Session; +use datafusion_catalog::{Session, TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; @@ -48,7 +47,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::planner::{ExprPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::var_provider::{is_system_variables, VarType}; @@ -69,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, Sq use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; -use sqlparser::ast::Expr as SQLExpr; +use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}; use sqlparser::dialect::dialect_from_str; use std::any::Any; use std::collections::hash_map::Entry; @@ -126,8 +125,10 @@ pub struct SessionState { session_id: String, /// Responsible for analyzing and rewrite a logical plan before optimization analyzer: Analyzer, - /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?` + /// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + /// Provides support for customizing the SQL type planning + type_planner: Option>, /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan @@ -171,6 +172,9 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + /// Cache logical plans of prepared statements for later execution. + /// Key is the prepared statement name. + prepared_plans: HashMap>, } impl Debug for SessionState { @@ -189,6 +193,7 @@ impl Debug for SessionState { .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners) + .field("type_planner", &self.type_planner) .field("query_planners", &self.query_planner) .field("analyzer", &self.analyzer) .field("optimizer", &self.optimizer) @@ -197,6 +202,7 @@ impl Debug for SessionState { .field("scalar_functions", &self.scalar_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) + .field("prepared_plans", &self.prepared_plans) .finish() } } @@ -227,15 +233,15 @@ impl Session for SessionState { } fn scalar_functions(&self) -> &HashMap> { - self.scalar_functions() + &self.scalar_functions } fn aggregate_functions(&self) -> &HashMap> { - self.aggregate_functions() + &self.aggregate_functions } fn window_functions(&self) -> &HashMap> { - self.window_functions() + &self.window_functions } fn runtime_env(&self) -> &Arc { @@ -289,16 +295,18 @@ impl SessionState { .resolve(&catalog.default_catalog, &catalog.default_schema) } - pub(crate) fn schema_for_ref( + /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if it + /// exists. + pub fn schema_for_ref( &self, table_ref: impl Into, ) -> datafusion_common::Result> { let resolved_ref = self.resolve_table_ref(table_ref); if self.config.information_schema() && *resolved_ref.schema == *INFORMATION_SCHEMA { - return Ok(Arc::new(InformationSchemaProvider::new( - self.catalog_list.clone(), - ))); + return Ok(Arc::new(InformationSchemaProvider::new(Arc::clone( + &self.catalog_list, + )))); } self.catalog_list @@ -492,11 +500,22 @@ impl SessionState { sql: &str, dialect: &str, ) -> datafusion_common::Result { + self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) + } + + /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`]. + /// + /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + pub fn sql_to_expr_with_alias( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ - MsSQL, ClickHouse, BigQuery, Ansi." + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." ) })?; @@ -512,7 +531,7 @@ impl SessionState { /// [`catalog::resolve_table_references`]: crate::catalog_common::resolve_table_references pub fn resolve_table_references( &self, - statement: &datafusion_sql::parser::Statement, + statement: &Statement, ) -> datafusion_common::Result> { let enable_ident_normalization = self.config.options().sql_parser.enable_ident_normalization; @@ -526,7 +545,7 @@ impl SessionState { /// Convert an AST Statement into a LogicalPlan pub async fn statement_to_plan( &self, - statement: datafusion_sql::parser::Statement, + statement: Statement, ) -> datafusion_common::Result { let references = self.resolve_table_references(&statement)?; @@ -536,8 +555,9 @@ impl SessionState { }; for reference in references { - let resolved = &self.resolve_table_ref(reference); - if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) { + let resolved = self.resolve_table_ref(reference); + if let Entry::Vacant(v) = provider.tables.entry(resolved) { + let resolved = v.key(); if let Ok(schema) = self.schema_for_ref(resolved.clone()) { if let Some(table) = schema.table(&resolved.table).await? { v.insert(provider_as_source(table)); @@ -594,7 +614,7 @@ impl SessionState { ) -> datafusion_common::Result { let dialect = self.config.options().sql_parser.dialect.as_str(); - let sql_expr = self.sql_to_expr(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; let provider = SessionContextProvider { state: self, @@ -602,7 +622,7 @@ impl SessionState { }; let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) + query.sql_to_expr_with_alias(sql_expr, df_schema, &mut PlannerContext::new()) } /// Returns the [`Analyzer`] for this session @@ -644,9 +664,9 @@ impl SessionState { return Ok(LogicalPlan::Explain(Explain { verbose: e.verbose, - plan: e.plan.clone(), + plan: Arc::clone(&e.plan), stringified_plans, - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded: false, })); } @@ -673,7 +693,7 @@ impl SessionState { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans .push(StringifiedPlan::new(plan_type, err.to_string())); - (e.plan.clone(), false) + (Arc::clone(&e.plan), false) } Err(e) => return Err(e), }; @@ -682,7 +702,7 @@ impl SessionState { verbose: e.verbose, plan, stringified_plans, - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded, })) } else { @@ -904,7 +924,41 @@ impl SessionState { name: &str, ) -> datafusion_common::Result>> { let udtf = self.table_functions.remove(name); - Ok(udtf.map(|x| x.function().clone())) + Ok(udtf.map(|x| Arc::clone(x.function()))) + } + + /// Store the logical plan and the parameter types of a prepared statement. + pub(crate) fn store_prepared( + &mut self, + name: String, + data_types: Vec, + plan: Arc, + ) -> datafusion_common::Result<()> { + match self.prepared_plans.entry(name) { + Entry::Vacant(e) => { + e.insert(Arc::new(PreparedPlan { data_types, plan })); + Ok(()) + } + Entry::Occupied(e) => { + exec_err!("Prepared statement '{}' already exists", e.key()) + } + } + } + + /// Get the prepared plan with the given name. + pub(crate) fn get_prepared(&self, name: &str) -> Option> { + self.prepared_plans.get(name).map(Arc::clone) + } + + /// Remove the prepared plan with the given name. + pub(crate) fn remove_prepared( + &mut self, + name: &str, + ) -> datafusion_common::Result<()> { + match self.prepared_plans.remove(name) { + Some(_) => Ok(()), + None => exec_err!("Prepared statement '{}' does not exist", name), + } } } @@ -916,6 +970,7 @@ pub struct SessionStateBuilder { session_id: Option, analyzer: Option, expr_planners: Option>>, + type_planner: Option>, optimizer: Option, physical_optimizers: Option, query_planner: Option>, @@ -945,6 +1000,7 @@ impl SessionStateBuilder { session_id: None, analyzer: None, expr_planners: None, + type_planner: None, optimizer: None, physical_optimizers: None, query_planner: None, @@ -992,6 +1048,7 @@ impl SessionStateBuilder { session_id: None, analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), + type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), query_planner: Some(existing.query_planner), @@ -1027,6 +1084,7 @@ impl SessionStateBuilder { .with_scalar_functions(SessionStateDefaults::default_scalar_functions()) .with_aggregate_functions(SessionStateDefaults::default_aggregate_functions()) .with_window_functions(SessionStateDefaults::default_window_functions()) + .with_table_function_list(SessionStateDefaults::default_table_functions()) } /// Set the session id. @@ -1086,6 +1144,12 @@ impl SessionStateBuilder { self } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. + pub fn with_type_planner(mut self, type_planner: Arc) -> Self { + self.type_planner = Some(type_planner); + self + } + /// Set the [`PhysicalOptimizerRule`]s used to optimize plans. pub fn with_physical_optimizer_rules( mut self, @@ -1135,6 +1199,19 @@ impl SessionStateBuilder { self } + /// Set the list of [`TableFunction`]s + pub fn with_table_function_list( + mut self, + table_functions: Vec>, + ) -> Self { + let functions = table_functions + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + self.table_functions = Some(functions); + self + } + /// Set the map of [`ScalarUDF`]s pub fn with_scalar_functions( mut self, @@ -1279,6 +1356,7 @@ impl SessionStateBuilder { session_id, analyzer, expr_planners, + type_planner, optimizer, physical_optimizers, query_planner, @@ -1307,6 +1385,7 @@ impl SessionStateBuilder { session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), + type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), @@ -1327,6 +1406,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + prepared_plans: HashMap::new(), }; if let Some(file_formats) = file_formats { @@ -1416,6 +1496,11 @@ impl SessionStateBuilder { &mut self.expr_planners } + /// Returns the current type_planner value + pub fn type_planner(&mut self) -> &mut Option> { + &mut self.type_planner + } + /// Returns the current optimizer value pub fn optimizer(&mut self) -> &mut Option { &mut self.optimizer @@ -1538,6 +1623,7 @@ impl Debug for SessionStateBuilder { .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners) + .field("type_planner", &self.type_planner) .field("query_planners", &self.query_planner) .field("analyzer_rules", &self.analyzer_rules) .field("analyzer", &self.analyzer) @@ -1571,19 +1657,27 @@ impl From for SessionStateBuilder { /// having a direct dependency on the [`SessionState`] struct (and core crate) struct SessionContextProvider<'a> { state: &'a SessionState, - tables: HashMap>, + tables: HashMap>, } -impl<'a> ContextProvider for SessionContextProvider<'a> { +impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { &self.state.expr_planners } + fn get_type_planner(&self) -> Option> { + if let Some(type_planner) = &self.state.type_planner { + Some(Arc::clone(type_planner)) + } else { + None + } + } + fn get_table_source( &self, name: TableReference, ) -> datafusion_common::Result> { - let name = self.state.resolve_table_ref(name).to_string(); + let name = self.state.resolve_table_ref(name); self.tables .get(&name) .cloned() @@ -1671,7 +1765,7 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { .ok_or(plan_datafusion_err!( "There is no registered file format with ext {ext}" )) - .map(|file_type| format_as_file_type(file_type.clone())) + .map(|file_type| format_as_file_type(Arc::clone(file_type))) } } @@ -1709,7 +1803,8 @@ impl FunctionRegistry for SessionState { udf: Arc, ) -> datafusion_common::Result>> { udf.aliases().iter().for_each(|alias| { - self.scalar_functions.insert(alias.clone(), udf.clone()); + self.scalar_functions + .insert(alias.clone(), Arc::clone(&udf)); }); Ok(self.scalar_functions.insert(udf.name().into(), udf)) } @@ -1719,7 +1814,8 @@ impl FunctionRegistry for SessionState { udaf: Arc, ) -> datafusion_common::Result>> { udaf.aliases().iter().for_each(|alias| { - self.aggregate_functions.insert(alias.clone(), udaf.clone()); + self.aggregate_functions + .insert(alias.clone(), Arc::clone(&udaf)); }); Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) } @@ -1729,7 +1825,8 @@ impl FunctionRegistry for SessionState { udwf: Arc, ) -> datafusion_common::Result>> { udwf.aliases().iter().for_each(|alias| { - self.window_functions.insert(alias.clone(), udwf.clone()); + self.window_functions + .insert(alias.clone(), Arc::clone(&udwf)); }); Ok(self.window_functions.insert(udwf.name().into(), udwf)) } @@ -1823,7 +1920,7 @@ impl From<&SessionState> for TaskContext { state.scalar_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), - state.runtime_env.clone(), + Arc::clone(&state.runtime_env), ) } } @@ -1858,7 +1955,7 @@ impl<'a> SessionSimplifyProvider<'a> { } } -impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { +impl SimplifyInfo for SessionSimplifyProvider<'_> { fn is_boolean_type(&self, expr: &Expr) -> datafusion_common::Result { Ok(expr.get_type(self.df_schema)? == DataType::Boolean) } @@ -1876,6 +1973,14 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { } } +#[derive(Debug)] +pub(crate) struct PreparedPlan { + /// Data types of the parameters + pub(crate) data_types: Vec, + /// The prepared logical plan + pub(crate) plan: Arc, +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index b5370efa0a979..106082bc7b3bf 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -29,7 +29,8 @@ use crate::datasource::provider::DefaultTableFactory; use crate::execution::context::SessionState; #[cfg(feature = "nested_expressions")] use crate::functions_nested; -use crate::{functions, functions_aggregate, functions_window}; +use crate::{functions, functions_aggregate, functions_table, functions_window}; +use datafusion_catalog::TableFunction; use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; @@ -119,6 +120,11 @@ impl SessionStateDefaults { functions_window::all_default_window_functions() } + /// returns the list of default [`TableFunction`]s + pub fn default_table_functions() -> Vec> { + functions_table::all_default_table_functions() + } + /// returns the list of default [`FileFormatFactory']'s pub fn default_file_formats() -> Vec> { let file_formats: Vec> = vec![ @@ -193,8 +199,13 @@ impl SessionStateDefaults { Some(factory) => factory, _ => return, }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let schema = ListingSchemaProvider::new( + authority, + path, + Arc::clone(factory), + store, + format, + ); let _ = default_catalog .register_schema("default", Arc::new(schema)) .expect("Failed to register default schema"); diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 63d4fbc0bba5e..e9501bd37a8ab 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -14,6 +14,9 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that @@ -179,7 +182,7 @@ //! //! DataFusion is designed to be highly extensible, so you can //! start with a working, full featured engine, and then -//! specialize any behavior for your usecase. For example, +//! specialize any behavior for your use case. For example, //! some projects may add custom [`ExecutionPlan`] operators, or create their own //! query language that directly creates [`LogicalPlan`] rather than using the //! built in SQL planner, [`SqlToRel`]. @@ -379,14 +382,14 @@ //! //! Calling [`execute`] produces 1 or more partitions of data, //! as a [`SendableRecordBatchStream`], which implements a pull based execution -//! API. Calling `.next().await` will incrementally compute and return the next +//! API. Calling [`next()`]`.await` will incrementally compute and return the next //! [`RecordBatch`]. Balanced parallelism is achieved using [Volcano style] //! "Exchange" operations implemented by [`RepartitionExec`]. //! //! While some recent research such as [Morsel-Driven Parallelism] describes challenges //! with the pull style Volcano execution model on NUMA architectures, in practice DataFusion achieves -//! similar scalability as systems that use morsel driven approach such as DuckDB. -//! See the [DataFusion paper submitted to SIGMOD] for more details. +//! similar scalability as systems that use push driven schedulers [such as DuckDB]. +//! See the [DataFusion paper in SIGMOD 2024] for more details. //! //! [`execute`]: physical_plan::ExecutionPlan::execute //! [`SendableRecordBatchStream`]: crate::physical_plan::SendableRecordBatchStream @@ -400,24 +403,189 @@ //! [`RepartitionExec`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/repartition/struct.RepartitionExec.html //! [Volcano style]: https://w6113.github.io/files/papers/volcanoparallelism-89.pdf //! [Morsel-Driven Parallelism]: https://db.in.tum.de/~leis/papers/morsels.pdf -//! [DataFusion paper submitted SIGMOD]: https://github.com/apache/datafusion/files/13874720/DataFusion_Query_Engine___SIGMOD_2024.pdf +//! [DataFusion paper in SIGMOD 2024]: https://github.com/apache/datafusion/files/15149988/DataFusion_Query_Engine___SIGMOD_2024-FINAL-mk4.pdf +//! [such as DuckDB]: https://github.com/duckdb/duckdb/issues/1583 //! [implementors of `ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#implementors //! -//! ## Thread Scheduling +//! ## Streaming Execution +//! +//! DataFusion is a "streaming" query engine which means `ExecutionPlan`s incrementally +//! read from their input(s) and compute output one [`RecordBatch`] at a time +//! by continually polling [`SendableRecordBatchStream`]s. Output and +//! intermediate `RecordBatch`s each have approximately `batch_size` rows, +//! which amortizes per-batch overhead of execution. +//! +//! Note that certain operations, sometimes called "pipeline breakers", +//! (for example full sorts or hash aggregations) are fundamentally non streaming and +//! must read their input fully before producing **any** output. As much as possible, +//! other operators read a single [`RecordBatch`] from their input to produce a +//! single `RecordBatch` as output. +//! +//! For example, given this SQL query: +//! +//! ```sql +//! SELECT date_trunc('month', time) FROM data WHERE id IN (10,20,30); +//! ``` +//! +//! The diagram below shows the call sequence when a consumer calls [`next()`] to +//! get the next `RecordBatch` of output. While it is possible that some +//! steps run on different threads, typically tokio will use the same thread +//! that called `next()` to read from the input, apply the filter, and +//! return the results without interleaving any other operations. This results +//! in excellent cache locality as the same CPU core that produces the data often +//! consumes it immediately as well. //! -//! DataFusion incrementally computes output from a [`SendableRecordBatchStream`] -//! with `target_partitions` threads. Parallelism is implementing using multiple -//! [Tokio] [`task`]s, which are executed by threads managed by a tokio Runtime. -//! While tokio is most commonly used -//! for asynchronous network I/O, its combination of an efficient, work-stealing -//! scheduler, first class compiler support for automatic continuation generation, -//! and exceptional performance makes it a compelling choice for CPU intensive -//! applications as well. This is explained in more detail in [Using Rustlang’s Async Tokio -//! Runtime for CPU-Bound Tasks]. +//! ```text +//! +//! Step 3: FilterExec calls next() Step 2: ProjectionExec calls +//! on input Stream next() on input Stream +//! ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +//! │ Step 1: Consumer +//! ▼ ▼ │ calls next() +//! ┏━━━━━━━━━━━━━━┓ ┏━━━━━┻━━━━━━━━━━━━━┓ ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ +//! ┃ ┃ ┃ ┃ ┃ ◀ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! ┃ DataSource ┃ ┃ ┃ ┃ ┃ +//! ┃ (e.g. ┃ ┃ FilterExec ┃ ┃ ProjectionExec ┃ +//! ┃ ParquetExec) ┃ ┃id IN (10, 20, 30) ┃ ┃date_bin('month', time) ┃ +//! ┃ ┃ ┃ ┃ ┃ ┣ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▶ +//! ┃ ┃ ┃ ┃ ┃ ┃ +//! ┗━━━━━━━━━━━━━━┛ ┗━━━━━━━━━━━┳━━━━━━━┛ ┗━━━━━━━━━━━━━━━━━━━━━━━━┛ +//! │ ▲ ▲ Step 6: ProjectionExec +//! │ │ │ computes date_trunc into a +//! └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ new RecordBatch returned +//! ┌─────────────────────┐ ┌─────────────┐ from client +//! │ RecordBatch │ │ RecordBatch │ +//! └─────────────────────┘ └─────────────┘ +//! +//! Step 4: DataSource returns a Step 5: FilterExec returns a new +//! single RecordBatch RecordBatch with only matching rows +//! ``` +//! +//! [`next()`]: futures::StreamExt::next +//! +//! ## Thread Scheduling, CPU / IO Thread Pools, and [Tokio] [`Runtime`]s +//! +//! DataFusion automatically runs each plan with multiple CPU cores using +//! a [Tokio] [`Runtime`] as a thread pool. While tokio is most commonly used +//! for asynchronous network I/O, the combination of an efficient, work-stealing +//! scheduler and first class compiler support for automatic continuation +//! generation (`async`), also makes it a compelling choice for CPU intensive +//! applications as explained in the [Using Rustlang’s Async Tokio +//! Runtime for CPU-Bound Tasks] blog. +//! +//! The number of cores used is determined by the `target_partitions` +//! configuration setting, which defaults to the number of CPU cores. +//! While preparing for execution, DataFusion tries to create this many distinct +//! `async` [`Stream`]s for each `ExecutionPlan`. +//! The `Stream`s for certain `ExecutionPlans`, such as as [`RepartitionExec`] +//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by +//! threads managed by the `Runtime`. +//! Many DataFusion `Stream`s perform CPU intensive processing. +//! +//! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s +//! to perform network I/O using standard Rust `async` during execution. +//! However, this design also makes it very easy to mix CPU intensive and latency +//! sensitive I/O work on the same thread pool ([`Runtime`]). +//! Using the same (default) `Runtime` is convenient, and often works well for +//! initial development and processing local files, but it can lead to problems +//! under load and/or when reading from network sources such as AWS S3. +//! +//! If your system does not fully utilize either the CPU or network bandwidth +//! during execution, or you see significantly higher tail (e.g. p99) latencies +//! responding to network requests, **it is likely you need to use a different +//! `Runtime` for CPU intensive DataFusion plans**. This effect can be especially +//! pronounced when running several queries concurrently. +//! +//! As shown in the following figure, using the same `Runtime` for both CPU +//! intensive processing and network requests can introduce significant +//! delays in responding to those network requests. Delays in processing network +//! requests can and does lead network flow control to throttle the available +//! bandwidth in response. +//! +//! ```text +//! Legend +//! +//! ┏━━━━━━┓ +//! Processing network request ┃ ┃ CPU bound work +//! is delayed due to processing ┗━━━━━━┛ +//! CPU bound work ┌─┐ +//! │ │ Network request +//! ││ └─┘ processing +//! +//! ││ +//! ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! │ │ +//! +//! ▼ ▼ +//! ┌─────────────┐ ┌─┐┌─┐┏━━━━━━━━━━━━━━━━━━━┓┏━━━━━━━━━━━━━━━━━━━┓┌─┐ +//! │ │thread 1 │ ││ │┃ Decoding ┃┃ Filtering ┃│ │ +//! │ │ └─┘└─┘┗━━━━━━━━━━━━━━━━━━━┛┗━━━━━━━━━━━━━━━━━━━┛└─┘ +//! │ │ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │Tokio Runtime│thread 2 ┃ Decoding ┃ Filtering ┃ Decoding ┃ ... +//! │(thread pool)│ ┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! │ │ ... ... +//! │ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓┌─┐ ┏━━━━━━━━━━━━━━┓ +//! │ │thread N ┃ Decoding ┃ Filtering ┃│ │ ┃ Decoding ┃ +//! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┛└─┘ ┗━━━━━━━━━━━━━━┛ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//! ``` +//! +//! The bottleneck resulting from network throttling can be avoided +//! by using separate [`Runtime`]s for the different types of work, as shown +//! in the diagram below. +//! +//! ```text +//! A separate thread pool processes network Legend +//! requests, reducing the latency for +//! processing each request ┏━━━━━━┓ +//! ┃ ┃ CPU bound work +//! │ ┗━━━━━━┛ +//! │ ┌─┐ +//! ┌ ─ ─ ─ ─ ┘ │ │ Network request +//! ┌ ─ ─ ─ ┘ └─┘ processing +//! │ +//! ▼ ▼ +//! ┌─────────────┐ ┌─┐┌─┐┌─┐ +//! │ │thread 1 │ ││ ││ │ +//! │ │ └─┘└─┘└─┘ +//! │Tokio Runtime│ ... +//! │(thread pool)│thread 2 +//! │ │ +//! │"IO Runtime" │ ... +//! │ │ ┌─┐ +//! │ │thread N │ │ +//! └─────────────┘ └─┘ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//! +//! ┌─────────────┐ ┏━━━━━━━━━━━━━━━━━━━┓┏━━━━━━━━━━━━━━━━━━━┓ +//! │ │thread 1 ┃ Decoding ┃┃ Filtering ┃ +//! │ │ ┗━━━━━━━━━━━━━━━━━━━┛┗━━━━━━━━━━━━━━━━━━━┛ +//! │Tokio Runtime│ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │(thread pool)│thread 2 ┃ Decoding ┃ Filtering ┃ Decoding ┃ ... +//! │ │ ┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! │ CPU Runtime │ ... ... +//! │ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │ │thread N ┃ Decoding ┃ Filtering ┃ Decoding ┃ +//! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//!``` +//! +//! Note that DataFusion does not use [`tokio::task::spawn_blocking`] for +//! CPU-bounded work, because `spawn_blocking` is designed for blocking **IO**, +//! not designed CPU bound tasks. Among other challenges, spawned blocking +//! tasks can't yield waiting for input (can't call `await`) so they +//! can't be used to limit the number of concurrent CPU bound tasks or +//! keep the processing pipeline to the same core. //! //! [Tokio]: https://tokio.rs +//! [`Runtime`]: tokio::runtime::Runtime //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ +//! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec +//! [`CoalescePartitionsExec`]: physical_plan::coalesce_partitions::CoalescePartitionsExec //! //! ## State Management and Configuration //! @@ -605,6 +773,11 @@ pub mod functions_window { pub use datafusion_functions_window::*; } +/// re-export of [`datafusion_functions_table`] crate +pub mod functions_table { + pub use datafusion_functions_table::*; +} + /// re-export of variable provider for `@name` and `@@name` style runtime values. pub mod variable { pub use datafusion_expr::var_provider::{VarProvider, VarType}; diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index b64d1d5a83f81..67cea5c3d596f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -52,11 +52,13 @@ use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, PhysicalExpr, PhysicalExprRef, }; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::izip; /// The `EnforceDistribution` rule ensures that distribution requirements are @@ -273,7 +275,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { fn adjust_input_keys_ordering( mut requirements: PlanWithKeyRequirements, ) -> Result> { - let plan = requirements.plan.clone(); + let plan = Arc::clone(&requirements.plan); if let Some(HashJoinExec { left, @@ -294,8 +296,8 @@ fn adjust_input_keys_ordering( Vec, )| { HashJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_conditions.0, filter.clone(), join_type, @@ -328,7 +330,8 @@ fn adjust_input_keys_ordering( JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti - | JoinType::Full => vec![], + | JoinType::Full + | JoinType::LeftMark => vec![], }; } PartitionMode::Auto => { @@ -360,8 +363,8 @@ fn adjust_input_keys_ordering( Vec, )| { SortMergeJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_conditions.0, filter.clone(), *join_type, @@ -493,8 +496,8 @@ fn reorder_aggregate_keys( PhysicalGroupBy::new_single(new_group_exprs), agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.input().clone(), - agg_exec.input_schema.clone(), + Arc::clone(agg_exec.input()), + Arc::clone(&agg_exec.input_schema), )?); // Build new group expressions that correspond to the output // of the "reordered" aggregator: @@ -512,11 +515,11 @@ fn reorder_aggregate_keys( new_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - partial_agg.clone(), + Arc::clone(&partial_agg) as _, agg_exec.input_schema(), )?); - agg_node.plan = new_final_agg.clone(); + agg_node.plan = Arc::clone(&new_final_agg) as _; agg_node.data.clear(); agg_node.children = vec![PlanWithKeyRequirements::new( partial_agg as _, @@ -615,15 +618,15 @@ pub(crate) fn reorder_join_keys_to_inputs( left.equivalence_properties(), right.equivalence_properties(), ); - if positions.map_or(false, |idxs| !idxs.is_empty()) { + if positions.is_some_and(|idxs| !idxs.is_empty()) { let JoinKeyPairs { left_keys, right_keys, } = join_keys; let new_join_on = new_join_conditions(&left_keys, &right_keys); return Ok(Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_join_on, filter.clone(), join_type, @@ -662,8 +665,8 @@ pub(crate) fn reorder_join_keys_to_inputs( .map(|idx| sort_options[positions[idx]]) .collect(); return SortMergeJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), new_join_on, filter.clone(), *join_type, @@ -724,19 +727,19 @@ fn try_reorder( } else if !equivalence_properties.eq_group().is_empty() { normalized_expected = expected .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect(); normalized_left_keys = join_keys .left_keys .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect(); normalized_right_keys = join_keys .right_keys .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect(); if physical_exprs_equal(&normalized_expected, &normalized_left_keys) @@ -759,8 +762,8 @@ fn try_reorder( let mut new_left_keys = vec![]; let mut new_right_keys = vec![]; for pos in positions.iter() { - new_left_keys.push(join_keys.left_keys[*pos].clone()); - new_right_keys.push(join_keys.right_keys[*pos].clone()); + new_left_keys.push(Arc::clone(&join_keys.left_keys[*pos])); + new_right_keys.push(Arc::clone(&join_keys.right_keys[*pos])); } let pairs = JoinKeyPairs { left_keys: new_left_keys, @@ -798,7 +801,7 @@ fn expected_expr_positions( fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) -> JoinKeyPairs { let (left_keys, right_keys) = on .iter() - .map(|(l, r)| (l.clone() as _, r.clone() as _)) + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) .unzip(); JoinKeyPairs { left_keys, @@ -813,7 +816,7 @@ fn new_join_conditions( new_left_keys .iter() .zip(new_right_keys.iter()) - .map(|(l_key, r_key)| (l_key.clone(), r_key.clone())) + .map(|(l_key, r_key)| (Arc::clone(l_key), Arc::clone(r_key))) .collect() } @@ -842,8 +845,9 @@ fn add_roundrobin_on_top( // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.prefer_existing_sort`) let partitioning = Partitioning::RoundRobinBatch(n_target); - let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? - .with_preserve_order(); + let repartition = + RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)? + .with_preserve_order(); let new_plan = Arc::new(repartition) as _; @@ -900,8 +904,9 @@ fn add_hash_on_top( // - Usage of order preserving variants is not desirable (per the flag // `config.optimizer.prefer_existing_sort`). let partitioning = dist.create_partitioning(n_target); - let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? - .with_preserve_order(); + let repartition = + RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)? + .with_preserve_order(); let plan = Arc::new(repartition) as _; return Ok(DistributionContext::new(plan, true, vec![input])); @@ -934,11 +939,15 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - input.plan.output_ordering().unwrap_or(&[]).to_vec(), - input.plan.clone(), + input + .plan + .output_ordering() + .unwrap_or(&LexOrdering::default()) + .clone(), + Arc::clone(&input.plan), )) as _ } else { - Arc::new(CoalescePartitionsExec::new(input.plan.clone())) as _ + Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _ }; DistributionContext::new(new_plan, true, vec![input]) @@ -1014,7 +1023,7 @@ fn replace_order_preserving_variants( .collect::>>()?; if is_sort_preserving_merge(&context.plan) { - let child_plan = context.children[0].plan.clone(); + let child_plan = Arc::clone(&context.children[0].plan); context.plan = Arc::new(CoalescePartitionsExec::new(child_plan)); return Ok(context); } else if let Some(repartition) = @@ -1022,7 +1031,7 @@ fn replace_order_preserving_variants( { if repartition.preserve_order() { context.plan = Arc::new(RepartitionExec::try_new( - context.children[0].plan.clone(), + Arc::clone(&context.children[0].plan), repartition.partitioning().clone(), )?); return Ok(context); @@ -1153,12 +1162,17 @@ fn ensure_distribution( let should_use_estimates = config .execution .use_row_number_estimates_to_optimize_partitioning; - let is_unbounded = dist_context.plan.execution_mode().is_unbounded(); + let unbounded_and_pipeline_friendly = dist_context.plan.boundedness().is_unbounded() + && matches!( + dist_context.plan.pipeline_behavior(), + EmissionType::Incremental | EmissionType::Both + ); // Use order preserving variants either of the conditions true // - it is desired according to config // - when plan is unbounded + // - when it is pipeline friendly (can incrementally produce results) let order_preserving_variants_desirable = - is_unbounded || config.optimizer.prefer_existing_sort; + unbounded_and_pipeline_friendly || config.optimizer.prefer_existing_sort; // Remove unnecessary repartition from the physical plan if any let DistributionContext { @@ -1194,7 +1208,7 @@ fn ensure_distribution( // We store the updated children in `new_children`. let children = izip!( children.into_iter(), - plan.required_input_ordering().iter(), + plan.required_input_ordering(), plan.maintains_input_order(), repartition_status_flags.into_iter() ) @@ -1238,7 +1252,7 @@ fn ensure_distribution( // to increase parallelism. child = add_roundrobin_on_top(child, target_partitions)?; } - // When inserting hash is necessary to satisy hash requirement, insert hash repartition. + // When inserting hash is necessary to satisfy hash requirement, insert hash repartition. if hash_necessary { child = add_hash_on_top(child, exprs.to_vec(), target_partitions)?; @@ -1261,7 +1275,7 @@ fn ensure_distribution( let ordering_satisfied = child .plan .equivalence_properties() - .ordering_satisfy_requirement(required_input_ordering); + .ordering_satisfy_requirement(&required_input_ordering); if (!ordering_satisfied || !order_preserving_variants_desirable) && child.data { @@ -1300,7 +1314,10 @@ fn ensure_distribution( ) .collect::>>()?; - let children_plans = children.iter().map(|c| c.plan.clone()).collect::>(); + let children_plans = children + .iter() + .map(|c| Arc::clone(&c.plan)) + .collect::>(); plan = if plan.as_any().is::() && !config.optimizer.prefer_existing_union @@ -1399,9 +1416,6 @@ pub(crate) mod tests { use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig, ParquetExec}; use crate::physical_optimizer::enforce_sorting::EnforceSorting; - use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_partitions_exec, repartition_exec, - }; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::expressions::col; use crate::physical_plan::filter::FilterExec; @@ -1410,14 +1424,16 @@ pub(crate) mod tests { use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; use datafusion_physical_optimizer::output_requirements::OutputRequirements; + use datafusion_physical_optimizer::test_utils::{ + check_integrity, coalesce_partitions_exec, repartition_exec, + }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{ - expressions, expressions::binary, expressions::lit, LexOrdering, - PhysicalSortExpr, PhysicalSortRequirement, + expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, }; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::PlanProperties; @@ -1434,7 +1450,7 @@ pub(crate) mod tests { impl SortRequiredExec { fn new_with_requirement( input: Arc, - requirement: Vec, + requirement: LexOrdering, ) -> Self { let cache = Self::compute_properties(&input); Self { @@ -1449,7 +1465,8 @@ pub(crate) mod tests { PlanProperties::new( input.equivalence_properties().clone(), // Equivalence Properties input.output_partitioning().clone(), // Output Partitioning - input.execution_mode(), // Execution Mode + input.pipeline_behavior(), // Pipeline Behavior + input.boundedness(), // Boundedness ) } } @@ -1460,11 +1477,7 @@ pub(crate) mod tests { _t: DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - write!( - f, - "SortRequiredExec: [{}]", - PhysicalSortExpr::format_list(&self.expr) - ) + write!(f, "SortRequiredExec: [{}]", self.expr) } } @@ -1494,7 +1507,7 @@ pub(crate) mod tests { if self.expr.is_empty() { vec![None] } else { - vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))] + vec![Some(LexRequirement::from(self.expr.clone()))] } } @@ -1539,7 +1552,7 @@ pub(crate) mod tests { /// create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( - output_ordering: Vec>, + output_ordering: Vec, ) -> Arc { ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) @@ -1555,7 +1568,7 @@ pub(crate) mod tests { /// Created a sorted parquet exec with multiple files fn parquet_exec_multiple_sorted( - output_ordering: Vec>, + output_ordering: Vec, ) -> Arc { ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) @@ -1572,7 +1585,7 @@ pub(crate) mod tests { csv_exec_with_sort(vec![]) } - fn csv_exec_with_sort(output_ordering: Vec>) -> Arc { + fn csv_exec_with_sort(output_ordering: Vec) -> Arc { Arc::new( CsvExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) @@ -1595,9 +1608,7 @@ pub(crate) mod tests { } // Created a sorted parquet exec with multiple files - fn csv_exec_multiple_sorted( - output_ordering: Vec>, - ) -> Arc { + fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc { Arc::new( CsvExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) @@ -1646,8 +1657,7 @@ pub(crate) mod tests { .enumerate() .map(|(index, (_col, name))| { ( - Arc::new(expressions::Column::new(name, index)) - as Arc, + Arc::new(Column::new(name, index)) as Arc, name.clone(), ) }) @@ -1728,7 +1738,7 @@ pub(crate) mod tests { } fn sort_exec( - sort_exprs: Vec, + sort_exprs: LexOrdering, input: Arc, preserve_partitioning: bool, ) -> Arc { @@ -1738,7 +1748,7 @@ pub(crate) mod tests { } fn sort_preserving_merge_exec( - sort_exprs: Vec, + sort_exprs: LexOrdering, input: Arc, ) -> Arc { Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) @@ -1960,6 +1970,7 @@ pub(crate) mod tests { JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, ]; @@ -1982,7 +1993,8 @@ pub(crate) mod tests { | JoinType::Right | JoinType::Full | JoinType::LeftSemi - | JoinType::LeftAnti => { + | JoinType::LeftAnti + | JoinType::LeftMark => { // Join on (a == c) let top_join_on = vec![( Arc::new(Column::new_with_schema("a", &join.schema()).unwrap()) @@ -2000,7 +2012,7 @@ pub(crate) mod tests { let expected = match join_type { // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => vec![ top_join_plan.as_str(), join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", @@ -2099,7 +2111,7 @@ pub(crate) mod tests { assert_optimized!(expected, top_join.clone(), true); assert_optimized!(expected, top_join, false); } - JoinType::LeftSemi | JoinType::LeftAnti => {} + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {} } } @@ -2821,11 +2833,11 @@ pub(crate) mod tests { ], // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs // Since ordering of the left child is not preserved after SortMergeJoin - // when mode is Right, RgihtSemi, RightAnti, Full + // when mode is Right, RightSemi, RightAnti, Full // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases // when mode is Inner, Left, LeftSemi, LeftAnti // Similarly, since partitioning of the left side is not preserved - // when mode is Right, RgihtSemi, RightAnti, Full + // when mode is Right, RightSemi, RightAnti, Full // - We need to add one additional Hash Repartition after SortMergeJoin in contrast the test // cases when mode is Inner, Left, LeftSemi, LeftAnti _ => vec![ @@ -2873,11 +2885,11 @@ pub(crate) mod tests { ], // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs // Since ordering of the left child is not preserved after SortMergeJoin - // when mode is Right, RgihtSemi, RightAnti, Full + // when mode is Right, RightSemi, RightAnti, Full // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases // when mode is Inner, Left, LeftSemi, LeftAnti // Similarly, since partitioning of the left side is not preserved - // when mode is Right, RgihtSemi, RightAnti, Full + // when mode is Right, RightSemi, RightAnti, Full // - We need to add one additional Hash Repartition and Roundrobin repartition after // SortMergeJoin in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti _ => vec![ @@ -3074,7 +3086,7 @@ pub(crate) mod tests { // Only two RepartitionExecs added let expected = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "SortExec: expr=[b3@1 ASC,a3@0 ASC], preserve_partitioning=[true]", + "SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true]", "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", @@ -3082,7 +3094,7 @@ pub(crate) mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b2@1 ASC,a2@0 ASC], preserve_partitioning=[true]", + "SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true]", "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", "RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", @@ -3094,9 +3106,9 @@ pub(crate) mod tests { let expected_first_sort_enforcement = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC,a3@0 ASC", + "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC, a3@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b3@1 ASC,a3@0 ASC], preserve_partitioning=[false]", + "SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", @@ -3105,9 +3117,9 @@ pub(crate) mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC,a2@0 ASC", + "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC, a2@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b2@1 ASC,a2@0 ASC], preserve_partitioning=[false]", + "SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", @@ -3125,10 +3137,10 @@ pub(crate) mod tests { fn merge_does_not_need_sort() -> Result<()> { // see https://github.com/apache/datafusion/issues/4331 let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); @@ -3327,10 +3339,10 @@ pub(crate) mod tests { #[test] fn repartition_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); let expected = &[ @@ -3349,10 +3361,10 @@ pub(crate) mod tests { #[test] fn repartition_sorted_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_required_exec_with_req( filter_exec(sort_exec(sort_key.clone(), parquet_exec(), false)), sort_key, @@ -3428,10 +3440,10 @@ pub(crate) mod tests { fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); // need resort as the data was not sorted correctly @@ -3449,10 +3461,10 @@ pub(crate) mod tests { fn repartition_ignores_sort_preserving_merge() -> Result<()> { // sort preserving merge already sorted input, let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_preserving_merge_exec( sort_key.clone(), parquet_exec_multiple_sorted(vec![sort_key]), @@ -3481,10 +3493,10 @@ pub(crate) mod tests { fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); let plan = sort_preserving_merge_exec(sort_key, input); @@ -3515,10 +3527,10 @@ pub(crate) mod tests { // SortRequired // Parquet(sorted) let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("d", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_required_exec_with_req( filter_exec(parquet_exec_with_sort(vec![sort_key.clone()])), sort_key, @@ -3550,10 +3562,10 @@ pub(crate) mod tests { // Parquet(unsorted) let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input1 = sort_required_exec_with_req( parquet_exec_with_sort(vec![sort_key.clone()]), sort_key, @@ -3592,10 +3604,10 @@ pub(crate) mod tests { )]; // non sorted input let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("sum", &proj.schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_preserving_merge_exec(sort_key, proj); let expected = &[ @@ -3625,10 +3637,10 @@ pub(crate) mod tests { #[test] fn repartition_ignores_transitively_with_projection() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -3658,10 +3670,10 @@ pub(crate) mod tests { #[test] fn repartition_transitively_past_sort_with_projection() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -3691,10 +3703,10 @@ pub(crate) mod tests { #[test] fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); let expected = &[ @@ -3725,10 +3737,10 @@ pub(crate) mod tests { #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_exec( sort_key, projection_exec_with_alias( @@ -3795,10 +3807,10 @@ pub(crate) mod tests { #[test] fn parallelization_multiple_files() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); let plan = sort_required_exec_with_req(plan, sort_key); @@ -3959,10 +3971,10 @@ pub(crate) mod tests { #[test] fn parallelization_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); let plan_csv = limit_exec(sort_exec(sort_key, csv_exec(), false)); @@ -3991,10 +4003,10 @@ pub(crate) mod tests { #[test] fn parallelization_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_parquet = limit_exec(filter_exec(sort_exec( sort_key.clone(), parquet_exec(), @@ -4114,10 +4126,10 @@ pub(crate) mod tests { #[test] fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // sort preserving merge already sorted input, let plan_parquet = sort_preserving_merge_exec( sort_key.clone(), @@ -4144,10 +4156,10 @@ pub(crate) mod tests { #[test] fn parallelization_sort_preserving_merge_with_union() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let input_parquet = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); @@ -4178,10 +4190,10 @@ pub(crate) mod tests { #[test] fn parallelization_does_not_benefit() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // SortRequired // Parquet(sorted) let plan_parquet = sort_required_exec_with_req( @@ -4212,10 +4224,10 @@ pub(crate) mod tests { fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { // sorted input let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -4226,10 +4238,10 @@ pub(crate) mod tests { parquet_exec_with_sort(vec![sort_key]), alias_pairs, ); - let sort_key_after_projection = vec![PhysicalSortExpr { + let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c2", &proj_parquet.schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_parquet = sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); let expected = &[ @@ -4253,10 +4265,10 @@ pub(crate) mod tests { fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { // sorted input let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -4266,10 +4278,10 @@ pub(crate) mod tests { let proj_csv = projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = vec![PhysicalSortExpr { + let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c2", &proj_csv.schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); let expected = &[ "SortPreservingMergeExec: [c2@1 ASC]", @@ -4316,10 +4328,10 @@ pub(crate) mod tests { #[test] fn remove_unnecessary_spm_after_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4341,10 +4353,10 @@ pub(crate) mod tests { #[test] fn preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("d", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4364,10 +4376,10 @@ pub(crate) mod tests { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4396,10 +4408,10 @@ pub(crate) mod tests { #[test] fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4420,16 +4432,16 @@ pub(crate) mod tests { #[test] fn do_not_preserve_ordering_through_repartition2() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key]); - let sort_req = vec![PhysicalSortExpr { + let sort_req = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); let expected = &[ @@ -4458,10 +4470,10 @@ pub(crate) mod tests { #[test] fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key]); let physical_plan = filter_exec(input); @@ -4479,10 +4491,10 @@ pub(crate) mod tests { #[test] fn do_not_put_sort_when_input_is_invalid() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec(); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); let expected = &[ @@ -4516,10 +4528,10 @@ pub(crate) mod tests { #[test] fn put_sort_when_input_is_valid() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); @@ -4553,10 +4565,10 @@ pub(crate) mod tests { #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_with_sort(vec![sort_key]); let physical_plan = aggregate_exec_with_alias(input, alias); @@ -4576,10 +4588,10 @@ pub(crate) mod tests { #[test] fn do_not_add_unnecessary_hash2() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_multiple_sorted(vec![sort_key]); let aggregate = aggregate_exec_with_alias(input, alias.clone()); diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index aa28f9d6b6aa6..167f9d6d45e75 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -61,13 +61,14 @@ use crate::physical_plan::{Distribution, ExecutionPlan, InputOrderMode}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::{Partitioning, PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::partial_sort::PartialSortExec; use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -212,27 +213,27 @@ fn replace_with_partial_sort( ) -> Result> { let plan_any = plan.as_any(); if let Some(sort_plan) = plan_any.downcast_ref::() { - let child = sort_plan.children()[0].clone(); - if !child.execution_mode().is_unbounded() { + let child = Arc::clone(sort_plan.children()[0]); + if !child.boundedness().is_unbounded() { return Ok(plan); } // here we're trying to find the common prefix for sorted columns that is required for the // sort and already satisfied by the given ordering let child_eq_properties = child.equivalence_properties(); - let sort_req = PhysicalSortRequirement::from_sort_exprs(sort_plan.expr()); + let sort_req = LexRequirement::from(sort_plan.expr().clone()); let mut common_prefix_length = 0; - while child_eq_properties - .ordering_satisfy_requirement(&sort_req[0..common_prefix_length + 1]) - { + while child_eq_properties.ordering_satisfy_requirement(&LexRequirement { + inner: sort_req[0..common_prefix_length + 1].to_vec(), + }) { common_prefix_length += 1; } if common_prefix_length > 0 { return Ok(Arc::new( PartialSortExec::new( - sort_plan.expr().to_vec(), - sort_plan.input().clone(), + LexOrdering::new(sort_plan.expr().to_vec()), + Arc::clone(sort_plan.input()), common_prefix_length, ) .with_preserve_partitioning(sort_plan.preserve_partitioning()) @@ -274,8 +275,8 @@ fn parallelize_sorts( { // Take the initial sort expressions and requirements let (sort_exprs, fetch) = get_sort_exprs(&requirements.plan)?; - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs); - let sort_exprs = sort_exprs.to_vec(); + let sort_reqs = LexRequirement::from(sort_exprs.clone()); + let sort_exprs = sort_exprs.clone(); // If there is a connection between a `CoalescePartitionsExec` and a // global sort that satisfy the requirements (i.e. intermediate @@ -289,7 +290,8 @@ fn parallelize_sorts( requirements = add_sort_above_with_check(requirements, sort_reqs, fetch); - let spm = SortPreservingMergeExec::new(sort_exprs, requirements.plan.clone()); + let spm = + SortPreservingMergeExec::new(sort_exprs, Arc::clone(&requirements.plan)); Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(spm.with_fetch(fetch)), @@ -306,7 +308,7 @@ fn parallelize_sorts( Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( - Arc::new(CoalescePartitionsExec::new(requirements.plan.clone())), + Arc::new(CoalescePartitionsExec::new(Arc::clone(&requirements.plan))), false, vec![requirements], ), @@ -390,16 +392,18 @@ fn analyze_immediate_sort_removal( if let Some(sort_exec) = node.plan.as_any().downcast_ref::() { let sort_input = sort_exec.input(); // If this sort is unnecessary, we should remove it: - if sort_input - .equivalence_properties() - .ordering_satisfy(sort_exec.properties().output_ordering().unwrap_or(&[])) - { + if sort_input.equivalence_properties().ordering_satisfy( + sort_exec + .properties() + .output_ordering() + .unwrap_or(LexOrdering::empty()), + ) { node.plan = if !sort_exec.preserve_partitioning() && sort_input.output_partitioning().partition_count() > 1 { // Replace the sort with a sort-preserving merge: - let expr = sort_exec.expr().to_vec(); - Arc::new(SortPreservingMergeExec::new(expr, sort_input.clone())) as _ + let expr = LexOrdering::new(sort_exec.expr().to_vec()); + Arc::new(SortPreservingMergeExec::new(expr, Arc::clone(sort_input))) as _ } else { // Remove the sort: node.children = node.children.swap_remove(0).children; @@ -411,12 +415,16 @@ fn analyze_immediate_sort_removal( .partition_count() == 1 { - Arc::new(GlobalLimitExec::new(sort_input.clone(), 0, Some(fetch))) + Arc::new(GlobalLimitExec::new( + Arc::clone(sort_input), + 0, + Some(fetch), + )) } else { - Arc::new(LocalLimitExec::new(sort_input.clone(), fetch)) + Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) } } else { - sort_input.clone() + Arc::clone(sort_input) } }; for child in node.children.iter_mut() { @@ -476,7 +484,7 @@ fn adjust_window_sort_removal( // Satisfy the ordering requirement so that the window can run: let mut child_node = window_tree.children.swap_remove(0); child_node = add_sort_above(child_node, reqs, None); - let child_plan = child_node.plan.clone(); + let child_plan = Arc::clone(&child_node.plan); window_tree.children.push(child_node); if window_expr.iter().all(|e| e.uses_bounded_memory()) { @@ -601,12 +609,12 @@ fn remove_corresponding_sort_from_sub_plan( // Replace with variants that do not preserve order. if is_sort_preserving_merge(&node.plan) { node.children = node.children.swap_remove(0).children; - node.plan = node.plan.children().swap_remove(0).clone(); + node.plan = Arc::clone(node.plan.children().swap_remove(0)); } else if let Some(repartition) = node.plan.as_any().downcast_ref::() { node.plan = Arc::new(RepartitionExec::try_new( - node.children[0].plan.clone(), + Arc::clone(&node.children[0].plan), repartition.properties().output_partitioning().clone(), )?) as _; } @@ -617,9 +625,12 @@ fn remove_corresponding_sort_from_sub_plan( { // If there is existing ordering, to preserve ordering use // `SortPreservingMergeExec` instead of a `CoalescePartitionsExec`. - let plan = node.plan.clone(); + let plan = Arc::clone(&node.plan); let plan = if let Some(ordering) = plan.output_ordering() { - Arc::new(SortPreservingMergeExec::new(ordering.to_vec(), plan)) as _ + Arc::new(SortPreservingMergeExec::new( + LexOrdering::new(ordering.to_vec()), + plan, + )) as _ } else { Arc::new(CoalescePartitionsExec::new(plan)) as _ }; @@ -629,10 +640,10 @@ fn remove_corresponding_sort_from_sub_plan( Ok(node) } -/// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. +/// Converts an [ExecutionPlan] trait object to a [LexOrdering] reference when possible. fn get_sort_exprs( sort_any: &Arc, -) -> Result<(&[PhysicalSortExpr], Option)> { +) -> Result<(&LexOrdering, Option)> { if let Some(sort_exec) = sort_any.as_any().downcast_ref::() { Ok((sort_exec.expr(), sort_exec.fetch())) } else if let Some(spm) = sort_any.as_any().downcast_ref::() @@ -645,20 +656,19 @@ fn get_sort_exprs( #[cfg(test)] mod tests { - use super::*; use crate::physical_optimizer::enforce_distribution::EnforceDistribution; - use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec, - limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, - repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, spr_repartition_exec, union_exec, - RequirementsTestExec, - }; + use crate::physical_optimizer::test_utils::{parquet_exec, parquet_exec_sorted}; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::{csv_exec_ordered, csv_exec_sorted, stream_exec_ordered}; + use datafusion_physical_optimizer::test_utils::{ + aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, + coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec, + limit_exec, local_limit_exec, memory_exec, repartition_exec, sort_exec, + sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, + spr_repartition_exec, union_exec, RequirementsTestExec, + }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -710,7 +720,7 @@ mod tests { let state = session_ctx.state(); // This file has 4 rules that use tree node, apply these rules as in the - // EnforSorting::optimize implementation + // EnforceSorting::optimize implementation // After these operations tree nodes should be in a consistent state. // This code block makes sure that these rules doesn't violate tree node integrity. { @@ -936,8 +946,8 @@ mod tests { "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", @@ -961,10 +971,10 @@ mod tests { let sort = sort_exec(sort_exprs.clone(), source); let spm = sort_preserving_merge_exec(sort_exprs, sort); - let sort_exprs = vec![ + let sort_exprs = LexOrdering::new(vec![ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ]); let repartition_exec = repartition_exec(spm); let sort2 = Arc::new( SortExec::new(sort_exprs.clone(), repartition_exec) @@ -979,8 +989,8 @@ mod tests { // it with a `CoalescePartitionsExec` instead of directly removing it. let expected_input = [ "AggregateExec: mode=Final, gby=[], aggr=[]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[true]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", @@ -1006,7 +1016,7 @@ mod tests { let source2 = repartition_exec(memory_exec(&schema)); let union = union_exec(vec![source1, source2]); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; + let sort_exprs = LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]); // let sort = sort_exec(sort_exprs.clone(), union); let sort = Arc::new( SortExec::new(sort_exprs.clone(), union).with_preserve_partitioning(true), @@ -1029,7 +1039,7 @@ mod tests { // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true]", @@ -1039,8 +1049,8 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[true]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " UnionExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -1085,8 +1095,11 @@ mod tests { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = Arc::new( - SortExec::new(vec![sort_expr("non_nullable_col", &schema)], source) - .with_fetch(Some(2)), + SortExec::new( + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), + source, + ) + .with_fetch(Some(2)), ); let physical_plan = sort_exec( vec![ @@ -1097,12 +1110,12 @@ mod tests { ); let expected_input = [ - "SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]", + "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]", + "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1115,26 +1128,29 @@ mod tests { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = Arc::new(SortExec::new( - vec![ + LexOrdering::new(vec![ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ]), source, )); let physical_plan = Arc::new( - SortExec::new(vec![sort_expr("non_nullable_col", &schema)], input) - .with_fetch(Some(2)), + SortExec::new( + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), + input, + ) + .with_fetch(Some(2)), ) as Arc; let expected_input = [ "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]", + " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ "GlobalLimitExec: skip=0, fetch=2", - " SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]", + " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1147,7 +1163,7 @@ mod tests { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = Arc::new(SortExec::new( - vec![sort_expr("non_nullable_col", &schema)], + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), source, )); let limit = Arc::new(LocalLimitExec::new(input, 2)); @@ -1160,14 +1176,14 @@ mod tests { ); let expected_input = [ - "SortExec: expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]", + "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " LocalLimitExec: fetch=2", " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ "LocalLimitExec: fetch=2", - " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC,nullable_col@0 ASC], preserve_partitioning=[false]", + " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1181,7 +1197,7 @@ mod tests { let source = memory_exec(&schema); // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); let input = Arc::new(SortExec::new( - vec![sort_expr("non_nullable_col", &schema)], + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), source, )); let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _; @@ -1253,24 +1269,24 @@ mod tests { let repartition = repartition_exec(union); let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // We should keep the bottom `SortExec`. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[true]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1288,12 +1304,12 @@ mod tests { let sort = sort_exec(vec![sort_exprs[0].clone()], source); let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1317,7 +1333,7 @@ mod tests { let expected_input = [ "SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ @@ -1409,17 +1425,17 @@ mod tests { // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First ParquetExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1450,7 +1466,7 @@ mod tests { // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1490,20 +1506,20 @@ mod tests { // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1542,9 +1558,9 @@ mod tests { // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", @@ -1588,7 +1604,7 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Should adjust the requirement in the third input of the union so @@ -1625,9 +1641,9 @@ mod tests { // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -1676,9 +1692,9 @@ mod tests { // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. let expected_input = ["UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST,non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. @@ -1744,10 +1760,10 @@ mod tests { async fn test_window_multi_path_sort2() -> Result<()> { let schema = create_test_schema()?; - let sort_exprs1 = vec![ + let sort_exprs1 = LexOrdering::new(vec![ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ]); let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; let source1 = parquet_exec_sorted(&schema, sort_exprs2.clone()); let source2 = parquet_exec_sorted(&schema, sort_exprs2.clone()); @@ -1761,11 +1777,11 @@ mod tests { // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -1810,11 +1826,11 @@ mod tests { // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", @@ -1822,7 +1838,7 @@ mod tests { " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1867,7 +1883,7 @@ mod tests { let join_plan2 = format!( " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" ); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", join_plan2.as_str(), " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; @@ -1879,7 +1895,7 @@ mod tests { // can push down the sort requirements and save 1 SortExec vec![ join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", @@ -1888,7 +1904,7 @@ mod tests { _ => { // can not push down the sort requirements vec![ - "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", join_plan2.as_str(), " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", @@ -1938,9 +1954,9 @@ mod tests { ); let spm_plan = match join_type { JoinType::RightAnti => { - "SortPreservingMergeExec: [col_a@0 ASC,col_b@1 ASC]" + "SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC]" } - _ => "SortPreservingMergeExec: [col_a@2 ASC,col_b@3 ASC]", + _ => "SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC]", }; let join_plan2 = format!( " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" @@ -1956,14 +1972,14 @@ mod tests { join_plan.as_str(), " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC,col_b@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ] } _ => { // can not push down the sort requirements for Left and Full join. vec![ - "SortExec: expr=[col_a@2 ASC,col_b@3 ASC], preserve_partitioning=[false]", + "SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false]", join_plan2.as_str(), " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", @@ -2001,13 +2017,13 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); - let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC,col_a@2 ASC]", + let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[col_b@3 ASC,col_a@2 ASC], preserve_partitioning=[false]", + let expected_optimized = ["SortExec: expr=[col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", @@ -2023,13 +2039,13 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC], preserve_partitioning=[false]", + let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", @@ -2069,7 +2085,7 @@ mod tests { let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -2124,7 +2140,7 @@ mod tests { let state = session_ctx.state(); let memory_exec = memory_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); let window = bounded_window_exec("nullable_col", sort_exprs.clone(), memory_exec); let repartition = repartition_exec(window); @@ -2174,7 +2190,7 @@ mod tests { let repartition = repartition_exec(source); let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(repartition)); let repartition = repartition_exec(coalesce_partitions); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); // Add local sort let sort = Arc::new( SortExec::new(sort_exprs.clone(), repartition) @@ -2332,11 +2348,11 @@ mod tests { let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); let expected_input = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; let expected_optimized = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); @@ -2360,12 +2376,12 @@ mod tests { spm, ); - let expected_input = ["SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + let expected_input = ["SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", - " SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC], preserve_partitioning=[true]", + let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", + " SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); @@ -2387,15 +2403,15 @@ mod tests { let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC,b@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", ]; let expected_optimized = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[a@0 ASC,b@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -2418,11 +2434,11 @@ mod tests { ); let expected_input = [ - "SortExec: expr=[a@0 ASC,c@2 ASC], preserve_partitioning=[false]", + "SortExec: expr=[a@0 ASC, c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]" ]; let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC,c@2 ASC], common_prefix_length=[1]", + "PartialSortExec: expr=[a@0 ASC, c@2 ASC], common_prefix_length=[1]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -2445,12 +2461,12 @@ mod tests { ); let expected_input = [ - "SortExec: expr=[a@0 ASC,c@2 ASC,d@3 ASC], preserve_partitioning=[false]", + "SortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]" ]; // let optimized let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC,c@2 ASC,d@3 ASC], common_prefix_length=[2]", + "PartialSortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], common_prefix_length=[2]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -2472,7 +2488,7 @@ mod tests { parquet_input, ); let expected_input = [ - "SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC], preserve_partitioning=[false]", + "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[b@1 ASC, c@2 ASC]" ]; let expected_no_change = expected_input; @@ -2495,7 +2511,7 @@ mod tests { unbounded_input, ); let expected_input = [ - "SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC], preserve_partitioning=[false]", + "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; let expected_no_change = expected_input; @@ -2510,8 +2526,8 @@ mod tests { // SortExec: expr=[a] // MemoryExec let schema = create_test_schema3()?; - let sort_exprs_a = vec![sort_expr("a", &schema)]; - let sort_exprs_b = vec![sort_expr("b", &schema)]; + let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); + let sort_exprs_b = LexOrdering::new(vec![sort_expr("b", &schema)]); let plan = memory_exec(&schema); let plan = sort_exec(sort_exprs_a.clone(), plan); let plan = RequirementsTestExec::new(plan) @@ -2540,8 +2556,9 @@ mod tests { // SortExec: expr=[a] // MemoryExec let schema = create_test_schema3()?; - let sort_exprs_a = vec![sort_expr("a", &schema)]; - let sort_exprs_ab = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); + let sort_exprs_ab = + LexOrdering::new(vec![sort_expr("a", &schema), sort_expr("b", &schema)]); let plan = memory_exec(&schema); let plan = sort_exec(sort_exprs_a.clone(), plan); let plan = RequirementsTestExec::new(plan) @@ -2551,7 +2568,7 @@ mod tests { let plan = sort_exec(sort_exprs_ab, plan); let expected_input = [ - "SortExec: expr=[a@0 ASC,b@1 ASC], preserve_partitioning=[false]", + "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " RequiredInputOrderingExec", " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", @@ -2559,7 +2576,7 @@ mod tests { // should able to push shorts let expected = [ "RequiredInputOrderingExec", - " SortExec: expr=[a@0 ASC,b@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected, plan, true); diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index efdd3148d03f4..63fe115e602c9 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -21,18 +21,14 @@ //! "Repartition" or "Sortedness" //! //! [`ExecutionPlan`]: crate::physical_plan::ExecutionPlan -pub mod coalesce_batches; + pub mod enforce_distribution; pub mod enforce_sorting; -pub mod join_selection; pub mod optimizer; pub mod projection_pushdown; -pub mod pruning; pub mod replace_with_order_preserving_variants; -pub mod sanity_checker; #[cfg(test)] pub mod test_utils; -pub mod update_aggr_exprs; mod sort_pushdown; mod utils; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 7ac70d701cf61..e04931f7c7975 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -51,10 +51,11 @@ use datafusion_physical_expr::{ utils::collect_columns, Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; +use datafusion_physical_plan::joins::utils::{JoinOn, JoinOnRef}; use datafusion_physical_plan::streaming::StreamingTableExec; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::Itertools; @@ -101,7 +102,7 @@ pub fn remove_unnecessary_projections( // If the projection does not cause any change on the input, we can // safely remove it: if is_projection_removable(projection) { - return Ok(Transformed::yes(projection.input().clone())); + return Ok(Transformed::yes(Arc::clone(projection.input()))); } // If it does, check if we can push it under its child(ren): let input = projection.input().as_any(); @@ -137,14 +138,11 @@ pub fn remove_unnecessary_projections( } else if let Some(union) = input.downcast_ref::() { try_pushdown_through_union(projection, union)? } else if let Some(hash_join) = input.downcast_ref::() { - try_pushdown_through_hash_join(projection, hash_join)?.map_or_else( - || try_embed_projection(projection, hash_join), - |e| Ok(Some(e)), - )? + try_pushdown_through_hash_join(projection, hash_join)? } else if let Some(cross_join) = input.downcast_ref::() { try_swapping_with_cross_join(projection, cross_join)? } else if let Some(nl_join) = input.downcast_ref::() { - try_swapping_with_nested_loop_join(projection, nl_join)? + try_pushdown_through_nested_loop_join(projection, nl_join)? } else if let Some(sm_join) = input.downcast_ref::() { try_swapping_with_sort_merge_join(projection, sm_join)? } else if let Some(sym_join) = input.downcast_ref::() { @@ -246,7 +244,7 @@ fn try_swapping_with_streaming_table( let mut lex_orderings = vec![]; for lex_ordering in streaming_table.projected_output_ordering().into_iter() { - let mut orderings = vec![]; + let mut orderings = LexOrdering::default(); for order in lex_ordering { let Some(new_ordering) = update_expr(&order.expr, projection.expr(), false)? else { @@ -261,7 +259,7 @@ fn try_swapping_with_streaming_table( } StreamingTableExec::try_new( - streaming_table.partition_schema().clone(), + Arc::clone(streaming_table.partition_schema()), streaming_table.partitions().clone(), Some(new_projections.as_ref()), lex_orderings, @@ -297,7 +295,7 @@ fn try_unifying_projections( // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].0)) }) { return Ok(None); } @@ -312,7 +310,7 @@ fn try_unifying_projections( projected_exprs.push((expr, alias.clone())); } - ProjectionExec::try_new(projected_exprs, child.input().clone()) + ProjectionExec::try_new(projected_exprs, Arc::clone(child.input())) .map(|e| Some(Arc::new(e) as _)) } @@ -467,7 +465,7 @@ fn try_swapping_with_sort( return Ok(None); } - let mut updated_exprs = vec![]; + let mut updated_exprs = LexOrdering::default(); for sort in sort.expr() { let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? else { return Ok(None); @@ -497,7 +495,7 @@ fn try_swapping_with_sort_preserving_merge( return Ok(None); } - let mut updated_exprs = vec![]; + let mut updated_exprs = LexOrdering::default(); for sort in spm.expr() { let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? else { @@ -549,6 +547,12 @@ impl EmbeddedProjection for HashJoinExec { } } +impl EmbeddedProjection for NestedLoopJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + impl EmbeddedProjection for FilterExec { fn with_projection(&self, projection: Option>) -> Result { self.with_projection(projection) @@ -603,7 +607,7 @@ fn try_embed_projection( // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection. let new_projection = Arc::new(ProjectionExec::try_new( new_projection_exprs, - new_execution_plan.clone(), + Arc::clone(&new_execution_plan) as _, )?); if is_projection_removable(&new_projection) { Ok(Some(new_execution_plan)) @@ -615,64 +619,55 @@ fn try_embed_projection( /// Collect all column indices from the given projection expressions. fn collect_column_indices(exprs: &[(Arc, String)]) -> Vec { // Collect indices and remove duplicates. - let mut indexs = exprs + let mut indices = exprs .iter() .flat_map(|(expr, _)| collect_columns(expr)) .map(|x| x.index()) .collect::>() .into_iter() .collect::>(); - indexs.sort(); - indexs + indices.sort(); + indices } -/// Tries to push `projection` down through `hash_join`. If possible, performs the -/// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections -/// as its children. Otherwise, returns `None`. -fn try_pushdown_through_hash_join( - projection: &ProjectionExec, - hash_join: &HashJoinExec, -) -> Result>> { - // TODO: currently if there is projection in HashJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later. - if hash_join.contain_projection() { - return Ok(None); - } +struct JoinData { + projected_left_child: ProjectionExec, + projected_right_child: ProjectionExec, + join_filter: Option, + join_on: JoinOn, +} - // Convert projected expressions to columns. We can not proceed if this is - // not possible. +fn try_pushdown_through_join( + projection: &ProjectionExec, + join_left: &Arc, + join_right: &Arc, + join_on: JoinOnRef, + schema: SchemaRef, + filter: Option<&JoinFilter>, +) -> Result> { + // Convert projected expressions to columns. We can not proceed if this is not possible. let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { return Ok(None); }; - let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( - hash_join.left().schema().fields().len(), - &projection_as_columns, - ); + let (far_right_left_col_ind, far_left_right_col_ind) = + join_table_borders(join_left.schema().fields().len(), &projection_as_columns); if !join_allows_pushdown( &projection_as_columns, - &hash_join.schema(), + &schema, far_right_left_col_ind, far_left_right_col_ind, ) { return Ok(None); } - let Some(new_on) = update_join_on( - &projection_as_columns[0..=far_right_left_col_ind as _], - &projection_as_columns[far_left_right_col_ind as _..], - hash_join.on(), - hash_join.left().schema().fields().len(), - ) else { - return Ok(None); - }; - - let new_filter = if let Some(filter) = hash_join.filter() { + let new_filter = if let Some(filter) = filter { match update_join_filter( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], filter, - hash_join.left().schema().fields().len(), + join_left.schema().fields().len(), ) { Some(updated_filter) => Some(updated_filter), None => return Ok(None), @@ -681,72 +676,116 @@ fn try_pushdown_through_hash_join( None }; + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + join_on, + join_left.schema().fields().len(), + ) else { + return Ok(None); + }; + let (new_left, new_right) = new_join_children( &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, - hash_join.left(), - hash_join.right(), + join_left, + join_right, )?; - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::new(new_left), - Arc::new(new_right), - new_on, - new_filter, - hash_join.join_type(), - hash_join.projection.clone(), - *hash_join.partition_mode(), - hash_join.null_equals_null, - )?))) + Ok(Some(JoinData { + projected_left_child: new_left, + projected_right_child: new_right, + join_filter: new_filter, + join_on: new_on, + })) } -/// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, -/// it returns the new swapped version having the [`CrossJoinExec`] as the top plan. -/// Otherwise, it returns None. -fn try_swapping_with_cross_join( +/// Tries to push `projection` down through `nested_loop_join`. If possible, performs the +/// pushdown and returns a new [`NestedLoopJoinExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_nested_loop_join( projection: &ProjectionExec, - cross_join: &CrossJoinExec, + nl_join: &NestedLoopJoinExec, ) -> Result>> { - // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. - let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + // TODO: currently if there is projection in NestedLoopJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later. + if nl_join.contains_projection() { return Ok(None); - }; + } - let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( - cross_join.left().schema().fields().len(), - &projection_as_columns, - ); + if let Some(JoinData { + projected_left_child, + projected_right_child, + join_filter, + .. + }) = try_pushdown_through_join( + projection, + nl_join.left(), + nl_join.right(), + &[], + nl_join.schema(), + nl_join.filter(), + )? { + Ok(Some(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(projected_left_child), + Arc::new(projected_right_child), + join_filter, + nl_join.join_type(), + // Returned early if projection is not None + None, + )?))) + } else { + try_embed_projection(projection, nl_join) + } +} - if !join_allows_pushdown( - &projection_as_columns, - &cross_join.schema(), - far_right_left_col_ind, - far_left_right_col_ind, - ) { +/// Tries to push `projection` down through `hash_join`. If possible, performs the +/// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_hash_join( + projection: &ProjectionExec, + hash_join: &HashJoinExec, +) -> Result>> { + // TODO: currently if there is projection in HashJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later. + if hash_join.contains_projection() { return Ok(None); } - let (new_left, new_right) = new_join_children( - &projection_as_columns, - far_right_left_col_ind, - far_left_right_col_ind, - cross_join.left(), - cross_join.right(), - )?; - - Ok(Some(Arc::new(CrossJoinExec::new( - Arc::new(new_left), - Arc::new(new_right), - )))) + if let Some(JoinData { + projected_left_child, + projected_right_child, + join_filter, + join_on, + }) = try_pushdown_through_join( + projection, + hash_join.left(), + hash_join.right(), + hash_join.on(), + hash_join.schema(), + hash_join.filter(), + )? { + Ok(Some(Arc::new(HashJoinExec::try_new( + Arc::new(projected_left_child), + Arc::new(projected_right_child), + join_on, + join_filter, + hash_join.join_type(), + // Returned early if projection is not None + None, + *hash_join.partition_mode(), + hash_join.null_equals_null, + )?))) + } else { + try_embed_projection(projection, hash_join) + } } -/// Tries to swap the projection with its input [`NestedLoopJoinExec`]. If it can be done, -/// it returns the new swapped version having the [`NestedLoopJoinExec`] as the top plan. +/// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`CrossJoinExec`] as the top plan. /// Otherwise, it returns None. -fn try_swapping_with_nested_loop_join( +fn try_swapping_with_cross_join( projection: &ProjectionExec, - nl_join: &NestedLoopJoinExec, + cross_join: &CrossJoinExec, ) -> Result>> { // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { @@ -754,47 +793,31 @@ fn try_swapping_with_nested_loop_join( }; let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( - nl_join.left().schema().fields().len(), + cross_join.left().schema().fields().len(), &projection_as_columns, ); if !join_allows_pushdown( &projection_as_columns, - &nl_join.schema(), + &cross_join.schema(), far_right_left_col_ind, far_left_right_col_ind, ) { return Ok(None); } - let new_filter = if let Some(filter) = nl_join.filter() { - match update_join_filter( - &projection_as_columns[0..=far_right_left_col_ind as _], - &projection_as_columns[far_left_right_col_ind as _..], - filter, - nl_join.left().schema().fields().len(), - ) { - Some(updated_filter) => Some(updated_filter), - None => return Ok(None), - } - } else { - None - }; - let (new_left, new_right) = new_join_children( &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, - nl_join.left(), - nl_join.right(), + cross_join.left(), + cross_join.right(), )?; - Ok(Some(Arc::new(NestedLoopJoinExec::try_new( + Ok(Some(Arc::new(CrossJoinExec::new( Arc::new(new_left), Arc::new(new_right), - new_filter, - nl_join.join_type(), - )?))) + )))) } /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, @@ -915,8 +938,14 @@ fn try_swapping_with_sym_hash_join( new_filter, sym_join.join_type(), sym_join.null_equals_null(), - sym_join.right().output_ordering().map(|p| p.to_vec()), - sym_join.left().output_ordering().map(|p| p.to_vec()), + sym_join + .right() + .output_ordering() + .map(|p| LexOrdering::new(p.to_vec())), + sym_join + .left() + .output_ordering() + .map(|p| LexOrdering::new(p.to_vec())), sym_join.partition_mode(), )?))) } @@ -999,8 +1028,7 @@ fn update_expr( let mut state = RewriteState::Unchanged; - let new_expr = expr - .clone() + let new_expr = Arc::clone(expr) .transform_up(|expr: Arc| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); @@ -1012,7 +1040,9 @@ fn update_expr( if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: - Ok(Transformed::yes(projected_exprs[column.index()].0.clone())) + Ok(Transformed::yes(Arc::clone( + &projected_exprs[column.index()].0, + ))) } else { // default to invalid, in case we can't find the relevant column state = RewriteState::RewrittenInvalid; @@ -1049,7 +1079,7 @@ fn make_with_child( projection: &ProjectionExec, child: &Arc, ) -> Result> { - ProjectionExec::try_new(projection.expr().to_vec(), child.clone()) + ProjectionExec::try_new(projection.expr().to_vec(), Arc::clone(child)) .map(|e| Arc::new(e) as _) } @@ -1149,8 +1179,7 @@ fn new_columns_for_join_on( .iter() .filter_map(|on| { // Rewrite all columns in `on` - (*on) - .clone() + Arc::clone(*on) .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { // Find the column in the projection expressions @@ -1213,7 +1242,7 @@ fn update_join_filter( == join_filter.column_indices().len()) .then(|| { JoinFilter::new( - join_filter.expression().clone(), + Arc::clone(join_filter.expression()), join_filter .column_indices() .iter() @@ -1296,7 +1325,7 @@ fn new_join_children( ) }) .collect_vec(), - left_child.clone(), + Arc::clone(left_child), )?; let left_size = left_child.schema().fields().len() as i32; let new_right = ProjectionExec::try_new( @@ -1314,7 +1343,7 @@ fn new_join_children( ) }) .collect_vec(), - right_child.clone(), + Arc::clone(right_child), )?; Ok((new_left, new_right)) @@ -1339,7 +1368,7 @@ mod tests { ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_physical_expr::expressions::{ - BinaryExpr, CaseExpr, CastExpr, NegativeExpr, + binary, col, BinaryExpr, CaseExpr, CastExpr, NegativeExpr, }; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_plan::joins::PartitionMode; @@ -1376,7 +1405,11 @@ mod tests { Ok(DataType::Int32) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { unimplemented!("DummyUDF::invoke") } } @@ -1863,7 +1896,7 @@ mod tests { }) as _], Some(&vec![0_usize, 2, 4, 3]), vec![ - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("e", 2)), options: SortOptions::default(), @@ -1872,11 +1905,11 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), }, - ], - vec![PhysicalSortExpr { + ]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("d", 3)), options: SortOptions::default(), - }], + }]), ] .into_iter(), true, @@ -1923,7 +1956,7 @@ mod tests { assert_eq!( result.projected_output_ordering().into_iter().collect_vec(), vec![ - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("e", 1)), options: SortOptions::default(), @@ -1932,11 +1965,11 @@ mod tests { expr: Arc::new(Column::new("a", 2)), options: SortOptions::default(), }, - ], - vec![PhysicalSortExpr { + ]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("d", 0)), options: SortOptions::default(), - }], + }]), ] ); assert!(result.is_infinite()); @@ -2392,6 +2425,73 @@ mod tests { Ok(()) } + #[test] + fn test_nested_loop_join_after_projection() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let col_left_a = col("a", &left_csv.schema())?; + let col_right_b = col("b", &right_csv.schema())?; + let col_left_c = col("c", &left_csv.schema())?; + // left_a < right_b + let filter_expr = + binary(col_left_a, Operator::Lt, col_right_b, &Schema::empty())?; + let filter_column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ]; + let filter_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + + let join: Arc = Arc::new(NestedLoopJoinExec::try_new( + left_csv, + right_csv, + Some(JoinFilter::new( + filter_expr, + filter_column_indices, + filter_schema, + )), + &JoinType::Inner, + None, + )?); + + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![(col_left_c, "c".to_string())], + Arc::clone(&join), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c]", + " NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let expected = [ + "NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1, projection=[c@2]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + Ok(()) + } + #[test] fn test_hash_join_after_projection() -> Result<()> { // sql like @@ -2553,7 +2653,7 @@ mod tests { fn test_sort_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(SortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), @@ -2566,7 +2666,7 @@ mod tests { )), options: SortOptions::default(), }, - ], + ]), csv.clone(), )); let projection: Arc = Arc::new(ProjectionExec::try_new( @@ -2581,7 +2681,7 @@ mod tests { let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortExec: expr=[b@1 ASC,c@2 + a@0 ASC], preserve_partitioning=[false]", + " SortExec: expr=[b@1 ASC, c@2 + a@0 ASC], preserve_partitioning=[false]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2590,7 +2690,7 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "SortExec: expr=[b@2 ASC,c@0 + new_a@1 ASC], preserve_partitioning=[false]", + "SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false]", " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; @@ -2603,7 +2703,7 @@ mod tests { fn test_sort_preserving_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), @@ -2616,7 +2716,7 @@ mod tests { )), options: SortOptions::default(), }, - ], + ]), csv.clone(), )); let projection: Arc = Arc::new(ProjectionExec::try_new( @@ -2631,7 +2731,7 @@ mod tests { let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortPreservingMergeExec: [b@1 ASC,c@2 + a@0 ASC]", + " SortPreservingMergeExec: [b@1 ASC, c@2 + a@0 ASC]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2640,7 +2740,7 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "SortPreservingMergeExec: [b@2 ASC,c@0 + new_a@1 ASC]", + "SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC]", " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index a989be987d3db..f32ffa8a58308 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -29,7 +29,9 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::tree_node::PlanContext; use datafusion_physical_plan::ExecutionPlanProperties; @@ -119,7 +121,7 @@ fn plan_with_order_preserving_variants( { // When a `RepartitionExec` doesn't preserve ordering, replace it with // a sort-preserving variant if appropriate: - let child = sort_input.children[0].plan.clone(); + let child = Arc::clone(&sort_input.children[0].plan); let partitioning = sort_input.plan.output_partitioning().clone(); sort_input.plan = Arc::new( RepartitionExec::try_new(child, partitioning)?.with_preserve_order(), @@ -128,10 +130,10 @@ fn plan_with_order_preserving_variants( return Ok(sort_input); } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { let child = &sort_input.children[0].plan; - if let Some(ordering) = child.output_ordering().map(Vec::from) { + if let Some(ordering) = child.output_ordering() { // When the input of a `CoalescePartitionsExec` has an ordering, // replace it with a `SortPreservingMergeExec` if appropriate: - let spm = SortPreservingMergeExec::new(ordering, child.clone()); + let spm = SortPreservingMergeExec::new(ordering.clone(), Arc::clone(child)); sort_input.plan = Arc::new(spm) as _; sort_input.children[0].data = true; return Ok(sort_input); @@ -158,7 +160,7 @@ fn plan_with_order_breaking_variants( // not required by intermediate operators: if maintains && (is_sort_preserving_merge(plan) - || !required_ordering.map_or(false, |required_ordering| { + || !required_ordering.is_some_and(|required_ordering| { node.plan .equivalence_properties() .ordering_satisfy_requirement(&required_ordering) @@ -175,12 +177,12 @@ fn plan_with_order_breaking_variants( if is_repartition(plan) && plan.maintains_input_order()[0] { // When a `RepartitionExec` preserves ordering, replace it with a // non-sort-preserving variant: - let child = sort_input.children[0].plan.clone(); + let child = Arc::clone(&sort_input.children[0].plan); let partitioning = plan.output_partitioning().clone(); sort_input.plan = Arc::new(RepartitionExec::try_new(child, partitioning)?) as _; } else if is_sort_preserving_merge(plan) { // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec`: - let child = sort_input.children[0].plan.clone(); + let child = Arc::clone(&sort_input.children[0].plan); let coalesce = CoalescePartitionsExec::new(child); sort_input.plan = Arc::new(coalesce) as _; } else { @@ -242,7 +244,8 @@ pub(crate) fn replace_with_order_preserving_variants( // For unbounded cases, we replace with the order-preserving variant in any // case, as doing so helps fix the pipeline. Also replace if config allows. let use_order_preserving_variant = config.optimizer.prefer_existing_sort - || !requirements.plan.execution_mode().pipeline_friendly(); + || (requirements.plan.boundedness().is_unbounded() + && requirements.plan.pipeline_behavior() == EmissionType::Final); // Create an alternate plan with order-preserving variants: let mut alternate_plan = plan_with_order_preserving_variants( @@ -255,7 +258,12 @@ pub(crate) fn replace_with_order_preserving_variants( if alternate_plan .plan .equivalence_properties() - .ordering_satisfy(requirements.plan.output_ordering().unwrap_or(&[])) + .ordering_satisfy( + requirements + .plan + .output_ordering() + .unwrap_or(LexOrdering::empty()), + ) { for child in alternate_plan.children.iter_mut() { child.data = false; @@ -274,10 +282,7 @@ pub(crate) fn replace_with_order_preserving_variants( mod tests { use super::*; - use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; - use crate::physical_optimizer::test_utils::check_integrity; + use crate::execution::TaskContext; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; @@ -285,18 +290,25 @@ mod tests { use crate::physical_plan::{ displayable, get_plan_string, ExecutionPlan, Partitioning, }; - use crate::prelude::SessionConfig; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::TestStreamPartition; + use datafusion_physical_optimizer::test_utils::check_integrity; + use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::Result; - use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::collect; + use datafusion_physical_plan::memory::MemoryExec; use datafusion_physical_plan::streaming::StreamingTableExec; + use object_store::memory::InMemory; + use object_store::ObjectStore; + use url::Url; use rstest::rstest; @@ -317,20 +329,24 @@ mod tests { /// * `$PLAN`: The plan to optimize. /// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. macro_rules! assert_optimized_in_all_boundedness_situations { - ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr) => { + ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr, $PREFER_EXISTING_SORT: expr) => { if $SOURCE_UNBOUNDED { assert_optimized_prefer_sort_on_off!( $EXPECTED_UNBOUNDED_PLAN_LINES, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, - $PLAN + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED ); } else { assert_optimized_prefer_sort_on_off!( $EXPECTED_BOUNDED_PLAN_LINES, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED ); } }; @@ -348,19 +364,24 @@ mod tests { /// the flag `prefer_existing_sort` is `true`. /// * `$PLAN`: The plan to optimize. macro_rules! assert_optimized_prefer_sort_on_off { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN.clone(), - false - ); - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN, - true - ); + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { + if $PREFER_EXISTING_SORT { + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED + ); + } else { + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_OPTIMIZED_PLAN_LINES, + $PLAN, + $PREFER_EXISTING_SORT, + $SOURCE_UNBOUNDED + ); + } }; } @@ -374,7 +395,7 @@ mod tests { /// * `$PLAN`: The plan to optimize. /// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr) => { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { let physical_plan = $PLAN; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -401,6 +422,19 @@ mod tests { expected_optimized_lines, actual, "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" ); + + if !$SOURCE_UNBOUNDED { + let ctx = SessionContext::new(); + let object_store = InMemory::new(); + object_store.put(&object_store::path::Path::from("file_path"), bytes::Bytes::from("").into()).await?; + ctx.register_object_store(&Url::parse("test://").unwrap(), Arc::new(object_store)); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let res = collect(optimized_physical_plan, task_ctx).await; + assert!( + res.is_ok(), + "Some errors occurred while executing the optimized physical plan: {:?}", res.unwrap_err() + ); + } }; } @@ -409,13 +443,14 @@ mod tests { // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected async fn test_replace_multiple_input_repartition_1( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -436,7 +471,7 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -453,13 +488,13 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -468,7 +503,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -477,13 +513,14 @@ mod tests { #[tokio::test] async fn test_with_inter_children_change_only( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -527,7 +564,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", ]; // Expected unbounded result (same for with and without flag) @@ -553,7 +590,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC]", @@ -563,7 +600,7 @@ mod tests { " SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -572,7 +609,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -581,13 +619,14 @@ mod tests { #[tokio::test] async fn test_replace_multiple_input_repartition_2( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); @@ -612,7 +651,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -631,14 +670,14 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -647,7 +686,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -656,13 +696,14 @@ mod tests { #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -690,7 +731,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -711,7 +752,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -719,7 +760,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -728,7 +769,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -737,13 +779,14 @@ mod tests { #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps_2( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); @@ -775,7 +818,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -798,7 +841,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -807,7 +850,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -816,7 +859,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -825,13 +869,14 @@ mod tests { #[tokio::test] async fn test_not_replacing_when_no_need_to_preserve_sorting( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -856,7 +901,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -876,7 +921,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; @@ -887,7 +932,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -896,13 +942,14 @@ mod tests { #[tokio::test] async fn test_with_multiple_replacable_repartitions( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -933,7 +980,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -956,7 +1003,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -965,7 +1012,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -974,7 +1021,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -983,13 +1031,14 @@ mod tests { #[tokio::test] async fn test_not_replace_with_different_orderings( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -1017,7 +1066,7 @@ mod tests { " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1035,7 +1084,7 @@ mod tests { " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; @@ -1046,7 +1095,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1055,13 +1105,14 @@ mod tests { #[tokio::test] async fn test_with_lost_ordering( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -1082,7 +1133,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1099,13 +1150,13 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -1114,7 +1165,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1123,13 +1175,14 @@ mod tests { #[tokio::test] async fn test_with_lost_and_kept_ordering( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; let source = if source_unbounded { stream_exec_ordered(&schema, sort_exprs) } else { - csv_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); @@ -1173,7 +1226,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1200,7 +1253,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [c@1 ASC]", @@ -1211,7 +1264,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; assert_optimized_in_all_boundedness_situations!( expected_input_unbounded, @@ -1220,7 +1273,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1229,6 +1283,7 @@ mod tests { #[tokio::test] async fn test_with_multiple_child_trees( #[values(false, true)] source_unbounded: bool, + #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; @@ -1236,7 +1291,7 @@ mod tests { let left_source = if source_unbounded { stream_exec_ordered(&schema, left_sort_exprs) } else { - csv_exec_sorted(&schema, left_sort_exprs) + memory_exec_sorted(&schema, left_sort_exprs) }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); @@ -1247,7 +1302,7 @@ mod tests { let right_source = if source_unbounded { stream_exec_ordered(&schema, right_sort_exprs) } else { - csv_exec_sorted(&schema, right_sort_exprs) + memory_exec_sorted(&schema, right_sort_exprs) }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); @@ -1288,11 +1343,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; // Expected unbounded result (same for with and without flag) @@ -1319,11 +1374,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " MemoryExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", ]; let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; @@ -1334,7 +1389,8 @@ mod tests { expected_optimized_bounded, expected_optimized_bounded_sort_preserve, physical_plan, - source_unbounded + source_unbounded, + prefer_existing_sort ); Ok(()) } @@ -1481,33 +1537,36 @@ mod tests { ) } - // creates a csv exec source for the test purposes - // projection and has_header parameters are given static due to testing needs - fn csv_exec_sorted( + // creates a memory exec source for the test purposes + // projection parameter is given static due to testing needs + fn memory_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - let projection: Vec = vec![0, 2, 3]; + pub fn make_partition(schema: &SchemaRef, sz: i32) -> RecordBatch { + let values = (0..sz).collect::>(); + let arr = Arc::new(Int32Array::from(values)); + let arr = arr as ArrayRef; - Arc::new( - CsvExec::builder( - FileScanConfig::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - ) - .with_file(PartitionedFile::new("file_path".to_string(), 100)) - .with_projection(Some(projection)) - .with_output_ordering(vec![sort_exprs]), + RecordBatch::try_new( + schema.clone(), + vec![arr.clone(), arr.clone(), arr.clone(), arr], ) - .with_has_header(true) - .with_delimeter(0) - .with_quote(b'"') - .with_escape(None) - .with_comment(None) - .with_newlines_in_values(false) - .with_file_compression_type(FileCompressionType::UNCOMPRESSED) - .build(), - ) + .unwrap() + } + + let rows = 5; + let partitions = 1; + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new({ + let data: Vec> = (0..partitions) + .map(|_| vec![make_partition(schema, rows)]) + .collect(); + let projection: Vec = vec![0, 2, 3]; + MemoryExec::try_new(&data, schema.clone(), Some(projection)) + .unwrap() + .try_with_sort_information(vec![sort_exprs]) + .unwrap() + }) } } diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs deleted file mode 100644 index 4d2baf1fe1ab2..0000000000000 --- a/datafusion/core/src/physical_optimizer/sanity_checker.rs +++ /dev/null @@ -1,671 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The [SanityCheckPlan] rule ensures that a given plan can -//! accommodate its infinite sources, if there are any. It will reject -//! non-runnable query plans that use pipeline-breaking operators on -//! infinite input(s). In addition, it will check if all order and -//! distribution requirements of a plan are satisfied by its children. - -use std::sync::Arc; - -use crate::error::Result; -use crate::physical_plan::ExecutionPlan; - -use datafusion_common::config::{ConfigOptions, OptimizerOptions}; -use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; -use datafusion_physical_plan::joins::SymmetricHashJoinExec; -use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; - -use datafusion_physical_expr_common::sort_expr::format_physical_sort_requirement_list; -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use itertools::izip; - -/// The SanityCheckPlan rule rejects the following query plans: -/// 1. Invalid plans containing nodes whose order and/or distribution requirements -/// are not satisfied by their children. -/// 2. Plans that use pipeline-breaking operators on infinite input(s), -/// it is impossible to execute such queries (they will never generate output nor finish) -#[derive(Default, Debug)] -pub struct SanityCheckPlan {} - -impl SanityCheckPlan { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for SanityCheckPlan { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - plan.transform_up(|p| check_plan_sanity(p, &config.optimizer)) - .data() - } - - fn name(&self) -> &str { - "SanityCheckPlan" - } - - fn schema_check(&self) -> bool { - true - } -} - -/// This function propagates finiteness information and rejects any plan with -/// pipeline-breaking operators acting on infinite inputs. -pub fn check_finiteness_requirements( - input: Arc, - optimizer_options: &OptimizerOptions, -) -> Result>> { - if let Some(exec) = input.as_any().downcast_ref::() { - if !(optimizer_options.allow_symmetric_joins_without_pruning - || (exec.check_if_order_information_available()? && is_prunable(exec))) - { - return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ - the 'allow_symmetric_joins_without_pruning' configuration flag"); - } - } - if !input.execution_mode().pipeline_friendly() { - plan_err!( - "Cannot execute pipeline breaking queries, operator: {:?}", - input - ) - } else { - Ok(Transformed::no(input)) - } -} - -/// This function returns whether a given symmetric hash join is amenable to -/// data pruning. For this to be possible, it needs to have a filter where -/// all involved [`PhysicalExpr`]s, [`Operator`]s and data types support -/// interval calculations. -/// -/// [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr -/// [`Operator`]: datafusion_expr::Operator -fn is_prunable(join: &SymmetricHashJoinExec) -> bool { - join.filter().map_or(false, |filter| { - check_support(filter.expression(), &join.schema()) - && filter - .schema() - .fields() - .iter() - .all(|f| is_datatype_supported(f.data_type())) - }) -} - -/// Ensures that the plan is pipeline friendly and the order and -/// distribution requirements from its children are satisfied. -pub fn check_plan_sanity( - plan: Arc, - optimizer_options: &OptimizerOptions, -) -> Result>> { - check_finiteness_requirements(plan.clone(), optimizer_options)?; - - for ((idx, child), sort_req, dist_req) in izip!( - plan.children().iter().enumerate(), - plan.required_input_ordering().iter(), - plan.required_input_distribution().iter() - ) { - let child_eq_props = child.equivalence_properties(); - if let Some(sort_req) = sort_req { - if !child_eq_props.ordering_satisfy_requirement(sort_req) { - let plan_str = get_plan_string(&plan); - return plan_err!( - "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", - plan_str, - format_physical_sort_requirement_list(sort_req), - idx, - child_eq_props.oeq_class - ); - } - } - - if !child - .output_partitioning() - .satisfy(dist_req, child_eq_props) - { - let plan_str = get_plan_string(&plan); - return plan_err!( - "Plan: {:?} does not satisfy distribution requirements: {}. Child-{} output partitioning: {}", - plan_str, - dist_req, - idx, - child.output_partitioning() - ); - } - } - - Ok(Transformed::no(plan)) -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::physical_optimizer::test_utils::{ - bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, - repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, - BinaryTestCase, QueryCase, SourceType, UnaryTestCase, - }; - - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::Result; - use datafusion_expr::JoinType; - use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::Partitioning; - use datafusion_physical_plan::displayable; - use datafusion_physical_plan::repartition::RepartitionExec; - - fn create_test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![Field::new("c9", DataType::Int32, true)])) - } - - fn create_test_schema2() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ])) - } - - /// Check if sanity checker should accept or reject plans. - fn assert_sanity_check(plan: &Arc, is_sane: bool) { - let sanity_checker = SanityCheckPlan::new(); - let opts = ConfigOptions::default(); - assert_eq!( - sanity_checker.optimize(plan.clone(), &opts).is_ok(), - is_sane - ); - } - - /// Check if the plan we created is as expected by comparing the plan - /// formatted as a string. - fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { - let plan_str = displayable(plan).indent(true).to_string(); - let actual_lines: Vec<&str> = plan_str.trim().lines().collect(); - assert_eq!(actual_lines, expected_lines); - } - - #[tokio::test] - async fn test_hash_left_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: false, - }; - - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_right_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: false, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_inner_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: false, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: false, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "Join Error".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_full_outer_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_aggregate() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: AggregateExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_window_agg_hash_partition() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT - c9, - SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 - FROM test - LIMIT 5".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: SortExec".to_string() - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_window_agg_single_partition() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT - c9, - SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 - FROM test".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: SortExec".to_string() - }; - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_cross_join() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test4 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(), - cases: vec![ - Arc::new(test1), - Arc::new(test2), - Arc::new(test3), - Arc::new(test4), - ], - error_operator: "operator: CrossJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_analyzer() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: false, - }; - let case = QueryCase { - sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Analyze Error".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - /// Tests that plan is valid when the sort requirements are satisfied. - async fn test_bounded_window_agg_sort_requirement() -> Result<()> { - let schema = create_test_schema(); - let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr_options( - "c9", - &source.schema(), - SortOptions { - descending: false, - nulls_first: false, - }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); - let bw = bounded_window_exec("c9", sort_exprs, sort); - assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", - " MemoryExec: partitions=1, partition_sizes=[0]" - ]); - assert_sanity_check(&bw, true); - Ok(()) - } - - #[tokio::test] - /// Tests that plan is invalid when the sort requirements are not satisfied. - async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { - let schema = create_test_schema(); - let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr_options( - "c9", - &source.schema(), - SortOptions { - descending: false, - nulls_first: false, - }, - )]; - let bw = bounded_window_exec("c9", sort_exprs, source); - assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " MemoryExec: partitions=1, partition_sizes=[0]" - ]); - // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. - assert_sanity_check(&bw, false); - Ok(()) - } - - #[tokio::test] - /// A valid when a single partition requirement - /// is satisfied. - async fn test_global_limit_single_partition() -> Result<()> { - let schema = create_test_schema(); - let source = memory_exec(&schema); - let limit = global_limit_exec(source); - - assert_plan( - limit.as_ref(), - vec![ - "GlobalLimitExec: skip=0, fetch=100", - " MemoryExec: partitions=1, partition_sizes=[0]", - ], - ); - assert_sanity_check(&limit, true); - Ok(()) - } - - #[tokio::test] - /// An invalid plan when a single partition requirement - /// is not satisfied. - async fn test_global_limit_multi_partition() -> Result<()> { - let schema = create_test_schema(); - let source = memory_exec(&schema); - let limit = global_limit_exec(repartition_exec(source)); - - assert_plan( - limit.as_ref(), - vec![ - "GlobalLimitExec: skip=0, fetch=100", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[0]", - ], - ); - // Distribution requirement of the `GlobalLimitExec` is not satisfied. We expect to receive error during sanity check. - assert_sanity_check(&limit, false); - Ok(()) - } - - #[tokio::test] - /// A plan with no requirements should satisfy. - async fn test_local_limit() -> Result<()> { - let schema = create_test_schema(); - let source = memory_exec(&schema); - let limit = local_limit_exec(source); - - assert_plan( - limit.as_ref(), - vec![ - "LocalLimitExec: fetch=100", - " MemoryExec: partitions=1, partition_sizes=[0]", - ], - ); - assert_sanity_check(&limit, true); - Ok(()) - } - - #[tokio::test] - /// Valid plan with multiple children satisfy both order and distribution. - async fn test_sort_merge_join_satisfied() -> Result<()> { - let schema1 = create_test_schema(); - let schema2 = create_test_schema2(); - let source1 = memory_exec(&schema1); - let source2 = memory_exec(&schema2); - let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); - let left = Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(vec![left_jcol.clone()], 10), - )?); - - let right = Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(vec![right_jcol.clone()], 10), - )?); - - let join_on = vec![(left_jcol as _, right_jcol as _)]; - let join_ty = JoinType::Inner; - let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " MemoryExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " MemoryExec: partitions=1, partition_sizes=[0]", - ], - ); - assert_sanity_check(&smj, true); - Ok(()) - } - - #[tokio::test] - /// Invalid case when the order is not satisfied by the 2nd - /// child. - async fn test_sort_merge_join_order_missing() -> Result<()> { - let schema1 = create_test_schema(); - let schema2 = create_test_schema2(); - let source1 = memory_exec(&schema1); - let right = memory_exec(&schema2); - let sort_exprs1 = vec![sort_expr_options( - "c9", - &source1.schema(), - SortOptions::default(), - )]; - let left = sort_exec(sort_exprs1, source1); - // Missing sort of the right child here.. - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); - let left = Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(vec![left_jcol.clone()], 10), - )?); - - let right = Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(vec![right_jcol.clone()], 10), - )?); - - let join_on = vec![(left_jcol as _, right_jcol as _)]; - let join_ty = JoinType::Inner; - let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " MemoryExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[0]", - ], - ); - // Order requirement for the `SortMergeJoin` is not satisfied for right child. We expect to receive error during sanity check. - assert_sanity_check(&smj, false); - Ok(()) - } - - #[tokio::test] - /// Invalid case when the distribution is not satisfied by the 2nd - /// child. - async fn test_sort_merge_join_dist_missing() -> Result<()> { - let schema1 = create_test_schema(); - let schema2 = create_test_schema2(); - let source1 = memory_exec(&schema1); - let source2 = memory_exec(&schema2); - let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); - let right = Arc::new(RepartitionExec::try_new( - right, - Partitioning::RoundRobinBatch(10), - )?); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); - let left = Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(vec![left_jcol.clone()], 10), - )?); - - // Missing hash partitioning on right child. - - let join_on = vec![(left_jcol as _, right_jcol as _)]; - let join_ty = JoinType::Inner; - let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " MemoryExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " MemoryExec: partitions=1, partition_sizes=[0]", - ], - ); - // Distribution requirement for the `SortMergeJoin` is not satisfied for right child (has round-robin partitioning). We expect to receive error during sanity check. - assert_sanity_check(&smj, false); - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index c7677d725b036..6c761f674b3b5 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,20 +28,19 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::tree_node::PlanContext; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow_schema::SchemaRef; use datafusion_common::tree_node::{ ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{plan_err, JoinSide, Result}; +use datafusion_common::{plan_err, HashSet, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{ - LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, -}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; - -use hashbrown::HashSet; +use datafusion_physical_expr::PhysicalSortRequirement; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::HashJoinExec; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -89,15 +88,17 @@ fn pushdown_sorts_helper( let parent_reqs = requirements .data .ordering_requirement - .as_deref() - .unwrap_or(&[]); + .clone() + .unwrap_or_default(); let satisfy_parent = plan .equivalence_properties() - .ordering_satisfy_requirement(parent_reqs); + .ordering_satisfy_requirement(&parent_reqs); + if is_sort(plan) { let required_ordering = plan .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs) + .cloned() + .map(LexRequirement::from) .unwrap_or_default(); if !satisfy_parent { // Make sure this `SortExec` satisfies parent requirements: @@ -141,7 +142,7 @@ fn pushdown_sorts_helper( for (child, order) in requirements.children.iter_mut().zip(reqs) { child.data.ordering_requirement = order; } - } else if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_reqs)? { + } else if let Some(adjusted) = pushdown_requirement_to_children(plan, &parent_reqs)? { // Can not satisfy the parent requirements, check whether we can push // requirements down: for (child, order) in requirements.children.iter_mut().zip(adjusted) { @@ -164,14 +165,16 @@ fn pushdown_sorts_helper( fn pushdown_requirement_to_children( plan: &Arc, - parent_required: LexRequirementRef, + parent_required: &LexRequirement, ) -> Result>>> { let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[0].as_deref().unwrap_or(&[]); + let request_child = required_input_ordering[0].clone().unwrap_or_default(); let child_plan = plan.children().swap_remove(0); - match determine_children_requirement(parent_required, request_child, child_plan) { + + match determine_children_requirement(parent_required, &request_child, child_plan) + { RequirementsCompatibility::Satisfy => { let req = (!request_child.is_empty()) .then(|| LexRequirement::new(request_child.to_vec())); @@ -181,8 +184,12 @@ fn pushdown_requirement_to_children( RequirementsCompatibility::NonCompatible => Ok(None), } } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let sort_req = PhysicalSortRequirement::from_sort_exprs( - sort_exec.properties().output_ordering().unwrap_or(&[]), + let sort_req = LexRequirement::from( + sort_exec + .properties() + .output_ordering() + .cloned() + .unwrap_or(LexOrdering::default()), ); if sort_exec .properties() @@ -203,8 +210,11 @@ fn pushdown_requirement_to_children( .iter() .all(|maintain| *maintain) { - let output_req = PhysicalSortRequirement::from_sort_exprs( - plan.properties().output_ordering().unwrap_or(&[]), + let output_req = LexRequirement::from( + plan.properties() + .output_ordering() + .cloned() + .unwrap_or(LexOrdering::default()), ); // Push down through operator with fetch when: // - requirement is aligned with output ordering @@ -223,19 +233,21 @@ fn pushdown_requirement_to_children( } else if is_union(plan) { // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); + let req = (!parent_required.is_empty()).then(|| parent_required.clone()); Ok(Some(vec![req; plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { // If the current plan is SortMergeJoinExec let left_columns_len = smj.left().schema().fields().len(); - let parent_required_expr = - PhysicalSortRequirement::to_sort_exprs(parent_required.iter().cloned()); - match expr_source_side(&parent_required_expr, smj.join_type(), left_columns_len) { + let parent_required_expr = LexOrdering::from(parent_required.clone()); + match expr_source_side( + parent_required_expr.as_ref(), + smj.join_type(), + left_columns_len, + ) { Some(JoinSide::Left) => try_pushdown_requirements_to_join( smj, parent_required, - &parent_required_expr, + parent_required_expr.as_ref(), JoinSide::Left, ), Some(JoinSide::Right) => { @@ -243,12 +255,11 @@ fn pushdown_requirement_to_children( smj.schema().fields.len() - smj.right().schema().fields.len(); let new_right_required = shift_right_required(parent_required, right_offset)?; - let new_right_required_expr = - PhysicalSortRequirement::to_sort_exprs(new_right_required); + let new_right_required_expr = LexOrdering::from(new_right_required); try_pushdown_requirements_to_join( smj, parent_required, - &new_right_required_expr, + new_right_required_expr.as_ref(), JoinSide::Right, ) } @@ -270,14 +281,14 @@ fn pushdown_requirement_to_children( // Pushing down is not beneficial Ok(None) } else if is_sort_preserving_merge(plan) { - let new_ordering = - PhysicalSortRequirement::to_sort_exprs(parent_required.to_vec()); + let new_ordering = LexOrdering::from(parent_required.clone()); let mut spm_eqs = plan.equivalence_properties().clone(); // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. spm_eqs = spm_eqs.with_reorder(new_ordering); // Do not push-down through SortPreservingMergeExec when // ordering requirement invalidates requirement of sort preserving merge exec. - if !spm_eqs.ordering_satisfy(plan.output_ordering().unwrap_or(&[])) { + if !spm_eqs.ordering_satisfy(&plan.output_ordering().cloned().unwrap_or_default()) + { Ok(None) } else { // Can push-down through SortPreservingMergeExec, because parent requirement is finer @@ -286,6 +297,8 @@ fn pushdown_requirement_to_children( .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req])) } + } else if let Some(hash_join) = plan.as_any().downcast_ref::() { + handle_hash_join(hash_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -295,7 +308,7 @@ fn pushdown_requirement_to_children( /// Return true if pushing the sort requirements through a node would violate /// the input sorting requirements for the plan fn pushdown_would_violate_requirements( - parent_required: LexRequirementRef, + parent_required: &LexRequirement, child: &dyn ExecutionPlan, ) -> bool { child @@ -321,8 +334,8 @@ fn pushdown_would_violate_requirements( /// - If parent requirements are more specific, push down parent requirements. /// - If they are not compatible, need to add a sort. fn determine_children_requirement( - parent_required: LexRequirementRef, - request_child: LexRequirementRef, + parent_required: &LexRequirement, + request_child: &LexRequirement, child_plan: &Arc, ) -> RequirementsCompatibility { if child_plan @@ -344,10 +357,11 @@ fn determine_children_requirement( RequirementsCompatibility::NonCompatible } } + fn try_pushdown_requirements_to_join( smj: &SortMergeJoinExec, - parent_required: LexRequirementRef, - sort_expr: &[PhysicalSortExpr], + parent_required: &LexRequirement, + sort_expr: &LexOrdering, push_side: JoinSide, ) -> Result>>> { let left_eq_properties = smj.left().equivalence_properties(); @@ -355,13 +369,13 @@ fn try_pushdown_requirements_to_join( let mut smj_required_orderings = smj.required_input_ordering(); let right_requirement = smj_required_orderings.swap_remove(1); let left_requirement = smj_required_orderings.swap_remove(0); - let left_ordering = smj.left().output_ordering().unwrap_or(&[]); - let right_ordering = smj.right().output_ordering().unwrap_or(&[]); + let left_ordering = &smj.left().output_ordering().cloned().unwrap_or_default(); + let right_ordering = &smj.right().output_ordering().cloned().unwrap_or_default(); + let (new_left_ordering, new_right_ordering) = match push_side { JoinSide::Left => { - let left_eq_properties = left_eq_properties - .clone() - .with_reorder(Vec::from(sort_expr)); + let left_eq_properties = + left_eq_properties.clone().with_reorder(sort_expr.clone()); if left_eq_properties .ordering_satisfy_requirement(&left_requirement.unwrap_or_default()) { @@ -372,9 +386,8 @@ fn try_pushdown_requirements_to_join( } } JoinSide::Right => { - let right_eq_properties = right_eq_properties - .clone() - .with_reorder(Vec::from(sort_expr)); + let right_eq_properties = + right_eq_properties.clone().with_reorder(sort_expr.clone()); if right_eq_properties .ordering_satisfy_requirement(&right_requirement.unwrap_or_default()) { @@ -384,6 +397,7 @@ fn try_pushdown_requirements_to_join( return Ok(None); } } + JoinSide::None => return Ok(None), }; let join_type = smj.join_type(); let probe_side = SortMergeJoinExec::probe_side(&join_type); @@ -402,7 +416,7 @@ fn try_pushdown_requirements_to_join( let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); Ok(should_pushdown.then(|| { let mut required_input_ordering = smj.required_input_ordering(); - let new_req = Some(PhysicalSortRequirement::from_sort_exprs(sort_expr)); + let new_req = Some(LexRequirement::from(sort_expr.clone())); match push_side { JoinSide::Left => { required_input_ordering[0] = new_req; @@ -410,18 +424,23 @@ fn try_pushdown_requirements_to_join( JoinSide::Right => { required_input_ordering[1] = new_req; } + JoinSide::None => unreachable!(), } required_input_ordering })) } fn expr_source_side( - required_exprs: &[PhysicalSortExpr], + required_exprs: &LexOrdering, join_type: JoinType, left_columns_len: usize, ) -> Option { match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { let all_column_sides = required_exprs .iter() .filter_map(|r| { @@ -464,7 +483,7 @@ fn expr_source_side( } fn shift_right_required( - parent_required: LexRequirementRef, + parent_required: &LexRequirement, left_columns_len: usize, ) -> Result { let new_right_required = parent_required @@ -502,7 +521,7 @@ fn shift_right_required( /// pushed down, `Ok(None)` if not. On error, returns a `Result::Err`. fn handle_custom_pushdown( plan: &Arc, - parent_required: LexRequirementRef, + parent_required: &LexRequirement, maintains_input_order: Vec, ) -> Result>>> { // If there's no requirement from the parent or the plan has no children, return early @@ -557,9 +576,7 @@ fn handle_custom_pushdown( .iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = req - .expr - .clone() + let updated_columns = Arc::clone(&req.expr) .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let new_index = col.index() - sub_offset; @@ -594,6 +611,102 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() || !plan.maintains_input_order()[1] { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .into_iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[1].schema(); + let updated_columns = Arc::clone(&req.expr) + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + Some(LexRequirement::new(updated_parent_req)), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec { + let map_fields = |schema: SchemaRef, side: JoinSide| { + schema + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { index, side }) + .collect::>() + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + map_fields(plan.left().schema(), JoinSide::Left) + .into_iter() + .chain(map_fields(plan.right().schema(), JoinSide::Right)) + .collect::>() + } + JoinType::RightSemi | JoinType::RightAnti => { + map_fields(plan.right().schema(), JoinSide::Right) + } + _ => unreachable!("unexpected join type: {}", plan.join_type()), + } +} + /// Define the Requirements Compatibility #[derive(Debug)] enum RequirementsCompatibility { diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 98f1a7c21a39b..aba24309b2a07 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -17,270 +17,17 @@ //! Collection of testing utility functions that are leveraged by the query optimizer rules -use std::any::Any; -use std::fmt::Formatter; +#![allow(missing_docs)] + use std::sync::Arc; use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; -use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; -use crate::error::Result; -use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; -use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; -use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::joins::utils::{JoinFilter, JoinOn}; -use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; -use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use crate::physical_plan::memory::MemoryExec; -use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::union::UnionExec; -use crate::physical_plan::windows::create_window_expr; -use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; -use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::physical_plan::ExecutionPlan; -use arrow_schema::{Schema, SchemaRef, SortOptions}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::JoinType; +use arrow_schema::SchemaRef; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_plan::tree_node::PlanContext; -use datafusion_physical_plan::{ - displayable, DisplayAs, DisplayFormatType, PlanProperties, -}; - -use async_trait::async_trait; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr_common::sort_expr::{ - LexRequirement, PhysicalSortRequirement, -}; - -async fn register_current_csv( - ctx: &SessionContext, - table_name: &str, - infinite: bool, -) -> Result<()> { - let testdata = crate::test_util::arrow_test_data(); - let schema = crate::test_util::aggr_test_schema(); - let path = format!("{testdata}/csv/aggregate_test_100.csv"); - - match infinite { - true => { - let source = FileStreamProvider::new_file(schema, path.into()); - let config = StreamConfig::new(Arc::new(source)); - ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; - } - false => { - ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) - .await?; - } - } - - Ok(()) -} - -#[derive(Eq, PartialEq, Debug)] -pub enum SourceType { - Unbounded, - Bounded, -} - -#[async_trait] -pub trait SqlTestCase { - async fn register_table(&self, ctx: &SessionContext) -> Result<()>; - fn expect_fail(&self) -> bool; -} - -/// [UnaryTestCase] is designed for single input [ExecutionPlan]s. -pub struct UnaryTestCase { - pub(crate) source_type: SourceType, - pub(crate) expect_fail: bool, -} - -#[async_trait] -impl SqlTestCase for UnaryTestCase { - async fn register_table(&self, ctx: &SessionContext) -> Result<()> { - let table_is_infinite = self.source_type == SourceType::Unbounded; - register_current_csv(ctx, "test", table_is_infinite).await?; - Ok(()) - } - - fn expect_fail(&self) -> bool { - self.expect_fail - } -} -/// [BinaryTestCase] is designed for binary input [ExecutionPlan]s. -pub struct BinaryTestCase { - pub(crate) source_types: (SourceType, SourceType), - pub(crate) expect_fail: bool, -} - -#[async_trait] -impl SqlTestCase for BinaryTestCase { - async fn register_table(&self, ctx: &SessionContext) -> Result<()> { - let left_table_is_infinite = self.source_types.0 == SourceType::Unbounded; - let right_table_is_infinite = self.source_types.1 == SourceType::Unbounded; - register_current_csv(ctx, "left", left_table_is_infinite).await?; - register_current_csv(ctx, "right", right_table_is_infinite).await?; - Ok(()) - } - - fn expect_fail(&self) -> bool { - self.expect_fail - } -} - -pub struct QueryCase { - pub(crate) sql: String, - pub(crate) cases: Vec>, - pub(crate) error_operator: String, -} - -impl QueryCase { - /// Run the test cases - pub(crate) async fn run(&self) -> Result<()> { - for case in &self.cases { - let ctx = SessionContext::new(); - case.register_table(&ctx).await?; - let error = if case.expect_fail() { - Some(&self.error_operator) - } else { - None - }; - self.run_case(ctx, error).await?; - } - Ok(()) - } - async fn run_case(&self, ctx: SessionContext, error: Option<&String>) -> Result<()> { - let dataframe = ctx.sql(self.sql.as_str()).await?; - let plan = dataframe.create_physical_plan().await; - if let Some(error) = error { - let plan_error = plan.unwrap_err(); - assert!( - plan_error.to_string().contains(error.as_str()), - "plan_error: {:?} doesn't contain message: {:?}", - plan_error, - error.as_str() - ); - } else { - assert!(plan.is_ok()) - } - Ok(()) - } -} - -pub fn sort_merge_join_exec( - left: Arc, - right: Arc, - join_on: &JoinOn, - join_type: &JoinType, -) -> Arc { - Arc::new( - SortMergeJoinExec::try_new( - left, - right, - join_on.clone(), - None, - *join_type, - vec![SortOptions::default(); join_on.len()], - false, - ) - .unwrap(), - ) -} - -/// make PhysicalSortExpr with default options -pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { - sort_expr_options(name, schema, SortOptions::default()) -} - -/// PhysicalSortExpr with specified options -pub fn sort_expr_options( - name: &str, - schema: &Schema, - options: SortOptions, -) -> PhysicalSortExpr { - PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options, - } -} - -pub fn coalesce_partitions_exec(input: Arc) -> Arc { - Arc::new(CoalescePartitionsExec::new(input)) -} - -pub(crate) fn memory_exec(schema: &SchemaRef) -> Arc { - Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) -} - -pub fn hash_join_exec( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, -) -> Result> { - Ok(Arc::new(HashJoinExec::try_new( - left, - right, - on, - filter, - join_type, - None, - PartitionMode::Partitioned, - true, - )?)) -} - -pub fn bounded_window_exec( - col_name: &str, - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); - let schema = input.schema(); - - Arc::new( - crate::physical_plan::windows::BoundedWindowAggExec::try_new( - vec![create_window_expr( - &WindowFunctionDefinition::AggregateUDF(count_udaf()), - "count".to_owned(), - &[col(col_name, &schema).unwrap()], - &[], - &sort_exprs, - Arc::new(WindowFrame::new(Some(false))), - schema.as_ref(), - false, - ) - .unwrap()], - input.clone(), - vec![], - InputOrderMode::Sorted, - ) - .unwrap(), - ) -} - -pub fn filter_exec( - predicate: Arc, - input: Arc, -) -> Arc { - Arc::new(FilterExec::try_new(predicate, input).unwrap()) -} - -pub fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) -} +use datafusion_physical_expr::PhysicalSortExpr; /// Create a non sorted parquet exec pub fn parquet_exec(schema: &SchemaRef) -> Arc { @@ -305,175 +52,3 @@ pub fn parquet_exec_sorted( ) .build_arc() } - -pub fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) -} - -pub fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) -} - -pub fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) -} - -pub fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) -} - -pub fn repartition_exec(input: Arc) -> Arc { - Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) -} - -pub fn spr_repartition_exec(input: Arc) -> Arc { - Arc::new( - RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)) - .unwrap() - .with_preserve_order(), - ) -} - -pub fn aggregate_exec(input: Arc) -> Arc { - let schema = input.schema(); - Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![], - vec![], - input, - schema, - ) - .unwrap(), - ) -} - -pub fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 128)) -} - -pub fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) -} - -/// A test [`ExecutionPlan`] whose requirements can be configured. -#[derive(Debug)] -pub struct RequirementsTestExec { - required_input_ordering: Vec, - maintains_input_order: bool, - input: Arc, -} - -impl RequirementsTestExec { - pub fn new(input: Arc) -> Self { - Self { - required_input_ordering: vec![], - maintains_input_order: true, - input, - } - } - - /// sets the required input ordering - pub fn with_required_input_ordering( - mut self, - required_input_ordering: Vec, - ) -> Self { - self.required_input_ordering = required_input_ordering; - self - } - - /// set the maintains_input_order flag - pub fn with_maintains_input_order(mut self, maintains_input_order: bool) -> Self { - self.maintains_input_order = maintains_input_order; - self - } - - /// returns this ExecutionPlan as an Arc - pub fn into_arc(self) -> Arc { - Arc::new(self) - } -} - -impl DisplayAs for RequirementsTestExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "RequiredInputOrderingExec") - } -} - -impl ExecutionPlan for RequirementsTestExec { - fn name(&self) -> &str { - "RequiredInputOrderingExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.input.properties() - } - - fn required_input_ordering(&self) -> Vec> { - let requirement = - PhysicalSortRequirement::from_sort_exprs(&self.required_input_ordering); - vec![Some(requirement)] - } - - fn maintains_input_order(&self) -> Vec { - vec![self.maintains_input_order] - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - assert_eq!(children.len(), 1); - Ok(RequirementsTestExec::new(children[0].clone()) - .with_required_input_ordering(self.required_input_ordering.clone()) - .with_maintains_input_order(self.maintains_input_order) - .into_arc()) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!("Test exec does not support execution") - } -} - -/// A [`PlanContext`] object is susceptible to being left in an inconsistent state after -/// untested mutable operations. It is crucial that there be no discrepancies between a plan -/// associated with the root node and the plan generated after traversing all nodes -/// within the [`PlanContext`] tree. In addition to verifying the plans resulting from optimizer -/// rules, it is essential to ensure that the overall tree structure corresponds with the plans -/// contained within the node contexts. -/// TODO: Once [`ExecutionPlan`] implements [`PartialEq`], string comparisons should be -/// replaced with direct plan equality checks. -pub fn check_integrity(context: PlanContext) -> Result> { - context - .transform_up(|node| { - let children_plans = node.plan.children(); - assert_eq!(node.children.len(), children_plans.len()); - for (child_plan, child_node) in - children_plans.iter().zip(node.children.iter()) - { - assert_eq!( - displayable(child_plan.as_ref()).one_line().to_string(), - displayable(child_node.plan.as_ref()).one_line().to_string() - ); - } - Ok(Transformed::no(node)) - }) - .data() -} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 2c0d042281e6f..9f2c28d564f05 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -27,7 +27,8 @@ use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; -use datafusion_physical_expr::{LexRequirement, PhysicalSortRequirement}; +use datafusion_physical_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::tree_node::PlanContext; @@ -38,14 +39,14 @@ pub fn add_sort_above( sort_requirements: LexRequirement, fetch: Option, ) -> PlanContext { - let mut sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirements); + let mut sort_expr = LexOrdering::from(sort_requirements); sort_expr.retain(|sort_expr| { !node .plan .equivalence_properties() .is_expr_constant(&sort_expr.expr) }); - let mut new_sort = SortExec::new(sort_expr, node.plan.clone()).with_fetch(fetch); + let mut new_sort = SortExec::new(sort_expr, Arc::clone(&node.plan)).with_fetch(fetch); if node.plan.output_partitioning().partition_count() > 1 { new_sort = new_sort.with_preserve_partitioning(true); } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index dac0634316f3c..3ad820eb5263e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ - Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window, + Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window, }; use crate::logical_expr::{ Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition, UserDefinedLogicalNode, }; -use crate::logical_expr::{Limit, Values}; use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; @@ -55,7 +54,6 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; -use crate::physical_plan::values::ValuesExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, @@ -78,8 +76,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, - StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery, + SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -328,7 +326,7 @@ impl DefaultPhysicalPlanner { // Spawning tasks which will traverse leaf up to the root. let tasks = flat_tree_leaf_indices .into_iter() - .map(|index| self.task_helper(index, flat_tree.clone(), session_state)); + .map(|index| self.task_helper(index, Arc::clone(&flat_tree), session_state)); let mut outputs = futures::stream::iter(tasks) .buffer_unordered(max_concurrency) .try_collect::>() @@ -467,7 +465,8 @@ impl DefaultPhysicalPlanner { .collect::>>>() }) .collect::>>()?; - let value_exec = ValuesExec::try_new(SchemaRef::new(exec_schema), exprs)?; + let value_exec = + MemoryExec::try_new_as_values(SchemaRef::new(exec_schema), exprs)?; Arc::new(value_exec) } LogicalPlan::EmptyRelation(EmptyRelation { @@ -487,7 +486,7 @@ impl DefaultPhysicalPlanner { output_schema, }) => { let output_schema: Schema = output_schema.as_ref().into(); - self.plan_describe(schema.clone(), Arc::new(output_schema))? + self.plan_describe(Arc::clone(schema), Arc::new(output_schema))? } // 1 Child @@ -522,6 +521,9 @@ impl DefaultPhysicalPlanner { return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{}\"", value))), }; + let sink_format = file_type_to_format(file_type)? + .create(session_state, source_option_tuples)?; + // Set file sink related options let config = FileSinkConfig { object_store_url, @@ -531,11 +533,9 @@ impl DefaultPhysicalPlanner { table_partition_cols, insert_op: InsertOp::Append, keep_partition_by_columns, + file_extension: sink_format.get_ext(), }; - let sink_format = file_type_to_format(file_type)? - .create(session_state, source_option_tuples)?; - sink_format .create_writer_physical_plan(input_exec, session_state, config, None) .await? @@ -650,15 +650,51 @@ impl DefaultPhysicalPlanner { aggr_expr, .. }) => { + let options = session_state.config().options(); // Initially need to perform the aggregate and then merge the partitions let input_exec = children.one()?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); - let physical_input_schema_from_logical: Arc = - logical_input_schema.as_ref().clone().into(); + let physical_input_schema_from_logical = logical_input_schema.inner(); - if physical_input_schema != physical_input_schema_from_logical { - return internal_err!("Physical input schema should be the same as the one converted from logical input schema."); + if !options.execution.skip_physical_aggregate_schema_check + && &physical_input_schema != physical_input_schema_from_logical + { + let mut differences = Vec::new(); + if physical_input_schema.fields().len() + != physical_input_schema_from_logical.fields().len() + { + differences.push(format!( + "Different number of fields: (physical) {} vs (logical) {}", + physical_input_schema.fields().len(), + physical_input_schema_from_logical.fields().len() + )); + } + for (i, (physical_field, logical_field)) in physical_input_schema + .fields() + .iter() + .zip(physical_input_schema_from_logical.fields()) + .enumerate() + { + if physical_field.name() != logical_field.name() { + differences.push(format!( + "field name at index {}: (physical) {} vs (logical) {}", + i, + physical_field.name(), + logical_field.name() + )); + } + if physical_field.data_type() != logical_field.data_type() { + differences.push(format!("field data type at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.data_type(), logical_field.data_type())); + } + if physical_field.is_nullable() != logical_field.is_nullable() { + differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); + } + } + return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences + .iter() + .map(|s| format!("\n\t- {}", s)) + .join("")); } let groups = self.create_grouping_physical_expr( @@ -689,7 +725,7 @@ impl DefaultPhysicalPlanner { aggregates, filters.clone(), input_exec, - physical_input_schema.clone(), + Arc::clone(&physical_input_schema), )?); let can_repartition = !groups.is_empty() @@ -720,7 +756,7 @@ impl DefaultPhysicalPlanner { updated_aggregates, filters, initial_aggr, - physical_input_schema.clone(), + Arc::clone(&physical_input_schema), )?) } LogicalPlan::Projection(Projection { input, expr, .. }) => self @@ -796,8 +832,20 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(limit) => { let input = children.one()?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!( + "Unsupported OFFSET expression: {:?}", + limit.skip + ); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!( + "Unsupported LIMIT expression: {:?}", + limit.fetch + ); + }; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -806,13 +854,13 @@ impl DefaultPhysicalPlanner { // Apply a LocalLimitExec to each partition. The optimizer will also insert // a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec if let Some(fetch) = fetch { - Arc::new(LocalLimitExec::new(input, *fetch + skip)) + Arc::new(LocalLimitExec::new(input, fetch + skip)) } else { input } }; - Arc::new(GlobalLimitExec::new(input, *skip, *fetch)) + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } LogicalPlan::Unnest(Unnest { list_type_columns, @@ -880,8 +928,8 @@ impl DefaultPhysicalPlanner { let right = Arc::new(right); let new_join = LogicalPlan::Join(Join::try_new_with_project_input( node, - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), column_on, )?); @@ -1014,14 +1062,21 @@ impl DefaultPhysicalPlanner { }) .collect(); + let metadata: HashMap<_, _> = left_df_schema + .metadata() + .clone() + .into_iter() + .chain(right_df_schema.metadata().clone()) + .collect(); + // Construct intermediate schemas used for filtering data and // convert logical expression to physical according to filter schema let filter_df_schema = DFSchema::new_with_metadata( filter_df_fields, - HashMap::new(), + metadata.clone(), )?; let filter_schema = - Schema::new_with_metadata(filter_fields, HashMap::new()); + Schema::new_with_metadata(filter_fields, metadata); let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1045,14 +1100,19 @@ impl DefaultPhysicalPlanner { session_state.config_options().optimizer.prefer_hash_join; let join: Arc = if join_on.is_empty() { - // there is no equal join condition, use the nested loop join - // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` - Arc::new(NestedLoopJoinExec::try_new( - physical_left, - physical_right, - join_filter, - join_type, - )?) + if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + // cross join if there is no join conditions and no join filter set + Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else { + // there is no equal join condition, use the nested loop join + Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + None, + )?) + } } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && !prefer_hash_join @@ -1112,10 +1172,6 @@ impl DefaultPhysicalPlanner { join } } - LogicalPlan::CrossJoin(_) => { - let [left, right] = children.two()?; - Arc::new(CrossJoinExec::new(left, right)) - } LogicalPlan::RecursiveQuery(RecursiveQuery { name, is_distinct, .. }) => { @@ -1175,12 +1231,6 @@ impl DefaultPhysicalPlanner { let name = statement.name(); return not_impl_err!("Unsupported logical plan: Statement({name})"); } - LogicalPlan::Prepare(_) => { - // There is no default plan for "PREPARE" -- it must be - // handled at a higher level (so that the appropriate - // statement can be prepared) - return not_impl_err!("Unsupported logical plan: Prepare"); - } LogicalPlan::Dml(dml) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this return not_impl_err!("Unsupported logical plan: Dml({0})", dml.op); @@ -1498,7 +1548,7 @@ pub fn create_window_expr_with_name( name, &physical_args, &partition_by, - &order_by, + order_by.as_ref(), window_frame, physical_schema, ignore_nulls, @@ -1523,11 +1573,11 @@ pub fn create_window_expr( } type AggregateExprWithOptionalArgs = ( - AggregateFunctionExpr, + Arc, // The filter clause, if any Option>, // Ordering requirements, if any - Option>, + Option, ); /// Create an aggregate expression with a name from a logical expression @@ -1577,17 +1627,18 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); + let ordering_reqs: LexOrdering = + physical_sort_exprs.clone().unwrap_or_default(); let agg_expr = AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) - .order_by(ordering_reqs.to_vec()) + .order_by(ordering_reqs) .schema(Arc::new(physical_input_schema.to_owned())) .alias(name) .with_ignore_nulls(ignore_nulls) .with_distinct(*distinct) - .build()?; + .build() + .map(Arc::new)?; (agg_expr, filter, physical_sort_exprs) }; @@ -1650,7 +1701,7 @@ pub fn create_physical_sort_exprs( exprs .iter() .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) - .collect::>>() + .collect::>() } impl DefaultPhysicalPlanner { @@ -1773,8 +1824,12 @@ impl DefaultPhysicalPlanner { Err(e) => return Err(e), } } - Err(e) => stringified_plans - .push(StringifiedPlan::new(InitialPhysicalPlan, e.to_string())), + Err(err) => { + stringified_plans.push(StringifiedPlan::new( + PhysicalPlanError, + err.strip_backtrace(), + )); + } } } @@ -1964,7 +2019,7 @@ mod tests { use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::MemTable; use crate::physical_plan::{ - expressions, DisplayAs, DisplayFormatType, ExecutionMode, PlanProperties, + expressions, DisplayAs, DisplayFormatType, PlanProperties, SendableRecordBatchStream, }; use crate::prelude::{SessionConfig, SessionContext}; @@ -1979,6 +2034,7 @@ mod tests { use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; fn make_session_state() -> SessionState { let runtime = Arc::new(RuntimeEnv::default()); @@ -2566,13 +2622,11 @@ mod tests { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); PlanProperties::new( - eq_properties, - // Output Partitioning + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(1), - // Execution Mode - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 08740daa0c8e7..e91785c7421aa 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -17,7 +17,8 @@ //! Common unit test utility methods -use std::any::Any; +#![allow(missing_docs)] + use std::fs::File; use std::io::prelude::*; use std::io::{BufReader, BufWriter}; @@ -40,13 +41,10 @@ use crate::test_util::{aggr_test_schema, arrow_test_data}; use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Statistics}; +use datafusion_common::DataFusionError; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalSortExpr}; +use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, PlanProperties, -}; #[cfg(feature = "compression")] use bzip2::write::BzEncoder; @@ -69,7 +67,7 @@ pub fn create_table_dual() -> Arc { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), Arc::new(array::StringArray::from(vec!["a"])), ], ) @@ -105,6 +103,28 @@ pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result(BzEncoder); + +#[cfg(feature = "compression")] +impl Write for AutoFinishBzEncoder { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.0.flush() + } +} + +#[cfg(feature = "compression")] +impl Drop for AutoFinishBzEncoder { + fn drop(&mut self) { + let _ = self.0.try_finish(); + } +} + /// Returns file groups [`Vec>`] for scanning `partitions` of `filename` pub fn partitioned_file_groups( path: &str, @@ -146,9 +166,10 @@ pub fn partitioned_file_groups( Box::new(encoder) } #[cfg(feature = "compression")] - FileCompressionType::BZIP2 => { - Box::new(BzEncoder::new(file, BzCompression::default())) - } + FileCompressionType::BZIP2 => Box::new(AutoFinishBzEncoder(BzEncoder::new( + file, + BzCompression::default(), + ))), #[cfg(not(feature = "compression"))] FileCompressionType::GZIP | FileCompressionType::BZIP2 @@ -182,8 +203,8 @@ pub fn partitioned_file_groups( } } - // Must drop the stream before creating ObjectMeta below as drop - // triggers finish for ZstdEncoder which writes additional data + // Must drop the stream before creating ObjectMeta below as drop triggers + // finish for ZstdEncoder/BzEncoder which writes additional data for mut w in writers.into_iter() { w.flush().unwrap(); } @@ -360,96 +381,5 @@ pub fn csv_exec_ordered( ) } -/// A mock execution plan that simply returns the provided statistics -#[derive(Debug, Clone)] -pub struct StatisticsExec { - stats: Statistics, - schema: Arc, - cache: PlanProperties, -} - -impl StatisticsExec { - pub fn new(stats: Statistics, schema: Schema) -> Self { - assert_eq!( - stats.column_statistics.len(), schema.fields().len(), - "if defined, the column statistics vector length should be the number of fields" - ); - let cache = Self::compute_properties(Arc::new(schema.clone())); - Self { - stats, - schema: Arc::new(schema), - cache, - } - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - PlanProperties::new( - eq_properties, - // Output Partitioning - Partitioning::UnknownPartitioning(2), - // Execution Mode - ExecutionMode::Bounded, - ) - } -} - -impl DisplayAs for StatisticsExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "StatisticsExec: col_count={}, row_count={:?}", - self.schema.fields().len(), - self.stats.num_rows, - ) - } - } - } -} - -impl ExecutionPlan for StatisticsExec { - fn name(&self) -> &'static str { - Self::static_name() - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!("This plan only serves for testing statistics") - } - - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } -} - pub mod object_store; pub mod variable; diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 6c0a2fc7bec47..cac430c5b49d4 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -14,7 +14,9 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + //! Object store implementation used for testing + use crate::execution::context::SessionState; use crate::execution::session_state::SessionStateBuilder; use crate::prelude::SessionContext; diff --git a/datafusion/core/src/test_util/csv.rs b/datafusion/core/src/test_util/csv.rs new file mode 100644 index 0000000000000..94c7efb954022 --- /dev/null +++ b/datafusion/core/src/test_util/csv.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for writing csv files and reading them back + +use std::fs::File; +use std::path::PathBuf; +use std::sync::Arc; + +use crate::arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use crate::error::Result; + +use arrow::csv::WriterBuilder; + +/// a CSV file that has been created for testing. +pub struct TestCsvFile { + path: PathBuf, + schema: SchemaRef, +} + +impl TestCsvFile { + /// Creates a new csv file at the specified location + pub fn try_new( + path: PathBuf, + batches: impl IntoIterator, + ) -> Result { + let file = File::create(&path).unwrap(); + let builder = WriterBuilder::new().with_header(true); + let mut writer = builder.build(file); + + let mut batches = batches.into_iter(); + let first_batch = batches.next().expect("need at least one record batch"); + let schema = first_batch.schema(); + + let mut num_rows = 0; + for batch in batches { + writer.write(&batch)?; + num_rows += batch.num_rows(); + } + + println!("Generated test dataset with {num_rows} rows"); + + Ok(Self { path, schema }) + } + + /// The schema of this csv file + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// The path to the csv file + pub fn path(&self) -> &std::path::Path { + self.path.as_path() + } +} diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index e03c18fec7c4a..d608db25fe981 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -20,6 +20,8 @@ #[cfg(feature = "parquet")] pub mod parquet; +pub mod csv; + use std::any::Any; use std::collections::HashMap; use std::fs::File; @@ -34,25 +36,21 @@ use crate::dataframe::DataFrame; use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use crate::datasource::{empty::EmptyTable, provider_as_source}; use crate::error::Result; -use crate::execution::context::TaskContext; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; -use crate::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, -}; +use crate::physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_functions_aggregate::count::count_udaf; -use datafusion_physical_expr::{expressions, EquivalenceProperties, PhysicalExpr}; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::{expressions, PhysicalExpr}; use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use futures::Stream; use tempfile::TempDir; // backwards compatibility @@ -209,7 +207,7 @@ impl TableProvider for TestTableProvider { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { @@ -227,135 +225,6 @@ impl TableProvider for TestTableProvider { } } -/// A mock execution plan that simply returns the provided data source characteristic -#[derive(Debug, Clone)] -pub struct UnboundedExec { - batch_produce: Option, - batch: RecordBatch, - cache: PlanProperties, -} -impl UnboundedExec { - /// Create new exec that clones the given record batch to its output. - /// - /// Set `batch_produce` to `Some(n)` to emit exactly `n` batches per partition. - pub fn new( - batch_produce: Option, - batch: RecordBatch, - partitions: usize, - ) -> Self { - let cache = Self::compute_properties(batch.schema(), batch_produce, partitions); - Self { - batch_produce, - batch, - cache, - } - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - batch_produce: Option, - n_partitions: usize, - ) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - let mode = if batch_produce.is_none() { - ExecutionMode::Unbounded - } else { - ExecutionMode::Bounded - }; - PlanProperties::new( - eq_properties, - Partitioning::UnknownPartitioning(n_partitions), - mode, - ) - } -} - -impl DisplayAs for UnboundedExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "UnboundableExec: unbounded={}", - self.batch_produce.is_none(), - ) - } - } - } -} - -impl ExecutionPlan for UnboundedExec { - fn name(&self) -> &'static str { - Self::static_name() - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - Ok(Box::pin(UnboundedStream { - batch_produce: self.batch_produce, - count: 0, - batch: self.batch.clone(), - })) - } -} - -#[derive(Debug)] -struct UnboundedStream { - batch_produce: Option, - count: usize, - batch: RecordBatch, -} - -impl Stream for UnboundedStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if let Some(val) = self.batch_produce { - if val <= self.count { - return Poll::Ready(None); - } - } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } -} - -impl RecordBatchStream for UnboundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() - } -} - /// This function creates an unbounded sorted file for testing purposes. pub fn register_unbounded_file_with_ordering( ctx: &SessionContext, @@ -425,7 +294,7 @@ impl TestAggregate { /// Create a new COUNT(column) aggregate pub fn new_count_column(schema: &Arc) -> Self { - Self::ColumnA(schema.clone()) + Self::ColumnA(Arc::clone(schema)) } /// Return appropriate expr depending if COUNT is for col or table (*) diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 9f06ad9308ab8..685ed14777b40 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -87,7 +87,8 @@ impl TestParquetFile { let first_batch = batches.next().expect("need at least one record batch"); let schema = first_batch.schema(); - let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props)).unwrap(); + let mut writer = + ArrowWriter::try_new(file, Arc::clone(&schema), Some(props)).unwrap(); writer.write(&first_batch).unwrap(); let mut num_rows = first_batch.num_rows(); @@ -102,7 +103,17 @@ impl TestParquetFile { let size = std::fs::metadata(&path)?.len() as usize; - let canonical_path = path.canonicalize()?; + let mut canonical_path = path.canonicalize()?; + + if cfg!(target_os = "windows") { + canonical_path = canonical_path + .to_str() + .unwrap() + .replace("\\", "/") + .strip_prefix("//?/") + .unwrap() + .into(); + }; let object_store_url = ListingTableUrl::parse(canonical_path.to_str().unwrap_or_default())? @@ -144,20 +155,21 @@ impl TestParquetFile { maybe_filter: Option, ) -> Result> { let scan_config = - FileScanConfig::new(self.object_store_url.clone(), self.schema.clone()) + FileScanConfig::new(self.object_store_url.clone(), Arc::clone(&self.schema)) .with_file(PartitionedFile { object_meta: self.object_meta.clone(), partition_values: vec![], range: None, statistics: None, extensions: None, + metadata_size_hint: None, }); - let df_schema = self.schema.clone().to_dfschema_ref()?; + let df_schema = Arc::clone(&self.schema).to_dfschema_ref()?; // run coercion on the filters to coerce types etc. let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(df_schema.clone()); + let context = SimplifyContext::new(&props).with_schema(Arc::clone(&df_schema)); let parquet_options = ctx.copied_table_options().parquet; if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); @@ -167,7 +179,7 @@ impl TestParquetFile { let parquet_exec = ParquetExecBuilder::new_with_options(scan_config, parquet_options) - .with_predicate(physical_filter_expr.clone()) + .with_predicate(Arc::clone(&physical_filter_expr)) .build_arc(); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?); @@ -199,7 +211,7 @@ impl TestParquetFile { /// The schema of this parquet file pub fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } /// The path to the parquet file diff --git a/datafusion/core/tests/config_from_env.rs b/datafusion/core/tests/config_from_env.rs index a5a5a4524e609..976597c8a9ac5 100644 --- a/datafusion/core/tests/config_from_env.rs +++ b/datafusion/core/tests/config_from_env.rs @@ -22,10 +22,19 @@ use std::env; fn from_env() { // Note: these must be a single test to avoid interference from concurrent execution let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; - env::set_var(env_key, "true"); - let config = ConfigOptions::from_env().unwrap(); + // valid testing in different cases + for bool_option in ["true", "TRUE", "True", "tRUe"] { + env::set_var(env_key, bool_option); + let config = ConfigOptions::from_env().unwrap(); + env::remove_var(env_key); + assert!(config.optimizer.filter_null_join_keys); + } + + // invalid testing + env::set_var(env_key, "ttruee"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!(err, "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`"); env::remove_var(env_key); - assert!(config.optimizer.filter_null_join_keys); let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; @@ -37,7 +46,7 @@ fn from_env() { // for invalid testing env::set_var(env_key, "abc"); let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing abc as usize\ncaused by\nExternal error: invalid digit found in string"); + assert_eq!(err, "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string"); env::remove_var(env_key); let config = ConfigOptions::from_env().unwrap(); diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index 79e5056e3cf5b..e0917e6cca198 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -24,6 +24,9 @@ mod dataframe; /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; +/// Run all tests that are found in the `execution` directory +mod execution; + /// Run all tests that are found in the `expr_api` directory mod expr_api; diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index e1bd14105e23e..aafefac04e321 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -35,15 +35,16 @@ use datafusion::physical_plan::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; use datafusion::scalar::ScalarValue; +use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; use datafusion_common::project_schema; use datafusion_common::stats::Precision; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::{ExecutionMode, PlanProperties}; +use datafusion_physical_plan::PlanProperties; use async_trait::async_trait; -use datafusion_catalog::Session; use futures::stream::Stream; mod provider_filter_pushdown; @@ -91,12 +92,11 @@ impl CustomExecutionPlan { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); PlanProperties::new( - eq_properties, - // Output Partitioning + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index ca390ada2b696..30e0e736c0619 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -28,19 +28,20 @@ use datafusion::execution::context::TaskContext; use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; +use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; -use datafusion_catalog::Session; fn create_batch(value: i32, num_rows: usize) -> Result { let mut builder = Int32Builder::with_capacity(num_rows); @@ -72,11 +73,11 @@ impl CustomPlan { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); PlanProperties::new( - eq_properties, + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 41d182a3767b3..9d3bd594a9299 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -26,17 +26,18 @@ use datafusion::{ error::Result, logical_expr::Expr, physical_plan::{ - ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, - Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PlanProperties, SendableRecordBatchStream, Statistics, }, prelude::SessionContext, scalar::ScalarValue, }; +use datafusion_catalog::Session; use datafusion_common::{project_schema, stats::Precision}; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; -use datafusion_catalog::Session; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -64,12 +65,11 @@ impl StatisticsValidation { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - PlanProperties::new( - eq_properties, + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(2), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } diff --git a/datafusion/core/tests/data/csv/aggregate_test_100_with_nulls.csv b/datafusion/core/tests/data/csv/aggregate_test_100_with_nulls.csv new file mode 100644 index 0000000000000..0aabb2785250a --- /dev/null +++ b/datafusion/core/tests/data/csv/aggregate_test_100_with_nulls.csv @@ -0,0 +1,101 @@ +c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15 +c,2,1,18109,2033001162,-6513304855495910254,25,43062,1491205016,5863949479783605708,0.110830784,0.9294097332465232,6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW,,NULL +d,5,-40,22614,706441268,-7542719935673075327,155,14337,3373581039,11720144131976083864,0.69632107,0.3114712539863804,C2GT5KVyOPZpgKVl110TyZO0NcJ434,,NULL +b,1,29,-18218,994303988,5983957848665088916,204,9489,3275293996,14857091259186476033,0.53840446,0.17909035118828576,AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz,,NULL +a,1,-85,-15154,1171968280,1919439543497968449,77,52286,774637006,12101411955859039553,0.12285209,0.6864391962767343,0keZ5G8BffGwgF2RwQD59TFzMStxCB,,NULL +b,5,-82,22080,1824882165,7373730676428214987,208,34331,3342719438,3330177516592499461,0.82634634,0.40975383525297016,Ig1QcuKsjHXkproePdERo2w0mYzIqd,,NULL +b,4,-111,-1967,-4229382,1892872227362838079,67,9832,1243785310,8382489916947120498,0.06563997,0.152498292971736,Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH,,NULL +e,3,104,-25136,1738331255,300633854973581194,139,20807,3577318119,13079037564113702254,0.40154034,0.7764360990307122,DuJNG8tufSqW0ZstHqWj3aGvFLMg4A,,NULL +a,3,13,12613,1299719633,2020498574254265315,191,17835,3998790955,14881411008939145569,0.041445434,0.8813167497816289,Amn2K87Db5Es3dFQO9cw9cvpAM6h35,,NULL +d,1,38,18384,-335410409,-1632237090406591229,26,57510,2712615025,1842662804748246269,0.6064476,0.6404495093354053,4HX6feIvmNXBN7XGqgO4YVBkhu8GDI,,NULL +a,4,-38,20744,762932956,308913475857409919,7,45465,1787652631,878137512938218976,0.7459874,0.02182578039211991,ydkwycaISlYSlEq3TlkS2m15I2pcp8,,NULL +d,1,57,28781,-1143802338,2662536767954229885,202,62167,879082834,4338034436871150616,0.7618384,0.42950521730777025,VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4,,NULL +a,4,-54,-2376,434021400,5502271306323260832,113,15777,2502326480,7966148640299601101,0.5720931,0.30585375151301186,KJFcmTVjdkCMv94wYCtfHMFhzyRsmH,,NULL +e,3,112,-6823,-421042466,8535335158538929274,129,32712,3759340273,9916295859593918600,0.6424343,0.6316565296547284,BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE,,NULL +d,2,113,3917,-108973366,-7220140168410319165,197,24380,63044568,4225581724448081782,0.11867094,0.2944158618048994,90gAtmGEeIqUTbo1ZrxCvWtsseukXC,,NULL +b,1,54,-18410,1413111008,-7145106120930085900,249,5382,1842680163,17818611040257178339,0.8881188,0.24899794314659673,6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ,,NULL +c,1,103,-22186,431378678,1346564663822463162,146,12393,3766999078,10901819591635583995,0.064453244,0.7784918983501654,2T3wSlHdEmASmO0xcXHnndkKEt6bz8,,NULL +e,2,49,24495,-587831330,9178511478067509438,129,12757,1289293657,10948666249269100825,0.5610077,0.5991138115095911,bgK1r6v3BCTh0aejJUhkA1Hn6idXGp,,NULL +d,1,-98,13630,-1991133944,1184110014998006843,220,2986,225513085,9634106610243643486,0.89651865,0.1640882545084913,y7C453hRWd4E7ImjNDWlpexB8nUqjh,,NULL +d,3,77,15091,-1302295658,8795481303066536947,154,35477,2093538928,17419098323248948387,0.11952883,0.7035635283169166,O66j6PaYuZhEUtqV6fuU7TyjM2WxC5,,NULL +e,2,97,18167,1593800404,-9112448817105133638,163,45185,3188005828,2792105417953811674,0.38175434,0.4094218353587008,ukOiFGGFnQJDHFgZxHMpvhD3zybF0M,,NULL +e,4,-56,-31500,1544188174,3096047390018154410,220,417,557517119,2774306934041974261,0.15459597,0.19113293583306745,IZTkHMLvIKuiLjhDjYMmIHxh166we4,,NULL +d,1,-99,5613,1213926989,-8863698443222021480,19,18736,4216440507,14933742247195536130,0.6067944,0.33639590659276175,aDxBtor7Icd9C5hnTvvw5NrIre740e,,NULL +a,5,36,-16974,623103518,6834444206535996609,71,29458,141047417,17448660630302620693,0.17100024,0.04429073092078406,OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh,,NULL +e,4,-53,13788,2064155045,-691093532952651300,243,35106,2778168728,9463973906560740422,0.34515214,0.27159190516490006,0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm,,NULL +c,2,-29,25305,-537142430,-7683452043175617798,150,31648,598822671,11759014161799384683,0.8315913,0.946325164889271,9UbObCsVkmYpJGcGrgfK90qOnwb2Lj,,NULL +a,1,-25,15295,383352709,4980135132406487265,231,102,3276123488,12763583666216333412,0.53796273,0.17592486905979987,XemNcT1xp61xcM1Qz3wZ1VECCnq06O,,NULL +c,4,123,16620,852509237,-3087630526856906991,196,33715,3566741189,4546434653720168472,0.07606989,0.819715865079681,8LIh0b6jmDGm87BmIyjdxNIpX4ugjD,,NULL +a,5,-31,-12907,586844478,-4862189775214031241,170,28086,1013876852,11005002152861474932,0.35319167,0.05573662213439634,MeSTAXq8gVxVjbEjgkvU9YLte0X9uE,,NULL +a,2,45,15673,-1899175111,398282800995316041,99,2555,145294611,8554426087132697832,0.17333257,0.6405262429561641,b3b9esRhTzFEawbs6XhpKnD9ojutHB,,NULL +b,3,17,14457,670497898,-2390782464845307388,255,24770,1538863055,12662506238151717757,0.34077626,0.7614304100703713,6x93sxYioWuq5c9Kkk8oTAAORM7cH0,,NULL +e,4,97,-13181,2047637360,6176835796788944083,158,53000,2042457019,9726016502640071617,0.7085086,0.12357539988406441,oHJMNvWuunsIMIWFnYG31RCfkOo2V7,,NULL +c,2,-60,-16312,-1808210365,-3368300253197863813,71,39635,2844041986,7045482583778080653,0.805363,0.6425694115212065,BJqx5WokrmrrezZA0dUbleMYkG5U2O,,NULL +e,1,36,-21481,-928766616,-3471238138418013024,150,52569,2610290479,7788847578701297242,0.2578469,0.7670021786149205,gpo8K5qtYePve6jyPt6xgJx4YOVjms,,NULL +b,5,-5,24896,1955646088,2430204191283109071,118,43655,2424630722,11429640193932435507,0.87989986,0.7328050041291218,JafwVLSVk5AVoXFuzclesQ000EE2k1,,NULL +a,3,13,32064,912707948,3826618523497875379,42,21463,2214035726,10771380284714693539,0.6133468,0.7325106678655877,i6RQVXKUh7MzuGMDaNclUYnFUAireU,,NULL +c,1,41,-4667,-644225469,7049620391314639084,196,48099,2125812933,15419512479294091215,0.5780736,0.9255031346434324,mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS,,NULL +d,2,93,-12642,2053379412,6468763445799074329,147,50842,1000948272,5536487915963301239,0.4279275,0.28534428578703896,lqhzgLsXZ8JhtpeeUWWNbMz8PHI705,,NULL +c,3,73,-9565,-382483011,1765659477910680019,186,1535,1088543984,2906943497598597237,0.680652,0.6009475544728957,Ow5PGpfTm4dXCfTDsXAOTatXRoAydR,,NULL +c,3,-2,-18655,-2141999138,-3154042970870838072,251,34970,3862393166,13062025193350212516,0.034291923,0.7697753383420857,IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr,,NULL +c,3,22,13741,-2098805236,8604102724776612452,45,2516,1362369177,196777795886465166,0.94669616,0.0494924465469434,6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE,,NULL +b,2,63,21456,-2138770630,-2380041687053733364,181,57594,2705709344,13144161537396946288,0.09683716,0.3051364088814128,nYVJnVicpGRqKZibHyBAmtmzBXAFfT,,NULL +d,4,102,-24558,1991172974,-7823479531661596016,14,36599,1534194097,2240998421986827216,0.028003037,0.8824879447595726,0og6hSkhbX8AC1ktFS4kounvTzy8Vo,,NULL +d,1,-8,27138,-1383162419,7682021027078563072,36,64517,2861376515,9904216782086286050,0.80954456,0.9463098243875633,AFGCj7OWlEB5QfniEFgonMq90Tq5uH,,NULL +a,3,17,-22796,1337043149,-1282905594104562444,167,2809,754775609,732272194388185106,0.3884129,0.658671129040488,VDhtJkYjAYPykCgOU9x3v7v3t4SO1a,,NULL +e,2,52,23388,715235348,605432070100399212,165,56980,3314983189,7386391799827871203,0.46076488,0.980809631269599,jQimhdepw3GKmioWUlVSWeBVRKFkY3,,NULL +b,5,68,21576,1188285940,5717755781990389024,224,27600,974297360,9865419128970328044,0.80895734,0.7973920072996036,ioEncce3mPOXD2hWhpZpCPWGATG6GU,,NULL +b,2,31,23127,-800561771,-8706387435232961848,153,27034,1098639440,3343692635488765507,0.35692692,0.5590205548347534,okOkcWflkNXIy4R8LzmySyY1EC3sYd,,NULL +c,1,-24,-24085,-1882293856,7385529783747709716,41,48048,520189543,2402288956117186783,0.39761502,0.3600766362333053,Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u,,NULL +a,4,65,-28462,-1813935549,7602389238442209730,18,363,1865307672,11378396836996498283,0.09130204,0.5593249815276734,WHmjWk2AY4c6m7DA4GitUx6nmb1yYS,,NULL +d,1,125,31106,-1176490478,-4306856842351827308,90,17910,3625286410,17869394731126786457,0.8882508,0.7631239070049998,dVdvo6nUD5FgCgsbOZLds28RyGTpnx,,NULL +b,4,17,-28070,-673237643,1904316899655860234,188,27744,933879086,3732692885824435932,0.41860116,0.40342283197779727,JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ,,NULL +c,2,-106,-1114,-1927628110,1080308211931669384,177,20421,141680161,7464432081248293405,0.56749094,0.565352842229935,Vp3gmWunM5A7wOC9YW2JroFqTWjvTi,,NULL +d,5,-59,2045,-2117946883,1170799768349713170,189,63353,1365198901,2501626630745849169,0.75173044,0.18628859265874176,F7NSTjWvQJyBburN7CXRUlbgp2dIrA,,NULL +d,4,55,-1471,1902023838,1252101628560265705,157,3691,811650497,1524771507450695976,0.2968701,0.5437595540422571,f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX,,NULL +b,2,-60,-21739,-1908480893,-8897292622858103761,59,50009,2525744318,1719090662556698549,0.52930677,0.560333188635217,l7uwDoTepWwnAP0ufqtHJS3CRi7RfP,,NULL +d,3,-76,8809,141218956,-9110406195556445909,58,5494,1824517658,12046662515387914426,0.8557294,0.6668423897406515,Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK,,NULL +e,4,73,-22501,1282464673,2541794052864382235,67,21119,538589788,9575476605699527641,0.48515016,0.296036538664718,4JznSdBajNWhu4hRQwjV1FjTTxY68i,,NULL +b,4,-117,19316,2051224722,-5534418579506232438,133,52046,3023531799,13684453606722360110,0.62608826,0.8506721053047003,mhjME0zBHbrK6NMkytMTQzOssOa1gF,,NULL +a,4,-101,11640,1993193190,2992662416070659899,230,40566,466439833,16778113360088370541,0.3991115,0.574210838214554,NEhyk8uIx4kEULJGa8qIyFjjBcP2G6,,NULL +b,5,62,16337,41423756,-2274773899098124524,121,34206,2307004493,10575647935385523483,0.23794776,0.1754261586710173,qnPOOmslCJaT45buUisMRnM0rc77EK,,NULL +c,4,-79,5281,-237425046,373011991904079451,121,55620,2818832252,2464584078983135763,0.49774808,0.9237877978193884,t6fQUjJejPcjc04wHvHTPe55S65B4V,,NULL +b,2,68,15874,49866617,1179733259727844435,121,23948,3455216719,3898128009708892708,0.6306253,0.9185813970744787,802bgTGl6Bk5TlkPYYTxp5JkKyaYUA,,NULL +c,1,70,27752,1325868318,1241882478563331892,63,61637,473294098,4976799313755010034,0.13801557,0.5081765563442366,Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn,,NULL +e,2,-61,-2888,-1660426473,2553892468492435401,126,35429,4144173353,939909697866979632,0.4405142,0.9231889896940375,BPtQMxnuSPpxMExYV9YkDa6cAN7GP3,,NULL +e,4,74,-12612,-1885422396,1702850374057819332,130,3583,3198969145,10767179755613315144,0.5518061,0.5614503754617461,QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv,,NULL +d,2,122,10130,-168758331,-3179091803916845592,30,794,4061635107,15695681119022625322,0.69592506,0.9748360509016578,OPwBqCEK5PWTjWaiOyL45u2NLTaDWv,,NULL +e,3,71,194,1436496767,-5639533800082367925,158,44507,3105312559,3998472996619161534,0.930117,0.6108938307533,pTeu0WMjBRTaNRT15rLCuEh3tBJVc5,,NULL +c,5,-94,-15880,2025611582,-3348824099853919681,5,40622,4268716378,12849419495718510869,0.34163946,0.4830878559436823,RilTlL1tKkPOUFuzmLydHAVZwv1OGl,,NULL +d,1,-72,25590,1188089983,3090286296481837049,241,832,3542840110,5885937420286765261,0.41980565,0.21535402343780985,wwXqSGKLyBQyPkonlzBNYUJTCo4LRS,,NULL +e,1,71,-5479,-1339586153,-3920238763788954243,123,53012,4229654142,10297218950720052365,0.73473036,0.5773498217058918,cBGc0kSm32ylBDnxogG727C0uhZEYZ,,NULL +e,4,96,-30336,427197269,7506304308750926996,95,48483,3521368277,5437030162957481122,0.58104324,0.42073125331890115,3BEOHQsMEFZ58VcNTOJYShTBpAPzbt,,NULL +a,2,-48,-18025,439738328,-313657814587041987,222,13763,3717551163,9135746610908713318,0.055064857,0.9800193410444061,ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8,,NULL +a,1,-56,8692,2106705285,-7811675384226570375,231,15573,1454057357,677091006469429514,0.42794758,0.2739938529235548,JN0VclewmjwYlSl8386MlWv5rEhWCz,,NULL +e,2,52,-12056,-1090239422,9011500141803970147,238,4168,2013662838,12565360638488684051,0.6694766,0.39144436569161134,xipQ93429ksjNcXPX5326VSg1xJZcW,,NULL +a,1,-5,12636,794623392,2909750622865366631,15,24022,2669374863,4776679784701509574,0.29877836,0.2537253407987472,waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs,,NULL +b,1,12,7652,-1448995523,-5332734971209541785,136,49283,4076864659,15449267433866484283,0.6214579,0.05636955101974106,akiiY5N0I44CMwEnBL6RTBk7BRkxEj,,NULL +e,5,64,-26526,1689098844,8950618259486183091,224,45253,662099130,16127995415060805595,0.2897315,0.5759450483859969,56MZa5O1hVtX4c5sbnCfxuX5kDChqI,,NULL +c,4,-90,-2935,1579876740,6733733506744649678,254,12876,3593959807,4094315663314091142,0.5708688,0.5603062368164834,Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV,,NULL +e,5,-86,32514,-467659022,-8012578250188146150,254,2684,2861911482,2126626171973341689,0.12559289,0.01479305307777301,gxfHWUF8XgY2KdFxigxvNEXe2V2XMl,,NULL +c,2,-117,-30187,-1222533990,-191957437217035800,136,47061,2293105904,12659011877190539078,0.2047385,0.9706712283358269,pLk3i59bZwd5KBZrI1FiweYTd5hteG,,NULL +a,3,14,28162,397430452,-452851601758273256,57,14722,431948861,8164671015278284913,0.40199697,0.07260475960924484,TtDKUZxzVxsq758G6AWPSYuZgVgbcl,,NULL +c,2,29,-3855,1354539333,4742062657200940467,81,53815,3398507249,562977550464243101,0.7124534,0.991517828651004,Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0,,NULL +b,4,-59,25286,1423957796,2646602445954944051,0,61069,3570297463,15100310750150419896,0.49619365,0.04893135681998029,fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG,,NULL +a,1,83,-14704,2143473091,-4387559599038777245,37,829,4015442341,4602675983996931623,0.89542526,0.9567595541247681,ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU,,NULL +a,3,-12,-9168,1489733240,-1569376002217735076,206,33821,3959216334,16060348691054629425,0.9488028,0.9293883502480845,oLZ21P2JEDooxV1pU31cIxQHEeeoLu,,NULL +c,4,3,-30508,659422734,-6455460736227846736,133,59663,2306130875,8622584762448622224,0.16999894,0.4273123318932347,EcCuckwsF3gV1Ecgmh5v4KM8g1ozif,,NULL +a,3,-72,-11122,-2141451704,-2578916903971263854,83,30296,1995343206,17452974532402389080,0.94209343,0.3231750610081745,e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG,,NULL +c,2,-107,-2904,-1011669561,782342092880993439,18,29527,1157161427,4403623840168496677,0.31988364,0.36936304600612724,QYlaIAnJA6r8rlAb6f59wcxvcPcWFf,,NULL +c,5,118,19208,-134213907,-2120241105523909127,86,57751,1229567292,16493024289408725403,0.5536642,0.9723580396501548,TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX,,NULL +c,3,97,29106,-903316089,2874859437662206732,207,42171,3473924576,8188072741116415408,0.32792538,0.2667177795079635,HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g,,NULL +b,3,-101,-13217,-346989627,5456800329302529236,26,54276,243203849,17929716297117857676,0.05422181,0.09465635123783445,MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ,,NULL +a,2,-43,13080,370975815,5881039805148485053,2,20120,2939920218,906367167997372130,0.42733806,0.16301110515739792,m6jD0LBIQWaMfenwRCTANI9eOdyyto,,NULL +a,5,-101,-12484,-842693467,-6140627905445351305,57,57885,2496054700,2243924747182709810,0.59520596,0.9491397432856566,QJYm7YRA3YetcBHI5wkMZeLXVmfuNy,,NULL +b,5,-44,15788,-629486480,5822642169425315613,13,11872,3457053821,2413406423648025909,0.44318348,0.32869374687050157,ALuRhobVWbnQTTWZdSOk0iVe8oYFhW,,NULL +d,4,5,-7688,702611616,6239356364381313700,4,39363,3126475872,35363005357834672,0.3766935,0.061029375346466685,H5j5ZHy1FGesOAHjkQEDYCucbpKWRu,,NULL +e,1,120,10837,-1331533190,6342019705133850847,245,3975,2830981072,16439861276703750332,0.6623719,0.9965400387585364,LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW,,NULL +e,3,-95,13611,2030965207,927403809957470678,119,59134,559847112,10966649192992996919,0.5301289,0.047343434291126085,gTpyQnEODMcpsPnJMZC66gh33i3m0b,,NULL +d,3,123,29533,240273900,1176001466590906949,117,30972,2592330556,12883447461717956514,0.39075065,0.38870280983958583,1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO,,NULL +b,4,47,20690,-1009656194,-2027442591571700798,200,7781,326151275,2881913079548128905,0.57360977,0.2145232647388039,52mKlRE3aHCBZtjECq6sY9OqVf8Dze,,NULL +e,4,30,-16110,61035129,-3356533792537910152,159,299,28774375,13526465947516666293,0.6999775,0.03968347085780355,cq4WSAIFwx3wwTUS5bp1wCe71R6U5I,,NULL \ No newline at end of file diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 41219c7116cef..60c2f23640d83 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -114,10 +114,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .await? .aggregate(vec![], vec![count(wildcard())])? .select(vec![count(wildcard())])? - .into_unoptimized_plan(), - // Usually, into_optimized_plan() should be used here, but due to - // https://github.com/apache/datafusion/issues/5771, - // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here. + .into_optimized_plan()?, ), ))? .select(vec![col("a"), col("b")])? @@ -1143,7 +1140,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { } #[tokio::test] -async fn unnest_fixed_list_nonull() -> Result<()> { +async fn unnest_fixed_list_non_null() -> Result<()> { let mut shape_id_builder = UInt32Builder::new(); let mut tags_builder = FixedSizeListBuilder::new(StringBuilder::new(), 2); @@ -1249,6 +1246,43 @@ async fn unnest_aggregate_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_no_empty_batches() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=10 { + for tag_id in 1..=10 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("tag_id")).alias("tag_id")], + )? + .collect() + .await?; + + // Assert that there are no empty batches in result + for rb in results { + assert!(rb.num_rows() > 0); + } + Ok(()) +} + #[tokio::test] async fn unnest_array_agg() -> Result<()> { let mut shape_id_builder = UInt32Builder::new(); @@ -1271,6 +1305,12 @@ async fn unnest_array_agg() -> Result<()> { let df = ctx.table("shapes").await?; let results = df.clone().collect().await?; + + // Assert that there are no empty batches in result + for rb in results.clone() { + assert!(rb.num_rows() > 0); + } + let expected = vec![ "+----------+--------+", "| shape_id | tag_id |", @@ -1434,9 +1474,7 @@ async fn unnest_analyze_metrics() -> Result<()> { .explain(false, true)? .collect() .await?; - let formatted = arrow::util::pretty::pretty_format_batches(&results) - .unwrap() - .to_string(); + let formatted = pretty_format_batches(&results).unwrap().to_string(); assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); @@ -1990,6 +2028,7 @@ async fn test_array_agg() -> Result<()> { async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { let ctx = SessionContext::new(); + // Creating LogicalPlans with placeholders should work. let df = ctx .read_empty() .unwrap() @@ -2011,17 +2050,16 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - // The placeholder is not replaced with a value, - // so the filter data type is not know, i.e. a = $0. - // Therefore, the optimization fails. - let optimized_plan = ctx.state().optimize(logical_plan); - assert!(optimized_plan.is_err()); - assert!(optimized_plan - .unwrap_err() - .to_string() - .contains("Placeholder type could not be resolved. Make sure that the placeholder is bound to a concrete type, e.g. by providing parameter values.")); - - // Prodiving a parameter value should resolve the error + // Executing LogicalPlans with placeholders that don't have bound values + // should fail. + let results = df.collect().await; + let err_msg = results.unwrap_err().strip_backtrace(); + assert_eq!( + err_msg, + "Execution error: Placeholder '$0' was not provided a value for execution." + ); + + // Providing a parameter value should resolve the error let df = ctx .read_empty() .unwrap() @@ -2045,12 +2083,152 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - let optimized_plan = ctx.state().optimize(logical_plan); - assert!(optimized_plan.is_ok()); + // N.B., the test is basically `SELECT 1 as a WHERE a = 3;` which returns no results. + #[rustfmt::skip] + let expected = [ + "++", + "++" + ]; + + assert_batches_eq!(expected, &df.collect().await.unwrap()); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_placeholder_column_parameter() -> Result<()> { + let ctx = SessionContext::new(); + + // Creating LogicalPlans with placeholders should work + let df = ctx.read_empty().unwrap().select_exprs(&["$1"]).unwrap(); + let logical_plan = df.logical_plan(); + let formatted = logical_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + #[rustfmt::skip] + let expected = vec![ + "Projection: $1 [$1:Null;N]", + " EmptyRelation []" + ]; + + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Executing LogicalPlans with placeholders that don't have bound values + // should fail. + let results = df.collect().await; + let err_msg = results.unwrap_err().strip_backtrace(); + assert_eq!( + err_msg, + "Execution error: Placeholder '$1' was not provided a value for execution." + ); + + // Providing a parameter value should resolve the error + let df = ctx + .read_empty() + .unwrap() + .select_exprs(&["$1"]) + .unwrap() + .with_param_values(vec![("1", ScalarValue::from(3i32))]) + .unwrap(); + + let logical_plan = df.logical_plan(); + let formatted = logical_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + let expected = vec![ + "Projection: Int32(3) AS $1 [$1:Null;N]", + " EmptyRelation []", + ]; + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| $1 |", + "+----+", + "| 3 |", + "+----+" + ]; + + assert_batches_eq!(expected, &df.collect().await.unwrap()); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_placeholder_like_expression() -> Result<()> { + let ctx = SessionContext::new(); + + // Creating LogicalPlans with placeholders should work + let df = ctx + .read_empty() + .unwrap() + .with_column("a", lit("foo")) + .unwrap() + .filter(col("a").like(placeholder("$1"))) + .unwrap(); + + let logical_plan = df.logical_plan(); + let formatted = logical_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + let expected = vec![ + "Filter: a LIKE $1 [a:Utf8]", + " Projection: Utf8(\"foo\") AS a [a:Utf8]", + " EmptyRelation []", + ]; + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Executing LogicalPlans with placeholders that don't have bound values + // should fail. + let results = df.collect().await; + let err_msg = results.unwrap_err().strip_backtrace(); + assert_eq!( + err_msg, + "Execution error: Placeholder '$1' was not provided a value for execution." + ); + + // Providing a parameter value should resolve the error + let df = ctx + .read_empty() + .unwrap() + .with_column("a", lit("foo")) + .unwrap() + .filter(col("a").like(placeholder("$1"))) + .unwrap() + .with_param_values(vec![("1", ScalarValue::from("f%"))]) + .unwrap(); - let actual = optimized_plan.unwrap().display_indent_schema().to_string(); - let expected = "EmptyRelation [a:Int32]"; - assert_eq!(expected, actual); + let logical_plan = df.logical_plan(); + let formatted = logical_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + let expected = vec![ + "Filter: a LIKE Utf8(\"f%\") [a:Utf8]", + " Projection: Utf8(\"foo\") AS a [a:Utf8]", + " EmptyRelation []", + ]; + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + #[rustfmt::skip] + let expected = [ + "+-----+", + "| a |", + "+-----+", + "| foo |", + "+-----+" + ]; + + assert_batches_eq!(expected, &df.collect().await.unwrap()); Ok(()) } @@ -2099,12 +2277,12 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Explicitly read the parquet file at c2=123 to verify the physical files are partitioned let partitioned_file = format!("{out_dir}/c2=123", out_dir = out_dir); - let filted_df = ctx + let filter_df = ctx .read_parquet(&partitioned_file, ParquetReadOptions::default()) .await?; // Check that the c2 column is gone and that c1 is abc. - let results = filted_df.collect().await?; + let results = filter_df.collect().await?; let expected = ["+-----+", "| c1 |", "+-----+", "| abc |", "+-----+"]; assert_batches_eq!(expected, &results); @@ -2468,3 +2646,135 @@ async fn boolean_dictionary_as_filter() { assert_batches_eq!(expected, &result_df.collect().await.unwrap()); } + +#[tokio::test] +async fn test_alias() -> Result<()> { + let df = create_test_table("test") + .await? + .select(vec![col("a"), col("test.b"), lit(1).alias("one")])? + .alias("table_alias")?; + // All ouput column qualifiers are changed to "table_alias" + df.schema().columns().iter().for_each(|c| { + assert_eq!(c.relation, Some("table_alias".into())); + }); + let expected = "SubqueryAlias: table_alias [a:Utf8, b:Int32, one:Int32]\ + \n Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32]\ + \n TableScan: test [a:Utf8, b:Int32]"; + let plan = df + .clone() + .into_unoptimized_plan() + .display_indent_schema() + .to_string(); + assert_eq!(plan, expected); + + // Select over the aliased DataFrame + let df = df.select(vec![ + col("table_alias.a"), + col("b") + col("table_alias.one"), + ])?; + let expected = [ + "+-----------+---------------------------------+", + "| a | table_alias.b + table_alias.one |", + "+-----------+---------------------------------+", + "| abcDEF | 2 |", + "| abc123 | 11 |", + "| CBAdef | 11 |", + "| 123AbcDef | 101 |", + "+-----------+---------------------------------+", + ]; + assert_batches_sorted_eq!(expected, &df.collect().await?); + Ok(()) +} + +// Use alias to perform a self-join +// Issue: https://github.com/apache/datafusion/issues/14112 +#[tokio::test] +async fn test_alias_self_join() -> Result<()> { + let left = create_test_table("t1").await?; + let right = left.clone().alias("t2")?; + let joined = left.join(right, JoinType::Full, &["a"], &["a"], None)?; + let expected = [ + "+-----------+-----+-----------+-----+", + "| a | b | a | b |", + "+-----------+-----+-----------+-----+", + "| abcDEF | 1 | abcDEF | 1 |", + "| abc123 | 10 | abc123 | 10 |", + "| CBAdef | 10 | CBAdef | 10 |", + "| 123AbcDef | 100 | 123AbcDef | 100 |", + "+-----------+-----+-----------+-----+", + ]; + assert_batches_sorted_eq!(expected, &joined.collect().await?); + Ok(()) +} + +#[tokio::test] +async fn test_alias_empty() -> Result<()> { + let df = create_test_table("test").await?.alias("")?; + let expected = "SubqueryAlias: [a:Utf8, b:Int32]\ + \n TableScan: test [a:Utf8, b:Int32]"; + let plan = df + .clone() + .into_unoptimized_plan() + .display_indent_schema() + .to_string(); + assert_eq!(plan, expected); + let expected = [ + "+-----------+-----+", + "| a | b |", + "+-----------+-----+", + "| abcDEF | 1 |", + "| abc123 | 10 |", + "| CBAdef | 10 |", + "| 123AbcDef | 100 |", + "+-----------+-----+", + ]; + assert_batches_sorted_eq!( + expected, + &df.select(vec![col("a"), col("b")])?.collect().await? + ); + Ok(()) +} + +#[tokio::test] +async fn test_alias_nested() -> Result<()> { + let df = create_test_table("test") + .await? + .select(vec![col("a"), col("test.b"), lit(1).alias("one")])? + .alias("alias1")? + .alias("alias2")?; + let expected = "SubqueryAlias: alias2 [a:Utf8, b:Int32, one:Int32]\ + \n SubqueryAlias: alias1 [a:Utf8, b:Int32, one:Int32]\ + \n Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32]\ + \n TableScan: test projection=[a, b] [a:Utf8, b:Int32]"; + let plan = df + .clone() + .into_optimized_plan()? + .display_indent_schema() + .to_string(); + assert_eq!(plan, expected); + + // Select over the aliased DataFrame + let select1 = df + .clone() + .select(vec![col("alias2.a"), col("b") + col("alias2.one")])?; + let expected = [ + "+-----------+-----------------------+", + "| a | alias2.b + alias2.one |", + "+-----------+-----------------------+", + "| 123AbcDef | 101 |", + "| CBAdef | 11 |", + "| abc123 | 11 |", + "| abcDEF | 2 |", + "+-----------+-----------------------+", + ]; + assert_batches_sorted_eq!(expected, &select1.collect().await?); + + // Only the outermost alias is visible + let select2 = df.select(vec![col("alias1.a")]); + assert_eq!( + select2.unwrap_err().strip_backtrace(), + "Schema error: No field named alias1.a. \ + Valid fields are alias2.a, alias2.b, alias2.one." + ); + Ok(()) +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs new file mode 100644 index 0000000000000..2d3e4217c8d95 --- /dev/null +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::Int64Array; +use arrow_schema::{DataType, Field}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::logical_plan::{LogicalPlan, Values}; +use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_functions_aggregate::count::Count; +use datafusion_physical_plan::collect; +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +///! Logical plans need to provide stable semantics, as downstream projects +///! create them and depend on them. Test executable semantics of logical plans. + +#[tokio::test] +async fn count_only_nulls() -> Result<()> { + // Input: VALUES (NULL), (NULL), (NULL) AS _(col) + let input_schema = Arc::new(DFSchema::from_unqualified_fields( + vec![Field::new("col", DataType::Null, true)].into(), + HashMap::new(), + )?); + let input = Arc::new(LogicalPlan::Values(Values { + schema: input_schema, + values: vec![ + vec![Expr::from(ScalarValue::Null)], + vec![Expr::from(ScalarValue::Null)], + vec![Expr::from(ScalarValue::Null)], + ], + })); + let input_col_ref = Expr::Column(Column { + relation: None, + name: "col".to_string(), + }); + + // Aggregation: count(col) AS count + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + input, + vec![], + vec![Expr::AggregateFunction(AggregateFunction { + func: Arc::new(AggregateUDF::new_from_impl(Count::new())), + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + })], + )?); + + // Execute and verify results + let session_state = SessionStateBuilder::new().build(); + let physical_plan = session_state.create_physical_plan(&aggregate).await?; + let result = + collect(physical_plan, Arc::new(TaskContext::from(&session_state))).await?; + + let result = only(result.as_slice()); + let result_schema = result.schema(); + let field = only(result_schema.fields().deref()); + let column = only(result.columns()); + + assert_eq!(field.data_type(), &DataType::Int64); // TODO should be UInt64 + assert_eq!(column.deref(), &Int64Array::from(vec![0])); + + Ok(()) +} + +fn only(elements: &[T]) -> &T +where + T: Debug, +{ + let [element] = elements else { + panic!("Expected exactly one element, got {:?}", elements); + }; + element +} diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs new file mode 100644 index 0000000000000..8169db1a4611e --- /dev/null +++ b/datafusion/core/tests/execution/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod logical_plan; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 81a33361008f0..b9f1632ea6957 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -28,7 +28,7 @@ use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_functions_nested::expr_ext::{IndexAccessor, SliceAccessor}; use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; mod parse_sql_expr; mod simplification; @@ -305,13 +305,11 @@ async fn test_aggregate_ext_null_treatment() { /// Evaluates the specified expr as an aggregate and compares the result to the /// expected result. async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { - let batch = test_batch(); - let ctx = SessionContext::new(); let group_expr = vec![]; let agg_expr = vec![expr]; let result = ctx - .read_batch(batch) + .read_batch(TEST_BATCH.clone()) .unwrap() .aggregate(group_expr, agg_expr) .unwrap() @@ -332,13 +330,13 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided /// `RecordBatch` and compares the result to the expected result. fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { - let batch = test_batch(); + let batch = &TEST_BATCH; let df_schema = DFSchema::try_from(batch.schema()).unwrap(); let physical_expr = SessionContext::new() .create_physical_expr(expr, &df_schema) .unwrap(); - let result = physical_expr.evaluate(&batch).unwrap(); + let result = physical_expr.evaluate(batch).unwrap(); let array = result.into_array(1).unwrap(); let result = pretty_format_columns("expr", &[array]).unwrap().to_string(); let actual_lines = result.lines().collect::>(); @@ -350,39 +348,33 @@ fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { ); } -static TEST_BATCH: OnceLock = OnceLock::new(); - -fn test_batch() -> RecordBatch { - TEST_BATCH - .get_or_init(|| { - let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"])); - let int_array: ArrayRef = - Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)])); - - // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" } - let struct_array: ArrayRef = Arc::from(StructArray::from(vec![( - Arc::new(Field::new("a", DataType::Utf8, false)), - Arc::new(StringArray::from(vec![ - "2021-02-01", - "2021-02-02", - "2021-02-03", - ])) as _, - )])); - - // ["one"] ["two", "three", "four"] ["five"] - let mut builder = ListBuilder::new(StringBuilder::new()); - builder.append_value([Some("one")]); - builder.append_value([Some("two"), Some("three"), Some("four")]); - builder.append_value([Some("five")]); - let list_array: ArrayRef = Arc::new(builder.finish()); - - RecordBatch::try_from_iter(vec![ - ("id", string_array), - ("i", int_array), - ("props", struct_array), - ("list", list_array), - ]) - .unwrap() - }) - .clone() -} +static TEST_BATCH: LazyLock = LazyLock::new(|| { + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"])); + let int_array: ArrayRef = + Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)])); + + // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" } + let struct_array: ArrayRef = Arc::from(StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![ + "2021-02-01", + "2021-02-02", + "2021-02-03", + ])) as _, + )])); + + // ["one"] ["two", "three", "four"] ["five"] + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value([Some("one")]); + builder.append_value([Some("two"), Some("three"), Some("four")]); + builder.append_value([Some("five")]); + let list_array: ArrayRef = Arc::new(builder.finish()); + + RecordBatch::try_from_iter(vec![ + ("id", string_array), + ("i", int_array), + ("props", struct_array), + ("list", list_array), + ]) + .unwrap() +}); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index a279f516b996c..e3d2b0a019923 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -29,10 +29,10 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr, table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, - LogicalPlanBuilder, ScalarUDF, Volatility, + table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarUDF, Volatility, }; -use datafusion_functions::{math, string}; +use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; @@ -333,8 +333,8 @@ fn simplify_scan_predicate() -> Result<()> { .build()?; // before simplify: t.g = power(t.f, 1.0) - // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; + // after simplify: t.g = t.f" + let expected = "TableScan: test, full_filters=[g = f]"; let actual = get_optimized_plan_formatted(plan, &Utc::now()); assert_eq!(expected, actual); Ok(()) @@ -368,13 +368,13 @@ fn test_const_evaluator() { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]); + let expr = concat(vec![lit("foo"), lit("bar")]); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]); - let expr = string::expr_fn::concat(vec![lit("foo"), concat1]); + let concat1 = concat(vec![lit("bar"), lit("baz")]); + let expr = concat(vec![lit("foo"), concat1]); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments @@ -407,7 +407,7 @@ fn test_const_evaluator_scalar_functions() { #[test] fn test_const_evaluator_now() { let ts_nanos = 1599566400000000000i64; - let time = chrono::Utc.timestamp_nanos(ts_nanos); + let time = Utc.timestamp_nanos(ts_nanos); let ts_string = "2020-09-08T12:05:00+00:00"; // now() --> ts test_evaluate_with_start_time(now(), lit_timestamp_nano(ts_nanos), &time); @@ -429,7 +429,7 @@ fn test_evaluator_udfs() { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + let expr = Expr::ScalarFunction(ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -438,21 +438,16 @@ fn test_evaluator_udfs() { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - args.clone(), - )); + let expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = - Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); - let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - folded_args, - )); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); test_evaluate(expr, expected_expr); } @@ -488,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef { Field::new("c2", DataType::Boolean, true), Field::new("c3", DataType::Int64, true), Field::new("c4", DataType::UInt32, true), + Field::new("c5", DataType::Utf8View, true), Field::new("c1_non_null", DataType::Utf8, false), Field::new("c2_non_null", DataType::Boolean, false), Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5_non_null", DataType::Utf8View, false), ]) .to_dfschema_ref() .unwrap() @@ -670,20 +667,32 @@ fn test_simplify_concat_ws_with_null() { } #[test] -fn test_simplify_concat() { +fn test_simplify_concat() -> Result<()> { + let schema = expr_test_schema(); let null = lit(ScalarValue::Utf8(None)); let expr = concat(vec![ null.clone(), - col("c0"), + col("c1"), lit("hello "), null.clone(), lit("rust"), - col("c1"), + lit(ScalarValue::Utf8View(Some("!".to_string()))), + col("c2"), lit(""), null, + col("c5"), ]); - let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); - test_simplify(expr, expected) + let expr_datatype = expr.get_type(schema.as_ref())?; + let expected = concat(vec![ + col("c1"), + lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))), + col("c2"), + col("c5"), + ]); + let expected_datatype = expected.get_type(schema.as_ref())?; + assert_eq!(expr_datatype, expected_datatype); + test_simplify(expr, expected); + Ok(()) } #[test] fn test_simplify_cycles() { diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 62e9be63983cb..09d0c8d5ca2e0 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -23,6 +23,10 @@ use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_array::types::Int64Type; +use arrow_schema::{ + IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, +}; use datafusion::common::Result; use datafusion::datasource::MemTable; use datafusion::physical_expr::aggregate::AggregateExprBuilder; @@ -39,11 +43,215 @@ use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; -use hashbrown::HashMap; +use crate::fuzz_cases::aggregation_fuzzer::{ + AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder, +}; +use datafusion_common::HashMap; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::{thread_rng, Rng, SeedableRng}; +use std::str; use tokio::task::JoinSet; +// ======================================================================== +// The new aggregation fuzz tests based on [`AggregationFuzzer`] +// ======================================================================== +// +// Notes on tests: +// +// Since the supported types differ for each aggregation function, the tests +// below are structured so they enumerate each different aggregate function. +// +// The test framework handles varying combinations of arguments (data types), +// sortedness, and grouping parameters +// +// TODO: Test floating point values (where output needs to be compared with some +// acceptable range due to floating point rounding) +// +// TODO: test other aggregate functions +// - AVG (unstable given the wide range of inputs) +#[tokio::test(flavor = "multi_thread")] +async fn test_min() { + let data_gen_config = baseline_config(); + + // Queries like SELECT min(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("min") + // min works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_max() { + let data_gen_config = baseline_config(); + + // Queries like SELECT max(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("max") + // max works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_sum() { + let data_gen_config = baseline_config(); + + // Queries like SELECT sum(a), sum(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("sum") + .with_distinct_aggregate_function("sum") + // sum only works on numeric columns + .with_aggregate_arguments(data_gen_config.numeric_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_count() { + let data_gen_config = baseline_config(); + + // Queries like SELECT count(a), count(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("count") + .with_distinct_aggregate_function("count") + // count work for all arguments + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +/// Return a standard set of columns for testing data generation +/// +/// Includes numeric and string types +/// +/// Does not include: +/// 1. Floating point numbers +/// 1. structured types +fn baseline_config() -> DatasetGeneratorConfig { + let mut rng = thread_rng(); + let columns = vec![ + ColumnDescr::new("i8", DataType::Int8), + ColumnDescr::new("i16", DataType::Int16), + ColumnDescr::new("i32", DataType::Int32), + ColumnDescr::new("i64", DataType::Int64), + ColumnDescr::new("u8", DataType::UInt8), + ColumnDescr::new("u16", DataType::UInt16), + ColumnDescr::new("u32", DataType::UInt32), + ColumnDescr::new("u64", DataType::UInt64), + ColumnDescr::new("date32", DataType::Date32), + ColumnDescr::new("date64", DataType::Date64), + ColumnDescr::new("time32_s", DataType::Time32(TimeUnit::Second)), + ColumnDescr::new("time32_ms", DataType::Time32(TimeUnit::Millisecond)), + ColumnDescr::new("time64_us", DataType::Time64(TimeUnit::Microsecond)), + ColumnDescr::new("time64_ns", DataType::Time64(TimeUnit::Nanosecond)), + // `None` is passed in here however when generating the array, it will generate + // random timezones. + ColumnDescr::new("timestamp_s", DataType::Timestamp(TimeUnit::Second, None)), + ColumnDescr::new( + "timestamp_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ColumnDescr::new( + "timestamp_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + ), + ColumnDescr::new( + "timestamp_ns", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ColumnDescr::new("float32", DataType::Float32), + ColumnDescr::new("float64", DataType::Float64), + ColumnDescr::new( + "interval_year_month", + DataType::Interval(IntervalUnit::YearMonth), + ), + ColumnDescr::new( + "interval_day_time", + DataType::Interval(IntervalUnit::DayTime), + ), + ColumnDescr::new( + "interval_month_day_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + ), + // begin decimal columns + ColumnDescr::new("decimal128", { + // Generate valid precision and scale for Decimal128 randomly. + let precision: u8 = rng.gen_range(1..=DECIMAL128_MAX_PRECISION); + // It's safe to cast `precision` to i8 type directly. + let scale: i8 = rng.gen_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE), + ); + DataType::Decimal128(precision, scale) + }), + ColumnDescr::new("decimal256", { + // Generate valid precision and scale for Decimal256 randomly. + let precision: u8 = rng.gen_range(1..=DECIMAL256_MAX_PRECISION); + // It's safe to cast `precision` to i8 type directly. + let scale: i8 = rng.gen_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE), + ); + DataType::Decimal256(precision, scale) + }), + // begin string columns + ColumnDescr::new("utf8", DataType::Utf8), + ColumnDescr::new("largeutf8", DataType::LargeUtf8), + ColumnDescr::new("utf8view", DataType::Utf8View), + // low cardinality columns + ColumnDescr::new("u8_low", DataType::UInt8).with_max_num_distinct(10), + ColumnDescr::new("utf8_low", DataType::Utf8).with_max_num_distinct(10), + ColumnDescr::new("bool", DataType::Boolean), + ColumnDescr::new("binary", DataType::Binary), + ColumnDescr::new("large_binary", DataType::LargeBinary), + ColumnDescr::new("binaryview", DataType::BinaryView), + ]; + + let min_num_rows = 512; + let max_num_rows = 1024; + + DatasetGeneratorConfig { + columns, + rows_num_range: (min_num_rows, max_num_rows), + sort_keys_set: vec![ + // low cardinality to try and get many repeated runs + vec![String::from("u8_low")], + vec![String::from("utf8_low"), String::from("u8_low")], + ], + } +} + +// ======================================================================== +// The old aggregation fuzz tests +// ======================================================================== + +/// Tracks if this stream is generating input or output /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -58,7 +266,7 @@ async fn streaming_aggregate_test() { vec!["d", "c", "a"], vec!["d", "c", "b", "a"], ]; - let n = 300; + let n = 10; let distincts = vec![10, 20]; for distinct in distincts { let mut join_set = JoinSet::new(); @@ -84,7 +292,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = vec![]; + let mut sort_keys = LexOrdering::default(); for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { expr: col(ordering_col, &schema).unwrap(), @@ -100,7 +308,8 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys]), + .try_with_sort_information(vec![sort_keys]) + .unwrap(), ); let aggregate_expr = @@ -109,6 +318,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .schema(Arc::clone(&schema)) .alias("sum1") .build() + .map(Arc::new) .unwrap(), ]; let expr = group_by_columns @@ -311,6 +521,7 @@ async fn group_by_string_test( let actual = extract_result_counts(results); assert_eq!(expected, actual); } + async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { struct Visitor { expected_sort: bool, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs new file mode 100644 index 0000000000000..2aeecd8ff2eae --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp, sync::Arc}; + +use datafusion::{ + datasource::MemTable, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::TableProvider; +use datafusion_common::ScalarValue; +use datafusion_common::{error::Result, utils::get_available_parallelism}; +use datafusion_expr::col; +use rand::{thread_rng, Rng}; + +use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; + +/// SessionContext generator +/// +/// During testing, `generate_baseline` will be called firstly to generate a standard [`SessionContext`], +/// and we will run `sql` on it to get the `expected result`. Then `generate` will be called some times to +/// generate some random [`SessionContext`]s, and we will run the same `sql` on them to get `actual results`. +/// Finally, we compare the `actual results` with `expected result`, the test only success while all they are +/// same with the expected. +/// +/// Following parameters of [`SessionContext`] used in query running will be generated randomly: +/// - `batch_size` +/// - `target_partitions` +/// - `skip_partial parameters` +/// - hint `sorted` or not +/// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed +/// to support this) +/// +pub struct SessionContextGenerator { + /// Current testing dataset + dataset: Arc, + + /// Table name of the test table + table_name: String, + + /// Used in generate the random `batch_size` + /// + /// The generated `batch_size` is between (0, total_rows_num] + max_batch_size: usize, + + /// Candidate `SkipPartialParams` which will be picked randomly + candidate_skip_partial_params: Vec, + + /// The upper bound of the randomly generated target partitions, + /// and the lower bound will be 1 + max_target_partitions: usize, +} + +impl SessionContextGenerator { + pub fn new(dataset_ref: Arc, table_name: &str) -> Self { + let candidate_skip_partial_params = vec![ + SkipPartialParams::ensure_trigger(), + SkipPartialParams::ensure_not_trigger(), + ]; + + let max_batch_size = cmp::max(1, dataset_ref.total_rows_num); + let max_target_partitions = get_available_parallelism(); + + Self { + dataset: dataset_ref, + table_name: table_name.to_string(), + max_batch_size, + candidate_skip_partial_params, + max_target_partitions, + } + } +} + +impl SessionContextGenerator { + /// Generate the `SessionContext` for the baseline run + pub fn generate_baseline(&self) -> Result { + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // The baseline context should try best to disable all optimizations, + // and pursuing the rightness. + let batch_size = self.max_batch_size; + let target_partitions = 1; + let skip_partial_params = SkipPartialParams::ensure_not_trigger(); + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + skip_partial_params, + sort_hint: false, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } + + /// Randomly generate session context + pub fn generate(&self) -> Result { + let mut rng = thread_rng(); + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // We will randomly generate following options: + // - `batch_size`, from range: [1, `total_rows_num`] + // - `target_partitions`, from range: [1, cpu_num] + // - `skip_partial`, trigger or not trigger currently for simplicity + // - `sorted`, if found a sorted dataset, will or will not push down this information + // - `spilling`(TODO) + let batch_size = rng.gen_range(1..=self.max_batch_size); + + let target_partitions = rng.gen_range(1..=self.max_target_partitions); + + let skip_partial_params_idx = + rng.gen_range(0..self.candidate_skip_partial_params.len()); + let skip_partial_params = + self.candidate_skip_partial_params[skip_partial_params_idx]; + + let (provider, sort_hint) = + if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + // Sort keys exist and random to push down + let sort_exprs = self + .dataset + .sort_keys + .iter() + .map(|key| col(key).sort(true, true)) + .collect::>(); + (provider.with_sort_order(vec![sort_exprs]), true) + } else { + (provider, false) + }; + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + sort_hint, + skip_partial_params, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } +} + +/// The generated [`SessionContext`] with its params +/// +/// Storing the generated `params` is necessary for +/// reporting the broken test case. +pub struct SessionContextWithParams { + pub ctx: SessionContext, + pub params: SessionContextParams, +} + +/// Collect the generated params, and build the [`SessionContext`] +struct GeneratedSessionContextBuilder { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, + table_name: String, + table_provider: Arc, +} + +impl GeneratedSessionContextBuilder { + fn build(self) -> Result { + // Build session context + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.batch_size", + &ScalarValue::UInt64(Some(self.batch_size as u64)), + ); + session_config = session_config.set( + "datafusion.execution.target_partitions", + &ScalarValue::UInt64(Some(self.target_partitions as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::UInt64(Some(self.skip_partial_params.rows_threshold as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(self.skip_partial_params.ratio_threshold)), + ); + + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table(self.table_name, self.table_provider)?; + + let params = SessionContextParams { + batch_size: self.batch_size, + target_partitions: self.target_partitions, + sort_hint: self.sort_hint, + skip_partial_params: self.skip_partial_params, + }; + + Ok(SessionContextWithParams { ctx, params }) + } +} + +/// The generated params for [`SessionContext`] +#[derive(Debug)] +#[allow(dead_code)] +pub struct SessionContextParams { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, +} + +/// Partial skipping parameters +#[derive(Debug, Clone, Copy)] +pub struct SkipPartialParams { + /// Related to `skip_partial_aggregation_probe_ratio_threshold` in `ExecutionOptions` + pub ratio_threshold: f64, + + /// Related to `skip_partial_aggregation_probe_rows_threshold` in `ExecutionOptions` + pub rows_threshold: usize, +} + +impl SkipPartialParams { + /// Generate `SkipPartialParams` ensuring to trigger partial skipping + pub fn ensure_trigger() -> Self { + Self { + ratio_threshold: 0.0, + rows_threshold: 0, + } + } + + /// Generate `SkipPartialParams` ensuring not to trigger partial skipping + pub fn ensure_not_trigger() -> Self { + Self { + ratio_threshold: 1.0, + rows_threshold: usize::MAX, + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::{RecordBatch, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[tokio::test] + async fn test_generated_context() { + // 1. Define a test dataset firstly + let a_col: StringArray = [ + Some("rust"), + Some("java"), + Some("cpp"), + Some("go"), + Some("go1"), + Some("python"), + Some("python1"), + Some("python2"), + ] + .into_iter() + .collect(); + // Sort by "b" + let b_col: UInt32Array = [ + Some(1), + Some(2), + Some(4), + Some(8), + Some(8), + Some(16), + Some(16), + Some(16), + ] + .into_iter() + .collect(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::UInt32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a_col), Arc::new(b_col)], + ) + .unwrap(); + + // One row a group to create batches + let mut batches = Vec::with_capacity(batch.num_rows()); + for start in 0..batch.num_rows() { + let sub_batch = batch.slice(start, 1); + batches.push(sub_batch); + } + + let dataset = Dataset::new(batches, vec!["b".to_string()]); + + // 2. Generate baseline context, and some randomly session contexts. + // Run the same query on them, and all randoms' results should equal to baseline's + let ctx_generator = SessionContextGenerator::new(Arc::new(dataset), "fuzz_table"); + + let query = "select b, count(a) from fuzz_table group by b"; + let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap(); + let mut random_wrapped_ctxs = Vec::with_capacity(8); + for _ in 0..8 { + let ctx = ctx_generator.generate().unwrap(); + random_wrapped_ctxs.push(ctx); + } + + let base_result = baseline_wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + + for wrapped_ctx in random_wrapped_ctxs { + let random_result = wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + check_equality_of_batches(&base_result, &random_result).unwrap(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs new file mode 100644 index 0000000000000..e4c0cb6fe77f7 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -0,0 +1,812 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{ + BinaryType, BinaryViewType, BooleanType, ByteArrayType, ByteViewType, Date32Type, + Date64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalYearMonthType, LargeBinaryType, LargeUtf8Type, StringViewType, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, +}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::sorts::sort::sort_batch; +use rand::{ + rngs::{StdRng, ThreadRng}, + thread_rng, Rng, SeedableRng, +}; +use test_utils::{ + array_gen::{ + BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, + PrimitiveArrayGenerator, StringArrayGenerator, + }, + stagger_batch, +}; + +/// Config for Dataset generator +/// +/// # Parameters +/// - `columns`, you just need to define `column name`s and `column data type`s +/// for the test datasets, and then they will be randomly generated from the generator +/// when you call `generate` function +/// +/// - `rows_num_range`, the number of rows in the datasets will be randomly generated +/// within this range +/// +/// - `sort_keys`, if `sort_keys` are defined, when you call the `generate` function, the generator +/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted +/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets +/// will be returned +/// +#[derive(Debug, Clone)] +pub struct DatasetGeneratorConfig { + /// Descriptions of columns in datasets, it's `required` + pub columns: Vec, + + /// Rows num range of the generated datasets, it's `required` + pub rows_num_range: (usize, usize), + + /// Additional optional sort keys + /// + /// The generated datasets always include a non-sorted copy. For each + /// element in `sort_keys_set`, an additional dataset is created that + /// is sorted by these values as well. + pub sort_keys_set: Vec>, +} + +impl DatasetGeneratorConfig { + /// Return a list of all column names + pub fn all_columns(&self) -> Vec<&str> { + self.columns.iter().map(|d| d.name.as_str()).collect() + } + + /// Return a list of column names that are "numeric" + pub fn numeric_columns(&self) -> Vec<&str> { + self.columns + .iter() + .filter_map(|d| { + if d.column_type.is_numeric() + && !matches!(d.column_type, DataType::Float32 | DataType::Float64) + { + Some(d.name.as_str()) + } else { + None + } + }) + .collect() + } +} + +/// Dataset generator +/// +/// It will generate one random [`Dataset`] when `generate` function is called. +/// +/// The generation logic in `generate`: +/// +/// - Randomly generate a base record from `batch_generator` firstly. +/// And `columns`, `rows_num_range` in `config`(detail can see `DataSetsGeneratorConfig`), +/// will be used in generation. +/// +/// - Sort the batch according to `sort_keys` in `config` to generate another +/// `len(sort_keys)` sorted batches. +/// +/// - Split each batch to multiple batches which each sub-batch in has the randomly `rows num`, +/// and this multiple batches will be used to create the `Dataset`. +/// +pub struct DatasetGenerator { + batch_generator: RecordBatchGenerator, + sort_keys_set: Vec>, +} + +impl DatasetGenerator { + pub fn new(config: DatasetGeneratorConfig) -> Self { + let batch_generator = RecordBatchGenerator::new( + config.rows_num_range.0, + config.rows_num_range.1, + config.columns, + ); + + Self { + batch_generator, + sort_keys_set: config.sort_keys_set, + } + } + + pub fn generate(&self) -> Result> { + let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); + + // Generate the base batch (unsorted) + let base_batch = self.batch_generator.generate()?; + let batches = stagger_batch(base_batch.clone()); + let dataset = Dataset::new(batches, Vec::new()); + datasets.push(dataset); + + // Generate the related sorted batches + let schema = base_batch.schema_ref(); + for sort_keys in self.sort_keys_set.clone() { + let sort_exprs = sort_keys + .iter() + .map(|key| { + let col_expr = col(key, schema)?; + Ok(PhysicalSortExpr::new_default(col_expr)) + }) + .collect::>()?; + let sorted_batch = sort_batch(&base_batch, sort_exprs.as_ref(), None)?; + + let batches = stagger_batch(sorted_batch); + let dataset = Dataset::new(batches, sort_keys); + datasets.push(dataset); + } + + Ok(datasets) + } +} + +/// Single test data set +#[derive(Debug)] +pub struct Dataset { + pub batches: Vec, + pub total_rows_num: usize, + pub sort_keys: Vec, +} + +impl Dataset { + pub fn new(batches: Vec, sort_keys: Vec) -> Self { + let total_rows_num = batches.iter().map(|batch| batch.num_rows()).sum::(); + + Self { + batches, + total_rows_num, + sort_keys, + } + } +} + +#[derive(Debug, Clone)] +pub struct ColumnDescr { + /// Column name + name: String, + + /// Data type of this column + column_type: DataType, + + /// The maximum number of distinct values in this column. + /// + /// See [`ColumnDescr::with_max_num_distinct`] for more information + max_num_distinct: Option, +} + +impl ColumnDescr { + #[inline] + pub fn new(name: &str, column_type: DataType) -> Self { + Self { + name: name.to_string(), + column_type, + max_num_distinct: None, + } + } + + /// set the maximum number of distinct values in this column + /// + /// If `None`, the number of distinct values is randomly selected between 1 + /// and the number of rows. + pub fn with_max_num_distinct(mut self, num_distinct: usize) -> Self { + self.max_num_distinct = Some(num_distinct); + self + } +} + +/// Record batch generator +struct RecordBatchGenerator { + min_rows_nun: usize, + + max_rows_num: usize, + + columns: Vec, + + candidate_null_pcts: Vec, +} + +macro_rules! generate_string_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let max_len = $BATCH_GEN_RNG.gen_range(1..50); + + let mut generator = StringArrayGenerator { + max_len, + num_strings: $NUM_ROWS, + num_distinct_strings: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + match $ARROW_TYPE::DATA_TYPE { + DataType::Utf8 => generator.gen_data::(), + DataType::LargeUtf8 => generator.gen_data::(), + DataType::Utf8View => generator.gen_string_view(), + _ => unreachable!(), + } + }}; +} + +macro_rules! generate_decimal_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let mut generator = DecimalArrayGenerator { + precision: $PRECISION, + scale: $SCALE, + num_decimals: $NUM_ROWS, + num_distinct_decimals: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + +// Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) +macro_rules! generate_boolean_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ + // Select a null percentage from the candidate percentages + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; + + let mut generator = BooleanArrayGenerator { + num_booleans: $NUM_ROWS, + num_distinct_booleans, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + +macro_rules! generate_primitive_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + +macro_rules! generate_binary_array { + ( + $SELF:ident, + $NUM_ROWS:ident, + $MAX_NUM_DISTINCT:expr, + $BATCH_GEN_RNG:ident, + $ARRAY_GEN_RNG:ident, + $ARROW_TYPE:ident + ) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let max_len = $BATCH_GEN_RNG.gen_range(1..100); + + let mut generator = BinaryArrayGenerator { + max_len, + num_binaries: $NUM_ROWS, + num_distinct_binaries: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + match $ARROW_TYPE::DATA_TYPE { + DataType::Binary => generator.gen_data::(), + DataType::LargeBinary => generator.gen_data::(), + DataType::BinaryView => generator.gen_binary_view(), + _ => unreachable!(), + } + }}; +} + +impl RecordBatchGenerator { + fn new(min_rows_nun: usize, max_rows_num: usize, columns: Vec) -> Self { + let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; + + Self { + min_rows_nun, + max_rows_num, + columns, + candidate_null_pcts, + } + } + + fn generate(&self) -> Result { + let mut rng = thread_rng(); + let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(rng.gen()); + + // Build arrays + let mut arrays = Vec::with_capacity(self.columns.len()); + for col in self.columns.iter() { + let array = self.generate_array_of_type( + col, + num_rows, + &mut rng, + array_gen_rng.clone(), + ); + arrays.push(array); + } + + // Build schema + let fields = self + .columns + .iter() + .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) + } + + fn generate_array_of_type( + &self, + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut ThreadRng, + array_gen_rng: StdRng, + ) -> ArrayRef { + let num_distinct = if num_rows > 1 { + batch_gen_rng.gen_range(1..num_rows) + } else { + num_rows + }; + // cap to at most the num_distinct values + let max_num_distinct = col + .max_num_distinct + .map(|max| num_distinct.min(max)) + .unwrap_or(num_distinct); + + match col.column_type { + DataType::Int8 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int8Type + ) + } + DataType::Int16 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int16Type + ) + } + DataType::Int32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int32Type + ) + } + DataType::Int64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int64Type + ) + } + DataType::UInt8 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt8Type + ) + } + DataType::UInt16 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt16Type + ) + } + DataType::UInt32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt32Type + ) + } + DataType::UInt64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt64Type + ) + } + DataType::Float32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Float32Type + ) + } + DataType::Float64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Float64Type + ) + } + DataType::Date32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Date32Type + ) + } + DataType::Date64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Date64Type + ) + } + DataType::Time32(TimeUnit::Second) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time32SecondType + ) + } + DataType::Time32(TimeUnit::Millisecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time32MillisecondType + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time64MicrosecondType + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time64NanosecondType + ) + } + DataType::Interval(IntervalUnit::YearMonth) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + IntervalYearMonthType + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + IntervalDayTimeType + ) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + IntervalMonthDayNanoType + ) + } + DataType::Timestamp(TimeUnit::Second, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + TimestampSecondType + ) + } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + TimestampMillisecondType + ) + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + TimestampMicrosecondType + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + TimestampNanosecondType + ) + } + DataType::Binary => { + generate_binary_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + BinaryType + ) + } + DataType::LargeBinary => { + generate_binary_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + LargeBinaryType + ) + } + DataType::BinaryView => { + generate_binary_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + BinaryViewType + ) + } + DataType::Decimal128(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal128Type + ) + } + DataType::Decimal256(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal256Type + ) + } + DataType::Utf8 => { + generate_string_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Utf8Type + ) + } + DataType::LargeUtf8 => { + generate_string_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + LargeUtf8Type + ) + } + DataType::Utf8View => { + generate_string_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + StringViewType + ) + } + DataType::Boolean => { + generate_boolean_array! { + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + BooleanType + } + } + _ => { + panic!("Unsupported data generator type: {}", col.column_type) + } + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::UInt32Array; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[test] + fn test_generated_datasets() { + // The test datasets generation config + // We expect that after calling `generate` + // - Generates two datasets + // - They have two columns, "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + // - One of them is unsorted, another is sorted by column "b" + // - Their rows num should be same and between [16, 32] + let config = DatasetGeneratorConfig { + columns: vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::UInt32), + ], + rows_num_range: (16, 32), + sort_keys_set: vec![vec!["b".to_string()]], + }; + + let gen = DatasetGenerator::new(config); + let datasets = gen.generate().unwrap(); + + // Should Generate 2 datasets + assert_eq!(datasets.len(), 2); + + // Should have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + let check_fields = |batch: &RecordBatch| { + assert_eq!(batch.num_columns(), 2); + let fields = batch.schema().fields().clone(); + assert_eq!(fields[0].name(), "a"); + assert_eq!(*fields[0].data_type(), DataType::Utf8); + assert_eq!(fields[1].name(), "b"); + assert_eq!(*fields[1].data_type(), DataType::UInt32); + }; + + let batch = &datasets[0].batches[0]; + check_fields(batch); + let batch = &datasets[1].batches[0]; + check_fields(batch); + + // One of the batches should be sorted by "b" + let sorted_batches = &datasets[1].batches; + let b_vals = sorted_batches.iter().flat_map(|batch| { + let uint_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + uint_array.iter() + }); + let mut prev_b_val = u32::MIN; + for b_val in b_vals { + let b_val = b_val.unwrap_or(u32::MIN); + assert!(b_val >= prev_b_val); + prev_b_val = b_val; + } + + // Two batches should be the same after sorting + check_equality_of_batches(&datasets[0].batches, &datasets[1].batches).unwrap(); + + // The number of rows should be between [16, 32] + let rows_num0 = datasets[0] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + let rows_num1 = datasets[1] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + assert_eq!(rows_num0, rows_num1); + assert!(rows_num0 >= 16); + assert!(rows_num0 <= 32); + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs new file mode 100644 index 0000000000000..d021e73f35b20 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -0,0 +1,527 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion_common::{DataFusionError, Result}; +use rand::{thread_rng, Rng}; +use tokio::task::JoinSet; + +use crate::fuzz_cases::aggregation_fuzzer::{ + check_equality_of_batches, + context_generator::{SessionContextGenerator, SessionContextWithParams}, + data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig}, + run_sql, +}; + +/// Rounds to call `generate` of [`SessionContextGenerator`] +/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`] +/// will generated for each dataset for testing. +const CTX_GEN_ROUNDS: usize = 16; + +/// Aggregation fuzzer's builder +pub struct AggregationFuzzerBuilder { + /// See `candidate_sqls` in [`AggregationFuzzer`], no default, and required to set + candidate_sqls: Vec>, + + /// See `table_name` in [`AggregationFuzzer`], no default, and required to set + table_name: Option>, + + /// Used to generate `dataset_generator` in [`AggregationFuzzer`], + /// no default, and required to set + data_gen_config: Option, + + /// See `data_gen_rounds` in [`AggregationFuzzer`], default 16 + data_gen_rounds: usize, +} + +impl AggregationFuzzerBuilder { + fn new() -> Self { + Self { + candidate_sqls: Vec::new(), + table_name: None, + data_gen_config: None, + data_gen_rounds: 16, + } + } + + /// Adds random SQL queries to the fuzzer along with the table name + /// + /// Adds + /// - 3 random queries + /// - 3 random queries for each group by selected from the sort keys + /// - 1 random query with no grouping + pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { + const NUM_QUERIES: usize = 3; + for _ in 0..NUM_QUERIES { + let sql = query_builder.generate_query(); + self.candidate_sqls.push(Arc::from(sql)); + } + // also add several queries limited to grouping on the group by columns only, if any + // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` + if let Some(data_gen_config) = &self.data_gen_config { + for sort_keys in &data_gen_config.sort_keys_set { + let group_by_columns = sort_keys.iter().map(|s| s.as_str()); + query_builder = query_builder.set_group_by_columns(group_by_columns); + for _ in 0..NUM_QUERIES { + let sql = query_builder.generate_query(); + self.candidate_sqls.push(Arc::from(sql)); + } + } + } + // also add a query with no grouping + query_builder = query_builder.set_group_by_columns(vec![]); + let sql = query_builder.generate_query(); + self.candidate_sqls.push(Arc::from(sql)); + + self.table_name(query_builder.table_name()) + } + + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = Some(Arc::from(table_name)); + self + } + + pub fn data_gen_config(mut self, data_gen_config: DatasetGeneratorConfig) -> Self { + self.data_gen_config = Some(data_gen_config); + self + } + + pub fn build(self) -> AggregationFuzzer { + assert!(!self.candidate_sqls.is_empty()); + let candidate_sqls = self.candidate_sqls; + let table_name = self.table_name.expect("table_name is required"); + let data_gen_config = self.data_gen_config.expect("data_gen_config is required"); + let data_gen_rounds = self.data_gen_rounds; + + let dataset_generator = DatasetGenerator::new(data_gen_config); + + AggregationFuzzer { + candidate_sqls, + table_name, + dataset_generator, + data_gen_rounds, + } + } +} + +impl Default for AggregationFuzzerBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for AggregationFuzzerBuilder { + fn from(value: DatasetGeneratorConfig) -> Self { + Self::default().data_gen_config(value) + } +} + +/// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`], +/// and running them to check the correctness of the optimizations +/// (e.g. sorted, partial skipping, spilling...) +pub struct AggregationFuzzer { + /// Candidate test queries represented by sqls + candidate_sqls: Vec>, + + /// The queried table name + table_name: Arc, + + /// Dataset generator used to randomly generate datasets + dataset_generator: DatasetGenerator, + + /// Rounds to call `generate` of [`DatasetGenerator`], + /// len(sort_keys_set) + 1` datasets will be generated for testing. + /// + /// It is suggested to set value 2x or more bigger than num of + /// `candidate_sqls` for better test coverage. + data_gen_rounds: usize, +} + +/// Query group including the tested dataset and its sql query +struct QueryGroup { + dataset: Dataset, + sql: Arc, +} + +impl AggregationFuzzer { + /// Run the fuzzer, printing an error and panicking if any of the tasks fail + pub async fn run(&self) { + let res = self.run_inner().await; + + if let Err(e) = res { + // Print the error via `Display` so that it displays nicely (the default `unwrap()` + // prints using `Debug` which escapes newlines, and makes multi-line messages + // hard to read + println!("{e}"); + panic!("Error!"); + } + } + + async fn run_inner(&self) -> Result<()> { + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + + // Loop to generate datasets and its query + for _ in 0..self.data_gen_rounds { + // Generate datasets first + let datasets = self + .dataset_generator + .generate() + .expect("should success to generate dataset"); + + // Then for each of them, we random select a test sql for it + let query_groups = datasets + .into_iter() + .map(|dataset| { + let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql = self.candidate_sqls[sql_idx].clone(); + + QueryGroup { dataset, sql } + }) + .collect::>(); + + for q in &query_groups { + println!(" Testing with query {}", q.sql); + } + + let tasks = self.generate_fuzz_tasks(query_groups).await; + for task in tasks { + join_set.spawn(async move { task.run().await }); + } + } + + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.map_err(|e| { + DataFusionError::Internal(format!( + "AggregationFuzzer task error: {:?}", + e + )) + })??; + } + Ok(()) + } + + async fn generate_fuzz_tasks( + &self, + query_groups: Vec, + ) -> Vec { + let mut tasks = Vec::with_capacity(query_groups.len() * CTX_GEN_ROUNDS); + for QueryGroup { dataset, sql } in query_groups { + let dataset_ref = Arc::new(dataset); + let ctx_generator = + SessionContextGenerator::new(dataset_ref.clone(), &self.table_name); + + // Generate the baseline context, and get the baseline result firstly + let baseline_ctx_with_params = ctx_generator + .generate_baseline() + .expect("should success to generate baseline session context"); + let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) + .await + .expect("should success to run baseline sql"); + let baseline_result = Arc::new(baseline_result); + // Generate test tasks + for _ in 0..CTX_GEN_ROUNDS { + let ctx_with_params = ctx_generator + .generate() + .expect("should success to generate session context"); + let task = AggregationFuzzTestTask { + dataset_ref: dataset_ref.clone(), + expected_result: baseline_result.clone(), + sql: sql.clone(), + ctx_with_params, + }; + + tasks.push(task); + } + } + tasks + } +} + +/// One test task generated by [`AggregationFuzzer`] +/// +/// It includes: +/// - `expected_result`, the expected result generated by baseline [`SessionContext`] +/// (disable all possible optimizations for ensuring correctness). +/// +/// - `ctx`, a randomly generated [`SessionContext`], `sql` will be run +/// on it after, and check if the result is equal to expected. +/// +/// - `sql`, the selected test sql +/// +/// - `dataset_ref`, the input dataset, store it for error reported when found +/// the inconsistency between the one for `ctx` and `expected results`. +/// +struct AggregationFuzzTestTask { + /// Generated session context in current test case + ctx_with_params: SessionContextWithParams, + + /// Expected result in current test case + /// It is generate from `query` + `baseline session context` + expected_result: Arc>, + + /// The test query + /// Use sql to represent it currently. + sql: Arc, + + /// The test dataset for error reporting + dataset_ref: Arc, +} + +impl AggregationFuzzTestTask { + async fn run(&self) -> Result<()> { + let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx) + .await + .map_err(|e| e.context(self.context_error_report()))?; + self.check_result(&task_result, &self.expected_result) + } + + fn check_result( + &self, + task_result: &[RecordBatch], + expected_result: &[RecordBatch], + ) -> Result<()> { + check_equality_of_batches(task_result, expected_result).map_err(|e| { + // If we found inconsistent result, we print the test details for reproducing at first + let message = format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + e.row_idx, + e.lhs_row, + e.rhs_row, + format_batches_with_limit(task_result), + format_batches_with_limit(expected_result), + format_batches_with_limit(&self.dataset_ref.batches), + ); + DataFusionError::Internal(message) + }) + } + + /// Returns a formatted error message + fn context_error_report(&self) -> String { + format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + pretty_format_batches(&self.dataset_ref.batches).unwrap(), + ) + } +} + +/// Pretty prints the `RecordBatch`es, limited to the first 100 rows +fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display { + const MAX_ROWS: usize = 100; + let mut row_count = 0; + let to_print = batches + .iter() + .filter_map(|b| { + if row_count >= MAX_ROWS { + None + } else if row_count + b.num_rows() > MAX_ROWS { + // output last rows before limit + let slice_len = MAX_ROWS - row_count; + let b = b.slice(0, slice_len); + row_count += slice_len; + Some(b) + } else { + row_count += b.num_rows(); + Some(b.clone()) + } + }) + .collect::>(); + + pretty_format_batches(&to_print).unwrap() +} + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default, Clone)] +pub struct QueryBuilder { + /// The name of the table to query + table_name: String, + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + /// Columns to be used in group by + group_by_columns: Vec, + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} +impl QueryBuilder { + pub fn new() -> Self { + Default::default() + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Set the columns to be used in the group bys clauses + pub fn set_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + self.group_by_columns = group_by.into_iter().map(String::from).collect(); + self + } + + /// Add one or more columns to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + pub fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + let mut query = String::from("SELECT "); + query.push_str(&self.random_aggregate_functions().join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = thread_rng(); + let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.gen_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{}", alias_gen); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + let function = format!("{function_name}({distinct}{argument}) as {alias}"); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = thread_rng(); + let idx = rng.gen_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to 3 group by columns to ensure coverage for large groups. With + /// larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = thread_rng(); + const MAX_GROUPS: usize = 3; + let max_groups = self.group_by_columns.len().max(MAX_GROUPS); + let num_group_by = rng.gen_range(1..max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by + && already_used.len() != self.group_by_columns.len() + { + let idx = rng.gen_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs new file mode 100644 index 0000000000000..d93a5b7b9360b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::error::Result; + +mod context_generator; +mod data_generator; +mod fuzzer; + +pub use data_generator::{ColumnDescr, DatasetGeneratorConfig}; +pub use fuzzer::*; + +#[derive(Debug)] +pub(crate) struct InconsistentResult { + pub row_idx: usize, + pub lhs_row: String, + pub rhs_row: String, +} + +pub(crate) fn check_equality_of_batches( + lhs: &[RecordBatch], + rhs: &[RecordBatch], +) -> std::result::Result<(), InconsistentResult> { + let lhs_formatted_batches = pretty_format_batches(lhs).unwrap().to_string(); + let mut lhs_formatted_batches_sorted: Vec<&str> = + lhs_formatted_batches.trim().lines().collect(); + lhs_formatted_batches_sorted.sort_unstable(); + let rhs_formatted_batches = pretty_format_batches(rhs).unwrap().to_string(); + let mut rhs_formatted_batches_sorted: Vec<&str> = + rhs_formatted_batches.trim().lines().collect(); + rhs_formatted_batches_sorted.sort_unstable(); + + for (row_idx, (lhs_row, rhs_row)) in lhs_formatted_batches_sorted + .iter() + .zip(&rhs_formatted_batches_sorted) + .enumerate() + { + if lhs_row != rhs_row { + return Err(InconsistentResult { + row_idx, + lhs_row: lhs_row.to_string(), + rhs_row: rhs_row.to_string(), + }); + } + } + + Ok(()) +} + +pub(crate) async fn run_sql(sql: &str, ctx: &SessionContext) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/mod.rs b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs new file mode 100644 index 0000000000000..2f8a38200bf12 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `EquivalenceProperties` fuzz testing + +mod ordering; +mod projection; +mod properties; +mod utils; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs new file mode 100644 index 0000000000000..ecf267185baef --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -0,0 +1,395 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + generate_table_for_eq_properties, generate_table_for_orderings, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties {}", + requirement, expected, eq_properties + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(requirement.as_ref()), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}", + requirement, expected, eq_properties, + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(requirement.as_ref()), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(expr), + options, + }) + .collect::(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(required.as_ref()), + expected, + "{err_msg}" + ); + } + + Ok(()) +} + +// This test checks given a table is ordered with `[a ASC, b ASC, c ASC, d ASC]` and `[a ASC, c ASC, b ASC, d ASC]` +// whether the table is also ordered with `[a ASC, b ASC, d ASC]` and `[a ASC, c ASC, d ASC]` +// Since these orderings cannot be deduced, these orderings shouldn't be satisfied by the table generated. +// For background see discussion: https://github.com/apache/datafusion/issues/12700#issuecomment-2411134296 +#[test] +fn test_ordering_satisfy_on_data() -> Result<()> { + let schema = create_test_schema_2()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let orderings = vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ], + // [a ASC, c ASC, b ASC, d ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + + let batch = generate_table_for_orderings(orderings, schema, 1000, 10)?; + + // [a ASC, c ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC] can be deduced + let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(is_table_same_after_sort(ordering, batch.clone())?); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs new file mode 100644 index 0000000000000..f71df50fce2f1 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + apply_projection, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties {}, proj_exprs: {:?}", + ordering, eq_properties, proj_exprs, + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) +} + +#[test] +fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| Arc::clone(target)) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}, projected_eq: {}, projection_mapping: {:?}", + requirement, expected, eq_properties, projected_eq, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(requirement.as_ref()), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs new file mode 100644 index 0000000000000..fc21c620a711b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + TestScalarUDF, +}; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: Arc::clone(&exprs[idx]), + options: sort_expr.options, + }) + .collect::(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties: {}", + ordering, eq_properties + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs new file mode 100644 index 0000000000000..5bf42ea6889f4 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -0,0 +1,631 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::physical_plan::expressions::col; +use datafusion::physical_plan::expressions::Column; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; + +use itertools::izip; +use rand::prelude::*; + +pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, +) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) +} + +// Generate a schema which consists of 6 columns (a, b, c, d, e, f) +pub fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) +} + +/// Construct a schema with random ordering +/// among column a, b, c, d +/// where +/// Column [a=f] (e.g they are aliases). +/// Column e is constant. +pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f)?; + // Column e has constant value. + eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) +} + +// Apply projection to the input_data, return projected equivalence properties and record batch +pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, +) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(Arc::clone(&output_schema)) + } else { + RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? + }; + + let projected_eq = input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) +} + +#[test] +fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Existing equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) +} + +/// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. +/// +/// The function works by adding a unique column of ascending integers to the original table. This column ensures +/// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can +/// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce +/// deterministic sorting results. +/// +/// If the table remains the same after sorting with the added unique column, it indicates that the table was +/// already sorted according to `required_ordering` to begin with. +pub fn is_table_same_after_sort( + mut required_ordering: LexOrdering, + batch: RecordBatch, +) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(Arc::clone(&unique_col)); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) +} + +// If we already generated a random result for one of the +// expressions in the equivalence classes. For other expressions in the same +// equivalence class use same result. This util gets already calculated result, when available. +fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, +) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(Arc::clone(res)); + } + } + None +} + +// Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) +pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) +} + +/// Construct a schema with following properties +/// Schema satisfies following orderings: +/// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] +/// and +/// Column [a=c] (e.g they are aliases). +pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + eq_properties.add_equal_conditions(col_a, col_c)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) +} + +// Generate a table that satisfies the given equivalence properties; i.e. +// equivalences, ordering equivalences, and constants. +pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in eq_properties.constants() { + let col = constant.expr().as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class().iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group().iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(Arc::clone(&representative_array)); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) +} + +// Generate a table that satisfies the given orderings; +pub fn generate_table_for_orderings( + mut orderings: Vec, + schema: SchemaRef, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + assert!(!orderings.is_empty()); + // Sort the inner vectors by their lengths (longest first) + orderings.sort_by_key(|v| std::cmp::Reverse(v.len())); + + let arrays = schema + .fields + .iter() + .map(|field| { + ( + field.name(), + generate_random_f64_array(n_elem, n_distinct, &mut rng), + ) + }) + .collect::>(); + let batch = RecordBatch::try_from_iter(arrays)?; + + // Sort batch according to first ordering expression + let sort_columns = get_sort_columns(&batch, orderings[0].as_ref())?; + let sort_indices = lexsort_to_indices(&sort_columns, None)?; + let mut batch = take_record_batch(&batch, &sort_indices)?; + + // prune out rows that is invalid according to remaining orderings. + for ordering in orderings.iter().skip(1) { + let sort_columns = get_sort_columns(&batch, ordering)?; + + // Collect sort options and values into separate vectors. + let (sort_options, sort_col_values): (Vec<_>, Vec<_>) = sort_columns + .into_iter() + .map(|sort_col| (sort_col.options.unwrap(), sort_col.values)) + .unzip(); + + let mut cur_idx = 0; + let mut keep_indices = vec![cur_idx as u32]; + for next_idx in 1..batch.num_rows() { + let cur_row = get_row_at_idx(&sort_col_values, cur_idx)?; + let next_row = get_row_at_idx(&sort_col_values, next_idx)?; + + if compare_rows(&cur_row, &next_row, &sort_options)? != Ordering::Greater { + // next row satisfies ordering relation given, compared to the current row. + keep_indices.push(next_idx as u32); + cur_idx = next_idx; + } + } + // Only keep valid rows, that satisfies given ordering relation. + batch = take_record_batch(&batch, &UInt32Array::from_iter_values(keep_indices))?; + } + + Ok(batch) +} + +// Convert each tuple to PhysicalSortExpr +pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], +) -> LexOrdering { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(*expr), + options: *options, + }) + .collect() +} + +// Convert each inner tuple to PhysicalSortExpr +pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], +) -> Vec { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() +} + +// Utility function to generate random f64 array +fn generate_random_f64_array( + n_elems: usize, + n_distinct: usize, + rng: &mut StdRng, +) -> ArrayRef { + let values: Vec = (0..n_elems) + .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) +} + +// Helper function to get sort columns from a batch +fn get_sort_columns( + batch: &RecordBatch, + ordering: &LexOrdering, +) -> Result> { + ordering + .iter() + .map(|expr| expr.evaluate_to_sort_column(batch)) + .collect::>>() +} + +#[derive(Debug, Clone)] +pub struct TestScalarUDF { + pub(crate) signature: Signature, +} + +impl TestScalarUDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "test-scalar-udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f64::floor)) + .collect::() + }), + DataType::Float32 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f32::floor)) + .collect::() + }), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 2535b94eaec49..4a5598e8da8f2 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::memory::MemoryExec; +use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj}; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -96,7 +97,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -108,7 +109,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -120,13 +121,11 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -134,7 +133,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -146,13 +145,11 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -160,7 +157,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -172,13 +169,11 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -186,7 +181,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[NljHj, HjSmj], false) .await } @@ -198,7 +193,7 @@ async fn test_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -210,33 +205,79 @@ async fn test_semi_join_1k_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -async fn test_anti_join_1k() { +async fn test_left_anti_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, None, ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } #[tokio::test] -// flaky for HjSmj case, giving 1 rows difference sometimes -// https://github.com/apache/datafusion/issues/11555 -async fn test_anti_join_1k_filtered() { +async fn test_left_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_anti_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_anti_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_mark_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_mark_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) .await } @@ -445,7 +486,7 @@ impl JoinFuzzTestCase { let filter = JoinFilter::new(expression, column_indices, intermediate_schema); Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type) + NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type, None) .unwrap(), ) } @@ -494,8 +535,8 @@ impl JoinFuzzTestCase { nlj_formatted_sorted.sort_unstable(); if debug - && ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows) - || (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows)) + && ((join_tests.contains(&NljHj) && nlj_rows != hj_rows) + || (join_tests.contains(&HjSmj) && smj_rows != hj_rows)) { let fuzz_debug = "fuzz_test_debug"; std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); @@ -515,14 +556,11 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) - && join_tests.contains(&JoinTestType::NljHj) - && nlj_rows != hj_rows - { + if join_tests.contains(&NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== NestedLoopJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); Self::save_partitioned_batches_as_parquet( &nlj_collected, @@ -536,7 +574,7 @@ impl JoinFuzzTestCase { ); } - if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows { + if join_tests.contains(&HjSmj) && smj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== SortMergeJoinExec =================="); @@ -555,7 +593,7 @@ impl JoinFuzzTestCase { } } - if join_tests.contains(&JoinTestType::NljHj) { + if join_tests.contains(&NljHj) { let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); @@ -576,7 +614,7 @@ impl JoinFuzzTestCase { } } - if join_tests.contains(&JoinTestType::HjSmj) { + if join_tests.contains(&HjSmj) { let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 95d97709f3195..a82849f4ea929 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -281,7 +281,7 @@ fn i64string_batch<'a>( .unwrap() } -/// Run the TopK test, sorting the input batches with the specified ftch +/// Run the TopK test, sorting the input batches with the specified fetch /// (limit) and compares the results to the expected values. async fn run_limit_test(fetch: usize, data: &SortedData) { let input = data.batches(); @@ -341,7 +341,7 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - rand::thread_rng() + thread_rng() .sample_iter(rand::distributions::Alphanumeric) .take(len) .map(char::from) diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 4eb1070e6c857..4e895920dd3da 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -31,6 +31,7 @@ use datafusion::physical_plan::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; #[tokio::test] @@ -107,13 +108,13 @@ async fn run_merge_test(input: Vec>) { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let exec = MemoryExec::try_new(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 69241571b4af0..d5511e2970f4d 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,6 +21,11 @@ mod join_fuzz; mod merge_fuzz; mod sort_fuzz; +mod aggregation_fuzzer; +mod equivalence; + +mod pruning; + mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs new file mode 100644 index 0000000000000..8ce980ee080b8 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, OnceLock}; + +use arrow_array::{Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use bytes::{BufMut, Bytes, BytesMut}; +use datafusion::{ + datasource::{ + listing::PartitionedFile, + physical_plan::{parquet::ParquetExecBuilder, FileScanConfig}, + }, + prelude::*, +}; +use datafusion_common::DFSchema; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::{collect, filter::FilterExec, ExecutionPlan}; +use itertools::Itertools; +use object_store::{memory::InMemory, path::Path, ObjectStore, PutPayload}; +use parquet::{ + arrow::ArrowWriter, + file::properties::{EnabledStatistics, WriterProperties}, +}; +use rand::seq::SliceRandom; +use tokio::sync::Mutex; +use url::Url; + +#[tokio::test] +async fn test_utf8_eq() { + Utf8Test::new(|value| col("a").eq(lit(value))).run().await; +} + +#[tokio::test] +async fn test_utf8_not_eq() { + Utf8Test::new(|value| col("a").not_eq(lit(value))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_lt() { + Utf8Test::new(|value| col("a").lt(lit(value))).run().await; +} + +#[tokio::test] +async fn test_utf8_lt_eq() { + Utf8Test::new(|value| col("a").lt_eq(lit(value))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_gt() { + Utf8Test::new(|value| col("a").gt(lit(value))).run().await; +} + +#[tokio::test] +async fn test_utf8_gt_eq() { + Utf8Test::new(|value| col("a").gt_eq(lit(value))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_like() { + Utf8Test::new(|value| col("a").like(lit(value))).run().await; +} + +#[tokio::test] +async fn test_utf8_not_like() { + Utf8Test::new(|value| col("a").not_like(lit(value))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_like_prefix() { + Utf8Test::new(|value| col("a").like(lit(format!("%{}", value)))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_like_suffix() { + Utf8Test::new(|value| col("a").like(lit(format!("{}%", value)))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_not_like_prefix() { + Utf8Test::new(|value| col("a").not_like(lit(format!("%{}", value)))) + .run() + .await; +} + +#[tokio::test] +async fn test_utf8_not_like_suffix() { + Utf8Test::new(|value| col("a").not_like(lit(format!("{}%", value)))) + .run() + .await; +} + +/// Fuzz testing for UTF8 predicate pruning +/// The basic idea is that query results should always be the same with or without stats/pruning +/// If we get this right we at least guarantee that there are no incorrect results +/// There may still be suboptimal pruning or stats but that's something we can try to catch +/// with more targeted tests. +// +/// Since we know where the edge cases might be we don't do random black box fuzzing. +/// Instead we fuzz on specific pre-defined axis: +/// +/// - Which characters are in each value. We want to make sure to include characters that when +/// incremented, truncated or otherwise manipulated might cause issues. +/// - The values in each row group. This impacts which min/max stats are generated for each rg. +/// We'll generate combinations of the characters with lengths ranging from 1 to 4. +/// - Truncation of statistics to 1, 2 or 3 characters as well as no truncation. +struct Utf8Test { + /// Test queries the parquet files with this predicate both with and without + /// pruning enabled + predicate_generator: Box Expr + 'static>, +} + +impl Utf8Test { + /// Create a new test with the given predicate generator + fn new Expr + 'static>(f: F) -> Self { + Self { + predicate_generator: Box::new(f), + } + } + + /// Run the test by evaluating the predicate on the test files with and + /// without pruning enable + async fn run(&self) { + let ctx = SessionContext::new(); + + let mut predicates = vec![]; + for value in Self::values() { + predicates.push((self.predicate_generator)(value)); + } + + let store = Self::memory_store(); + ctx.register_object_store(&Url::parse("memory://").unwrap(), Arc::clone(store)); + + let files = Self::test_files().await; + let schema = Self::schema(); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); + + println!("Testing {} predicates", predicates.len()); + for predicate in predicates { + // println!("Testing predicate {:?}", predicate); + let phys_expr_predicate = ctx + .create_physical_expr(predicate.clone(), &df_schema) + .unwrap(); + let expected = execute_with_predicate( + &files, + Arc::clone(&phys_expr_predicate), + false, + schema.clone(), + &ctx, + ) + .await; + let with_pruning = execute_with_predicate( + &files, + phys_expr_predicate, + true, + schema.clone(), + &ctx, + ) + .await; + assert_eq!(expected, with_pruning); + } + } + + /// all combinations of interesting charactes with lengths ranging from 1 to 4 + fn values() -> &'static [String] { + VALUES.get_or_init(|| { + let mut rng = rand::thread_rng(); + + let characters = [ + "z", + "0", + "~", + "ß", + "℣", + "%", // this one is useful for like/not like tests since it will result in randomly inserted wildcards + "_", // this one is useful for like/not like tests since it will result in randomly inserted wildcards + "\u{7F}", + "\u{7FF}", + "\u{FF}", + "\u{10FFFF}", + "\u{D7FF}", + "\u{FDCF}", + // null character + "\u{0}", + ]; + let value_lengths = [1, 2, 3]; + let mut values = vec![]; + for length in &value_lengths { + values.extend( + characters + .iter() + .cloned() + .combinations(*length) + // now get all permutations of each combination + .flat_map(|c| c.into_iter().permutations(*length)) + // and join them into strings + .map(|c| c.join("")), + ); + } + println!("Generated {} values", values.len()); + // randomly pick 100 values + values.shuffle(&mut rng); + values.truncate(100); + values + }) + } + + /// return the in memory object store + fn memory_store() -> &'static Arc { + MEMORY_STORE.get_or_init(|| Arc::new(InMemory::new())) + } + + /// return the schema of the created test files + fn schema() -> Arc { + let schema = SCHEMA.get_or_init(|| { + Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])) + }); + Arc::clone(schema) + } + + /// Return a list of test files with UTF8 data and combinations of + /// [`Self::values`] + async fn test_files() -> Vec { + let files_mutex = TESTFILES.get_or_init(|| Mutex::new(vec![])); + let mut files = files_mutex.lock().await; + if !files.is_empty() { + return (*files).clone(); + } + + let mut rng = rand::thread_rng(); + let values = Self::values(); + + let mut row_groups = vec![]; + // generate all combinations of values for row groups (1 or 2 values per rg, more is unnecessary since we only get min/max stats out) + for rg_length in [1, 2] { + row_groups.extend(values.iter().cloned().combinations(rg_length)); + } + + println!("Generated {} row groups", row_groups.len()); + + // Randomly pick 100 row groups (combinations of said values) + row_groups.shuffle(&mut rng); + row_groups.truncate(100); + + let schema = Self::schema(); + + let store = Self::memory_store(); + for (idx, truncation_length) in [Some(1), Some(2), None].iter().enumerate() { + // parquet files only support 32767 row groups per file, so chunk up into multiple files so we don't error if running on a large number of row groups + for (rg_idx, row_groups) in row_groups.chunks(32766).enumerate() { + let buf = write_parquet_file( + *truncation_length, + Arc::clone(&schema), + row_groups.to_vec(), + ) + .await; + let filename = format!("test_fuzz_utf8_{idx}_{rg_idx}.parquet"); + let size = buf.len(); + let path = Path::from(filename); + let payload = PutPayload::from(buf); + store.put(&path, payload).await.unwrap(); + + files.push(TestFile { path, size }); + } + } + + println!("Generated {} parquet files", files.len()); + files.clone() + } +} + +async fn execute_with_predicate( + files: &[TestFile], + predicate: Arc, + prune_stats: bool, + schema: Arc, + ctx: &SessionContext, +) -> Vec { + let scan = + FileScanConfig::new(ObjectStoreUrl::parse("memory://").unwrap(), schema.clone()) + .with_file_group( + files + .iter() + .map(|test_file| { + PartitionedFile::new( + test_file.path.clone(), + test_file.size as u64, + ) + }) + .collect(), + ); + let mut builder = ParquetExecBuilder::new(scan); + if prune_stats { + builder = builder.with_predicate(predicate.clone()) + } + let exec = Arc::new(builder.build()) as Arc; + let exec = + Arc::new(FilterExec::try_new(predicate, exec).unwrap()) as Arc; + + let batches = collect(exec, ctx.task_ctx()).await.unwrap(); + let mut values = vec![]; + for batch in batches { + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..column.len() { + values.push(column.value(i).to_string()); + } + } + values +} + +async fn write_parquet_file( + truncation_length: Option, + schema: Arc, + row_groups: Vec>, +) -> Bytes { + let mut buf = BytesMut::new().writer(); + let mut props = WriterProperties::builder(); + if let Some(truncation_length) = truncation_length { + props = { + #[allow(deprecated)] + props.set_max_statistics_size(truncation_length) + } + } + props = props.set_statistics_enabled(EnabledStatistics::Chunk); // row group level + let props = props.build(); + { + let mut writer = + ArrowWriter::try_new(&mut buf, schema.clone(), Some(props)).unwrap(); + for rg_values in row_groups.iter() { + let arr = StringArray::from_iter_values(rg_values.iter()); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(arr)]).unwrap(); + writer.write(&batch).unwrap(); + writer.flush().unwrap(); // finishes the current row group and starts a new one + } + writer.finish().unwrap(); + } + buf.into_inner().freeze() +} + +/// The string values for [Utf8Test::values] +static VALUES: OnceLock> = OnceLock::new(); +/// The schema for the [Utf8Test::schema] +static SCHEMA: OnceLock> = OnceLock::new(); + +/// The InMemory object store +static MEMORY_STORE: OnceLock> = OnceLock::new(); + +/// List of in memory parquet files with UTF8 data +// Use a mutex rather than OnceLock to allow for async initialization +static TESTFILES: OnceLock>> = OnceLock::new(); + +/// Holds a temporary parquet file path and its size +#[derive(Debug, Clone)] +struct TestFile { + path: Path, + size: usize, +} diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index fae4731569b69..19ffa69f11d36 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -30,6 +30,7 @@ use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_execution::memory_pool::GreedyMemoryPool; use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use rand::Rng; use std::sync::Arc; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -37,8 +38,8 @@ use test_utils::{batches_to_vec, partitions_to_sorted_vec}; const KB: usize = 1 << 10; #[tokio::test] #[cfg_attr(tarpaulin, ignore)] -async fn test_sort_1k_mem() { - for (batch_size, should_spill) in [(5, false), (20000, true), (1000000, true)] { +async fn test_sort_10k_mem() { + for (batch_size, should_spill) in [(5, false), (20000, true), (500000, true)] { SortTest::new() .with_int32_batches(batch_size) .with_pool_size(10 * KB) @@ -92,7 +93,7 @@ impl SortTest { self } - /// specify that this test should use a memory pool of the specifeid size + /// specify that this test should use a memory pool of the specified size fn with_pool_size(mut self, pool_size: usize) -> Self { self.pool_size = Some(pool_size); self @@ -114,13 +115,13 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let exec = MemoryExec::try_new(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort, Arc::new(exec))); diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 0cd702372f7ce..daa282c8fe4a9 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -45,6 +45,7 @@ mod sp_repartition_fuzz_tests { }; use test_utils::add_empty_batches; + use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -174,7 +175,7 @@ mod sp_repartition_fuzz_tests { }) .unzip(); - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + let sort_arrs = lexsort(&sort_columns, None)?; for (idx, arr) in izip!(indices, sort_arrs) { schema_vec[idx] = Some(arr); } @@ -260,15 +261,15 @@ mod sp_repartition_fuzz_tests { for ordering in eq_properties.oeq_class().iter() { let err_msg = format!("error in eq properties: {:?}", eq_properties); - let sort_solumns = ordering + let sort_columns = ordering .iter() .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) .collect::>>()?; - let orig_columns = sort_solumns + let orig_columns = sort_columns .iter() .map(|sort_column| sort_column.values.clone()) .collect::>(); - let sorted_columns = lexsort(&sort_solumns, None)?; + let sorted_columns = lexsort(&sort_columns, None)?; // Make sure after merging ordering is still valid. assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); @@ -345,7 +346,7 @@ mod sp_repartition_fuzz_tests { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = vec![]; + let mut sort_keys = LexOrdering::default(); for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { expr: col(ordering_col, &schema).unwrap(), @@ -358,7 +359,8 @@ mod sp_repartition_fuzz_tests { let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys.clone()]), + .try_with_sort_information(vec![sort_keys.clone()]) + .unwrap(), ); let hash_exprs = vec![col("c", &schema).unwrap()]; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index a6c2cf700cc4e..979aa5a2da035 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -34,8 +34,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; @@ -45,7 +44,13 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use datafusion::functions_window::row_number::row_number_udwf; -use hashbrown::HashMap; +use datafusion_common::HashMap; +use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf}; +use datafusion_functions_window::nth_value::{ + first_value_udwf, last_value_udwf, nth_value_udwf, +}; +use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use rand::distributions::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -196,7 +201,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::WindowUDF(lag_udwf()), // its name "LAG", // no argument @@ -210,7 +215,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::WindowUDF(lead_udwf()), // its name "LEAD", // no argument @@ -224,9 +229,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::WindowUDF(rank_udwf()), // its name - "RANK", + "rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -238,11 +243,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), // its name - "DENSE_RANK", + "dense_rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -251,7 +254,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { ]; let partitionby_exprs = vec![]; - let orderby_exprs = vec![]; + let orderby_exprs = LexOrdering::default(); // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -284,7 +287,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { fn_name.to_string(), &args, &partitionby_exprs, - &orderby_exprs, + orderby_exprs.as_ref(), Arc::new(window_frame), &extended_schema, false, @@ -293,12 +296,10 @@ async fn bounded_window_causal_non_causal() -> Result<()> { vec![window_expr], memory_exec.clone(), vec![], - InputOrderMode::Linear, + Linear, )?); let task_ctx = ctx.task_ctx(); - let mut collected_results = - collect(running_window_exec, task_ctx).await?; - collected_results.retain(|batch| batch.num_rows() > 0); + let collected_results = collect(running_window_exec, task_ctx).await?; let input_batch_sizes = batches .iter() .map(|batch| batch.num_rows()) @@ -307,6 +308,8 @@ async fn bounded_window_causal_non_causal() -> Result<()> { .iter() .map(|batch| batch.num_rows()) .collect::>(); + // There should be no empty batches at results + assert!(result_batch_sizes.iter().all(|e| *e > 0)); if causal { // For causal window frames, we can generate results immediately // for each input batch. Hence, batch sizes should match. @@ -382,28 +385,19 @@ fn get_random_function( ); window_fn_map.insert( "rank", - ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Rank, - ), - vec![], - ), + (WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![]), ); window_fn_map.insert( "dense_rank", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lead, - ), + WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -414,9 +408,7 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lag, - ), + WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -428,27 +420,21 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue, - ), + WindowFunctionDefinition::WindowUDF(first_value_udwf()), vec![arg.clone()], ), ); window_fn_map.insert( "last_value", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue, - ), + WindowFunctionDefinition::WindowUDF(last_value_udwf()), vec![arg.clone()], ), ); window_fn_map.insert( "nth_value", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::NthValue, - ), + WindowFunctionDefinition::WindowUDF(nth_value_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -603,14 +589,14 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, InputOrderMode::Sorted); + let is_linear = !matches!(search_mode, Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); - let mut orderby_exprs = vec![]; + let mut orderby_exprs = LexOrdering::default(); for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { expr: col(column, &schema)?, @@ -618,13 +604,13 @@ async fn run_window_test( }) } if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { - orderby_exprs = orderby_exprs[0..1].to_vec(); + orderby_exprs = LexOrdering::new(orderby_exprs[0..1].to_vec()); } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { partitionby_exprs.push(col(column, &schema)?); } - let mut sort_keys = vec![]; + let mut sort_keys = LexOrdering::default(); for partition_by_expr in &partitionby_exprs { sort_keys.push(PhysicalSortExpr { expr: partition_by_expr.clone(), @@ -638,7 +624,7 @@ async fn run_window_test( } let concat_input_record = concat_batches(&schema, &input1)?; - let source_sort_keys = vec![ + let source_sort_keys = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: Default::default(), @@ -651,10 +637,10 @@ async fn run_window_test( expr: col("c", &schema)?, options: Default::default(), }, - ]; + ]); let mut exec1 = Arc::new( MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. @@ -670,7 +656,7 @@ async fn run_window_test( fn_name.clone(), &args, &partitionby_exprs, - &orderby_exprs, + orderby_exprs.as_ref(), Arc::new(window_frame.clone()), &extended_schema, false, @@ -680,7 +666,7 @@ async fn run_window_test( )?) as _; let exec2 = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( @@ -688,7 +674,7 @@ async fn run_window_test( fn_name, &args, &partitionby_exprs, - &orderby_exprs, + orderby_exprs.as_ref(), Arc::new(window_frame.clone()), &extended_schema, false, @@ -702,8 +688,8 @@ async fn run_window_test( let collected_running = collect(running_window_exec, task_ctx) .await? .into_iter() - .filter(|b| b.num_rows() > 0) .collect::>(); + assert!(collected_running.iter().all(|rb| rb.num_rows() > 0)); // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 72ac6e64fb0c5..5aff1d5e32961 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -14,9 +14,11 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -//! Verifies [Macro Hygene] + +//! Verifies [Macro Hygiene] //! -//! [Macro Hygene]: https://en.wikipedia.org/wiki/Hygienic_macro +//! [Macro Hygiene]: https://en.wikipedia.org/wiki/Hygienic_macro + mod plan_err { // NO other imports! use datafusion_common::plan_err; @@ -37,3 +39,13 @@ mod plan_datafusion_err { plan_datafusion_err!("foo"); } } + +mod record_batch { + // NO other imports! + use datafusion_common::record_batch; + + #[test] + fn test_macro() { + record_batch!(("column_name", Int32, vec![1, 2, 3])).unwrap(); + } +} diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/mod.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/mod.rs new file mode 100644 index 0000000000000..32df6c5d62937 --- /dev/null +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Validates query's actual memory usage is consistent with the specified memory +//! limit. + +mod sort_mem_validation; +mod utils; diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs new file mode 100644 index 0000000000000..1789f37535a94 --- /dev/null +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Memory limit validation tests for the sort queries +//! +//! These tests must run in separate processes to accurately measure memory usage. +//! This file is organized as: +//! - Test runners that spawn individual test processes +//! - Test cases that contain the actual validation logic +use std::{process::Command, str}; + +use log::info; + +use crate::memory_limit::memory_limit_validation::utils; + +// =========================================================================== +// Test runners: +// Runners are splitted into multiple tests to run in parallel +// =========================================================================== + +#[test] +fn memory_limit_validation_runner_works_runner() { + spawn_test_process("memory_limit_validation_runner_works"); +} + +#[test] +fn sort_no_mem_limit_runner() { + spawn_test_process("sort_no_mem_limit"); +} + +#[test] +fn sort_with_mem_limit_1_runner() { + spawn_test_process("sort_with_mem_limit_1"); +} + +#[test] +fn sort_with_mem_limit_2_runner() { + spawn_test_process("sort_with_mem_limit_2"); +} + +#[test] +fn sort_with_mem_limit_3_runner() { + spawn_test_process("sort_with_mem_limit_3"); +} + +#[test] +fn sort_with_mem_limit_2_cols_1_runner() { + spawn_test_process("sort_with_mem_limit_2_cols_1"); +} + +#[test] +fn sort_with_mem_limit_2_cols_2_runner() { + spawn_test_process("sort_with_mem_limit_2_cols_2"); +} + +/// Helper function that executes a test in a separate process with the required environment +/// variable set. Memory limit validation tasks need to measure memory resident set +/// size (RSS), so they must run in a separate process. +fn spawn_test_process(test: &str) { + let test_path = format!( + "memory_limit::memory_limit_validation::sort_mem_validation::{}", + test + ); + info!("Running test: {}", test_path); + + // Run the test command + let output = Command::new("cargo") + .arg("test") + .arg("--package") + .arg("datafusion") + .arg("--test") + .arg("core_integration") + .arg("--features") + .arg("extended_tests") + .arg("--") + .arg(&test_path) + .arg("--exact") + .arg("--nocapture") + .env("DATAFUSION_TEST_MEM_LIMIT_VALIDATION", "1") + .output() + .expect("Failed to execute test command"); + + // Convert output to strings + let stdout = str::from_utf8(&output.stdout).unwrap_or(""); + let stderr = str::from_utf8(&output.stderr).unwrap_or(""); + + info!("{}", stdout); + + assert!( + output.status.success(), + "Test '{}' failed with status: {}\nstdout:\n{}\nstderr:\n{}", + test, + output.status, + stdout, + stderr + ); +} + +// =========================================================================== +// Test cases: +// All following tests need to be run through their individual test wrapper. +// When run directly, environment variable `DATAFUSION_TEST_MEM_LIMIT_VALIDATION` +// is not set, test will return with a no-op. +// +// If some tests consistently fail, suppress by setting a larger expected memory +// usage (e.g. 80_000_000 * 3 -> 80_000_000 * 4) +// =========================================================================== + +/// Test runner itself: if memory limit violated, test should fail. +#[tokio::test] +async fn memory_limit_validation_runner_works() { + if std::env::var("DATAFUSION_TEST_MEM_LIMIT_VALIDATION").is_err() { + println!("Skipping test because DATAFUSION_TEST_MEM_LIMIT_VALIDATION is not set"); + + return; + } + + let result = std::panic::catch_unwind(|| { + tokio::runtime::Runtime::new().unwrap().block_on(async { + utils::validate_query_with_memory_limits( + 20_000_000, // set an impossible limit: query requires at least 80MB + None, + "select * from generate_series(1,10000000) as t1(c1) order by c1", + "select * from generate_series(1,1000000) as t1(c1) order by c1", // Baseline query with ~10% of data + ) + .await; + }) + }); + + assert!( + result.is_err(), + "Expected the query to panic due to memory limit" + ); +} + +#[tokio::test] +async fn sort_no_mem_limit() { + utils::validate_query_with_memory_limits( + 80_000_000 * 3, + None, + "select * from generate_series(1,10000000) as t1(c1) order by c1", + "select * from generate_series(1,1000000) as t1(c1) order by c1", // Baseline query with ~10% of data + ) + .await; +} + +#[tokio::test] +async fn sort_with_mem_limit_1() { + utils::validate_query_with_memory_limits( + 40_000_000 * 5, + Some(40_000_000), + "select * from generate_series(1,10000000) as t1(c1) order by c1", + "select * from generate_series(1,1000000) as t1(c1) order by c1", // Baseline query with ~10% of data + ) + .await; +} + +#[tokio::test] +async fn sort_with_mem_limit_2() { + utils::validate_query_with_memory_limits( + 80_000_000 * 3, + Some(80_000_000), + "select * from generate_series(1,10000000) as t1(c1) order by c1", + "select * from generate_series(1,1000000) as t1(c1) order by c1", // Baseline query with ~10% of data + ) + .await; +} + +#[tokio::test] +async fn sort_with_mem_limit_3() { + utils::validate_query_with_memory_limits( + 80_000_000 * 3, + Some(80_000_000 * 10), // mem limit is large enough so that no spill happens + "select * from generate_series(1,10000000) as t1(c1) order by c1", + "select * from generate_series(1,1000000) as t1(c1) order by c1", // Baseline query with ~10% of data + ) + .await; +} + +#[tokio::test] +async fn sort_with_mem_limit_2_cols_1() { + let memory_usage_in_theory = 80_000_000 * 2; // 2 columns + let expected_max_mem_usage = memory_usage_in_theory * 4; + utils::validate_query_with_memory_limits( + expected_max_mem_usage, + None, + "select c1, c1 as c2 from generate_series(1,10000000) as t1(c1) order by c2 DESC, c1 ASC NULLS LAST", + "select c1, c1 as c2 from generate_series(1,1000000) as t1(c1) order by c2 DESC, c1 ASC NULLS LAST", // Baseline query with ~10% of data + ) + .await; +} + +// TODO: Query fails, fix it +// Issue: https://github.com/apache/datafusion/issues/14143 +#[ignore] +#[tokio::test] +async fn sort_with_mem_limit_2_cols_2() { + let memory_usage_in_theory = 80_000_000 * 2; // 2 columns + let expected_max_mem_usage = memory_usage_in_theory * 3; + let mem_limit = memory_usage_in_theory as f64 * 0.5; + + utils::validate_query_with_memory_limits( + expected_max_mem_usage, + Some(mem_limit as i64), + "select c1, c1 as c2 from generate_series(1,10000000) as t1(c1) order by c2 DESC, c1 ASC NULLS LAST", + "select c1, c1 as c2 from generate_series(1,1000000) as t1(c1) order by c2 DESC, c1 ASC NULLS LAST", // Baseline query with ~10% of data + ) + .await; +} diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs new file mode 100644 index 0000000000000..bdf30c140afff --- /dev/null +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common_runtime::SpawnedTask; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use sysinfo::System; +use tokio::time::{interval, Duration}; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_execution::{ + memory_pool::{human_readable_size, FairSpillPool}, + runtime_env::RuntimeEnvBuilder, +}; + +/// Measures the maximum RSS (in bytes) during the execution of an async task. RSS +/// will be sampled every 7ms. +/// +/// # Arguments +/// +/// * `f` - A closure that returns the async task to be measured. +/// +/// # Returns +/// +/// A tuple containing the result of the async task and the maximum RSS observed. +async fn measure_max_rss(f: F) -> (T, usize) +where + F: FnOnce() -> Fut, + Fut: std::future::Future, +{ + // Initialize system information + let mut system = System::new_all(); + system.refresh_all(); + + // Get the current process ID + let pid = sysinfo::get_current_pid().expect("Failed to get current PID"); + + // Shared atomic variable to store max RSS + let max_rss = Arc::new(AtomicUsize::new(0)); + + // Clone for the monitoring task + let max_rss_clone = Arc::clone(&max_rss); + + // Spawn a monitoring task + let monitor_handle = SpawnedTask::spawn(async move { + let mut sys = System::new_all(); + let mut interval = interval(Duration::from_millis(7)); + + loop { + interval.tick().await; + sys.refresh_all(); + if let Some(process) = sys.process(pid) { + let rss_bytes = process.memory(); + max_rss_clone + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + if rss_bytes as usize > current { + Some(rss_bytes as usize) + } else { + None + } + }) + .ok(); + } else { + // Process no longer exists + break; + } + } + }); + + // Execute the async task + let result = f().await; + + // Give some time for the monitor to catch the final memory state + tokio::time::sleep(Duration::from_millis(200)).await; + + // Terminate the monitoring task + drop(monitor_handle); + + // Retrieve the maximum RSS + let peak_rss = max_rss.load(Ordering::Relaxed); + + (result, peak_rss) +} + +/// Query runner that validates the memory usage of the query. +/// +/// Note this function is supposed to run in a separate process for accurate memory +/// estimation. If environment variable `DATAFUSION_TEST_MEM_LIMIT_VALIDATION` is +/// not set, this function will return immediately, so test cases calls this function +/// should first set the environment variable, then create a new process to run. +/// See `sort_mem_validation.rs` for more details. +/// +/// # Arguments +/// +/// * `expected_mem_bytes` - The maximum expected memory usage for the query. +/// * `mem_limit_bytes` - The memory limit of the query in bytes. `None` means no +/// memory limit is presented. +/// * `query` - The SQL query to execute +/// * `baseline_query` - The SQL query to execute for estimating constant overhead. +/// This query should use 10% of the data of the main query. +/// +/// # Example +/// +/// utils::validate_query_with_memory_limits( +/// 40_000_000 * 2, +/// Some(40_000_000), +/// "SELECT * FROM generate_series(1, 100000000) AS t(i) ORDER BY i", +/// "SELECT * FROM generate_series(1, 10000000) AS t(i) ORDER BY i" +/// ); +/// +/// The above function call means: +/// Set the memory limit to 40MB, and the profiled memory usage of {query - baseline_query} +/// should be less than 40MB * 2. +pub async fn validate_query_with_memory_limits( + expected_mem_bytes: i64, + mem_limit_bytes: Option, + query: &str, + baseline_query: &str, +) { + if std::env::var("DATAFUSION_TEST_MEM_LIMIT_VALIDATION").is_err() { + println!("Skipping test because DATAFUSION_TEST_MEM_LIMIT_VALIDATION is not set"); + + return; + } + + println!("Current process ID: {}", std::process::id()); + + let runtime_builder = RuntimeEnvBuilder::new(); + + let runtime = match mem_limit_bytes { + Some(mem_limit_bytes) => runtime_builder + .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit_bytes as usize))) + .build_arc() + .unwrap(), + None => runtime_builder.build_arc().unwrap(), + }; + + let session_config = SessionConfig::new().with_target_partitions(4); // Make sure the configuration is the same if test is running on different machines + + let ctx = SessionContext::new_with_config_rt(session_config, runtime); + + let df = ctx.sql(query).await.unwrap(); + // Run a query with 10% data to estimate the constant overhead + let df_small = ctx.sql(baseline_query).await.unwrap(); + + let (_, baseline_max_rss) = + measure_max_rss(|| async { df_small.collect().await.unwrap() }).await; + + let (_, max_rss) = measure_max_rss(|| async { df.collect().await.unwrap() }).await; + + println!( + "Memory before: {}, Memory after: {}", + human_readable_size(baseline_max_rss), + human_readable_size(max_rss) + ); + + let actual_mem_usage = max_rss as f64 - baseline_max_rss as f64; + + println!( + "Query: {}, Memory usage: {}, Memory limit: {}", + query, + human_readable_size(actual_mem_usage as usize), + human_readable_size(expected_mem_bytes as usize) + ); + + assert!( + actual_mem_usage < expected_mem_bytes as f64, + "Memory usage exceeded the theoretical limit. Actual: {}, Expected limit: {}", + human_readable_size(actual_mem_usage as usize), + human_readable_size(expected_mem_bytes as usize) + ); +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index ec66df45c7baa..c7514d1c24b1b 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -31,10 +31,11 @@ use datafusion_execution::memory_pool::{ }; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::spill::get_record_batch_memory_size; use futures::StreamExt; use std::any::Any; use std::num::NonZeroUsize; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; use tokio::fs::File; use datafusion::datasource::streaming::StreamingTable; @@ -238,15 +239,15 @@ async fn sort_preserving_merge() { // SortPreservingMergeExec (not a Sort which would compete // with the SortPreservingMergeExec for memory) &[ - "+---------------+-----------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+-----------------------------------------------------------------------------------------------------------+", - "| logical_plan | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", - "| | TableScan: t projection=[a, b] |", - "| physical_plan | SortPreservingMergeExec: [a@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 |", - "| | MemoryExec: partitions=2, partition_sizes=[5, 5], output_ordering=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST |", - "| | |", - "+---------------+-----------------------------------------------------------------------------------------------------------+", + "+---------------+------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | SortPreservingMergeExec: [a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 |", + "| | MemoryExec: partitions=2, partition_sizes=[5, 5], output_ordering=a@0 ASC NULLS LAST, b@1 ASC NULLS LAST |", + "| | |", + "+---------------+------------------------------------------------------------------------------------------------------------+", ] ) .run() @@ -265,6 +266,10 @@ async fn sort_spill_reservation() { // This test case shows how sort_spill_reservation works by // purposely sorting data that requires non trivial memory to // sort/merge. + + // Merge operation needs extra memory to do row conversion, so make the + // memory limit larger. + let mem_limit = partition_size * 2; let test = TestCase::new() // This query uses a different order than the input table to // force a sort. It also needs to have multiple columns to @@ -272,7 +277,7 @@ async fn sort_spill_reservation() { // substantial memory .with_query("select * from t ORDER BY a , b DESC") // enough memory to sort if we don't try to merge it all at once - .with_memory_limit(partition_size) + .with_memory_limit(mem_limit) // use a single partition so only a sort is needed .with_scenario(scenario) .with_disk_manager_config(DiskManagerConfig::NewOs) @@ -281,15 +286,15 @@ async fn sort_spill_reservation() { // also merge, so we can ensure the sort could finish // given enough merging memory &[ - "+---------------+--------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+--------------------------------------------------------------------------------------------------------+", - "| logical_plan | Sort: t.a ASC NULLS LAST, t.b DESC NULLS FIRST |", - "| | TableScan: t projection=[a, b] |", - "| physical_plan | SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC], preserve_partitioning=[false] |", - "| | MemoryExec: partitions=1, partition_sizes=[5], output_ordering=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST |", - "| | |", - "+---------------+--------------------------------------------------------------------------------------------------------+", + "+---------------+---------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+---------------------------------------------------------------------------------------------------------+", + "| logical_plan | Sort: t.a ASC NULLS LAST, t.b DESC NULLS FIRST |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | SortExec: expr=[a@0 ASC NULLS LAST, b@1 DESC], preserve_partitioning=[false] |", + "| | MemoryExec: partitions=1, partition_sizes=[5], output_ordering=a@0 ASC NULLS LAST, b@1 ASC NULLS LAST |", + "| | |", + "+---------------+---------------------------------------------------------------------------------------------------------+", ] ); @@ -311,7 +316,7 @@ async fn sort_spill_reservation() { // reserve sufficient space up front for merge and this time, // which will force the spills to happen with less buffered // input and thus with enough to merge. - .with_sort_spill_reservation_bytes(partition_size / 2); + .with_sort_spill_reservation_bytes(mem_limit / 2); test.with_config(config).with_expected_success().run().await; } @@ -654,7 +659,7 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![vec![ + let sort_information = vec![LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema).unwrap(), options, @@ -663,7 +668,7 @@ impl Scenario { expr: col("b", &schema).unwrap(), options, }, - ]]; + ])]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) @@ -725,15 +730,14 @@ fn maybe_split_batches( .collect() } -static DICT_BATCHES: OnceLock> = OnceLock::new(); - /// Returns 5 sorted string dictionary batches each with 50 rows with /// this schema. /// /// a: Dictionary, /// b: Dictionary, fn dict_batches() -> Vec { - DICT_BATCHES.get_or_init(make_dict_batches).clone() + static DICT_BATCHES: LazyLock> = LazyLock::new(make_dict_batches); + DICT_BATCHES.clone() } fn make_dict_batches() -> Vec { @@ -774,7 +778,7 @@ fn make_dict_batches() -> Vec { // How many bytes does the memory from dict_batches consume? fn batches_byte_size(batches: &[RecordBatch]) -> usize { - batches.iter().map(|b| b.get_array_memory_size()).sum() + batches.iter().map(get_record_batch_memory_size).sum() } #[derive(Debug)] @@ -840,7 +844,7 @@ impl TableProvider for SortedTableProvider { ) -> Result> { let mem_exec = MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())? - .with_sort_information(self.sort_information.clone()); + .try_with_sort_information(self.sort_information.clone())?; Ok(Arc::new(mem_exec)) } diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 7c1e199ceb95a..dc57ba1e443a7 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -71,6 +71,7 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { range: None, statistics: None, extensions: Some(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))), + metadata_size_hint: None, }) .collect(); diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 03afc858dfcaa..61a9e9b5757c8 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -33,7 +33,8 @@ use datafusion_physical_plan::ExecutionPlan; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; -use std::sync::{Arc, OnceLock}; +use std::path::Path; +use std::sync::Arc; use tempfile::NamedTempFile; #[tokio::test] @@ -160,7 +161,7 @@ async fn plan_and_filter() { RowGroupAccess::Scan, ])); - // initia + // initial let parquet_metrics = TestFull { access_plan, expected_rows: 0, @@ -273,7 +274,7 @@ struct Test { impl Test { /// Runs the test case, panic'ing on error. /// - /// Returns the `MetricsSet` from the ParqeutExec + /// Returns the [`MetricsSet`] from the [`ParquetExec`] async fn run_success(self) -> MetricsSet { let Self { access_plan, @@ -313,13 +314,20 @@ impl TestFull { } = self; let TestData { - temp_file: _, - schema, - file_name, - file_size, + _temp_file: _, + ref schema, + ref file_name, + ref file_size, } = get_test_data(); - let mut partitioned_file = PartitionedFile::new(file_name, *file_size); + let new_file_name = if cfg!(target_os = "windows") { + // Windows path separator is different from Unix + file_name.replace("\\", "/") + } else { + file_name.clone() + }; + + let mut partitioned_file = PartitionedFile::new(new_file_name, *file_size); // add the access plan, if any, as an extension if let Some(access_plan) = access_plan { @@ -355,59 +363,57 @@ impl TestFull { pretty_format_batches(&results).unwrap() ); + std::fs::remove_file(file_name).unwrap(); + Ok(MetricsFinder::find_metrics(plan.as_ref()).unwrap()) } } // Holds necessary data for these tests to reuse the same parquet file struct TestData { - // field is present as on drop the file is deleted - #[allow(dead_code)] - temp_file: NamedTempFile, + /// Pointer to temporary file storage. Keeping it in scope to prevent temporary folder + /// to be deleted prematurely + _temp_file: NamedTempFile, schema: SchemaRef, file_name: String, file_size: u64, } -static TEST_DATA: OnceLock = OnceLock::new(); - /// Return a parquet file with 2 row groups each with 5 rows -fn get_test_data() -> &'static TestData { - TEST_DATA.get_or_init(|| { - let scenario = Scenario::UTF8; - let row_per_group = 5; +fn get_test_data() -> TestData { + let scenario = Scenario::UTF8; + let row_per_group = 5; - let mut temp_file = tempfile::Builder::new() - .prefix("user_access_plan") - .suffix(".parquet") - .tempfile() - .expect("tempfile creation"); + let mut temp_file = tempfile::Builder::new() + .prefix("user_access_plan") + .suffix(".parquet") + .tempfile_in(Path::new("")) + .expect("tempfile creation"); - let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) - .build(); + let props = WriterProperties::builder() + .set_max_row_group_size(row_per_group) + .build(); - let batches = create_data_batch(scenario); - let schema = batches[0].schema(); + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); - let mut writer = - ArrowWriter::try_new(&mut temp_file, schema.clone(), Some(props)).unwrap(); + let mut writer = + ArrowWriter::try_new(&mut temp_file, schema.clone(), Some(props)).unwrap(); - for batch in batches { - writer.write(&batch).expect("writing batch"); - } - writer.close().unwrap(); + for batch in batches { + writer.write(&batch).expect("writing batch"); + } + writer.close().unwrap(); - let file_name = temp_file.path().to_string_lossy().to_string(); - let file_size = temp_file.path().metadata().unwrap().len(); + let file_name = temp_file.path().to_string_lossy().to_string(); + let file_size = temp_file.path().metadata().unwrap().len(); - TestData { - temp_file, - schema, - file_name, - file_size, - } - }) + TestData { + _temp_file: temp_file, + schema, + file_name, + file_size, + } } /// Return the total value of the specified metric name diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 18d8300fb254d..4b5d22bfa71ff 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,7 +28,6 @@ use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, }; @@ -211,8 +210,8 @@ fn get_cache_runtime_state() -> ( SessionState, ) { let cache_config = CacheManagerConfig::default(); - let file_static_cache = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); - let list_file_cache = Arc::new(cache_unit::DefaultListFilesCache::default()); + let file_static_cache = Arc::new(DefaultFileStatisticsCache::default()); + let list_file_cache = Arc::new(DefaultListFilesCache::default()); let cache_config = cache_config .with_files_statistics_cache(Some(file_static_cache.clone())) diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 8def192f9331d..02fb59740493f 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -26,6 +26,8 @@ //! select * from data limit 10; //! ``` +use std::path::Path; + use arrow::compute::concat_batches; use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; @@ -67,7 +69,7 @@ fn generate_file(tempdir: &TempDir, props: WriterProperties) -> TestParquetFile async fn single_file() { // Only create the parquet file once as it is fairly large - let tempdir = TempDir::new().unwrap(); + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); // Set row group size smaller so can test with fewer rows let props = WriterProperties::builder() .set_max_row_group_size(1024) @@ -223,7 +225,7 @@ async fn single_file() { #[tokio::test] async fn single_file_small_data_pages() { - let tempdir = TempDir::new().unwrap(); + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); // Set low row count limit to improve page filtering let props = WriterProperties::builder() diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index cfa2a3df3ba23..f45eacce18df5 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -17,14 +17,14 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; -use arrow::array::Decimal128Array; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeBinaryArray, LargeStringArray, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -43,11 +43,8 @@ use std::sync::Arc; use tempfile::NamedTempFile; mod custom_reader; -// Don't run on windows as tempfiles don't seem to work the same -#[cfg(not(target_os = "windows"))] mod external_access_plan; mod file_statistics; -#[cfg(not(target_family = "windows"))] mod filter_pushdown; mod page_pruning; mod row_group_pruning; @@ -67,7 +64,7 @@ fn init() { // ---------------------- /// What data to use -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum Scenario { Timestamps, Dates, @@ -87,6 +84,7 @@ enum Scenario { WithNullValues, WithNullValuesPageLevel, UTF8, + Dictionary, } enum Unit { @@ -100,10 +98,9 @@ enum Unit { /// table "t" registered, pointing at a parquet file made with /// `make_test_file` struct ContextWithParquet { - #[allow(dead_code)] /// temp file parquet data is written to. The file is cleaned up /// when dropped - file: NamedTempFile, + _file: NamedTempFile, provider: Arc, ctx: SessionContext, } @@ -217,7 +214,7 @@ impl ContextWithParquet { ctx.register_table("t", provider.clone()).unwrap(); Self { - file, + _file: file, provider, ctx, } @@ -744,6 +741,54 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } +fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch { + let keys = Int32Array::from_iter(0..strings.len() as i32); + let small_keys = Int16Array::from_iter(0..strings.len() as i16); + + let utf8_values = StringArray::from(strings.clone()); + let utf8_dict = DictionaryArray::new(keys.clone(), Arc::new(utf8_values)); + + let large_utf8 = LargeStringArray::from(strings.clone()); + let large_utf8_dict = DictionaryArray::new(keys.clone(), Arc::new(large_utf8)); + + let binary = + BinaryArray::from_iter_values(strings.iter().cloned().map(|v| v.as_bytes())); + let binary_dict = DictionaryArray::new(keys.clone(), Arc::new(binary)); + + let large_binary = + LargeBinaryArray::from_iter_values(strings.iter().cloned().map(|v| v.as_bytes())); + let large_binary_dict = DictionaryArray::new(keys.clone(), Arc::new(large_binary)); + + let int32 = Int32Array::from_iter_values(integers.clone()); + let int32_dict = DictionaryArray::new(small_keys.clone(), Arc::new(int32)); + + let int64 = Int64Array::from_iter_values(integers.iter().cloned().map(|v| v as i64)); + let int64_dict = DictionaryArray::new(keys.clone(), Arc::new(int64)); + + let uint32 = + UInt32Array::from_iter_values(integers.iter().cloned().map(|v| v as u32)); + let uint32_dict = DictionaryArray::new(small_keys.clone(), Arc::new(uint32)); + + let decimal = Decimal128Array::from_iter_values( + integers.iter().cloned().map(|v| (v * 100) as i128), + ) + .with_precision_and_scale(6, 2) + .unwrap(); + let decimal_dict = DictionaryArray::new(keys.clone(), Arc::new(decimal)); + + RecordBatch::try_from_iter(vec![ + ("utf8", Arc::new(utf8_dict) as _), + ("large_utf8", Arc::new(large_utf8_dict) as _), + ("binary", Arc::new(binary_dict) as _), + ("large_binary", Arc::new(large_binary_dict) as _), + ("int32", Arc::new(int32_dict) as _), + ("int64", Arc::new(int64_dict) as _), + ("uint32", Arc::new(uint32_dict) as _), + ("decimal", Arc::new(decimal_dict) as _), + ]) + .unwrap() +} + fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -965,6 +1010,13 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } + + Scenario::Dictionary => { + vec![ + make_dictionary_batch(vec!["a", "b", "c", "d", "e"], vec![0, 1, 2, 5, 6]), + make_dictionary_batch(vec!["f", "g", "h", "i", "j"], vec![0, 1, 3, 8, 9]), + ] + } } } diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 15efd4bcd9ddf..65bfd03401258 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -54,7 +54,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { }; let schema = ParquetFormat::default() - .infer_schema(state, &store, &[meta.clone()]) + .infer_schema(state, &store, std::slice::from_ref(&meta)) .await .unwrap(); @@ -64,6 +64,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { range: None, statistics: None, extensions: None, + metadata_size_hint: None, }; let df_schema = schema.clone().to_dfschema().unwrap(); @@ -149,8 +150,9 @@ async fn page_index_filter_one_col() { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - // 5.create filter date_string_col == 1; - let filter = col("date_string_col").eq(lit("01/01/09")); + // 5.create filter date_string_col == "01/01/09"`; + // Note this test doesn't apply type coercion so the literal must match the actual view type + let filter = col("date_string_col").eq(lit(ScalarValue::new_utf8view("01/01/09"))); let parquet_exec = get_parquet_exec(&state, filter).await; let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); let batch = results.next().await.unwrap().unwrap(); diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 536ac5414a9a8..d8ce2970bdf74 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -1323,3 +1323,215 @@ async fn test_row_group_with_null_values() { .test_row_group_prune() .await; } + +#[tokio::test] +async fn test_bloom_filter_utf8_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE utf8 = 'h'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE utf8 = 'ab'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_utf8 = 'b'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_utf8 = 'cd'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn test_bloom_filter_integer_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int32 = arrow_cast(8, 'Int32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int32 = arrow_cast(7, 'Int32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int64 = 8") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int64 = 7") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn test_bloom_filter_unsigned_integer_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE uint32 = arrow_cast(8, 'UInt32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE uint32 = arrow_cast(7, 'UInt32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn test_bloom_filter_binary_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE binary = arrow_cast('b', 'Binary')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE binary = arrow_cast('banana', 'Binary')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_binary = arrow_cast('d', 'LargeBinary')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query( + "SELECT * FROM t WHERE large_binary = arrow_cast('dre', 'LargeBinary')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +// Makes sense to enable (or at least try to) after +// https://github.com/apache/datafusion/issues/13821 +#[ignore] +#[tokio::test] +async fn test_bloom_filter_decimal_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE decimal = arrow_cast(8, 'Decimal128(6, 2)')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE decimal = arrow_cast(7, 'Decimal128(6, 2)')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs deleted file mode 100644 index bbf4dcd2b799d..0000000000000 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ /dev/null @@ -1,325 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tests for the physical optimizer - -use datafusion_common::config::ConfigOptions; -use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::aggregates::AggregateExec; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlan; -use std::sync::Arc; - -use datafusion::error::Result; -use datafusion::logical_expr::Operator; -use datafusion::prelude::SessionContext; -use datafusion::test_util::TestAggregate; -use datafusion_physical_plan::aggregates::PhysicalGroupBy; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::common; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::memory::MemoryExec; - -use arrow::array::Int32Array; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion_common::cast::as_int64_array; -use datafusion_physical_expr::expressions::{self, cast}; -use datafusion_physical_plan::aggregates::AggregateMode; - -/// Mock data using a MemoryExec which has an exact count statistic -fn mock_data() -> Result> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), - Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), - ], - )?; - - Ok(Arc::new(MemoryExec::try_new( - &[vec![batch]], - Arc::clone(&schema), - None, - )?)) -} - -/// Checks that the count optimization was applied and we still get the right result -async fn assert_count_optim_success( - plan: AggregateExec, - agg: TestAggregate, -) -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let plan: Arc = Arc::new(plan); - - let optimized = - AggregateStatistics::new().optimize(Arc::clone(&plan), state.config_options())?; - - // A ProjectionExec is a sign that the count optimization was applied - assert!(optimized.as_any().is::()); - - // run both the optimized and nonoptimized plan - let optimized_result = - common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; - let nonoptimized_result = - common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; - assert_eq!(optimized_result.len(), nonoptimized_result.len()); - - // and validate the results are the same and expected - assert_eq!(optimized_result.len(), 1); - check_batch(optimized_result.into_iter().next().unwrap(), &agg); - // check the non optimized one too to ensure types and names remain the same - assert_eq!(nonoptimized_result.len(), 1); - check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); - - Ok(()) -} - -fn check_batch(batch: RecordBatch, agg: &TestAggregate) { - let schema = batch.schema(); - let fields = schema.fields(); - assert_eq!(fields.len(), 1); - - let field = &fields[0]; - assert_eq!(field.name(), agg.column_name()); - assert_eq!(field.data_type(), &DataType::Int64); - // note that nullabiolity differs - - assert_eq!( - as_int64_array(batch.column(0)).unwrap().values(), - &[agg.expected_count()] - ); -} - -#[tokio::test] -async fn test_count_partial_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_with_nulls_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) -} - -#[tokio::test] -async fn test_count_with_nulls_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) -} diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 24e46b3ad97c7..b8a96f0f5a222 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! Tests for [`CombinePartialFinalAggregate`] physical optimizer rule +//! +//! Note these tests are not in the same module as the optimizer pass because +//! they rely on `ParquetExec` which is in the core crate. use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -84,7 +88,7 @@ fn parquet_exec(schema: &SchemaRef) -> Arc { fn partial_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -104,7 +108,7 @@ fn partial_aggregate_exec( fn final_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -130,11 +134,12 @@ fn count_expr( expr: Arc, name: &str, schema: &Schema, -) -> AggregateFunctionExpr { +) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) .schema(Arc::new(schema.clone())) .alias(name) .build() + .map(Arc::new) .unwrap() } @@ -218,6 +223,7 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { .schema(Arc::clone(&schema)) .alias("Sum(b)") .build() + .map(Arc::new) .unwrap(), ]; let groups: Vec<(Arc, String)> = diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs deleted file mode 100644 index 1b4c28d41d198..0000000000000 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ /dev/null @@ -1,490 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; -use datafusion_common::config::ConfigOptions; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::Partitioning; -use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan, ExecutionPlanProperties}; -use std::sync::Arc; - -#[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, -} -impl PartitionStream for DummyStreamPartition { - fn schema(&self) -> &SchemaRef { - &self.schema - } - fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - unreachable!() - } -} - -#[test] -fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero( -) -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; - let global_limit = global_limit_exec(streaming_table, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( -) -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; - let global_limit = global_limit_exec(streaming_table, 2, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( -) -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema.clone())?; - let repartition = repartition_exec(streaming_table)?; - let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter); - let local_limit = local_limit_exec(coalesce_batches, 5); - let coalesce_partitions = coalesce_partitions_exec(local_limit); - let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn pushes_global_limit_exec_through_projection_exec() -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema.clone())?; - let filter = filter_exec(schema.clone(), streaming_table)?; - let projection = projection_exec(schema, filter)?; - let global_limit = global_limit_exec(projection, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " GlobalLimitExec: skip=0, fetch=5", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( -) -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema.clone()).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); - let projection = projection_exec(schema, coalesce_batches)?; - let global_limit = global_limit_exec(projection, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn pushes_global_limit_into_multiple_fetch_plans() -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema.clone()).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); - let projection = projection_exec(schema.clone(), coalesce_batches)?; - let repartition = repartition_exec(projection)?; - let sort = sort_exec( - vec![PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }], - repartition, - ); - let spm = sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); - let global_limit = global_limit_exec(spm, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " SortPreservingMergeExec: [c1@0 ASC]", - " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "SortPreservingMergeExec: [c1@0 ASC], fetch=5", - " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions( -) -> datafusion_common::Result<()> { - let schema = create_schema(); - let streaming_table = streaming_table_exec(schema.clone())?; - let repartition = repartition_exec(streaming_table)?; - let filter = filter_exec(schema, repartition)?; - let coalesce_partitions = coalesce_partitions_exec(filter); - let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn merges_local_limit_with_local_limit() -> datafusion_common::Result<()> { - let schema = create_schema(); - let empty_exec = empty_exec(schema); - let child_local_limit = local_limit_exec(empty_exec, 10); - let parent_local_limit = local_limit_exec(child_local_limit, 20); - - let initial = get_plan_string(&parent_local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " LocalLimitExec: fetch=10", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; - - let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn merges_global_limit_with_global_limit() -> datafusion_common::Result<()> { - let schema = create_schema(); - let empty_exec = empty_exec(schema); - let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); - let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); - - let initial = get_plan_string(&parent_global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=10, fetch=20", - " GlobalLimitExec: skip=10, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; - - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn merges_global_limit_with_local_limit() -> datafusion_common::Result<()> { - let schema = create_schema(); - let empty_exec = empty_exec(schema); - let local_limit = local_limit_exec(empty_exec, 40); - let global_limit = global_limit_exec(local_limit, 20, Some(30)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=20, fetch=30", - " LocalLimitExec: fetch=40", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -#[test] -fn merges_local_limit_with_global_limit() -> datafusion_common::Result<()> { - let schema = create_schema(); - let empty_exec = empty_exec(schema); - let global_limit = global_limit_exec(empty_exec, 20, Some(30)); - let local_limit = local_limit_exec(global_limit, 20); - - let initial = get_plan_string(&local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " GlobalLimitExec: skip=20, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); - - let after_optimize = - LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; - - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); - - Ok(()) -} - -fn create_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - Field::new("c3", DataType::Int32, true), - ])) -} - -fn streaming_table_exec( - schema: SchemaRef, -) -> datafusion_common::Result> { - Ok(Arc::new(StreamingTableExec::try_new( - schema.clone(), - vec![Arc::new(DummyStreamPartition { schema }) as _], - None, - None, - true, - None, - )?)) -} - -fn global_limit_exec( - input: Arc, - skip: usize, - fetch: Option, -) -> Arc { - Arc::new(GlobalLimitExec::new(input, skip, fetch)) -} - -fn local_limit_exec( - input: Arc, - fetch: usize, -) -> Arc { - Arc::new(LocalLimitExec::new(input, fetch)) -} - -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) -} - -fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) -} - -fn projection_exec( - schema: SchemaRef, - input: Arc, -) -> datafusion_common::Result> { - Ok(Arc::new(ProjectionExec::try_new( - vec![ - (col("c1", schema.as_ref()).unwrap(), "c1".to_string()), - (col("c2", schema.as_ref()).unwrap(), "c2".to_string()), - (col("c3", schema.as_ref()).unwrap(), "c3".to_string()), - ], - input, - )?)) -} - -fn filter_exec( - schema: SchemaRef, - input: Arc, -) -> datafusion_common::Result> { - Ok(Arc::new(FilterExec::try_new( - Arc::new(BinaryExpr::new( - col("c3", schema.as_ref()).unwrap(), - Operator::Gt, - lit(0), - )), - input, - )?)) -} - -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec( - local_limit: Arc, -) -> Arc { - Arc::new(CoalescePartitionsExec::new(local_limit)) -} - -fn repartition_exec( - streaming_table: Arc, -) -> datafusion_common::Result> { - Ok(Arc::new(RepartitionExec::try_new( - streaming_table, - Partitioning::RoundRobinBatch(8), - )?)) -} - -fn empty_exec(schema: SchemaRef) -> Arc { - Arc::new(EmptyExec::new(schema)) -} diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index 042f6d622565c..46a56fe1fb756 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! Tests for the limited distinct aggregation optimizer rule - +//! Tests for [`LimitedDistinctAggregation`] physical optimizer rule +//! +//! Note these tests are not in the same module as the optimizer pass because +//! they rely on `ParquetExec` which is in the core crate. use super::test_util::{parquet_exec_with_sort, schema, trim_plan_display}; use std::sync::Arc; @@ -37,6 +39,7 @@ use datafusion_physical_expr::{ expressions::{cast, col}, PhysicalExpr, PhysicalSortExpr, }; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::{ limited_distinct_aggregation::LimitedDistinctAggregation, PhysicalOptimizerRule, }; @@ -347,10 +350,10 @@ fn test_has_aggregate_expression() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema, vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -375,7 +378,7 @@ fn test_has_filter() -> Result<()> { // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec let filter_expr = Some(expressions::binary( - expressions::col("a", &schema)?, + col("a", &schema)?, Operator::Gt, cast(expressions::lit(1u32), &schema, DataType::Int32)?, &schema, @@ -384,10 +387,10 @@ fn test_has_filter() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -407,10 +410,10 @@ fn test_has_filter() -> Result<()> { #[test] fn test_has_order_by() -> Result<()> { - let sort_key = vec![PhysicalSortExpr { - expr: expressions::col("a", &schema()).unwrap(), + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let source = parquet_exec_with_sort(vec![sort_key]); let schema = source.schema(); diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 4ec981bf2a741..1fac68e2505c5 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -mod aggregate_statistics; mod combine_partial_final_agg; -mod limit_pushdown; mod limited_distinct_aggregation; -mod test_util; +mod sanity_checker; +pub(crate) mod test_util; diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs new file mode 100644 index 0000000000000..538f0e443ddb5 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -0,0 +1,536 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for [`SanityCheckPlan`] physical optimizer rule +//! +//! Note these tests are not in the same module as the optimizer pass because +//! they rely on `ParquetExec` which is in the core crate. + +use crate::physical_optimizer::test_util::{ + BinaryTestCase, QueryCase, SourceType, UnaryTestCase, +}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_expr::JoinType; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_optimizer::test_utils::{ + bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, + repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, +}; +use datafusion_physical_optimizer::{sanity_checker::*, PhysicalOptimizerRule}; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c9", DataType::Int32, true)])) +} + +fn create_test_schema2() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])) +} + +/// Check if sanity checker should accept or reject plans. +fn assert_sanity_check(plan: &Arc, is_sane: bool) { + let sanity_checker = SanityCheckPlan::new(); + let opts = ConfigOptions::default(); + assert_eq!( + sanity_checker.optimize(plan.clone(), &opts).is_ok(), + is_sane + ); +} + +/// Check if the plan we created is as expected by comparing the plan +/// formatted as a string. +fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { + let plan_str = displayable(plan).indent(true).to_string(); + let actual_lines: Vec<&str> = plan_str.trim().lines().collect(); + assert_eq!(actual_lines, expected_lines); +} + +#[tokio::test] +async fn test_hash_left_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: false, + }; + + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + // Left join for bounded build side and unbounded probe side can generate + // both incremental matched rows and final non-matched rows. + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_hash_right_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_hash_inner_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: false, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "Join Error".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_hash_full_outer_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + // Full join for bounded build side and unbounded probe side can generate + // both incremental matched rows and final non-matched rows. + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_aggregate() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: AggregateExec".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_hash_partition() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT + c9, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 + FROM test + LIMIT 5".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: SortExec".to_string() + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_single_partition() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 + FROM test".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: SortExec".to_string() + }; + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_hash_cross_join() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test4 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(), + cases: vec![ + Arc::new(test1), + Arc::new(test2), + Arc::new(test3), + Arc::new(test4), + ], + error_operator: "operator: CrossJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +async fn test_analyzer() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: false, + }; + let case = QueryCase { + sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "Analyze Error".to_string(), + }; + + case.run().await?; + Ok(()) +} + +#[tokio::test] +/// Tests that plan is valid when the sort requirements are satisfied. +async fn test_bounded_window_agg_sort_requirement() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr_options( + "c9", + &source.schema(), + SortOptions { + descending: false, + nulls_first: false, + }, + )]; + let sort = sort_exec(sort_exprs.clone(), source); + let bw = bounded_window_exec("c9", sort_exprs, sort); + assert_plan(bw.as_ref(), vec![ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]" + ]); + assert_sanity_check(&bw, true); + Ok(()) +} + +#[tokio::test] +/// Tests that plan is invalid when the sort requirements are not satisfied. +async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr_options( + "c9", + &source.schema(), + SortOptions { + descending: false, + nulls_first: false, + }, + )]; + let bw = bounded_window_exec("c9", sort_exprs, source); + assert_plan(bw.as_ref(), vec![ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[0]" + ]); + // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. + assert_sanity_check(&bw, false); + Ok(()) +} + +#[tokio::test] +/// A valid when a single partition requirement +/// is satisfied. +async fn test_global_limit_single_partition() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = global_limit_exec(source); + + assert_plan( + limit.as_ref(), + vec![ + "GlobalLimitExec: skip=0, fetch=100", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&limit, true); + Ok(()) +} + +#[tokio::test] +/// An invalid plan when a single partition requirement +/// is not satisfied. +async fn test_global_limit_multi_partition() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = global_limit_exec(repartition_exec(source)); + + assert_plan( + limit.as_ref(), + vec![ + "GlobalLimitExec: skip=0, fetch=100", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Distribution requirement of the `GlobalLimitExec` is not satisfied. We expect to receive error during sanity check. + assert_sanity_check(&limit, false); + Ok(()) +} + +#[tokio::test] +/// A plan with no requirements should satisfy. +async fn test_local_limit() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = local_limit_exec(source); + + assert_plan( + limit.as_ref(), + vec![ + "LocalLimitExec: fetch=100", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&limit, true); + Ok(()) +} + +#[tokio::test] +/// Valid plan with multiple children satisfy both order and distribution. +async fn test_sort_merge_join_satisfied() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let source2 = memory_exec(&schema2); + let sort_opts = SortOptions::default(); + let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; + let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; + let left = sort_exec(sort_exprs1, source1); + let right = sort_exec(sort_exprs2, source2); + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(vec![right_jcol.clone()], 10), + )?); + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&smj, true); + Ok(()) +} + +#[tokio::test] +/// Invalid case when the order is not satisfied by the 2nd +/// child. +async fn test_sort_merge_join_order_missing() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let right = memory_exec(&schema2); + let sort_exprs1 = vec![sort_expr_options( + "c9", + &source1.schema(), + SortOptions::default(), + )]; + let left = sort_exec(sort_exprs1, source1); + // Missing sort of the right child here.. + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(vec![right_jcol.clone()], 10), + )?); + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Order requirement for the `SortMergeJoin` is not satisfied for right child. We expect to receive error during sanity check. + assert_sanity_check(&smj, false); + Ok(()) +} + +#[tokio::test] +/// Invalid case when the distribution is not satisfied by the 2nd +/// child. +async fn test_sort_merge_join_dist_missing() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let source2 = memory_exec(&schema2); + let sort_opts = SortOptions::default(); + let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; + let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; + let left = sort_exec(sort_exprs1, source1); + let right = sort_exec(sort_exprs2, source2); + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::RoundRobinBatch(10), + )?); + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + // Missing hash partitioning on right child. + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Distribution requirement for the `SortMergeJoin` is not satisfied for right child (has round-robin partitioning). We expect to receive error during sanity check. + assert_sanity_check(&smj, false); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/test_util.rs b/datafusion/core/tests/physical_optimizer/test_util.rs index 131b887c4ec72..ea4b80a7899c0 100644 --- a/datafusion/core/tests/physical_optimizer/test_util.rs +++ b/datafusion/core/tests/physical_optimizer/test_util.rs @@ -19,17 +19,22 @@ use std::sync::Arc; +use async_trait::async_trait; +use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; +use datafusion::error::Result; +use datafusion::prelude::{CsvReadOptions, SessionContext}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::{ listing::PartitionedFile, physical_plan::{FileScanConfig, ParquetExec}, }; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; /// create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( - output_ordering: Vec>, + output_ordering: Vec, ) -> Arc { ParquetExec::builder( FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) @@ -55,3 +60,117 @@ pub(crate) fn trim_plan_display(plan: &str) -> Vec<&str> { .filter(|s| !s.is_empty()) .collect() } + +async fn register_current_csv( + ctx: &SessionContext, + table_name: &str, + infinite: bool, +) -> Result<()> { + let testdata = datafusion::test_util::arrow_test_data(); + let schema = datafusion::test_util::aggr_test_schema(); + let path = format!("{testdata}/csv/aggregate_test_100.csv"); + + match infinite { + true => { + let source = FileStreamProvider::new_file(schema, path.into()); + let config = StreamConfig::new(Arc::new(source)); + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + } + false => { + ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) + .await?; + } + } + + Ok(()) +} + +#[derive(Eq, PartialEq, Debug)] +pub enum SourceType { + Unbounded, + Bounded, +} + +#[async_trait] +pub trait SqlTestCase { + async fn register_table(&self, ctx: &SessionContext) -> Result<()>; + fn expect_fail(&self) -> bool; +} + +/// [UnaryTestCase] is designed for single input [ExecutionPlan]s. +pub struct UnaryTestCase { + pub source_type: SourceType, + pub expect_fail: bool, +} + +#[async_trait] +impl SqlTestCase for UnaryTestCase { + async fn register_table(&self, ctx: &SessionContext) -> Result<()> { + let table_is_infinite = self.source_type == SourceType::Unbounded; + register_current_csv(ctx, "test", table_is_infinite).await?; + Ok(()) + } + + fn expect_fail(&self) -> bool { + self.expect_fail + } +} +/// [BinaryTestCase] is designed for binary input [ExecutionPlan]s. +pub struct BinaryTestCase { + pub source_types: (SourceType, SourceType), + pub expect_fail: bool, +} + +#[async_trait] +impl SqlTestCase for BinaryTestCase { + async fn register_table(&self, ctx: &SessionContext) -> Result<()> { + let left_table_is_infinite = self.source_types.0 == SourceType::Unbounded; + let right_table_is_infinite = self.source_types.1 == SourceType::Unbounded; + register_current_csv(ctx, "left", left_table_is_infinite).await?; + register_current_csv(ctx, "right", right_table_is_infinite).await?; + Ok(()) + } + + fn expect_fail(&self) -> bool { + self.expect_fail + } +} + +pub struct QueryCase { + pub sql: String, + pub cases: Vec>, + pub error_operator: String, +} + +impl QueryCase { + /// Run the test cases + pub async fn run(&self) -> Result<()> { + for case in &self.cases { + let ctx = SessionContext::new(); + case.register_table(&ctx).await?; + let error = if case.expect_fail() { + Some(&self.error_operator) + } else { + None + }; + self.run_case(ctx, error).await?; + } + Ok(()) + } + async fn run_case(&self, ctx: SessionContext, error: Option<&String>) -> Result<()> { + let dataframe = ctx.sql(self.sql.as_str()).await?; + let plan = dataframe.create_physical_plan().await; + if let Some(error) = error { + let plan_error = plan.unwrap_err(); + assert!( + plan_error.to_string().contains(error.as_str()), + "plan_error: {:?} doesn't contain message: {:?}", + plan_error, + error.as_str() + ); + } else { + assert!(plan.is_ok()) + } + Ok(()) + } +} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 1f10cb244e83c..7b1f349e15b5a 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { *actual[0].schema(), Schema::new(vec![Field::new_list( "array_agg(DISTINCT aggregate_test_100.c2)", - Field::new("item", DataType::UInt32, true), + Field::new_list_field(DataType::UInt32, true), true ),]) ); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 39fd492786bc7..5fb0b9852641b 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -566,7 +566,7 @@ async fn csv_explain_verbose_plans() { #[tokio::test] async fn explain_analyze_runs_optimizers(#[values("*", "1")] count_expr: &str) { // repro for https://github.com/apache/datafusion/issues/917 - // where EXPLAIN ANALYZE was not correctly running optiimizer + // where EXPLAIN ANALYZE was not correctly running optimizer let ctx = SessionContext::new(); register_alltypes_parquet(&ctx).await; diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index addabc8a36127..fab92c0f9c2bf 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -33,7 +33,7 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; @@ -101,7 +101,7 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index dc9d047860213..03c4ad7c013ec 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -32,6 +32,7 @@ use datafusion::prelude::*; use datafusion::test_util; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; +use datafusion_common::utils::get_available_parallelism; use datafusion_common::{assert_contains, assert_not_contains}; use object_store::path::Path; use std::fs::File; @@ -65,7 +66,7 @@ pub mod select; mod sql_api; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let df = ctx .sql(&format!( @@ -103,7 +104,7 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -195,7 +196,7 @@ fn populate_csv_partitions( Ok(schema) } -/// Specialised String representation +/// Specialized String representation fn col_str(column: &ArrayRef, row_index: usize) -> String { // NullArray::is_null() does not work on NullArray. // can remove check for DataType::Null when @@ -227,7 +228,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { } async fn register_alltypes_parquet(ctx: &SessionContext) { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), @@ -259,7 +260,7 @@ impl ExplainNormalizer { // convert things like partitioning=RoundRobinBatch(16) // to partitioning=RoundRobinBatch(NUM_CORES) - let needle = format!("RoundRobinBatch({})", num_cpus::get()); + let needle = format!("RoundRobinBatch({})", get_available_parallelism()); replacements.push((needle, "RoundRobinBatch(NUM_CORES)".to_string())); Self { replacements } diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 0c983bd732d03..9eb09cb883124 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -47,7 +47,6 @@ use bytes::Bytes; use chrono::{TimeZone, Utc}; use datafusion_expr::{col, lit, Expr, Operator}; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; -use datafusion_physical_expr::PhysicalExpr; use futures::stream::{self, BoxStream}; use object_store::{ path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, @@ -97,7 +96,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> { assert!(pred.as_any().is::()); let pred = pred.as_any().downcast_ref::().unwrap(); - assert_eq!(pred, expected.as_any()); + assert_eq!(pred, expected.as_ref()); Ok(()) } @@ -184,7 +183,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { max_limit += 1; let last_batch = results .last() - .expect("There shouled be at least one record batch returned"); + .expect("There should be at least one record batch returned"); let last_row_idx = last_batch.num_rows() - 1; let mut min_limit = match ScalarValue::try_from_array(last_batch.column(0), last_row_idx)? { @@ -219,10 +218,11 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), - }; + assert!( + matches!(s.data_type(), DataType::Dictionary(_, v) if v.as_ref() == &DataType::Utf8), + "Expected month as Dict(_, Utf8) found {s:?}" + ); + let month = s.try_as_str().flatten().unwrap(); let sql_on_partition_boundary = format!( "SELECT month from t where month = '{}' LIMIT {}", @@ -242,15 +242,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new_with_config( @@ -569,7 +560,7 @@ async fn parquet_overlapping_columns() -> Result<()> { assert!( result.is_err(), - "Dupplicate qualified name should raise error" + "Duplicate qualified name should raise error" ); Ok(()) } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index dd660512f3469..6e81bf6410c11 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -57,7 +57,6 @@ async fn test_named_query_parameters() -> Result<()> { let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?; // sql to statement then to logical plan with parameters - // c1 defined as UINT32, c2 defined as UInt64 let results = ctx .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") .await? @@ -106,9 +105,9 @@ async fn test_prepare_statement() -> Result<()> { let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?; // sql to statement then to prepare logical plan with parameters - // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64 - let dataframe = - ctx.sql("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?; + let dataframe = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1") + .await?; // prepare logical plan to logical plan without parameters let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; @@ -156,7 +155,7 @@ async fn prepared_statement_type_coercion() -> Result<()> { ("unsigned", Arc::new(unsigned_ints) as ArrayRef), ])?; ctx.register_batch("test", batch)?; - let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") + let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") .await? .with_param_values(vec![ ScalarValue::from(1_i64), @@ -176,27 +175,6 @@ async fn prepared_statement_type_coercion() -> Result<()> { Ok(()) } -#[tokio::test] -async fn prepared_statement_invalid_types() -> Result<()> { - let ctx = SessionContext::new(); - let signed_ints: Int32Array = vec![-1, 0, 1].into(); - let unsigned_ints: UInt64Array = vec![1, 2, 3].into(); - let batch = RecordBatch::try_from_iter(vec![ - ("signed", Arc::new(signed_ints) as ArrayRef), - ("unsigned", Arc::new(unsigned_ints) as ArrayRef), - ])?; - ctx.register_batch("test", batch)?; - let results = ctx - .sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1") - .await? - .with_param_values(vec![ScalarValue::from("1")]); - assert_eq!( - results.unwrap_err().strip_backtrace(), - "Error during planning: Expected parameter of type Int32, got Utf8 at index 0" - ); - Ok(()) -} - #[tokio::test] async fn test_parameter_type_coercion() -> Result<()> { let ctx = SessionContext::new(); @@ -251,6 +229,98 @@ async fn test_parameter_invalid_types() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_positional_parameter_not_bound() -> Result<()> { + let ctx = SessionContext::new(); + let signed_ints: Int32Array = vec![-1, 0, 1].into(); + let unsigned_ints: UInt64Array = vec![1, 2, 3].into(); + let batch = RecordBatch::try_from_iter(vec![ + ("signed", Arc::new(signed_ints) as ArrayRef), + ("unsigned", Arc::new(unsigned_ints) as ArrayRef), + ])?; + ctx.register_batch("test", batch)?; + + let query = "SELECT signed, unsigned FROM test \ + WHERE $1 >= signed AND signed <= $2 \ + AND unsigned <= $3 AND unsigned = $4"; + + let results = ctx.sql(query).await?.collect().await; + + assert_eq!( + results.unwrap_err().strip_backtrace(), + "Execution error: Placeholder '$1' was not provided a value for execution." + ); + + let results = ctx + .sql(query) + .await? + .with_param_values(vec![ + ScalarValue::from(4_i32), + ScalarValue::from(-1_i64), + ScalarValue::from(2_i32), + ScalarValue::from("1"), + ])? + .collect() + .await?; + + let expected = [ + "+--------+----------+", + "| signed | unsigned |", + "+--------+----------+", + "| -1 | 1 |", + "+--------+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_named_parameter_not_bound() -> Result<()> { + let ctx = SessionContext::new(); + let signed_ints: Int32Array = vec![-1, 0, 1].into(); + let unsigned_ints: UInt64Array = vec![1, 2, 3].into(); + let batch = RecordBatch::try_from_iter(vec![ + ("signed", Arc::new(signed_ints) as ArrayRef), + ("unsigned", Arc::new(unsigned_ints) as ArrayRef), + ])?; + ctx.register_batch("test", batch)?; + + let query = "SELECT signed, unsigned FROM test \ + WHERE $foo >= signed AND signed <= $bar \ + AND unsigned <= $baz AND unsigned = $str"; + + let results = ctx.sql(query).await?.collect().await; + + assert_eq!( + results.unwrap_err().strip_backtrace(), + "Execution error: Placeholder '$foo' was not provided a value for execution." + ); + + let results = ctx + .sql(query) + .await? + .with_param_values(vec![ + ("foo", ScalarValue::from(4_i32)), + ("bar", ScalarValue::from(-1_i64)), + ("baz", ScalarValue::from(2_i32)), + ("str", ScalarValue::from("1")), + ])? + .collect() + .await?; + + let expected = [ + "+--------+----------+", + "| signed | unsigned |", + "+--------+----------+", + "| -1 | 1 |", + "+--------+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + #[tokio::test] async fn test_version_function() { let expected_version = format!( diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 48f4a66b65dcf..034d6fa23d9cb 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -113,6 +113,30 @@ async fn unsupported_statement_returns_error() { ctx.sql_with_options(sql, options).await.unwrap(); } +// Disallow PREPARE and EXECUTE statements if `allow_statements` is false +#[tokio::test] +async fn disable_prepare_and_execute_statement() { + let ctx = SessionContext::new(); + + let prepare_sql = "PREPARE plan(INT) AS SELECT $1"; + let execute_sql = "EXECUTE plan(1)"; + let options = SQLOptions::new().with_allow_statements(false); + let df = ctx.sql_with_options(prepare_sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: Prepare" + ); + let df = ctx.sql_with_options(execute_sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: Execute" + ); + + let options = options.with_allow_statements(true); + ctx.sql_with_options(prepare_sql, options).await.unwrap(); + ctx.sql_with_options(execute_sql, options).await.unwrap(); +} + #[tokio::test] async fn empty_statement_returns_error() { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index b99bc26800449..252d76d0f9d92 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -229,9 +229,6 @@ async fn tpcds_logical_q40() -> Result<()> { } #[tokio::test] -#[ignore] -// Optimizer rule 'scalar_subquery_to_join' failed: Optimizing disjunctions not supported! -// issue: https://github.com/apache/datafusion/issues/5368 async fn tpcds_logical_q41() -> Result<()> { create_logical_plan(41).await } @@ -571,7 +568,6 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -697,7 +693,6 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -728,8 +723,6 @@ async fn tpcds_physical_q40() -> Result<()> { create_physical_plan(40).await } -#[ignore] -// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..) #[tokio::test] async fn tpcds_physical_q41() -> Result<()> { create_physical_plan(41).await @@ -750,7 +743,6 @@ async fn tpcds_physical_q44() -> Result<()> { create_physical_plan(44).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q45() -> Result<()> { create_physical_plan(45).await diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index ff14fa0be3fb6..aa531632c60b7 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -26,7 +26,10 @@ use datafusion::{ use datafusion_catalog::{Session, TableProvider}; use datafusion_expr::{dml::InsertOp, Expr, TableType}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_plan::{DisplayAs, ExecutionMode, ExecutionPlan, PlanProperties}; +use datafusion_physical_plan::{ + execution_plan::{Boundedness, EmissionType}, + DisplayAs, ExecutionPlan, PlanProperties, +}; #[tokio::test] async fn insert_operation_is_passed_correctly_to_table_provider() { @@ -122,15 +125,14 @@ struct TestInsertExec { impl TestInsertExec { fn new(op: InsertOp) -> Self { - let eq_properties = EquivalenceProperties::new(make_count_schema()); - let plan_properties = PlanProperties::new( - eq_properties, - Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, - ); Self { op, - plan_properties, + plan_properties: PlanProperties::new( + EquivalenceProperties::new(make_count_schema()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), } } } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 1e0d3d9d514e8..bf32eef3b011c 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -19,6 +19,7 @@ //! user defined aggregate functions use std::hash::{DefaultHasher, Hash, Hasher}; +use std::mem::{size_of, size_of_val}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -723,7 +724,7 @@ impl Accumulator for FirstSelector { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // cast argumets to the appropriate type (DataFusion will type + // cast arguments to the appropriate type (DataFusion will type // check these based on the declared allowed input types) let v = as_primitive_array::(&values[0])?; let t = as_primitive_array::(&values[1])?; @@ -747,7 +748,7 @@ impl Accumulator for FirstSelector { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -816,7 +817,7 @@ impl Accumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } fn state(&mut self) -> Result> { @@ -864,6 +865,6 @@ impl GroupsAccumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index d9965c20890a9..1b8e742fef49f 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -68,9 +68,6 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use async_trait::async_trait; -use futures::{Stream, StreamExt}; - use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::{as_int64_array, as_string_array}, @@ -81,15 +78,14 @@ use datafusion::{ runtime_env::RuntimeEnv, }, logical_expr::{ - Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode, + Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }, optimizer::{OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, physical_plan::{ - DisplayAs, DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, - Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, - Statistics, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }, physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, @@ -97,10 +93,13 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::tree_node::replace_sort_expression; -use datafusion_expr::{Projection, Scalar, SortExpr}; +use datafusion_expr::{FetchType, Projection, Scalar, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; + +use async_trait::async_trait; +use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. @@ -361,28 +360,28 @@ impl OptimizerRule for TopKOptimizerRule { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - if let LogicalPlan::Limit(Limit { - fetch: Some(fetch), - input, + let LogicalPlan::Limit(ref limit) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort(Sort { + ref expr, + ref input, .. - }) = &plan + }) = limit.input.as_ref() { - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = **input - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: *fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - }), - }))); - } + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); } } @@ -440,7 +439,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { Ok(Self { k: self.k, input: inputs.swap_remove(0), - expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), + expr: self.expr.with_expr(exprs.swap_remove(0)), }) } @@ -483,7 +482,7 @@ impl ExtensionPlanner for TopKPlanner { /// code is not general and is meant as an illustration only struct TopKExec { input: Arc, - /// The maxium number of values + /// The maximum number of values k: usize, cache: PlanProperties, } @@ -496,12 +495,11 @@ impl TopKExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - PlanProperties::new( - eq_properties, + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } @@ -513,11 +511,7 @@ impl Debug for TopKExec { } impl DisplayAs for TopKExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "TopKExec: k={}", self.k) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a6e7847d463b1..f5b4fdf37bfe4 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::array::as_string_array; use arrow::compute::kernels::numeric::add; use arrow_array::builder::BooleanBuilder; use arrow_array::cast::AsArray; @@ -26,10 +27,6 @@ use arrow_array::{ Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow_schema::{DataType, Field, Schema}; -use parking_lot::Mutex; -use regex::Regex; -use sqlparser::ast::Ident; - use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -37,7 +34,8 @@ use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, + not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, HashMap, Result, + ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ @@ -46,6 +44,10 @@ use datafusion_expr::{ Volatility, }; use datafusion_functions_nested::range::range_udf; +use parking_lot::Mutex; +use regex::Regex; +use sqlparser::ast::Ident; +use sqlparser::tokenizer::Span; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -207,11 +209,11 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { Ok(self.return_type.clone()) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - not_impl_err!("{} function does not accept arguments", self.name()) - } - - fn invoke_no_args(&self, _number_rows: usize) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { Ok(ColumnarValue::from(ScalarValue::Int32(Some(100)))) } } @@ -249,7 +251,7 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { .err() .unwrap() .to_string(), - "UDF returned a different number of rows than expected" + "Internal error: UDF buggy_func returned a different number of rows than expected. Expected: 2, Got: 1" ); Ok(()) } @@ -483,6 +485,183 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { Ok(()) } +/// Volatile UDF that should append a different value to each row +#[derive(Debug)] +struct AddIndexToStringVolatileScalarUDF { + name: String, + signature: Signature, + return_type: DataType, +} + +impl AddIndexToStringVolatileScalarUDF { + fn new() -> Self { + Self { + name: "add_index_to_string".to_string(), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + return_type: DataType::Utf8, + } + } +} + +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + let answer = match &args[0] { + // When called with static arguments, the result is returned as an array. + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(value)) => { + let mut answer = vec![]; + for index in 1..=number_rows { + // When calling a function with immutable arguments, the result is returned with ")". + // Example: SELECT add_index_to_string('const_value') FROM table; + answer.push(index.to_string() + ") " + value); + } + answer + } + _ => return exec_err!("Unexpected scalar value"), + }, + // The result is returned as an array when called with dynamic arguments. + ColumnarValue::Array(array) => { + let string_array = as_string_array(array); + let mut counter = HashMap::<&str, u64>::new(); + string_array + .iter() + .map(|value| { + let value = value.expect("Unexpected null"); + let index = counter.get(value).unwrap_or(&0) + 1; + counter.insert(value, index); + + // When calling a function with mutable arguments, the result is returned with ".". + // Example: SELECT add_index_to_string(table.value) FROM table; + index.to_string() + ". " + value + }) + .collect() + } + }; + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) + } +} + +#[tokio::test] +async fn volatile_scalar_udf_with_params() -> Result<()> { + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters + .await?; + let expected = [ + "+-----------+", + "| str |", + "+-----------+", + "| 1. test_1 |", + "| 2. test_1 |", + "| 3. test_1 |", + "| 1. test_2 |", + "| 2. test_2 |", + "| 4. test_1 |", + "| 3. test_2 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters + .await?; + let expected = [ + "+---------+", + "| str |", + "+---------+", + "| 1) test |", + "| 2) test |", + "| 3) test |", + "| 4) test |", + "| 5) test |", + "| 6) test |", + "| 7) test |", + "+---------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters + .await?; + let expected = [ + "+---------------+", + "| str |", + "+---------------+", + "| 1) test_value |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + } + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") + .await?; + let expected = [ + "+-----------+", // + "| str |", // + "+-----------+", // + "| 1. test_1 |", // + "| 2. test_1 |", // + "| 3. test_1 |", // + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + } + Ok(()) +} + #[derive(Debug)] struct CastToI64UDF { signature: Signature, @@ -539,7 +718,11 @@ impl ScalarUDFImpl for CastToI64UDF { Ok(ExprSimplifyResult::Simplified(new_expr)) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { unimplemented!("Function should have been simplified prior to evaluation") } } @@ -669,7 +852,11 @@ impl ScalarUDFImpl for TakeUDF { } // The actual implementation - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let take_idx = match &args[2] { ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Int64(Some(v)) if v < &2 => *v as usize, @@ -760,11 +947,11 @@ struct ScalarFunctionWrapper { name: String, expr: Expr, signature: Signature, - return_type: arrow_schema::DataType, + return_type: DataType, } impl ScalarUDFImpl for ScalarFunctionWrapper { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -772,21 +959,19 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } - fn return_type( - &self, - _arg_types: &[arrow_schema::DataType], - ) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(self.return_type.clone()) } - fn invoke( + fn invoke_batch( &self, - _args: &[datafusion_expr::ColumnarValue], - ) -> Result { + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { internal_err!("This function should not get invoked!") } @@ -866,10 +1051,7 @@ impl TryFrom for ScalarFunctionWrapper { .into_iter() .map(|a| a.data_type) .collect(), - definition - .params - .behavior - .unwrap_or(datafusion_expr::Volatility::Volatile), + definition.params.behavior.unwrap_or(Volatility::Volatile), ), }) } @@ -1012,6 +1194,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( name: Some(Ident { value: "name".into(), quote_style: None, + span: Span::empty(), }), data_type: DataType::Utf8, default_expr: None, @@ -1021,6 +1204,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( language: Some(Ident { value: "plrust".into(), quote_style: None, + span: Span::empty(), }), behavior: None, function_body: Some(lit(body)), @@ -1073,7 +1257,11 @@ impl ScalarUDFImpl for MyRegexUdf { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args { [ColumnarValue::Scalar(scalar)] => match scalar.value() { ScalarValue::Utf8(value) => Ok(ColumnarValue::from( @@ -1175,7 +1363,7 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -1187,7 +1375,7 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { } async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index bf5c822be240c..ef999c1420eea 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -21,7 +21,6 @@ use arrow::csv::ReaderBuilder; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::TaskContext; @@ -29,6 +28,7 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion_catalog::Session; +use datafusion_catalog::TableFunctionImpl; use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, Scalar, TableType}; use std::fs::File; @@ -231,8 +231,8 @@ fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec i64 { } /// returns an array of num_rows that has the number of odd values in `arr` -fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { +fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } + +#[derive(Debug)] +struct VariadicWindowUDF { + signature: Signature, +} + +impl VariadicWindowUDF { + fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(0), + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for VariadicWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "variadic_window_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _: PartitionEvaluatorArgs, + ) -> Result> { + unimplemented!("unnecessary for testing"); + } + + fn field(&self, _: WindowUDFFieldArgs) -> Result { + unimplemented!("unnecessary for testing"); + } +} + +#[test] +// Fixes: default implementation of `WindowUDFImpl::expressions` +// returns all input expressions to the user-defined window +// function unmodified. +// +// See: https://github.com/apache/datafusion/pull/13169 +fn test_default_expressions() -> Result<()> { + let udwf = WindowUDF::from(VariadicWindowUDF::new()); + + let field_a = Field::new("a", DataType::Int32, false); + let field_b = Field::new("b", DataType::Float32, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Schema::new(vec![field_a, field_b, field_c]); + + let test_cases = vec![ + // + // Zero arguments + // + vec![], + // + // Single argument + // + vec![col("a", &schema)?], + vec![lit(1)], + // + // Two arguments + // + vec![col("a", &schema)?, col("b", &schema)?], + vec![col("a", &schema)?, lit(2)], + vec![lit(false), col("a", &schema)?], + // + // Three arguments + // + vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?], + vec![col("a", &schema)?, col("b", &schema)?, lit(false)], + vec![col("a", &schema)?, lit(0.5), col("c", &schema)?], + vec![lit(3), col("b", &schema)?, col("c", &schema)?], + ]; + + for input_exprs in &test_cases { + let input_types = input_exprs + .iter() + .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .collect::>(); + let expr_args = ExpressionArgs::new(input_exprs, &input_types); + + let ret_exprs = udwf.expressions(expr_args); + + // Verify same number of input expressions are returned + assert_eq!( + input_exprs.len(), + ret_exprs.len(), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + + // Compares each returned expression with original input expressions + for (expected, actual) in input_exprs.iter().zip(&ret_exprs) { + assert_eq!( + format!("{expected:?}"), + format!("{actual:?}"), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + } + } + Ok(()) +} diff --git a/dev/release/run-rat.sh b/datafusion/doc/Cargo.toml old mode 100755 new mode 100644 similarity index 51% rename from dev/release/run-rat.sh rename to datafusion/doc/Cargo.toml index 94fa55fbe0974..c188bcb2a5352 --- a/dev/release/run-rat.sh +++ b/datafusion/doc/Cargo.toml @@ -1,5 +1,3 @@ -#!/bin/bash -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,28 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - -RAT_VERSION=0.13 - -# download apache rat -if [ ! -f apache-rat-${RAT_VERSION}.jar ]; then - curl -s https://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar > apache-rat-${RAT_VERSION}.jar -fi - -RAT="java -jar apache-rat-${RAT_VERSION}.jar -x " -RELEASE_DIR=$(cd "$(dirname "$BASH_SOURCE")"; pwd) +[package] +name = "datafusion-doc" +description = "Documentation module for DataFusion query engine" +keywords = ["datafusion", "query", "sql"] +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } -# generate the rat report -$RAT $1 > rat.txt -python $RELEASE_DIR/check-rat-report.py $RELEASE_DIR/rat_exclude_files.txt rat.txt > filtered_rat.txt -cat filtered_rat.txt -UNAPPROVED=`cat filtered_rat.txt | grep "NOT APPROVED" | wc -l` +[lints] +workspace = true -if [ "0" -eq "${UNAPPROVED}" ]; then - echo "No unapproved licenses" -else - echo "${UNAPPROVED} unapproved licences. Check rat report: rat.txt" - exit 1 -fi +[lib] +name = "datafusion_doc" +path = "src/lib.rs" diff --git a/datafusion/doc/LICENSE.txt b/datafusion/doc/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/doc/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/doc/NOTICE.txt b/datafusion/doc/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/doc/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs new file mode 100644 index 0000000000000..6940a8ef3ca26 --- /dev/null +++ b/datafusion/doc/src/lib.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[allow(rustdoc::broken_intra_doc_links)] +/// Documentation for use by [`ScalarUDFImpl`](ScalarUDFImpl), +/// [`AggregateUDFImpl`](AggregateUDFImpl) and [`WindowUDFImpl`](WindowUDFImpl) functions. +/// +/// See the [`DocumentationBuilder`] to create a new [`Documentation`] struct. +/// +/// The DataFusion [SQL function documentation] is automatically generated from these structs. +/// The name of the udf will be pulled from the [`ScalarUDFImpl::name`](ScalarUDFImpl::name), +/// [`AggregateUDFImpl::name`](AggregateUDFImpl::name) or [`WindowUDFImpl::name`](WindowUDFImpl::name) +/// function as appropriate. +/// +/// All strings in the documentation are required to be +/// in [markdown format](https://www.markdownguide.org/basic-syntax/). +/// +/// Currently, documentation only supports a single language +/// thus all text should be in English. +/// +/// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html +#[derive(Debug, Clone)] +pub struct Documentation { + /// The section in the documentation where the UDF will be documented + pub doc_section: DocSection, + /// The description for the UDF + pub description: String, + /// A brief example of the syntax. For example "ascii(str)" + pub syntax_example: String, + /// A sql example for the UDF, usually in the form of a sql prompt + /// query and output. It is strongly recommended to provide an + /// example for anything but the most basic UDF's + pub sql_example: Option, + /// Arguments for the UDF which will be displayed in array order. + /// Left member of a pair is the argument name, right is a + /// description for the argument + pub arguments: Option>, + /// A list of alternative syntax examples for a function + pub alternative_syntax: Option>, + /// Related functions if any. Values should match the related + /// udf's name exactly. Related udf's must be of the same + /// UDF type (scalar, aggregate or window) for proper linking to + /// occur + pub related_udfs: Option>, +} + +impl Documentation { + /// Returns a new [`DocumentationBuilder`] with no options set. + pub fn builder( + doc_section: DocSection, + description: impl Into, + syntax_example: impl Into, + ) -> DocumentationBuilder { + DocumentationBuilder::new_with_details(doc_section, description, syntax_example) + } + + /// Output the `Documentation` struct in form of custom Rust documentation attributes + /// It is useful to semi automate during tmigration of UDF documentation + /// generation from code based to attribute based and can be safely removed after + pub fn to_doc_attribute(&self) -> String { + let mut result = String::new(); + + result.push_str("#[user_doc("); + // Doc Section + result.push_str( + format!( + "\n doc_section({}label = \"{}\"{}),", + if !self.doc_section.include { + "include = \"false\", " + } else { + "" + }, + self.doc_section.label, + self.doc_section + .description + .map(|s| format!(", description = \"{}\"", s)) + .unwrap_or_default(), + ) + .as_ref(), + ); + + // Description + result.push_str(format!("\n description=\"{}\",", self.description).as_ref()); + // Syntax Example + result.push_str( + format!("\n syntax_example=\"{}\",", self.syntax_example).as_ref(), + ); + // SQL Example + result.push_str( + &self + .sql_example + .clone() + .map(|s| format!("\n sql_example = r#\"{}\"#,", s)) + .unwrap_or_default(), + ); + + let st_arg_token = " expression to operate on. Can be a constant, column, or function, and any combination of operators."; + // Standard Arguments + if let Some(args) = self.arguments.clone() { + args.iter().for_each(|(name, value)| { + if value.contains(st_arg_token) { + if name.starts_with("The ") { + result.push_str(format!("\n standard_argument(\n name = \"{}\"),", name).as_ref()); + } else { + result.push_str(format!("\n standard_argument(\n name = \"{}\",\n prefix = \"{}\"\n ),", name, value.replace(st_arg_token, "")).as_ref()); + } + } + }); + } + + // Arguments + if let Some(args) = self.arguments.clone() { + args.iter().for_each(|(name, value)| { + if !value.contains(st_arg_token) { + result.push_str(format!("\n argument(\n name = \"{}\",\n description = \"{}\"\n ),", name, value).as_ref()); + } + }); + } + + if let Some(alt_syntax) = self.alternative_syntax.clone() { + alt_syntax.iter().for_each(|syntax| { + result.push_str( + format!("\n alternative_syntax = \"{}\",", syntax).as_ref(), + ); + }); + } + + // Related UDFs + if let Some(related_udf) = self.related_udfs.clone() { + related_udf.iter().for_each(|udf| { + result + .push_str(format!("\n related_udf(name = \"{}\"),", udf).as_ref()); + }); + } + + result.push_str("\n)]"); + + result + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DocSection { + /// True to include this doc section in the public + /// documentation, false otherwise + pub include: bool, + /// A display label for the doc section. For example: "Math Expressions" + pub label: &'static str, + /// An optional description for the doc section + pub description: Option<&'static str>, +} + +impl Default for DocSection { + /// Returns a "default" Doc section. + /// + /// This is suitable for user defined functions that do not appear in the + /// DataFusion documentation. + fn default() -> Self { + Self { + include: true, + label: "Default", + description: None, + } + } +} + +/// A builder for [`Documentation`]'s. +/// +/// Example: +/// +/// ```rust +/// +/// # fn main() { +/// use datafusion_doc::{DocSection, Documentation}; +/// let doc_section = DocSection { +/// include: true, +/// label: "Display Label", +/// description: None, +/// }; +/// +/// let documentation = Documentation::builder(doc_section, "Add one to an int32".to_owned(), "add_one(2)".to_owned()) +/// .with_argument("arg_1", "The int32 number to add one to") +/// .build(); +/// # } +pub struct DocumentationBuilder { + pub doc_section: DocSection, + pub description: String, + pub syntax_example: String, + pub sql_example: Option, + pub arguments: Option>, + pub alternative_syntax: Option>, + pub related_udfs: Option>, +} + +impl DocumentationBuilder { + #[allow(clippy::new_without_default)] + #[deprecated( + since = "44.0.0", + note = "please use `DocumentationBuilder::new_with_details` instead" + )] + pub fn new() -> Self { + Self::new_with_details(DocSection::default(), "", "") + } + + /// Creates a new [`DocumentationBuilder`] with all required fields + pub fn new_with_details( + doc_section: DocSection, + description: impl Into, + syntax_example: impl Into, + ) -> Self { + Self { + doc_section, + description: description.into(), + syntax_example: syntax_example.into(), + sql_example: None, + arguments: None, + alternative_syntax: None, + related_udfs: None, + } + } + + pub fn with_doc_section(mut self, doc_section: DocSection) -> Self { + self.doc_section = doc_section; + self + } + + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = description.into(); + self + } + + pub fn with_syntax_example(mut self, syntax_example: impl Into) -> Self { + self.syntax_example = syntax_example.into(); + self + } + + pub fn with_sql_example(mut self, sql_example: impl Into) -> Self { + self.sql_example = Some(sql_example.into()); + self + } + + /// Adds documentation for a specific argument to the documentation. + /// + /// Arguments are displayed in the order they are added. + pub fn with_argument( + mut self, + arg_name: impl Into, + arg_description: impl Into, + ) -> Self { + let mut args = self.arguments.unwrap_or_default(); + args.push((arg_name.into(), arg_description.into())); + self.arguments = Some(args); + self + } + + /// Add a standard "expression" argument to the documentation + /// + /// The argument is rendered like below if Some() is passed through: + /// + /// ```text + /// : + /// expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` + /// + /// The argument is rendered like below if None is passed through: + /// + /// ```text + /// : + /// The expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` + pub fn with_standard_argument( + self, + arg_name: impl Into, + expression_type: Option<&str>, + ) -> Self { + let description = format!( + "{} expression to operate on. Can be a constant, column, or function, and any combination of operators.", + expression_type.unwrap_or("The") + ); + self.with_argument(arg_name, description) + } + + pub fn with_alternative_syntax(mut self, syntax_name: impl Into) -> Self { + let mut alternative_syntax_array = self.alternative_syntax.unwrap_or_default(); + alternative_syntax_array.push(syntax_name.into()); + self.alternative_syntax = Some(alternative_syntax_array); + self + } + + pub fn with_related_udf(mut self, related_udf: impl Into) -> Self { + let mut related = self.related_udfs.unwrap_or_default(); + related.push(related_udf.into()); + self.related_udfs = Some(related); + self + } + + /// Build the documentation from provided components + /// + /// Panics if `doc_section`, `description` or `syntax_example` is not set + pub fn build(self) -> Documentation { + let Self { + doc_section, + description, + syntax_example, + sql_example, + arguments, + alternative_syntax, + related_udfs, + } = self; + + Documentation { + doc_section, + description, + syntax_example, + sql_example, + arguments, + alternative_syntax, + related_udfs, + } + } +} diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index fb2e7e914fe5b..bb86868a82146 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -37,15 +37,16 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } -chrono = { workspace = true } dashmap = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } futures = { workspace = true } -hashbrown = { workspace = true } log = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } rand = { workspace = true } tempfile = { workspace = true } url = { workspace = true } + +[dev-dependencies] +chrono = { workspace = true } diff --git a/datafusion/execution/LICENSE.txt b/datafusion/execution/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/execution/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/execution/NOTICE.txt b/datafusion/execution/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/execution/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/execution/src/cache/mod.rs b/datafusion/execution/src/cache/mod.rs index da19bff5658af..4271bebd0b326 100644 --- a/datafusion/execution/src/cache/mod.rs +++ b/datafusion/execution/src/cache/mod.rs @@ -22,7 +22,6 @@ pub mod cache_unit; /// This interface does not get `mut` references and thus has to handle its own /// locking via internal mutability. It can be accessed via multiple concurrent queries /// during planning and execution. - pub trait CacheAccessor: Send + Sync { // Extra info but not part of the cache key or cache value. type Extra: Clone; diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cede75d21ca47..53646dc5b468e 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -432,6 +432,20 @@ impl SessionConfig { self } + /// Enables or disables the enforcement of batch size in joins + pub fn with_enforce_batch_size_in_joins( + mut self, + enforce_batch_size_in_joins: bool, + ) -> Self { + self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self + } + + /// Returns true if the joins will be enforced to output batches of the configured size + pub fn enforce_batch_size_in_joins(&self) -> bool { + self.options.execution.enforce_batch_size_in_joins + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index c98d7e5579f0f..756da7ed5b468 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Manages files generated during query execution, files are -//! hashed among the directories listed in RuntimeConfig::local_dirs. +//! [`DiskManager`]: Manages files generated during query execution use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; @@ -139,7 +138,7 @@ impl DiskManager { let dir_index = thread_rng().gen_range(0..local_dirs.len()); Ok(RefCountedTempFile { - parent_temp_dir: Arc::clone(&local_dirs[dir_index]), + _parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() .tempfile_in(local_dirs[dir_index].as_ref()) .map_err(DataFusionError::IoError)?, @@ -153,8 +152,7 @@ impl DiskManager { pub struct RefCountedTempFile { /// The reference to the directory in which temporary files are created to ensure /// it is not cleaned up prior to the NamedTempFile - #[allow(dead_code)] - parent_temp_dir: Arc, + _parent_temp_dir: Arc, tempfile: NamedTempFile, } @@ -173,7 +171,7 @@ fn create_local_dirs(local_dirs: Vec) -> Result>> { local_dirs .iter() .map(|root| { - if !std::path::Path::new(root).exists() { + if !Path::new(root).exists() { std::fs::create_dir(root)?; } Builder::new() diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 909364fa805da..317bd3203ab1b 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index dcd59acbd49eb..45d467f133bf0 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -23,7 +23,9 @@ use std::{cmp::Ordering, sync::Arc}; mod pool; pub mod proxy { - pub use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; + pub use datafusion_common::utils::proxy::{ + HashTableAllocExt, RawTableAllocExt, VecAllocExt, + }; } pub use pool::*; @@ -68,11 +70,35 @@ pub use pool::*; /// Note that a `MemoryPool` can be shared by concurrently executing plans, /// which can be used to control memory usage in a multi-tenant system. /// +/// # How MemoryPool works by example +/// +/// Scenario 1: +/// For `Filter` operator, `RecordBatch`es will stream through it, so it +/// don't have to keep track of memory usage through [`MemoryPool`]. +/// +/// Scenario 2: +/// For `CrossJoin` operator, if the input size gets larger, the intermediate +/// state will also grow. So `CrossJoin` operator will use [`MemoryPool`] to +/// limit the memory usage. +/// 2.1 `CrossJoin` operator has read a new batch, asked memory pool for +/// additional memory. Memory pool updates the usage and returns success. +/// 2.2 `CrossJoin` has read another batch, and tries to reserve more memory +/// again, memory pool does not have enough memory. Since `CrossJoin` operator +/// has not implemented spilling, it will stop execution and return an error. +/// +/// Scenario 3: +/// For `Aggregate` operator, its intermediate states will also accumulate as +/// the input size gets larger, but with spilling capability. When it tries to +/// reserve more memory from the memory pool, and the memory pool has already +/// reached the memory limit, it will return an error. Then, `Aggregate` +/// operator will spill the intermediate buffers to disk, and release memory +/// from the memory pool, and continue to retry memory reservation. +/// /// # Implementing `MemoryPool` /// /// You can implement a custom allocation policy by implementing the /// [`MemoryPool`] trait and configuring a `SessionContext` appropriately. -/// However, mDataFusion comes with the following simple memory pool implementations that +/// However, DataFusion comes with the following simple memory pool implementations that /// handle many common cases: /// /// * [`UnboundedMemoryPool`]: no memory limits (the default) @@ -310,13 +336,17 @@ impl Drop for MemoryReservation { } } -const TB: u64 = 1 << 40; -const GB: u64 = 1 << 30; -const MB: u64 = 1 << 20; -const KB: u64 = 1 << 10; +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} /// Present size in human readable form pub fn human_readable_size(size: usize) -> String { + use units::*; + let size = size as u64; let (value, unit) = { if size >= 2 * TB { diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index e169c1f319cca..261332180e571 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -16,8 +16,8 @@ // under the License. use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use datafusion_common::HashMap; use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; -use hashbrown::HashMap; use log::debug; use parking_lot::Mutex; use std::{ @@ -62,7 +62,7 @@ pub struct GreedyMemoryPool { } impl GreedyMemoryPool { - /// Allocate up to `limit` bytes + /// Create a new pool that can allocate up to `pool_size` bytes pub fn new(pool_size: usize) -> Self { debug!("Created new GreedyMemoryPool(pool_size={pool_size})"); Self { diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 4022eb07de0c7..2b08b7ff9e889 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -41,13 +41,32 @@ use url::Url; /// Execution runtime environment that manages system resources such /// as memory, disk, cache and storage. /// -/// A [`RuntimeEnv`] is created from a [`RuntimeEnvBuilder`] and has the +/// A [`RuntimeEnv`] can be created using [`RuntimeEnvBuilder`] and has the /// following resource management functionality: /// /// * [`MemoryPool`]: Manage memory /// * [`DiskManager`]: Manage temporary files on local disk /// * [`CacheManager`]: Manage temporary cache data during the session lifetime /// * [`ObjectStoreRegistry`]: Manage mapping URLs to object store instances +/// +/// # Example: Create default `RuntimeEnv` +/// ``` +/// # use datafusion_execution::runtime_env::RuntimeEnv; +/// let runtime_env = RuntimeEnv::default(); +/// ``` +/// +/// # Example: Create a `RuntimeEnv` from [`RuntimeEnvBuilder`] with a new memory pool +/// ``` +/// # use std::sync::Arc; +/// # use datafusion_execution::memory_pool::GreedyMemoryPool; +/// # use datafusion_execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; +/// // restrict to using at most 100MB of memory +/// let pool_size = 100 * 1024 * 1024; +/// let runtime_env = RuntimeEnvBuilder::new() +/// .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))) +/// .build() +/// .unwrap(); +/// ``` pub struct RuntimeEnv { /// Runtime memory management pub memory_pool: Arc, @@ -66,28 +85,16 @@ impl Debug for RuntimeEnv { } impl RuntimeEnv { - #[deprecated(note = "please use `try_new` instead")] + #[deprecated(since = "43.0.0", note = "please use `RuntimeEnvBuilder` instead")] + #[allow(deprecated)] pub fn new(config: RuntimeConfig) -> Result { Self::try_new(config) } /// Create env based on configuration + #[deprecated(since = "44.0.0", note = "please use `RuntimeEnvBuilder` instead")] + #[allow(deprecated)] pub fn try_new(config: RuntimeConfig) -> Result { - let RuntimeConfig { - memory_pool, - disk_manager, - cache_manager, - object_store_registry, - } = config; - - let memory_pool = - memory_pool.unwrap_or_else(|| Arc::new(UnboundedMemoryPool::default())); - - Ok(Self { - memory_pool, - disk_manager: DiskManager::try_new(disk_manager)?, - cache_manager: CacheManager::try_new(&cache_manager)?, - object_store_registry, - }) + config.build() } /// Registers a custom `ObjectStore` to be used with a specific url. @@ -104,26 +111,26 @@ impl RuntimeEnv { /// # use std::sync::Arc; /// # use url::Url; /// # use datafusion_execution::runtime_env::RuntimeEnv; - /// # let runtime_env = RuntimeEnv::try_new(Default::default()).unwrap(); + /// # let runtime_env = RuntimeEnv::default(); /// let url = Url::try_from("file://").unwrap(); /// let object_store = object_store::local::LocalFileSystem::new(); /// // register the object store with the runtime environment /// runtime_env.register_object_store(&url, Arc::new(object_store)); /// ``` /// - /// # Example: Register local file system object store + /// # Example: Register remote URL object store like [Github](https://github.com) /// - /// To register reading from urls such as ` /// /// ``` /// # use std::sync::Arc; /// # use url::Url; /// # use datafusion_execution::runtime_env::RuntimeEnv; - /// # let runtime_env = RuntimeEnv::try_new(Default::default()).unwrap(); + /// # let runtime_env = RuntimeEnv::default(); /// # // use local store for example as http feature is not enabled /// # let http_store = object_store::local::LocalFileSystem::new(); /// // create a new object store via object_store::http::HttpBuilder; /// let base_url = Url::parse("https://github.com").unwrap(); + /// // (note this example can't depend on the http feature) /// // let http_store = HttpBuilder::new() /// // .with_url(base_url.clone()) /// // .build() @@ -157,10 +164,13 @@ impl Default for RuntimeEnv { /// Please see: /// This a type alias for backwards compatibility. +#[deprecated(since = "43.0.0", note = "please use `RuntimeEnvBuilder` instead")] pub type RuntimeConfig = RuntimeEnvBuilder; #[derive(Clone)] -/// Execution runtime configuration +/// Execution runtime configuration builder. +/// +/// See example on [`RuntimeEnv`] pub struct RuntimeEnvBuilder { /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, @@ -239,15 +249,20 @@ impl RuntimeEnvBuilder { /// Build a RuntimeEnv pub fn build(self) -> Result { - let memory_pool = self - .memory_pool - .unwrap_or_else(|| Arc::new(UnboundedMemoryPool::default())); + let Self { + disk_manager, + memory_pool, + cache_manager, + object_store_registry, + } = self; + let memory_pool = + memory_pool.unwrap_or_else(|| Arc::new(UnboundedMemoryPool::default())); Ok(RuntimeEnv { memory_pool, - disk_manager: DiskManager::try_new(self.disk_manager)?, - cache_manager: CacheManager::try_new(&self.cache_manager)?, - object_store_registry: self.object_store_registry, + disk_manager: DiskManager::try_new(disk_manager)?, + cache_manager: CacheManager::try_new(&cache_manager)?, + object_store_registry, }) } diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs index f3eb7b77e03cc..5b309210aa37c 100644 --- a/datafusion/execution/src/stream.rs +++ b/datafusion/execution/src/stream.rs @@ -40,6 +40,13 @@ pub trait RecordBatchStream: Stream> { /// `RecordBatch` returned by the stream should have the same schema as returned /// by [`schema`](`RecordBatchStream::schema`). /// +/// # See Also +/// +/// * [`RecordBatchStreamAdapter`] to convert an existing [`Stream`] +/// to [`SendableRecordBatchStream`] +/// +/// [`RecordBatchStreamAdapter`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/stream/struct.RecordBatchStreamAdapter.html +/// /// # Error Handling /// /// Once a stream returns an error, it should not be polled again (the caller diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 57fcac0ee5ab6..7cdb53c90d0ef 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -15,20 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - use crate::{ - config::SessionConfig, - memory_pool::MemoryPool, - registry::FunctionRegistry, - runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, + config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, + runtime_env::RuntimeEnv, }; use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use std::collections::HashSet; +use std::{collections::HashMap, sync::Arc}; /// Task Execution Context /// @@ -57,9 +52,7 @@ pub struct TaskContext { impl Default for TaskContext { fn default() -> Self { - let runtime = RuntimeEnvBuilder::new() - .build_arc() - .expect("default runtime created successfully"); + let runtime = Arc::new(RuntimeEnv::default()); // Create a default task context, mostly useful for testing Self { @@ -125,6 +118,18 @@ impl TaskContext { Arc::clone(&self.runtime) } + pub fn scalar_functions(&self) -> &HashMap> { + &self.scalar_functions + } + + pub fn aggregate_functions(&self) -> &HashMap> { + &self.aggregate_functions + } + + pub fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + /// Update the [`SessionConfig`] pub fn with_session_config(mut self, session_config: SessionConfig) -> Self { self.session_config = session_config; diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 7e477efc4ebc1..1ccc6fc17293c 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -19,7 +19,6 @@ name = "datafusion-expr-common" description = "Logical plan and expression representation for DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] -readme = "README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } @@ -40,4 +39,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } +itertools = { workspace = true } + +[dev-dependencies] paste = "^1.0" diff --git a/datafusion/expr-common/LICENSE.txt b/datafusion/expr-common/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/expr-common/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/expr-common/NOTICE.txt b/datafusion/expr-common/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/expr-common/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 75335209451e1..dc1e023d4c3cf 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -39,7 +39,7 @@ use std::fmt::Debug; /// function]) /// /// * convert its internal state to a vector of aggregate values via -/// [`state`] and combine the state from multiple accumulators' +/// [`state`] and combine the state from multiple accumulators /// via [`merge_batch`], as part of efficient multi-phase grouping. /// /// [`GroupsAccumulator`]: crate::GroupsAccumulator @@ -68,7 +68,7 @@ pub trait Accumulator: Send + Sync + Debug { /// result in potentially non-deterministic behavior. /// /// This function gets `&mut self` to allow for the accumulator to build - /// arrow compatible internal state that can be returned without copying + /// arrow-compatible internal state that can be returned without copying /// when possible (for example distinct strings) fn evaluate(&mut self) -> Result; @@ -89,14 +89,14 @@ pub trait Accumulator: Send + Sync + Debug { /// result in potentially non-deterministic behavior. /// /// This function gets `&mut self` to allow for the accumulator to build - /// arrow compatible internal state that can be returned without copying + /// arrow-compatible internal state that can be returned without copying /// when possible (for example distinct strings). /// /// Intermediate state is used for "multi-phase" grouping in /// DataFusion, where an aggregate is computed in parallel with /// multiple `Accumulator` instances, as described below: /// - /// # MultiPhase Grouping + /// # Multi-Phase Grouping /// /// ```text /// ▲ @@ -115,7 +115,7 @@ pub trait Accumulator: Send + Sync + Debug { /// │ │ /// │ │ /// ┌─────────────────────────┐ ┌─────────────────────────┐ - /// │ GroubyBy │ │ GroubyBy │ + /// │ GroupBy │ │ GroupBy │ /// │(AggregateMode::Partial) │ │(AggregateMode::Partial) │ /// └─────────────────────────┘ └─────────────────────────┘ /// ▲ ▲ @@ -140,9 +140,9 @@ pub trait Accumulator: Send + Sync + Debug { /// to be summed together) /// /// Some accumulators can return multiple values for their - /// intermediate states. For example average, tracks `sum` and - /// `n`, and this function should return - /// a vector of two values, sum and n. + /// intermediate states. For example, the average accumulator + /// tracks `sum` and `n`, and this function should return a vector + /// of two values, sum and n. /// /// Note that [`ScalarValue::List`] can be used to pass multiple /// values if the number of intermediate values is not known at @@ -181,7 +181,7 @@ pub trait Accumulator: Send + Sync + Debug { /// │ │ /// │ │ /// ┌─────────────────────────┐ ┌──────────────────────────┐ 2. Each AggregateMode::Partial - /// │ GroubyBy │ │ GroubyBy │ GroupBy has an entry for *all* + /// │ GroupBy │ │ GroupBy │ GroupBy has an entry for *all* /// │(AggregateMode::Partial) │ │ (AggregateMode::Partial) │ the groups /// └─────────────────────────┘ └──────────────────────────┘ /// ▲ ▲ @@ -204,7 +204,7 @@ pub trait Accumulator: Send + Sync + Debug { /// The final output is computed by repartitioning the result of /// [`Self::state`] from each Partial aggregate and `hash(group keys)` so /// that each distinct group key appears in exactly one of the - /// `AggregateMode::Final` GroupBy nodes. The output of the final nodes are + /// `AggregateMode::Final` GroupBy nodes. The outputs of the final nodes are /// then unioned together to produce the overall final output. /// /// Here is an example that shows the distribution of groups in the @@ -254,7 +254,7 @@ pub trait Accumulator: Send + Sync + Debug { /// or more intermediate values. /// /// For some aggregates (such as `SUM`), merge_batch is the same - /// as `update_batch`, but for some aggregrates (such as `COUNT`) + /// as `update_batch`, but for some aggregates (such as `COUNT`) /// the operations differ. See [`Self::state`] for more details on how /// state is used and merged. /// diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 51113557335c6..bf977affaed36 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -17,10 +17,9 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::ArrayRef; -use arrow::array::NullArray; +use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::{kernels, CastOptions}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::DataType; use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; @@ -132,7 +131,25 @@ impl ColumnarValue { }) } - /// null columnar values are implemented as a null array in order to pass batch + /// Convert a columnar value into an Arrow [`ArrayRef`] with the specified + /// number of rows. [`Self::Scalar`] is converted by repeating the same + /// scalar multiple times which is not as efficient as handling the scalar + /// directly. + /// + /// See [`Self::values_to_arrays`] to convert multiple columnar values into + /// arrays of the same length. + /// + /// # Errors + /// + /// Errors if `self` is a Scalar that fails to be converted into an array of size + pub fn to_array(&self, num_rows: usize) -> Result { + Ok(match self { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?, + }) + } + + /// Null columnar values are implemented as a null array in order to pass batch /// num_rows pub fn create_null_array(num_rows: usize) -> Self { ColumnarValue::Array(Arc::new(NullArray::new(num_rows))) @@ -179,7 +196,7 @@ impl ColumnarValue { let args = args .iter() - .map(|arg| arg.clone().into_array(inferred_length)) + .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; Ok(args) @@ -196,31 +213,13 @@ impl ColumnarValue { ColumnarValue::Array(array) => Ok(ColumnarValue::Array( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), - ColumnarValue::Scalar(scalar) => { + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( // TODO(@notfilippo, logical vs physical): if `scalar.data_type` is *logically equivalent* // to `cast_type` then skip the kernel cast and only change the `data_type` of the scalar. - - let scalar_array = - if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { - if let ScalarValue::Float64(Some(float_ts)) = scalar.value() { - ScalarValue::Int64(Some( - (float_ts * 1_000_000_000_f64).trunc() as i64, - )) - .to_array()? - } else { - scalar.to_array()? - } - } else { - scalar.to_array()? - }; - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - cast_type, - &cast_options, - )?; - let cast_scalar = Scalar::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } + scalar + .value() + .cast_to_with_options(cast_type, &cast_options)?, + )), } } } diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 8e81c51d8460f..5ff1c1d072164 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -82,7 +82,7 @@ impl EmitTo { /// group /// ``` /// -/// # Notes on Implementing `GroupAccumulator` +/// # Notes on Implementing `GroupsAccumulator` /// /// All aggregates must first implement the simpler [`Accumulator`] trait, which /// handles state for a single group. Implementing `GroupsAccumulator` is @@ -90,12 +90,17 @@ impl EmitTo { /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. /// +/// [`NullState`] can help keep the state for groups that have not seen any +/// values and produce the correct output for those groups. +/// +/// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html +/// /// # Details /// Each group is assigned a `group_index` by the hash table and each /// accumulator manages the specific state, one per `group_index`. /// /// `group_index`es are contiguous (there aren't gaps), and thus it is -/// expected that each `GroupAccumulator` will use something like `Vec<..>` +/// expected that each `GroupsAccumulator` will use something like `Vec<..>` /// to store the group states. /// /// [`Accumulator`]: crate::accumulator::Accumulator @@ -106,8 +111,7 @@ pub trait GroupsAccumulator: Send { /// /// * `values`: the input arguments to the accumulator /// - /// * `group_indices`: To which groups do the rows in `values` - /// belong, group id) + /// * `group_indices`: The group indices to which each row in `values` belongs. /// /// * `opt_filter`: if present, only update aggregate state using /// `values[i]` if `opt_filter[i]` is true @@ -117,6 +121,11 @@ pub trait GroupsAccumulator: Send { /// /// Note that subsequent calls to update_batch may have larger /// total_num_groups as new groups are seen. + /// + /// See [`NullState`] to help keep the state for groups that have not seen any + /// values and produce the correct output for those groups. + /// + /// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html fn update_batch( &mut self, values: &[ArrayRef], @@ -175,9 +184,9 @@ pub trait GroupsAccumulator: Send { /// differ. See [`Self::state`] for more details on how state is /// used and merged. /// - /// * `values`: arrays produced from calling `state` previously to the accumulator + /// * `values`: arrays produced from previously calling `state` on other accumulators. /// - /// Other arguments are the same as for [`Self::update_batch`]; + /// Other arguments are the same as for [`Self::update_batch`]. fn merge_batch( &mut self, values: &[ArrayRef], @@ -186,7 +195,7 @@ pub trait GroupsAccumulator: Send { total_num_groups: usize, ) -> Result<()>; - /// Converts an input batch directly the intermediate aggregate state. + /// Converts an input batch directly to the intermediate aggregate state. /// /// This is the equivalent of treating each input row as its own group. It /// is invoked when the Partial phase of a multi-phase aggregation is not diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 6424888c896a5..ffaa32f08075c 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1223,8 +1223,8 @@ pub fn satisfy_greater( } } - // Only the lower bound of left hand side and the upper bound of the right - // hand side can change after propagating the greater-than operation. + // Only the lower bound of left-hand side and the upper bound of the right-hand + // side can change after propagating the greater-than operation. let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { if strict { next_value(right.lower.clone()) @@ -1753,7 +1753,7 @@ impl NullableInterval { } _ => Ok(Self::MaybeNull { values }), } - } else if op.is_comparison_operator() { + } else if op.supports_propagation() { Ok(Self::Null { datatype: DataType::Boolean, }) diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index e013b6fafa22d..6ca0f04897aca 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -142,10 +142,11 @@ impl Operator { ) } - /// Return true if the operator is a comparison operator. + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation /// - /// For example, 'Binary(a, >, b)' would be a comparison expression. - pub fn is_comparison_operator(&self) -> bool { + /// For example, 'Binary(a, >, b)' expression supports propagation. + pub fn supports_propagation(&self) -> bool { matches!( self, Operator::Eq @@ -163,6 +164,15 @@ impl Operator { ) } + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation + /// + /// For example, 'Binary(a, >, b)' expression supports propagation. + #[deprecated(since = "43.0.0", note = "please use `supports_propagation` instead")] + pub fn is_comparison_operator(&self) -> bool { + self.supports_propagation() + } + /// Return true if the operator is a logic operator. /// /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would diff --git a/datafusion/expr-common/src/scalar.rs b/datafusion/expr-common/src/scalar.rs index a1e2cc40bb273..5bed206edc11e 100644 --- a/datafusion/expr-common/src/scalar.rs +++ b/datafusion/expr-common/src/scalar.rs @@ -161,6 +161,10 @@ impl Scalar { &self.data_type } + pub fn is_null(&self) -> bool { + self.value.is_null() + } + #[inline] pub fn to_array_of_size(&self, size: usize) -> Result { self.value.to_array_of_size_and_type(size, &self.data_type) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 320e1303a21b7..56f3029a4d7a5 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -18,7 +18,12 @@ //! Signature module contains foundational types that are used to represent signatures, types, //! and return types of functions in DataFusion. -use arrow::datatypes::DataType; +use std::fmt::Display; + +use crate::type_coercion::aggregates::NUMERICS; +use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::types::{LogicalTypeRef, NativeType}; +use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. /// This is used where a function can accept a timestamp type with any @@ -35,7 +40,7 @@ pub const TIMEZONE_WILDCARD: &str = "+TZ"; /// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths. pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN; -///A function's volatility, which defines the functions eligibility for certain optimizations +/// A function's volatility, which defines the functions eligibility for certain optimizations #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum Volatility { /// An immutable function will always return the same output when given the same @@ -69,6 +74,9 @@ pub enum Volatility { /// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. /// /// # Data Types +/// +/// ## Timestamps +/// /// Types to match are represented using Arrow's [`DataType`]. [`DataType::Timestamp`] has an optional variable /// timezone specification. To specify a function can handle a timestamp with *ANY* timezone, use /// the [`TIMEZONE_WILDCARD`]. For example: @@ -86,50 +94,134 @@ pub enum Volatility { /// ``` #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeSignature { - /// One or more arguments of an common type out of a list of valid types. + /// One or more arguments of a common type out of a list of valid types. + /// + /// For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). /// /// # Examples - /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` + /// + /// A function such as `concat` is `Variadic(vec![DataType::Utf8, + /// DataType::LargeUtf8])` Variadic(Vec), - /// The acceptable signature and coercions rules to coerce arguments to this - /// signature are special for this function. If this signature is specified, - /// DataFusion will call `ScalarUDFImpl::coerce_types` to prepare argument types. + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types. + /// + /// [`ScalarUDFImpl::coerce_types`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.coerce_types UserDefined, /// One or more arguments with arbitrary types VariadicAny, - /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. + /// One or more arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples + /// /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` /// 2. A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` Uniform(usize, Vec), - /// Exact number of arguments of an exact type + /// One or more arguments with exactly the specified types in order. + /// + /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. Exact(Vec), - /// The number of arguments that can be coerced to in order - /// For example, `Coercible(vec![DataType::Float64])` accepts - /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` - /// since i32 and f32 can be casted to f64 - Coercible(Vec), - /// Fixed number of arguments of arbitrary types - /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` + /// One or more arguments belonging to the [`TypeSignatureClass`], in order. + /// + /// For example, `Coercible(vec![logical_float64()])` accepts + /// arguments like `vec![Int32]` or `vec![Float32]` + /// since i32 and f32 can be cast to f64 + /// + /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. + Coercible(Vec), + /// One or more arguments coercible to a single, comparable type. + /// + /// Each argument will be coerced to a single type using the + /// coercion rules described in [`comparison_coercion_numeric`]. + /// + /// # Examples + /// + /// If the `nullif(1, 2)` function is called with `i32` and `i64` arguments + /// the types will both be coerced to `i64` before the function is invoked. + /// + /// If the `nullif('1', 2)` function is called with `Utf8` and `i64` arguments + /// the types will both be coerced to `Utf8` before the function is invoked. + /// + /// Note: + /// - For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). + /// - If all arguments have type [`DataType::Null`], they are coerced to `Utf8` + /// + /// [`comparison_coercion_numeric`]: crate::type_coercion::binary::comparison_coercion_numeric + Comparable(usize), + /// One or more arguments of arbitrary types. + /// + /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. Any(usize), - /// Matches exactly one of a list of [`TypeSignature`]s. Coercion is attempted to match - /// the signatures in order, and stops after the first success, if any. + /// Matches exactly one of a list of [`TypeSignature`]s. + /// + /// Coercion is attempted to match the signatures in order, and stops after + /// the first success, if any. /// /// # Examples - /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` + /// + /// Since `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), - /// Specifies Signatures for array functions + /// A function that has an [`ArrayFunctionSignature`] ArraySignature(ArrayFunctionSignature), - /// Fixed number of arguments of numeric types. - /// See to know which type is considered numeric + /// One or more arguments of numeric types. + /// + /// See [`NativeType::is_numeric`] to know which type is considered numeric + /// + /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// + /// [`NativeType::is_numeric`]: datafusion_common::types::NativeType::is_numeric Numeric(usize), - /// Fixed number of arguments of all the same string types. + /// One or arguments of all the same string types. + /// /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. - /// Null is considerd as Utf8 by default + /// Null is considered as `Utf8` by default /// Dictionary with string value type is also handled. + /// + /// For example, if a function is called with (utf8, large_utf8), all + /// arguments will be coerced to `LargeUtf8` + /// + /// For functions that take no arguments (e.g. `random()` use [`TypeSignature::Nullary`]). String(usize), + /// No arguments + Nullary, +} + +impl TypeSignature { + #[inline] + pub fn is_one_of(&self) -> bool { + matches!(self, TypeSignature::OneOf(_)) + } +} + +/// Represents the class of types that can be used in a function signature. +/// +/// This is used to specify what types are valid for function arguments in a more flexible way than +/// just listing specific DataTypes. For example, TypeSignatureClass::Timestamp matches any timestamp +/// type regardless of timezone or precision. +/// +/// Used primarily with TypeSignature::Coercible to define function signatures that can accept +/// arguments that can be coerced to a particular class of types. +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] +pub enum TypeSignatureClass { + Timestamp, + Date, + Time, + Interval, + Duration, + Native(LogicalTypeRef), + // TODO: + // Numeric + // Integer +} + +impl Display for TypeSignatureClass { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TypeSignatureClass::{self:?}") + } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -153,12 +245,15 @@ pub enum ArrayFunctionSignature { /// The function takes a single argument that must be a List/LargeList/FixedSizeList /// or something that can be coerced to one of those types. Array, + /// A function takes a single argument that must be a List/LargeList/FixedSizeList + /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. + RecursiveArray, /// Specialized Signature for MapArray /// The function takes a single argument that must be a MapArray MapArray, } -impl std::fmt::Display for ArrayFunctionSignature { +impl Display for ArrayFunctionSignature { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ArrayFunctionSignature::ArrayAndElement => { @@ -176,6 +271,9 @@ impl std::fmt::Display for ArrayFunctionSignature { ArrayFunctionSignature::Array => { write!(f, "array") } + ArrayFunctionSignature::RecursiveArray => { + write!(f, "recursive_array") + } ArrayFunctionSignature::MapArray => { write!(f, "map_array") } @@ -186,6 +284,9 @@ impl std::fmt::Display for ArrayFunctionSignature { impl TypeSignature { pub fn to_string_repr(&self) -> Vec { match self { + TypeSignature::Nullary => { + vec!["NullAry()".to_string()] + } TypeSignature::Variadic(types) => { vec![format!("{}, ..", Self::join_types(types, "/"))] } @@ -201,7 +302,13 @@ impl TypeSignature { TypeSignature::Numeric(num) => { vec![format!("Numeric({num})")] } - TypeSignature::Exact(types) | TypeSignature::Coercible(types) => { + TypeSignature::Comparable(num) => { + vec![format!("Comparable({num})")] + } + TypeSignature::Coercible(types) => { + vec![Self::join_types(types, ", ")] + } + TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] } TypeSignature::Any(arg_count) => { @@ -224,7 +331,7 @@ impl TypeSignature { } /// Helper function to join types with specified delimiter. - pub fn join_types(types: &[T], delimiter: &str) -> String { + pub fn join_types(types: &[T], delimiter: &str) -> String { types .iter() .map(|t| t.to_string()) @@ -236,13 +343,116 @@ impl TypeSignature { pub fn supports_zero_argument(&self) -> bool { match &self { TypeSignature::Exact(vec) => vec.is_empty(), - TypeSignature::Uniform(0, _) | TypeSignature::Any(0) => true, + TypeSignature::Nullary => true, TypeSignature::OneOf(types) => types .iter() .any(|type_sig| type_sig.supports_zero_argument()), _ => false, } } + + /// Returns true if the signature currently supports or used to supported 0 + /// input arguments in a previous version of DataFusion. + pub fn used_to_support_zero_arguments(&self) -> bool { + match &self { + TypeSignature::Any(num) => *num == 0, + _ => self.supports_zero_argument(), + } + } + + /// get all possible types for the given `TypeSignature` + pub fn get_possible_types(&self) -> Vec> { + match self { + TypeSignature::Exact(types) => vec![types.clone()], + TypeSignature::OneOf(types) => types + .iter() + .flat_map(|type_sig| type_sig.get_possible_types()) + .collect(), + TypeSignature::Uniform(arg_count, types) => types + .iter() + .cloned() + .map(|data_type| vec![data_type; *arg_count]) + .collect(), + TypeSignature::Coercible(types) => types + .iter() + .map(|logical_type| match logical_type { + TypeSignatureClass::Native(l) => get_data_types(l.native()), + TypeSignatureClass::Timestamp => { + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp( + TimeUnit::Nanosecond, + Some(TIMEZONE_WILDCARD.into()), + ), + ] + } + TypeSignatureClass::Date => { + vec![DataType::Date64] + } + TypeSignatureClass::Time => { + vec![DataType::Time64(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Interval => { + vec![DataType::Interval(IntervalUnit::DayTime)] + } + TypeSignatureClass::Duration => { + vec![DataType::Duration(TimeUnit::Nanosecond)] + } + }) + .multi_cartesian_product() + .collect(), + TypeSignature::Variadic(types) => types + .iter() + .cloned() + .map(|data_type| vec![data_type]) + .collect(), + TypeSignature::Numeric(arg_count) => NUMERICS + .iter() + .cloned() + .map(|numeric_type| vec![numeric_type; *arg_count]) + .collect(), + TypeSignature::String(arg_count) => get_data_types(&NativeType::String) + .into_iter() + .map(|dt| vec![dt; *arg_count]) + .collect::>(), + // TODO: Implement for other types + TypeSignature::Any(_) + | TypeSignature::Comparable(_) + | TypeSignature::Nullary + | TypeSignature::VariadicAny + | TypeSignature::ArraySignature(_) + | TypeSignature::UserDefined => vec![], + } + } +} + +fn get_data_types(native_type: &NativeType) -> Vec { + match native_type { + NativeType::Null => vec![DataType::Null], + NativeType::Boolean => vec![DataType::Boolean], + NativeType::Int8 => vec![DataType::Int8], + NativeType::Int16 => vec![DataType::Int16], + NativeType::Int32 => vec![DataType::Int32], + NativeType::Int64 => vec![DataType::Int64], + NativeType::UInt8 => vec![DataType::UInt8], + NativeType::UInt16 => vec![DataType::UInt16], + NativeType::UInt32 => vec![DataType::UInt32], + NativeType::UInt64 => vec![DataType::UInt64], + NativeType::Float16 => vec![DataType::Float16], + NativeType::Float32 => vec![DataType::Float32], + NativeType::Float64 => vec![DataType::Float64], + NativeType::Date => vec![DataType::Date32, DataType::Date64], + NativeType::Binary => vec![ + DataType::Binary, + DataType::LargeBinary, + DataType::BinaryView, + ], + NativeType::String => { + vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] + } + // TODO: support other native types + _ => vec![], + } } /// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. @@ -322,13 +532,31 @@ impl Signature { } } /// Target coerce types in order - pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { + pub fn coercible( + target_types: Vec, + volatility: Volatility, + ) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), volatility, } } + /// Used for function that expects comparable data types, it will try to coerced all the types into single final one. + pub fn comparable(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Comparable(arg_count), + volatility, + } + } + + pub fn nullary(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::Nullary, + volatility, + } + } + /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Signature { @@ -390,6 +618,8 @@ impl Signature { #[cfg(test)] mod tests { + use datafusion_common::types::{logical_int64, logical_string}; + use super::*; #[test] @@ -397,13 +627,12 @@ mod tests { // Testing `TypeSignature`s which supports 0 arg let positive_cases = vec![ TypeSignature::Exact(vec![]), - TypeSignature::Uniform(0, vec![DataType::Float64]), - TypeSignature::Any(0), TypeSignature::OneOf(vec![ TypeSignature::Exact(vec![DataType::Int8]), - TypeSignature::Any(0), + TypeSignature::Nullary, TypeSignature::Uniform(1, vec![DataType::Int8]), ]), + TypeSignature::Nullary, ]; for case in positive_cases { @@ -454,4 +683,101 @@ mod tests { < TypeSignature::Exact(vec![DataType::Null]) ); } + + #[test] + fn test_get_possible_types() { + let type_signature = TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]); + let possible_types = type_signature.get_possible_types(); + assert_eq!(possible_types, vec![vec![DataType::Int32, DataType::Int64]]); + + let type_signature = TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]), + ]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Int32, DataType::Int64], + vec![DataType::Float32, DataType::Float64] + ] + ); + + let type_signature = TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]), + TypeSignature::Exact(vec![DataType::Utf8]), + ]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Int32, DataType::Int64], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Utf8] + ] + ); + + let type_signature = + TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Float32, DataType::Float32], + vec![DataType::Int64, DataType::Int64] + ] + ); + + let type_signature = TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Native(logical_int64()), + ]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Utf8, DataType::Int64], + vec![DataType::LargeUtf8, DataType::Int64], + vec![DataType::Utf8View, DataType::Int64] + ] + ); + + let type_signature = + TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![vec![DataType::Int32], vec![DataType::Int64]] + ); + + let type_signature = TypeSignature::Numeric(2); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Int8, DataType::Int8], + vec![DataType::Int16, DataType::Int16], + vec![DataType::Int32, DataType::Int32], + vec![DataType::Int64, DataType::Int64], + vec![DataType::UInt8, DataType::UInt8], + vec![DataType::UInt16, DataType::UInt16], + vec![DataType::UInt32, DataType::UInt32], + vec![DataType::UInt64, DataType::UInt64], + vec![DataType::Float32, DataType::Float32], + vec![DataType::Float64, DataType::Float64] + ] + ); + + let type_signature = TypeSignature::String(2); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Utf8, DataType::Utf8], + vec![DataType::LargeUtf8, DataType::LargeUtf8], + vec![DataType::Utf8View, DataType::Utf8View] + ] + ); + } } diff --git a/datafusion/expr-common/src/sort_properties.rs b/datafusion/expr-common/src/sort_properties.rs index 7778be2ecf0dd..5d17a34a96fbc 100644 --- a/datafusion/expr-common/src/sort_properties.rs +++ b/datafusion/expr-common/src/sort_properties.rs @@ -129,19 +129,30 @@ impl Neg for SortProperties { } } -/// Represents the properties of a `PhysicalExpr`, including its sorting and range attributes. +/// Represents the properties of a `PhysicalExpr`, including its sorting, +/// range, and whether it preserves lexicographical ordering. #[derive(Debug, Clone)] pub struct ExprProperties { + /// Properties that describe the sorting behavior of the expression, + /// such as whether it is ordered, unordered, or a singleton value. pub sort_properties: SortProperties, + /// A closed interval representing the range of possible values for + /// the expression. Used to compute reliable bounds. pub range: Interval, + /// Indicates whether the expression preserves lexicographical ordering + /// of its inputs. For example, string concatenation preserves ordering, + /// while addition does not. + pub preserves_lex_ordering: bool, } impl ExprProperties { - /// Creates a new `ExprProperties` instance with unknown sort properties and unknown range. + /// Creates a new `ExprProperties` instance with unknown sort properties, + /// unknown range, and unknown lexicographical ordering preservation. pub fn new_unknown() -> Self { Self { sort_properties: SortProperties::default(), range: Interval::make_unbounded(&DataType::Null).unwrap(), + preserves_lex_ordering: false, } } @@ -156,4 +167,10 @@ impl ExprProperties { self.range = range; self } + + /// Sets whether the expression maintains lexicographical ordering and returns the modified instance. + pub fn with_preserves_lex_ordering(mut self, preserves_lex_ordering: bool) -> Self { + self.preserves_lex_ordering = preserves_lex_ordering; + self + } } diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 2add9e7c1867c..13d52959aba65 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -23,7 +23,8 @@ use arrow::datatypes::{ use datafusion_common::{internal_err, plan_err, Result}; -pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; +pub static STRINGS: &[DataType] = + &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; pub static SIGNED_INTEGERS: &[DataType] = &[ DataType::Int8, @@ -143,21 +144,21 @@ pub fn check_arg_count( Ok(()) } -/// function return type of a sum +/// Function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+10), s) + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+10), s) + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } @@ -165,7 +166,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } } -/// function return type of variance +/// Function return type of variance pub fn variance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) @@ -174,7 +175,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { } } -/// function return type of covariance +/// Function return type of covariance pub fn covariance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) @@ -183,7 +184,7 @@ pub fn covariance_return_type(arg_type: &DataType) -> Result { } } -/// function return type of correlation +/// Function return type of correlation pub fn correlation_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) @@ -192,19 +193,19 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { } } -/// function return type of an average +/// Function return type of an average pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { match arg_type { DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal256(new_precision, new_scale)) @@ -217,16 +218,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result } } -/// internal sum type of an average +/// Internal sum type of an average pub fn avg_sum_type(arg_type: &DataType) -> Result { match arg_type { DataType::Decimal128(precision, scale) => { - // in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } DataType::Decimal256(precision, scale) => { - // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) + // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s) let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } @@ -293,19 +294,19 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Result { - return match &data_type { + match &data_type { DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), d if d.is_numeric() => Ok(DataType::Float64), - DataType::Dictionary(_, v) => return coerced_type(func_name, v.as_ref()), + DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), _ => { - return plan_err!( + plan_err!( "The function {:?} does not support inputs of type {:?}.", func_name, data_type ) } - }; + } } Ok(vec![coerced_type(func_name, &arg_types[0])?]) } diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 6d66b8b4df447..12292502e683d 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -25,10 +25,14 @@ use crate::operator::Operator; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; +use datafusion_common::types::NativeType; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result, +}; +use itertools::Itertools; /// The type signature of an instantiation of binary operator expression such as /// `lhs + rhs` @@ -86,7 +90,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { // Logical binary boolean operators can only be evaluated for // boolean or null arguments. - Ok(Signature::uniform(DataType::Boolean)) + Ok(Signature::uniform(Boolean)) } else { plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" @@ -191,7 +195,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } } -/// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types +/// Returns the resulting type of a binary expression evaluating the `op` with the left and right hand types pub fn get_result_type( lhs: &DataType, op: &Operator, @@ -327,11 +331,13 @@ impl From<&DataType> for TypeCategory { return TypeCategory::Array; } - // String literal is possible to cast to many other types like numeric or datetime, - // therefore, it is categorized as a unknown type + // It is categorized as unknown type because the type will be resolved later on if matches!( data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Null ) { return TypeCategory::Unknown; } @@ -370,17 +376,21 @@ impl From<&DataType> for TypeCategory { /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted /// decimal precision and scale when coercing decimal types. +/// +/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type. +/// +/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; } - // if all the data_types is the same return first one + // If all the data_types is the same return first one if data_types.iter().all(|t| t == &data_types[0]) { return Some(data_types[0].clone()); } - // if all the data_types are null, return string + // If all the data_types are null, return string if data_types.iter().all(|t| t == &DataType::Null) { return Some(DataType::Utf8); } @@ -399,7 +409,7 @@ pub fn type_union_resolution(data_types: &[DataType]) -> Option { return None; } - // check if there is only one category excluding Unknown + // Check if there is only one category excluding Unknown let categories: HashSet = HashSet::from_iter( data_types_category .iter() @@ -471,15 +481,56 @@ fn type_union_resolution_coercion( let new_value_type = type_union_resolution_coercion(value_type, other_type); new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) } + (DataType::Struct(lhs), DataType::Struct(rhs)) => { + if lhs.len() != rhs.len() { + return None; + } + + // Search the field in the right hand side with the SAME field name + fn search_corresponding_coerced_type( + lhs_field: &FieldRef, + rhs: &Fields, + ) -> Option { + for rhs_field in rhs.iter() { + if lhs_field.name() == rhs_field.name() { + if let Some(t) = type_union_resolution_coercion( + lhs_field.data_type(), + rhs_field.data_type(), + ) { + return Some(t); + } else { + return None; + } + } + } + + None + } + + let types = lhs + .iter() + .map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs)) + .collect::>>()?; + + let fields = types + .into_iter() + .enumerate() + .map(|(i, datatype)| { + Arc::new(Field::new(format!("c{i}"), datatype, true)) + }) + .collect::>(); + Some(DataType::Struct(fields.into())) + } (DataType::List(lhs), DataType::List(rhs)) => { let new_item_type = type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) } _ => { - // numeric coercion is the same as comparison coercion, both find the narrowest type + // Numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) @@ -487,6 +538,89 @@ fn type_union_resolution_coercion( } } +/// Handle type union resolution including struct type and others. +pub fn try_type_union_resolution(data_types: &[DataType]) -> Result> { + let err = match try_type_union_resolution_with_struct(data_types) { + Ok(struct_types) => return Ok(struct_types), + Err(e) => Some(e), + }; + + if let Some(new_type) = type_union_resolution(data_types) { + Ok(vec![new_type; data_types.len()]) + } else { + exec_err!("Fail to find the coerced type, errors: {:?}", err) + } +} + +// Handle struct where we only change the data type but preserve the field name and nullability. +// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" +pub fn try_type_union_resolution_with_struct( + data_types: &[DataType], +) -> Result> { + let mut keys_string: Option = None; + for data_type in data_types { + if let DataType::Struct(fields) = data_type { + let keys = fields.iter().map(|f| f.name().to_owned()).join(","); + if let Some(ref k) = keys_string { + if *k != keys { + return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); + } + } else { + keys_string = Some(keys); + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut struct_types: Vec = if let DataType::Struct(fields) = &data_types[0] + { + fields.iter().map(|f| f.data_type().to_owned()).collect() + } else { + return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + }; + + for data_type in data_types.iter().skip(1) { + if let DataType::Struct(fields) = data_type { + let incoming_struct_types: Vec = + fields.iter().map(|f| f.data_type().to_owned()).collect(); + // The order of field is verified above + for (lhs_type, rhs_type) in + struct_types.iter_mut().zip(incoming_struct_types.iter()) + { + if let Some(coerced_type) = + type_union_resolution_coercion(lhs_type, rhs_type) + { + *lhs_type = coerced_type; + } else { + return exec_err!( + "Fail to find the coerced type for {} and {}", + lhs_type, + rhs_type + ); + } + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut final_struct_types = vec![]; + for s in data_types { + let mut new_fields = vec![]; + if let DataType::Struct(fields) = s { + for (i, f) in fields.iter().enumerate() { + let field = Arc::unwrap_or_clone(Arc::clone(f)) + .with_data_type(struct_types[i].to_owned()); + new_fields.push(Arc::new(field)); + } + } + final_struct_types.push(DataType::Struct(new_fields.into())) + } + + Ok(final_struct_types) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a /// comparison operation /// @@ -496,6 +630,19 @@ fn type_union_resolution_coercion( /// data type. However, users can write queries where the two arguments are /// different data types. In such cases, the data types are automatically cast /// (coerced) to a single data type to pass to the kernels. +/// +/// # Numeric comparisons +/// +/// When comparing numeric values, the lower precision type is coerced to the +/// higher precision type to avoid losing data. For example when comparing +/// `Int32` to `Int64` the coerced type is `Int64` so the `Int32` argument will +/// be cast. +/// +/// # Numeric / String comparisons +/// +/// When comparing numeric values and strings, both values will be coerced to +/// strings. For example when comparing `'2' > 1`, the arguments will be +/// coerced to `Utf8` for comparison pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible @@ -513,6 +660,28 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option 1` if `1` is an `Int32`, the arguments +/// will be coerced to `Int32`. +pub fn comparison_coercion_numeric( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + if lhs_type == rhs_type { + // same type => equality is possible + return Some(lhs_type.clone()); + } + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) + .or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type)) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is numeric and one is `Utf8`/`LargeUtf8`. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -526,6 +695,24 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + let lhs_logical_type = NativeType::from(lhs_type); + let rhs_logical_type = NativeType::from(rhs_type); + if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String { + return Some(lhs_type.to_owned()); + } + if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String { + return Some(rhs_type.to_owned()); + } + + None +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`. /// @@ -588,7 +775,7 @@ pub fn binary_numeric_coercion( return Some(t); } - // these are ordered from most informative to least informative so + // These are ordered from most informative to least informative so // that the coercion does not lose information via truncation match (lhs_type, rhs_type) { (Float64, _) | (_, Float64) => Some(Float64), @@ -814,12 +1001,12 @@ fn mathematics_numerical_coercion( ) -> Option { use arrow::datatypes::DataType::*; - // error on any non-numeric type + // Error on any non-numeric type if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) { return None; }; - // these are ordered from most informative to least informative so + // These are ordered from most informative to least informative so // that the coercion removes the least amount of information match (lhs_type, rhs_type) { (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { @@ -1006,27 +1193,46 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()]; + Some(Arc::new( + (**lhs_field) + .clone() + .with_data_type(type_union_resolution(&data_types)?) + .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()), + )) +} + /// Coercion rules for list types. fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (List(_), List(_)) => Some(lhs_type.clone()), - (LargeList(_), List(_)) => Some(lhs_type.clone()), - (List(_), LargeList(_)) => Some(rhs_type.clone()), - (LargeList(_), LargeList(_)) => Some(lhs_type.clone()), - (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()), // Coerce to the left side FixedSizeList type if the list lengths are the same, // otherwise coerce to list with the left type for dynamic length - (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => { + (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => { if ls == rs { - Some(lhs_type.clone()) + Some(FixedSizeList( + coerce_list_children(lhs_field, rhs_field)?, + *rs, + )) } else { - Some(List(Arc::clone(lf))) + Some(List(coerce_list_children(lhs_field, rhs_field)?)) } } - (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()), + // LargeList on any side + ( + LargeList(lhs_field), + List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _), + ) + | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => { + Some(LargeList(coerce_list_children(lhs_field, rhs_field)?)) + } + // Lists on both sides + (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _)) + | (FixedSizeList(lhs_field, _), List(rhs_field)) => { + Some(List(coerce_list_children(lhs_field, rhs_field)?)) + } _ => None, } } @@ -1080,7 +1286,7 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } -/// coercion rules for like operations. +/// Coercion rules for like operations. /// This is a union of string coercion rules and dictionary coercion rules pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) @@ -1091,13 +1297,13 @@ pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (DataType::Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), - (Utf8View | Utf8 | LargeUtf8, DataType::Null) => Some(lhs_type.clone()), - (DataType::Null, DataType::Null) => Some(Utf8), + (Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), + (Utf8View | Utf8 | LargeUtf8, Null) => Some(lhs_type.clone()), + (Null, Null) => Some(Utf8), _ => None, } } @@ -1253,7 +1459,7 @@ fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit { } } -/// coercion rules from NULL type. Since NULL can be casted to any other type in arrow, +/// Coercion rules from NULL type. Since NULL can be casted to any other type in arrow, /// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid. fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { match (lhs_type, rhs_type) { @@ -1917,7 +2123,7 @@ mod tests { ); // list - let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); + let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true)); test_coercion_binary_rule!( DataType::List(Arc::clone(&inner_field)), DataType::List(Arc::clone(&inner_field)), @@ -1973,10 +2179,35 @@ mod tests { DataType::List(Arc::clone(&inner_field)) ); + // Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible + let inner_timestamp_field = Arc::new(Field::new_list_field( + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )); + let result_type = get_input_types( + &DataType::List(Arc::clone(&inner_field)), + &Operator::Eq, + &DataType::List(Arc::clone(&inner_timestamp_field)), + ); + assert!(result_type.is_err()); + // TODO add other data type Ok(()) } + #[test] + fn test_list_coercion() { + let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false))); + + let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); + + let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!( + coerced_type, + DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) + ); // nullable because the RHS is nullable + } + #[test] fn test_type_coercion_logical_op() -> Result<()> { test_coercion_binary_rule!( diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d7dc1afe4d505..b4f3f7fb680f6 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -36,24 +36,22 @@ name = "datafusion_expr" path = "src/lib.rs" [features] +recursive_protection = ["dep:recursive"] [dependencies] -ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true } -arrow-buffer = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } +datafusion-doc = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } paste = "^1.0" +recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } -strum = { version = "0.26.1", features = ["derive"] } -strum_macros = "0.26.0" [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/expr/LICENSE.txt b/datafusion/expr/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/expr/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/expr/NOTICE.txt b/datafusion/expr/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/expr/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs deleted file mode 100644 index b136d6cacec8f..0000000000000 --- a/datafusion/expr/src/built_in_window_function.rs +++ /dev/null @@ -1,201 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Built-in functions module contains all the built-in functions definitions. - -use std::fmt; -use std::str::FromStr; - -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{Signature, TypeSignature, Volatility}; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; - -use arrow::datatypes::DataType; - -use strum_macros::EnumIter; - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - pub fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "first_value", - LastValue => "last_value", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } - - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use strum::IntoEnumIterator; - #[test] - // Test for BuiltInWindowFunction's Display and from_str() implementations. - // For each variant in BuiltInWindowFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in BuiltInWindowFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 7a2bf4b6c44a0..9cb51612d0cab 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -19,8 +19,7 @@ use crate::expr::Case; use crate::{expr_schema::ExprSchemable, Expr}; use arrow::datatypes::DataType; -use datafusion_common::{plan_err, DFSchema, Result}; -use std::collections::HashSet; +use datafusion_common::{plan_err, DFSchema, HashSet, Result}; /// Helper struct for building [Expr::Case] pub struct CaseBuilder { @@ -64,7 +63,7 @@ impl CaseBuilder { } fn build(&self) -> Result { - // collect all "then" expressions + // Collect all "then" expressions let mut then_expr = self.then_expr.clone(); if let Some(e) = &self.else_expr { then_expr.push(e.as_ref().to_owned()); @@ -79,7 +78,7 @@ impl CaseBuilder { .collect::>>()?; if then_types.contains(&DataType::Null) { - // cannot verify types until execution type + // Cannot verify types until execution type } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); if unique_types.len() != 1 { diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 3401a94b2736e..d672bd1acc460 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -18,7 +18,7 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::alias::AliasGenerator; -use std::collections::HashMap; +use datafusion_common::HashMap; use std::sync::Arc; /// Holds per-query execution properties and data (such as statement diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 93c7703ed9560..670a33b88eaf9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,28 +17,25 @@ //! Logical Expressions: [`Expr`] -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; -use std::str::FromStr; use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::Volatility; -use crate::{ - built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, - Signature, WindowFrame, WindowUDF, -}; +use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ - plan_err, Column, DFSchema, Result, ScalarValue, TableReference, + plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference, }; use datafusion_expr_common::scalar::Scalar; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -317,7 +314,7 @@ pub enum Expr { /// plan into physical plan. Wildcard { qualifier: Option, - options: WildcardOptions, + options: Box, }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list @@ -355,6 +352,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr { } } +impl<'a> TreeNodeContainer<'a, Self> for Expr { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -629,6 +642,15 @@ impl Sort { nulls_first: !self.nulls_first, } } + + /// Replaces the Sort expressions with `expr` + pub fn with_expr(&self, expr: Expr) -> Self { + Self { + expr, + asc: self.asc, + nulls_first: self.nulls_first, + } + } } impl Display for Sort { @@ -648,6 +670,24 @@ impl Display for Sort { } } +impl<'a> TreeNodeContainer<'a, Expr> for Sort { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.expr + .map_elements(f)? + .map_data(|expr| Ok(Self { expr, ..self })) + } +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` @@ -689,17 +729,17 @@ impl AggregateFunction { } } -/// WindowFunction +/// A function used as a SQL window function +/// +/// In SQL, you can use: +/// - Actual window functions ([`WindowUDF`]) +/// - Normal aggregate functions ([`AggregateUDF`]) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -/// Defines which implementation of an aggregate function DataFusion should call. pub enum WindowFunctionDefinition { - /// A built in aggregate function that leverages an aggregate function - /// A a built-in window function - BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), /// A user defined aggregate function AggregateUDF(Arc), /// A user defined aggregate function - WindowUDF(Arc), + WindowUDF(Arc), } impl WindowFunctionDefinition { @@ -711,9 +751,6 @@ impl WindowFunctionDefinition { display_name: &str, ) -> Result { match self { - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } WindowFunctionDefinition::AggregateUDF(fun) => { fun.return_type(input_expr_types) } @@ -723,10 +760,9 @@ impl WindowFunctionDefinition { } } - /// the signatures supported by the function `fun`. + /// The signatures supported by the function `fun`. pub fn signature(&self) -> Signature { match self { - WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), } @@ -735,31 +771,21 @@ impl WindowFunctionDefinition { /// Function's name for display pub fn name(&self) -> &str { match self { - WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(), WindowFunctionDefinition::WindowUDF(fun) => fun.name(), WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } } -impl fmt::Display for WindowFunctionDefinition { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - std::fmt::Display::fmt(fun, f) - } - WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Display::fmt(fun, f), - WindowFunctionDefinition::WindowUDF(fun) => std::fmt::Display::fmt(fun, f), + WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f), } } } -impl From for WindowFunctionDefinition { - fn from(value: BuiltInWindowFunction) -> Self { - Self::BuiltInWindowFunction(value) - } -} - impl From> for WindowFunctionDefinition { fn from(value: Arc) -> Self { Self::AggregateUDF(value) @@ -774,26 +800,16 @@ impl From> for WindowFunctionDefinition { /// Window function /// -/// Holds the actual actual function to call [`WindowFunction`] as well as its +/// Holds the actual function to call [`WindowFunction`] as well as its /// arguments (`args`) and the contents of the `OVER` clause: /// /// 1. `PARTITION BY` /// 2. `ORDER BY` /// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) /// -/// # Example -/// ``` -/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt}; -/// # use datafusion_expr::expr::WindowFunction; -/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) -/// let expr = Expr::WindowFunction( -/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")]) -/// ) -/// .partition_by(vec![col("b")]) -/// .order_by(vec![col("b").sort(true, true)]) -/// .build() -/// .unwrap(); -/// ``` +/// See [`ExprFunctionExt`] for examples of how to create a `WindowFunction`. +/// +/// [`ExprFunctionExt`]: crate::ExprFunctionExt #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct WindowFunction { /// Name of the function @@ -825,29 +841,10 @@ impl WindowFunction { } } -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = - built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) - { - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_function, - )) - } else { - None - } -} - /// EXISTS expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Exists { - /// subquery that will produce a single column of data + /// Subquery that will produce a single column of data pub subquery: Subquery, /// Whether the expression is negated pub negated: bool, @@ -1122,7 +1119,7 @@ impl Expr { } /// Returns a full and complete string representation of this expression. - #[deprecated(note = "use format! instead")] + #[deprecated(since = "42.0.0", note = "use format! instead")] pub fn canonical_name(&self) -> String { format!("{self}") } @@ -1291,7 +1288,7 @@ impl Expr { /// let expr = col("foo").alias("bar") + col("baz"); /// assert_eq!(expr.clone().unalias(), expr); /// - /// // `foo as "bar" as "baz" is unalaised to foo as "bar" + /// // `foo as "bar" as "baz" is unaliased to foo as "bar" /// let expr = col("foo").alias("bar").alias("baz"); /// assert_eq!(expr.unalias(), col("foo").alias("bar")); /// ``` @@ -1330,7 +1327,7 @@ impl Expr { expr, Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) ) { - // subqueries could contain aliases so don't recurse into those + // Subqueries could contain aliases so don't recurse into those TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue @@ -1347,7 +1344,7 @@ impl Expr { } }, ) - // unreachable code: internal closure doesn't return err + // Unreachable code: internal closure doesn't return err .unwrap() } @@ -1417,7 +1414,7 @@ impl Expr { )) } - /// return `self NOT BETWEEN low AND high` + /// Return `self NOT BETWEEN low AND high` pub fn not_between(self, low: Expr, high: Expr) -> Expr { Expr::Between(Between::new( Box::new(self), @@ -1558,13 +1555,13 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) - .unwrap() + .expect("exists closure is infallible") } - /// Return true when the expression contains out reference(correlated) expressions. + /// Return true if the expression contains out reference(correlated) expressions. pub fn contains_outer(&self) -> bool { self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. }))) - .unwrap() + .expect("exists closure is infallible") } /// Returns true if the expression node is volatile, i.e. whether it can return @@ -1578,16 +1575,26 @@ impl Expr { /// Returns true if the expression is volatile, i.e. whether it can return different /// results when evaluated multiple times with the same input. - pub fn is_volatile(&self) -> Result { + /// + /// For example the function call `RANDOM()` is volatile as each call will + /// return a different value. + /// + /// See [`Volatility`] for more information. + pub fn is_volatile(&self) -> bool { self.exists(|expr| Ok(expr.is_volatile_node())) + .expect("exists closure is infallible") } /// Recursively find all [`Expr::Placeholder`] expressions, and /// to infer their [`DataType`] from the context of their use. /// - /// For example, gicen an expression like ` = $0` will infer `$0` to + /// For example, given an expression like ` = $0` will infer `$0` to /// have type `int32`. - pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { + /// + /// Returns transformed expression and flag that is true if expression contains + /// at least one placeholder. + pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { + let mut has_placeholder = false; self.transform(|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { @@ -1604,9 +1611,13 @@ impl Expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } + if let Expr::Placeholder(_) = &expr { + has_placeholder = true; + } Ok(Transformed::yes(expr)) }) .data() + .map(|data| (data, has_placeholder)) } /// Returns true if some of this `exprs` subexpressions may not be evaluated @@ -1653,47 +1664,426 @@ impl Expr { | Expr::Placeholder(..) => false, } } +} - /// Hashes the direct content of an `Expr` without recursing into its children. - /// - /// This method is useful to incrementally compute hashes, such as in - /// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants - /// during the bottom-up phase of the first traversal and so avoid computing the hash - /// of the node and then the hash of its descendants separately. - /// - /// If a node doesn't have any children then this method is similar to `.hash()`, but - /// not necessarily returns the same value. - /// +impl Normalizeable for Expr { + fn can_normalize(&self) -> bool { + #[allow(clippy::match_like_matches_macro)] + match self { + Expr::BinaryExpr(BinaryExpr { + op: + _op @ (Operator::Plus + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::Eq + | Operator::NotEq), + .. + }) => true, + _ => false, + } + } +} + +impl NormalizeEq for Expr { + fn normalize_eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Expr::BinaryExpr(BinaryExpr { + left: self_left, + op: self_op, + right: self_right, + }), + Expr::BinaryExpr(BinaryExpr { + left: other_left, + op: other_op, + right: other_right, + }), + ) => { + if self_op != other_op { + return false; + } + + if matches!( + self_op, + Operator::Plus + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::Eq + | Operator::NotEq + ) { + (self_left.normalize_eq(other_left) + && self_right.normalize_eq(other_right)) + || (self_left.normalize_eq(other_right) + && self_right.normalize_eq(other_left)) + } else { + self_left.normalize_eq(other_left) + && self_right.normalize_eq(other_right) + } + } + ( + Expr::Alias(Alias { + expr: self_expr, + relation: self_relation, + name: self_name, + }), + Expr::Alias(Alias { + expr: other_expr, + relation: other_relation, + name: other_name, + }), + ) => { + self_name == other_name + && self_relation == other_relation + && self_expr.normalize_eq(other_expr) + } + ( + Expr::Like(Like { + negated: self_negated, + expr: self_expr, + pattern: self_pattern, + escape_char: self_escape_char, + case_insensitive: self_case_insensitive, + }), + Expr::Like(Like { + negated: other_negated, + expr: other_expr, + pattern: other_pattern, + escape_char: other_escape_char, + case_insensitive: other_case_insensitive, + }), + ) + | ( + Expr::SimilarTo(Like { + negated: self_negated, + expr: self_expr, + pattern: self_pattern, + escape_char: self_escape_char, + case_insensitive: self_case_insensitive, + }), + Expr::SimilarTo(Like { + negated: other_negated, + expr: other_expr, + pattern: other_pattern, + escape_char: other_escape_char, + case_insensitive: other_case_insensitive, + }), + ) => { + self_negated == other_negated + && self_escape_char == other_escape_char + && self_case_insensitive == other_case_insensitive + && self_expr.normalize_eq(other_expr) + && self_pattern.normalize_eq(other_pattern) + } + (Expr::Not(self_expr), Expr::Not(other_expr)) + | (Expr::IsNull(self_expr), Expr::IsNull(other_expr)) + | (Expr::IsTrue(self_expr), Expr::IsTrue(other_expr)) + | (Expr::IsFalse(self_expr), Expr::IsFalse(other_expr)) + | (Expr::IsUnknown(self_expr), Expr::IsUnknown(other_expr)) + | (Expr::IsNotNull(self_expr), Expr::IsNotNull(other_expr)) + | (Expr::IsNotTrue(self_expr), Expr::IsNotTrue(other_expr)) + | (Expr::IsNotFalse(self_expr), Expr::IsNotFalse(other_expr)) + | (Expr::IsNotUnknown(self_expr), Expr::IsNotUnknown(other_expr)) + | (Expr::Negative(self_expr), Expr::Negative(other_expr)) + | ( + Expr::Unnest(Unnest { expr: self_expr }), + Expr::Unnest(Unnest { expr: other_expr }), + ) => self_expr.normalize_eq(other_expr), + ( + Expr::Between(Between { + expr: self_expr, + negated: self_negated, + low: self_low, + high: self_high, + }), + Expr::Between(Between { + expr: other_expr, + negated: other_negated, + low: other_low, + high: other_high, + }), + ) => { + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_low.normalize_eq(other_low) + && self_high.normalize_eq(other_high) + } + ( + Expr::Cast(Cast { + expr: self_expr, + data_type: self_data_type, + }), + Expr::Cast(Cast { + expr: other_expr, + data_type: other_data_type, + }), + ) + | ( + Expr::TryCast(TryCast { + expr: self_expr, + data_type: self_data_type, + }), + Expr::TryCast(TryCast { + expr: other_expr, + data_type: other_data_type, + }), + ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ( + Expr::ScalarFunction(ScalarFunction { + func: self_func, + args: self_args, + }), + Expr::ScalarFunction(ScalarFunction { + func: other_func, + args: other_args, + }), + ) => { + self_func.name() == other_func.name() + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::AggregateFunction(AggregateFunction { + func: self_func, + args: self_args, + distinct: self_distinct, + filter: self_filter, + order_by: self_order_by, + null_treatment: self_null_treatment, + }), + Expr::AggregateFunction(AggregateFunction { + func: other_func, + args: other_args, + distinct: other_distinct, + filter: other_filter, + order_by: other_order_by, + null_treatment: other_null_treatment, + }), + ) => { + self_func.name() == other_func.name() + && self_distinct == other_distinct + && self_null_treatment == other_null_treatment + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && match (self_filter, other_filter) { + (Some(self_filter), Some(other_filter)) => { + self_filter.normalize_eq(other_filter) + } + (None, None) => true, + _ => false, + } + && match (self_order_by, other_order_by) { + (Some(self_order_by), Some(other_order_by)) => self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }), + (None, None) => true, + _ => false, + } + } + ( + Expr::WindowFunction(WindowFunction { + fun: self_fun, + args: self_args, + partition_by: self_partition_by, + order_by: self_order_by, + window_frame: self_window_frame, + null_treatment: self_null_treatment, + }), + Expr::WindowFunction(WindowFunction { + fun: other_fun, + args: other_args, + partition_by: other_partition_by, + order_by: other_order_by, + window_frame: other_window_frame, + null_treatment: other_null_treatment, + }), + ) => { + self_fun.name() == other_fun.name() + && self_window_frame == other_window_frame + && self_null_treatment == other_null_treatment + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && self_partition_by + .iter() + .zip(other_partition_by.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }) + } + ( + Expr::Exists(Exists { + subquery: self_subquery, + negated: self_negated, + }), + Expr::Exists(Exists { + subquery: other_subquery, + negated: other_negated, + }), + ) => { + self_negated == other_negated + && self_subquery.normalize_eq(other_subquery) + } + ( + Expr::InSubquery(InSubquery { + expr: self_expr, + subquery: self_subquery, + negated: self_negated, + }), + Expr::InSubquery(InSubquery { + expr: other_expr, + subquery: other_subquery, + negated: other_negated, + }), + ) => { + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_subquery.normalize_eq(other_subquery) + } + ( + Expr::ScalarSubquery(self_subquery), + Expr::ScalarSubquery(other_subquery), + ) => self_subquery.normalize_eq(other_subquery), + ( + Expr::GroupingSet(GroupingSet::Rollup(self_exprs)), + Expr::GroupingSet(GroupingSet::Rollup(other_exprs)), + ) + | ( + Expr::GroupingSet(GroupingSet::Cube(self_exprs)), + Expr::GroupingSet(GroupingSet::Cube(other_exprs)), + ) => { + self_exprs.len() == other_exprs.len() + && self_exprs + .iter() + .zip(other_exprs.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::GroupingSet(GroupingSet::GroupingSets(self_exprs)), + Expr::GroupingSet(GroupingSet::GroupingSets(other_exprs)), + ) => { + self_exprs.len() == other_exprs.len() + && self_exprs.iter().zip(other_exprs.iter()).all(|(a, b)| { + a.len() == b.len() + && a.iter().zip(b.iter()).all(|(x, y)| x.normalize_eq(y)) + }) + } + ( + Expr::InList(InList { + expr: self_expr, + list: self_list, + negated: self_negated, + }), + Expr::InList(InList { + expr: other_expr, + list: other_list, + negated: other_negated, + }), + ) => { + // TODO: normalize_eq for lists, for example `a IN (c1 + c3, c3)` is equal to `a IN (c3, c1 + c3)` + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_list.len() == other_list.len() + && self_list + .iter() + .zip(other_list.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::Case(Case { + expr: self_expr, + when_then_expr: self_when_then_expr, + else_expr: self_else_expr, + }), + Expr::Case(Case { + expr: other_expr, + when_then_expr: other_when_then_expr, + else_expr: other_else_expr, + }), + ) => { + // TODO: normalize_eq for when_then_expr + // for example `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END` is equal to `CASE a WHEN 3 THEN 4 WHEN 1 THEN 2 ELSE 5 END` + self_when_then_expr.len() == other_when_then_expr.len() + && self_when_then_expr + .iter() + .zip(other_when_then_expr.iter()) + .all(|((self_when, self_then), (other_when, other_then))| { + self_when.normalize_eq(other_when) + && self_then.normalize_eq(other_then) + }) + && match (self_expr, other_expr) { + (Some(self_expr), Some(other_expr)) => { + self_expr.normalize_eq(other_expr) + } + (None, None) => true, + (_, _) => false, + } + && match (self_else_expr, other_else_expr) { + (Some(self_else_expr), Some(other_else_expr)) => { + self_else_expr.normalize_eq(other_else_expr) + } + (None, None) => true, + (_, _) => false, + } + } + (_, _) => self == other, + } + } +} + +impl HashNode for Expr { /// As it is pretty easy to forget changing this method when `Expr` changes the /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes /// compile time. - pub fn hash_node(&self, hasher: &mut H) { - mem::discriminant(self).hash(hasher); + fn hash_node(&self, state: &mut H) { + mem::discriminant(self).hash(state); match self { Expr::Alias(Alias { expr: _expr, relation, name, }) => { - relation.hash(hasher); - name.hash(hasher); + relation.hash(state); + name.hash(state); } Expr::Column(column) => { - column.hash(hasher); + column.hash(state); } Expr::ScalarVariable(data_type, name) => { - data_type.hash(hasher); - name.hash(hasher); + data_type.hash(state); + name.hash(state); } Expr::Literal(scalar_value) => { - scalar_value.hash(hasher); + scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { left: _left, op, right: _right, }) => { - op.hash(hasher); + op.hash(state); } Expr::Like(Like { negated, @@ -1709,9 +2099,9 @@ impl Expr { escape_char, case_insensitive, }) => { - negated.hash(hasher); - escape_char.hash(hasher); - case_insensitive.hash(hasher); + negated.hash(state); + escape_char.hash(state); + case_insensitive.hash(state); } Expr::Not(_expr) | Expr::IsNotNull(_expr) @@ -1729,7 +2119,7 @@ impl Expr { low: _low, high: _high, }) => { - negated.hash(hasher); + negated.hash(state); } Expr::Case(Case { expr: _expr, @@ -1744,10 +2134,10 @@ impl Expr { expr: _expr, data_type, }) => { - data_type.hash(hasher); + data_type.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { - func.hash(hasher); + func.hash(state); } Expr::AggregateFunction(AggregateFunction { func, @@ -1757,9 +2147,9 @@ impl Expr { order_by: _order_by, null_treatment, }) => { - func.hash(hasher); - distinct.hash(hasher); - null_treatment.hash(hasher); + func.hash(state); + distinct.hash(state); + null_treatment.hash(state); } Expr::WindowFunction(WindowFunction { fun, @@ -1769,56 +2159,56 @@ impl Expr { window_frame, null_treatment, }) => { - fun.hash(hasher); - window_frame.hash(hasher); - null_treatment.hash(hasher); + fun.hash(state); + window_frame.hash(state); + null_treatment.hash(state); } Expr::InList(InList { expr: _expr, list: _list, negated, }) => { - negated.hash(hasher); + negated.hash(state); } Expr::Exists(Exists { subquery, negated }) => { - subquery.hash(hasher); - negated.hash(hasher); + subquery.hash(state); + negated.hash(state); } Expr::InSubquery(InSubquery { expr: _expr, subquery, negated, }) => { - subquery.hash(hasher); - negated.hash(hasher); + subquery.hash(state); + negated.hash(state); } Expr::ScalarSubquery(subquery) => { - subquery.hash(hasher); + subquery.hash(state); } Expr::Wildcard { qualifier, options } => { - qualifier.hash(hasher); - options.hash(hasher); + qualifier.hash(state); + options.hash(state); } Expr::GroupingSet(grouping_set) => { - mem::discriminant(grouping_set).hash(hasher); + mem::discriminant(grouping_set).hash(state); match grouping_set { GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} GroupingSet::GroupingSets(_exprs) => {} } } Expr::Placeholder(place_holder) => { - place_holder.hash(hasher); + place_holder.hash(state); } Expr::OuterReferenceColumn(data_type, column) => { - data_type.hash(hasher); - column.hash(hasher); + data_type.hash(state); + column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} }; } } -// modifies expr if it is a placeholder with datatype of right +// Modifies expr if it is a placeholder with datatype of right fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { if data_type.is_none() { @@ -1850,7 +2240,7 @@ macro_rules! expr_vec_fmt { } struct SchemaDisplay<'a>(&'a Expr); -impl<'a> Display for SchemaDisplay<'a> { +impl Display for SchemaDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.0 { // The same as Display @@ -1874,7 +2264,7 @@ impl<'a> Display for SchemaDisplay<'a> { "{}({}{})", func.name(), if *distinct { "DISTINCT " } else { "" }, - schema_name_from_exprs_comma_seperated_without_space(args)? + schema_name_from_exprs_comma_separated_without_space(args)? )?; if let Some(null_treatment) = null_treatment { @@ -1891,7 +2281,7 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - // expr is not shown since it is aliased + // Expr is not shown since it is aliased Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -1946,7 +2336,7 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "END") } - // cast expr is not shown to be consistant with Postgres and Spark + // Cast expr is not shown to be consistent with Postgres and Spark Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { write!(f, "{}", SchemaDisplay(expr)) } @@ -2076,7 +2466,7 @@ impl<'a> Display for SchemaDisplay<'a> { f, "{}({})", fun, - schema_name_from_exprs_comma_seperated_without_space(args)? + schema_name_from_exprs_comma_separated_without_space(args)? )?; if let Some(null_treatment) = null_treatment { @@ -2106,7 +2496,7 @@ impl<'a> Display for SchemaDisplay<'a> { /// Internal usage. Please call `schema_name_from_exprs` instead // TODO: Use ", " to standardize the formatting of Vec, // -pub(crate) fn schema_name_from_exprs_comma_seperated_without_space( +pub(crate) fn schema_name_from_exprs_comma_separated_without_space( exprs: &[Expr], ) -> Result { schema_name_from_exprs_inner(exprs, ",") @@ -2147,14 +2537,19 @@ pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result { Ok(s) } +pub const OUTER_REFERENCE_COLUMN_PREFIX: &str = "outer_ref"; +pub const UNNEST_COLUMN_PREFIX: &str = "UNNEST"; + /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. -impl fmt::Display for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for Expr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), - Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), + Expr::OuterReferenceColumn(_, c) => { + write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})") + } Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), Expr::Literal(v) => write!(f, "{v:?}"), Expr::Case(case) => { @@ -2209,7 +2604,7 @@ impl fmt::Display for Expr { Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } - // TODO: use udf's display_name, need to fix the seperator issue, + // TODO: use udf's display_name, need to fix the separator issue, // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } @@ -2347,14 +2742,14 @@ impl fmt::Display for Expr { }, Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"), Expr::Unnest(Unnest { expr }) => { - write!(f, "UNNEST({expr})") + write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } } } } fn fmt_function( - f: &mut fmt::Formatter, + f: &mut Formatter, fun: &str, distinct: bool, args: &[Expr], @@ -2387,7 +2782,7 @@ mod test { use crate::expr_fn::col; use crate::{ case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, - ScalarUDF, ScalarUDFImpl, Volatility, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, }; use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; @@ -2416,7 +2811,7 @@ mod test { let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{expr}")); - // note that CAST intentionally has a name that is different from its `Display` + // Note that CAST intentionally has a name that is different from its `Display` // representation. CAST does not change the name of expressions. assert_eq!("Float32(1.23)", expr.schema_name().to_string()); Ok(()) @@ -2516,7 +2911,10 @@ mod test { Ok(DataType::Utf8) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { Ok(ColumnarValue::from(ScalarValue::from("a"))) } } @@ -2537,157 +2935,6 @@ mod test { use super::*; - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64], &[true], "")?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = - fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = - fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[], &[], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[], &[], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16], &[true], "")?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - if fun.to_string() == "first_value" || fun.to_string() == "last_value" { - assert_eq!(fun.to_string(), name); - } else { - assert_eq!(fun.to_string(), name.to_uppercase()); - } - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } - #[test] fn test_display_wildcard() { assert_eq!(format!("{}", wildcard()), "*"); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4b30def27630f..ede7540e09579 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -27,8 +27,8 @@ use crate::function::{ }; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -39,6 +39,7 @@ use arrow::compute::kernels::cast_utils::{ use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; @@ -122,7 +123,7 @@ pub fn placeholder(id: impl Into) -> Expr { pub fn wildcard() -> Expr { Expr::Wildcard { qualifier: None, - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), } } @@ -130,7 +131,7 @@ pub fn wildcard() -> Expr { pub fn wildcard_with_options(options: WildcardOptions) -> Expr { Expr::Wildcard { qualifier: None, - options, + options: Box::new(options), } } @@ -147,7 +148,7 @@ pub fn wildcard_with_options(options: WildcardOptions) -> Expr { pub fn qualified_wildcard(qualifier: impl Into) -> Expr { Expr::Wildcard { qualifier: Some(qualifier.into()), - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), } } @@ -158,7 +159,7 @@ pub fn qualified_wildcard_with_options( ) -> Expr { Expr::Wildcard { qualifier: Some(qualifier.into()), - options, + options: Box::new(options), } } @@ -415,9 +416,10 @@ pub struct SimpleScalarUDF { impl Debug for SimpleScalarUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("ScalarUDF") + f.debug_struct("SimpleScalarUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } @@ -433,10 +435,24 @@ impl SimpleScalarUDF { volatility: Volatility, fun: ScalarFunctionImplementation, ) -> Self { - let name = name.into(); - let signature = Signature::exact(input_types, volatility); - Self { + Self::new_with_signature( name, + Signature::exact(input_types, volatility), + return_type, + fun, + ) + } + + /// Create a new `SimpleScalarUDF` from a name, signature, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + fun: ScalarFunctionImplementation, + ) -> Self { + Self { + name: name.into(), signature, return_type, fun, @@ -461,7 +477,11 @@ impl ScalarUDFImpl for SimpleScalarUDF { Ok(self.return_type.clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { (self.fun)(args) } } @@ -505,16 +525,17 @@ pub struct SimpleAggregateUDF { impl Debug for SimpleAggregateUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") + f.debug_struct("SimpleAggregateUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } } impl SimpleAggregateUDF { - /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility pub fn new( name: impl Into, @@ -535,6 +556,8 @@ impl SimpleAggregateUDF { } } + /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility pub fn new_with_signature( name: impl Into, signature: Signature, @@ -658,7 +681,10 @@ impl WindowUDFImpl for SimpleWindowUDF { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { (self.partition_evaluator_factory)() } @@ -697,7 +723,6 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// # use datafusion_expr::test::function_stub::count; /// # use sqlparser::ast::NullTreatment; /// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; -/// # use datafusion_expr::window_function::percent_rank; /// # // first_value is an aggregate function in another crate /// # fn first_value(_arg: Expr) -> Expr { /// unimplemented!() } @@ -717,6 +742,9 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// // Create a window expression for percent rank partitioned on column a /// // equivalent to: /// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// // percent_rank is an udwf function in another crate +/// # fn percent_rank() -> Expr { +/// unimplemented!() } /// let window = percent_rank() /// .partition_by(vec![col("a")]) /// .order_by(vec![col("b").sort(true, true)]) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index a304f43e6bee5..7e97e2c35cbd4 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -306,9 +306,16 @@ impl NamePreserver { /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan pub fn new(plan: &LogicalPlan) -> Self { Self { - // The schema of Filter and Join nodes comes from their inputs rather than their output expressions, - // so there is no need to use aliases to preserve expression names. - use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + // The expressions of these plans do not contribute to their output schema, + // so there is no need to preserve expression names to prevent a schema change. + use_alias: !matches!( + plan, + LogicalPlan::Filter(_) + | LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + ), } } @@ -454,10 +461,9 @@ mod test { normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[]) .unwrap_err() .strip_backtrace(); - assert_eq!( - error, - r#"Schema error: No field named b. Valid fields are "tableA".a."# - ); + let expected = "Schema error: No field named b. \ + Valid fields are \"tableA\".a."; + assert_eq!(error, expected); } #[test] diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 03c63e9b2de71..089092fe70bd4 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,34 +28,34 @@ use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, Column, ExprSchema, Result, - TableReference, + not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, + Result, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; use std::sync::Arc; -/// trait to allow expr to typable with respect to a schema +/// Trait to allow expr to typable with respect to a schema pub trait ExprSchemable { - /// given a schema, return the type of the expr + /// Given a schema, return the type of the expr fn get_type(&self, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the nullability of the expr + /// Given a schema, return the nullability of the expr fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; - /// given a schema, return the expr's optional metadata + /// Given a schema, return the expr's optional metadata fn metadata(&self, schema: &dyn ExprSchema) -> Result>; - /// convert to a field with respect to a schema + /// Convert to a field with respect to a schema fn to_field( &self, input_schema: &dyn ExprSchema, ) -> Result<(Option, Arc)>; - /// cast to a type with respect to a schema + /// Cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the type and nullability of the expr + /// Given a schema, return the type and nullability of the expr fn data_type_and_nullable(&self, schema: &dyn ExprSchema) -> Result<(DataType, bool)>; } @@ -99,6 +99,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { @@ -150,12 +151,15 @@ impl ExprSchemable for Expr { .map(|e| e.get_type(schema)) .collect::>>()?; - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) .map_err(|err| { plan_datafusion_err!( "{} {}", - err, + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, utils::generate_signature_error_msg( func.name(), func.signature().clone(), @@ -164,7 +168,7 @@ impl ExprSchemable for Expr { ) })?; - // perform additional function arguments validation (due to limited + // Perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) } @@ -180,7 +184,10 @@ impl ExprSchemable for Expr { .map_err(|err| { plan_datafusion_err!( "{} {}", - err, + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, utils::generate_signature_error_msg( func.name(), func.signature().clone(), @@ -213,17 +220,17 @@ impl ExprSchemable for Expr { }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { - data_type.clone().ok_or_else(|| { - plan_datafusion_err!( - "Placeholder type could not be resolved. Make sure that the \ - placeholder is bound to a concrete type, e.g. by providing \ - parameter values." - ) - }) + if let Some(dtype) = data_type { + Ok(dtype.clone()) + } else { + // If the placeholder's type hasn't been specified, treat it as + // null (unspecified placeholders generate an error during planning) + Ok(DataType::Null) + } } Expr::Wildcard { .. } => Ok(DataType::Null), Expr::GroupingSet(_) => { - // grouping sets do not really have a type and do not appear in projections + // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } } @@ -279,7 +286,7 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.value().is_null()), Expr::Case(case) => { - // this expression is nullable if any of the input expressions are nullable + // This expression is nullable if any of the input expressions are nullable let then_nullable = case .when_then_expr .iter() @@ -336,7 +343,7 @@ impl ExprSchemable for Expr { } Expr::Wildcard { .. } => Ok(false), Expr::GroupingSet(_) => { - // grouping sets do not really have the concept of nullable and do not appear + // Grouping sets do not really have the concept of nullable and do not appear // in projections Ok(true) } @@ -347,6 +354,7 @@ impl ExprSchemable for Expr { match self { Expr::Column(c) => Ok(schema.metadata(c)?.clone()), Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), + Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), _ => Ok(HashMap::new()), } } @@ -439,7 +447,7 @@ impl ExprSchemable for Expr { return Ok(self); } - // TODO(kszucs): most of the operations do not validate the type correctness + // TODO(kszucs): Most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? @@ -478,18 +486,15 @@ impl Expr { .map(|e| e.get_type(schema)) .collect::>>()?; match fun { - WindowFunctionDefinition::BuiltInWindowFunction(window_fun) => { - let return_type = window_fun.return_type(&data_types)?; - let nullable = - !["RANK", "NTILE", "CUME_DIST"].contains(&window_fun.name()); - Ok((return_type, nullable)) - } WindowFunctionDefinition::AggregateUDF(udaf) => { let new_types = data_types_with_aggregate_udf(&data_types, udaf) .map_err(|err| { plan_datafusion_err!( "{} {}", - err, + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, utils::generate_signature_error_msg( fun.name(), fun.signature(), @@ -508,7 +513,10 @@ impl Expr { data_types_with_window_udf(&data_types, udwf).map_err(|err| { plan_datafusion_err!( "{} {}", - err, + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, utils::generate_signature_error_msg( fun.name(), fun.signature(), @@ -526,7 +534,14 @@ impl Expr { } } -/// cast subquery in InSubquery/ScalarSubquery to a given type. +/// Cast subquery in InSubquery/ScalarSubquery to a given type. +/// +/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific +/// columns), it casts the first expression in the projection to the target type and creates a +/// new projection with the casted expression. +/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan +/// with the casted first column. +/// pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { return Ok(subquery); @@ -674,13 +689,11 @@ mod tests { .with_data_type(DataType::Int32) .with_metadata(meta.clone()); - // col and alias should be metadata-preserving + // col, alias, and cast should be metadata-preserving assert_eq!(meta, expr.metadata(&schema).unwrap()); assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap()); - - // cast should drop input metadata since the type has changed assert_eq!( - HashMap::new(), + meta, expr.clone() .cast_to(&DataType::Int64, &schema) .unwrap() diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 9814d16ddfa36..e0235d32292fa 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -27,7 +27,9 @@ pub use datafusion_functions_aggregate_common::accumulator::{ AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs, }; +pub use datafusion_functions_window_common::expr::ExpressionArgs; pub use datafusion_functions_window_common::field::WindowUDFFieldArgs; +pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; #[derive(Debug, Clone, Copy)] pub enum Hint { @@ -67,7 +69,7 @@ pub type StateTypeFunction = /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// -/// closure returns simplified [Expr] or an error. +/// Closure returns simplified [Expr] or an error. pub type AggregateFunctionSimplification = Box< dyn Fn( crate::expr::AggregateFunction, @@ -80,7 +82,7 @@ pub type AggregateFunctionSimplification = Box< /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// -/// closure returns simplified [Expr] or an error. +/// Closure returns simplified [Expr] or an error. pub type WindowFunctionSimplification = Box< dyn Fn( crate::expr::WindowFunction, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 014b271453ed3..ce4d087c69fff 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -27,14 +28,12 @@ //! //! The [expr_fn] module contains functions for creating expressions. -mod built_in_window_function; mod literal; mod operation; mod partition_evaluator; mod table_source; mod udaf; mod udf; -mod udf_docs; mod udwf; pub mod conditional_expressions; @@ -64,17 +63,17 @@ pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; -pub mod window_function; pub mod window_state; -pub use built_in_window_function::BuiltInWindowFunction; +pub use datafusion_doc::{DocSection, Documentation, DocumentationBuilder}; pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::scalar::Scalar; pub use datafusion_expr_common::signature::{ - ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, + ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass, Volatility, + TIMEZONE_WILDCARD, }; pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ @@ -95,8 +94,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; -pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl}; -pub use udf_docs::{DocSection, Documentation, DocumentationBuilder}; +pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index fea34c5c60d3b..4f55764aadd74 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; @@ -30,8 +31,8 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -40,22 +41,25 @@ use crate::utils::{ find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, + Statement, TableProviderFilterPushDown, TableSource, WriteOp, }; +use super::dml::InsertOp; +use super::plan::ColumnUnnestList; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, + exec_err, get_target_functional_dependencies, internal_err, not_impl_err, + plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, + FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; -use super::dml::InsertOp; -use super::plan::{ColumnUnnestList, ColumnUnnestType}; +use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -151,11 +155,11 @@ impl LogicalPlanBuilder { } // Ensure that the static term and the recursive term have the same number of fields let static_fields_len = self.plan.schema().fields().len(); - let recurive_fields_len = recursive_term.schema().fields().len(); - if static_fields_len != recurive_fields_len { + let recursive_fields_len = recursive_term.schema().fields().len(); + if static_fields_len != recursive_fields_len { return plan_err!( "Non-recursive term and recursive term must have the same number of columns ({} != {})", - static_fields_len, recurive_fields_len + static_fields_len, recursive_fields_len ); } // Ensure that the recursive term has the same field types as the static term @@ -173,16 +177,49 @@ impl LogicalPlanBuilder { /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. /// + /// so it's usually better to override the default names with a table alias list. + /// + /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. + pub fn values(values: Vec>) -> Result { + if values.is_empty() { + return plan_err!("Values list cannot be empty"); + } + let n_cols = values[0].len(); + if n_cols == 0 { + return plan_err!("Values list cannot be zero length"); + } + for (i, row) in values.iter().enumerate() { + if row.len() != n_cols { + return plan_err!( + "Inconsistent data length across values list: got {} values in row {} but expected {}", + row.len(), + i, + n_cols + ); + } + } + + // Infer from data itself + Self::infer_data(values) + } + + /// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming + /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) + /// documentation for more details. + /// /// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table. /// The column names are not specified by the SQL standard and different database systems do it differently, /// so it's usually better to override the default names with a table alias list. /// /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. - pub fn values(mut values: Vec>) -> Result { + pub fn values_with_schema( + values: Vec>, + schema: &DFSchemaRef, + ) -> Result { if values.is_empty() { return plan_err!("Values list cannot be empty"); } - let n_cols = values[0].len(); + let n_cols = schema.fields().len(); if n_cols == 0 { return plan_err!("Values list cannot be zero length"); } @@ -197,16 +234,53 @@ impl LogicalPlanBuilder { } } - let empty_schema = DFSchema::empty(); + // Check the type of value against the schema + Self::infer_values_from_schema(values, schema) + } + + fn infer_values_from_schema( + values: Vec>, + schema: &DFSchema, + ) -> Result { + let n_cols = values[0].len(); + let mut field_types: Vec = Vec::with_capacity(n_cols); + for j in 0..n_cols { + let field_type = schema.field(j).data_type(); + for row in values.iter() { + let value = &row[j]; + let data_type = value.get_type(schema)?; + + if !data_type.equals_datatype(field_type) { + if can_cast_types(&data_type, field_type) { + } else { + return exec_err!( + "type mismatch and can't cast to got {} and {}", + data_type, + field_type + ); + } + } + } + field_types.push(field_type.to_owned()); + } + + Self::infer_inner(values, &field_types, schema) + } + + fn infer_data(values: Vec>) -> Result { + let n_cols = values[0].len(); + let schema = DFSchema::empty(); + let mut field_types: Vec = Vec::with_capacity(n_cols); for j in 0..n_cols { let mut common_type: Option = None; for (i, row) in values.iter().enumerate() { let value = &row[j]; - let data_type = value.get_type(&empty_schema)?; + let data_type = value.get_type(&schema)?; if data_type == DataType::Null { continue; } + if let Some(prev_type) = common_type { // get common type of each column values. let data_types = vec![prev_type.clone(), data_type.clone()]; @@ -222,6 +296,15 @@ impl LogicalPlanBuilder { // since the code loop skips NULL field_types.push(common_type.unwrap_or(DataType::Null)); } + + Self::infer_inner(values, &field_types, &schema) + } + + fn infer_inner( + mut values: Vec>, + field_types: &[DataType], + schema: &DFSchema, + ) -> Result { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in field_types.iter().enumerate() { @@ -229,9 +312,7 @@ impl LogicalPlanBuilder { Expr::Literal(scalar) if scalar.value().is_null() => { Expr::from(ScalarValue::try_from(field_type)?) } - _ => { - std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)? - } + _ => std::mem::take(&mut row[j]).cast_to(field_type, schema)?, }; } } @@ -246,6 +327,7 @@ impl LogicalPlanBuilder { .collect::>(); let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); + Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) } @@ -420,11 +502,13 @@ impl LogicalPlanBuilder { /// Make a builder for a prepare logical plan from the builder's plan pub fn prepare(self, name: String, data_types: Vec) -> Result { - Ok(Self::new(LogicalPlan::Prepare(Prepare { - name, - data_types, - input: self.plan, - }))) + Ok(Self::new(LogicalPlan::Statement(Statement::Prepare( + Prepare { + name, + data_types, + input: self.plan, + }, + )))) } /// Limit the number of rows returned @@ -434,9 +518,22 @@ impl LogicalPlanBuilder { /// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows, /// if specified. pub fn limit(self, skip: usize, fetch: Option) -> Result { + let skip_expr = if skip == 0 { + None + } else { + Some(lit(skip as i64)) + }; + let fetch_expr = fetch.map(|f| lit(f as i64)); + self.limit_by_expr(skip_expr, fetch_expr) + } + + /// Limit the number of rows returned + /// + /// Similar to `limit` but uses expressions for `skip` and `fetch` + pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { Ok(Self::new(LogicalPlan::Limit(Limit { - skip, - fetch, + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), input: self.plan, }))) } @@ -476,7 +573,7 @@ impl LogicalPlanBuilder { /// See for more details fn add_missing_columns( curr_plan: LogicalPlan, - missing_cols: &[Column], + missing_cols: &IndexSet, is_distinct: bool, ) -> Result { match curr_plan { @@ -521,7 +618,7 @@ impl LogicalPlanBuilder { fn ambiguous_distinct_check( missing_exprs: &[Expr], - missing_cols: &[Column], + missing_cols: &IndexSet, projection_exprs: &[Expr], ) -> Result<()> { if missing_exprs.is_empty() { @@ -586,15 +683,16 @@ impl LogicalPlanBuilder { let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema - let mut missing_cols: Vec = vec![]; + let mut missing_cols: IndexSet = IndexSet::new(); sorts.iter().try_for_each::<_, Result<()>>(|sort| { let columns = sort.expr.column_refs(); - columns.into_iter().for_each(|c| { - if !schema.has_column(c) { - missing_cols.push(c.clone()); - } - }); + missing_cols.extend( + columns + .into_iter() + .filter(|c| !schema.has_column(c)) + .cloned(), + ); Ok(()) })?; @@ -683,7 +781,7 @@ impl LogicalPlanBuilder { self.join_detailed(right, join_type, join_keys, filter, false) } - /// Apply a join with using the specified expressions. + /// Apply a join using the specified expressions. /// /// Note that DataFusion automatically optimizes joins, including /// identifying and optimizing equality predicates. @@ -952,9 +1050,14 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin { + Ok(Self::new(LogicalPlan::Join(Join { left: self.plan, right: Arc::new(right), + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, schema: DFSchemaRef::new(join_schema), }))) } @@ -1101,12 +1204,20 @@ impl LogicalPlanBuilder { Ok(Arc::unwrap_or_clone(self.plan)) } - /// Apply a join with the expression on constraint. + /// Apply a join with both explicit equijoin and non equijoin predicates. + /// + /// Note this is a low level API that requires identifying specific + /// predicate types. Most users should use [`join_on`](Self::join_on) that + /// automatically identifies predicates appropriately. /// - /// equi_exprs are "equijoin" predicates expressions on the existing and right inputs, respectively. + /// `equi_exprs` defines equijoin predicates, of the form `l = r)` for each + /// `(l, r)` tuple. `l`, the first element of the tuple, must only refer + /// to columns from the existing input. `r`, the second element of the tuple, + /// must only refer to columns from the right input. /// - /// filter: any other filter expression to apply during the join. equi_exprs predicates are likely - /// to be evaluated more quickly than the filter expressions + /// `filter` contains any other other filter expression to apply during the + /// join. Note that `equi_exprs` predicates are evaluated more efficiently + /// than the filter expressions, so they are preferred. pub fn join_with_expr_keys( self, right: LogicalPlan, @@ -1121,25 +1232,24 @@ impl LogicalPlanBuilder { let join_key_pairs = equi_exprs .0 .into_iter() - .zip(equi_exprs.1.into_iter()) + .zip(equi_exprs.1) .map(|(l, r)| { let left_key = l.into(); let right_key = r.into(); - - let mut left_using_columns = HashSet::new(); + let mut left_using_columns = HashSet::new(); expr_to_columns(&left_key, &mut left_using_columns)?; let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check( left_key, - &[&[self.plan.schema(), right.schema()]], - &[left_using_columns], + &[&[self.plan.schema()]], + &[], )?; let mut right_using_columns = HashSet::new(); expr_to_columns(&right_key, &mut right_using_columns)?; let normalized_right_key = normalize_col_with_schemas_and_ambiguity_check( right_key, - &[&[self.plan.schema(), right.schema()]], - &[right_using_columns], + &[&[right.schema()]], + &[], )?; // find valid equijoin @@ -1183,7 +1293,7 @@ impl LogicalPlanBuilder { ) -> Result { unnest_with_options( Arc::unwrap_or_clone(self.plan), - vec![(column.into(), ColumnUnnestType::Inferred)], + vec![column.into()], options, ) .map(Self::new) @@ -1194,26 +1304,6 @@ impl LogicalPlanBuilder { self, columns: Vec, options: UnnestOptions, - ) -> Result { - unnest_with_options( - Arc::unwrap_or_clone(self.plan), - columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(), - options, - ) - .map(Self::new) - } - - /// Unnest the given columns with the given [`UnnestOptions`] - /// if one column is a list type, it can be recursively and simultaneously - /// unnested into the desired recursion levels - /// e.g select unnest(list_col,depth=1), unnest(list_col,depth=2) - pub fn unnest_columns_recursive_with_options( - self, - columns: Vec<(Column, ColumnUnnestType)>, - options: UnnestOptions, ) -> Result { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) @@ -1248,6 +1338,25 @@ pub fn change_redundant_column(fields: &Fields) -> Vec { }) .collect() } + +fn mark_field(schema: &DFSchema) -> (Option, Arc) { + let mut table_references = schema + .iter() + .filter_map(|(qualifier, _)| qualifier) + .collect::>(); + table_references.dedup(); + let table_reference = if table_references.len() == 1 { + table_references.pop().cloned() + } else { + None + }; + + ( + table_reference, + Arc::new(Field::new("mark", DataType::Boolean, false)), + ) +} + /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( @@ -1314,6 +1423,10 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } + JoinType::LeftMark => left_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(mark_field(right))) + .collect(), JoinType::RightSemi | JoinType::RightAnti => { // Only use the right side for the schema right_fields @@ -1326,8 +1439,12 @@ pub fn build_join_schema( join_type, left.fields().len(), ); - let mut metadata = left.metadata().clone(); - metadata.extend(right.metadata().clone()); + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } @@ -1402,9 +1519,23 @@ pub fn validate_unique_names<'a>( /// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union /// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { + if left_plan.schema().fields().len() != right_plan.schema().fields().len() { + return plan_err!( + "UNION queries have different number of columns: \ + left has {} columns whereas right has {} columns", + left_plan.schema().fields().len(), + right_plan.schema().fields().len() + ); + } + // Temporarily use the schema from the left input and later rely on the analyzer to // coerce the two schemas into a common one. - let schema = Arc::clone(left_plan.schema()); + + // Functional Dependencies doesn't preserve after UNION operation + let schema = (**left_plan.schema()).clone(); + let schema = + Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?); + Ok(LogicalPlan::Union(Union { inputs: vec![Arc::new(left_plan), Arc::new(right_plan)], schema, @@ -1505,7 +1636,7 @@ pub fn wrap_projection_for_join_if_necessary( .iter() .map(|key| { // The display_name() of cast expression will ignore the cast info, and show the inner expression name. - // If we do not add alais, it will throw same field name error in the schema when adding projection. + // If we do not add alias, it will throw same field name error in the schema when adding projection. // For example: // input scan : [a, b, c], // join keys: [cast(a as int)] @@ -1584,21 +1715,19 @@ impl TableSource for LogicalTableSource { fn supports_filters_pushdown( &self, filters: &[&Expr], - ) -> Result> { + ) -> Result> { Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) } } /// Create a [`LogicalPlan::Unnest`] plan pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { - let unnestings = columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(); - unnest_with_options(input, unnestings, UnnestOptions::default()) + unnest_with_options(input, columns, UnnestOptions::default()) } -pub fn get_unnested_list_datatype_recursive( +// Get the data type of a multi-dimensional type after unnesting it +// with a given depth +fn get_unnested_list_datatype_recursive( data_type: &DataType, depth: usize, ) -> Result { @@ -1617,27 +1746,6 @@ pub fn get_unnested_list_datatype_recursive( internal_err!("trying to unnest on invalid data type {:?}", data_type) } -/// Infer the unnest type based on the data type: -/// - list type: infer to unnest(list(col, depth=1)) -/// - struct type: infer to unnest(struct) -fn infer_unnest_type( - col_name: &String, - data_type: &DataType, -) -> Result { - match data_type { - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { - Ok(ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(col_name), - depth: 1, - }])) - } - DataType::Struct(_) => Ok(ColumnUnnestType::Struct), - _ => { - internal_err!("trying to unnest on invalid data type {:?}", data_type) - } - } -} - pub fn get_struct_unnested_columns( col_name: &String, inner_fields: &Fields, @@ -1669,7 +1777,7 @@ pub fn get_unnested_columns( let new_field = Arc::new(Field::new( col_name, data_type, // Unnesting may produce NULLs even if the list is not null. - // For example: unnset([1], []) -> 1, null + // For example: unnest([1], []) -> 1, null true, )); let column = Column::from_name(col_name); @@ -1726,20 +1834,15 @@ pub fn get_unnested_columns( /// ``` pub fn unnest_with_options( input: LogicalPlan, - columns_to_unnest: Vec<(Column, ColumnUnnestType)>, + columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; let mut struct_columns = vec![]; let indices_to_unnest = columns_to_unnest .iter() - .map(|col_unnesting| { - Ok(( - input.schema().index_of_column(&col_unnesting.0)?, - col_unnesting, - )) - }) - .collect::>>()?; + .map(|c| Ok((input.schema().index_of_column(c)?, c))) + .collect::>>()?; let input_schema = input.schema(); @@ -1764,51 +1867,59 @@ pub fn unnest_with_options( .enumerate() .map(|(index, (original_qualifier, original_field))| { match indices_to_unnest.get(&index) { - Some((column_to_unnest, unnest_type)) => { - let mut inferred_unnest_type = unnest_type.clone(); - if let ColumnUnnestType::Inferred = unnest_type { - inferred_unnest_type = infer_unnest_type( + Some(column_to_unnest) => { + let recursions_on_column = options + .recursions + .iter() + .filter(|p| -> bool { &p.input_column == *column_to_unnest }) + .collect::>(); + let mut transformed_columns = recursions_on_column + .iter() + .map(|r| { + list_columns.push(( + index, + ColumnUnnestList { + output_column: r.output_column.clone(), + depth: r.depth, + }, + )); + Ok(get_unnested_columns( + &r.output_column.name, + original_field.data_type(), + r.depth, + )? + .into_iter() + .next() + .unwrap()) // because unnesting a list column always result into one result + }) + .collect::)>>>()?; + if transformed_columns.is_empty() { + transformed_columns = get_unnested_columns( &column_to_unnest.name, original_field.data_type(), + 1, )?; - } - let transformed_columns: Vec<(Column, Arc)> = - match inferred_unnest_type { - ColumnUnnestType::Struct => { + match original_field.data_type() { + DataType::Struct(_) => { struct_columns.push(index); - get_unnested_columns( - &column_to_unnest.name, - original_field.data_type(), - 1, - )? } - ColumnUnnestType::List(unnest_lists) => { - list_columns.extend( - unnest_lists - .iter() - .map(|ul| (index, ul.to_owned().clone())), - ); - unnest_lists - .iter() - .map( - |ColumnUnnestList { - output_column, - depth, - }| { - get_unnested_columns( - &output_column.name, - original_field.data_type(), - *depth, - ) - }, - ) - .collect::)>>>>()? - .into_iter() - .flatten() - .collect::>() + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + list_columns.push(( + index, + ColumnUnnestList { + output_column: Column::from_name( + &column_to_unnest.name, + ), + depth: 1, + }, + )); } - _ => return internal_err!("Invalid unnest type"), + _ => {} }; + } + // new columns dependent on the same original index dependency_indices .extend(std::iter::repeat(index).take(transformed_columns.len())); @@ -1857,7 +1968,7 @@ mod tests { use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; - use datafusion_common::SchemaError; + use datafusion_common::{RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> { @@ -2265,24 +2376,19 @@ mod tests { // Simultaneously unnesting a list (with different depth) and a struct column let plan = nested_table_scan("test_table")? - .unnest_columns_recursive_with_options( - vec![ - ( - "stringss".into(), - ColumnUnnestType::List(vec![ - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_1"), - depth: 1, - }, - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_2"), - depth: 2, - }, - ]), - ), - ("struct_singular".into(), ColumnUnnestType::Inferred), - ], - UnnestOptions::default(), + .unnest_columns_with_options( + vec!["stringss".into(), "struct_singular".into()], + UnnestOptions::default() + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_1".into(), + depth: 1, + }) + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_2".into(), + depth: 2, + }), )? .build()?; @@ -2328,7 +2434,7 @@ mod tests { ], false, ); - let string_field = Field::new("item", DataType::Utf8, false); + let string_field = Field::new_list_field(DataType::Utf8, false); let strings_field = Field::new_list("item", string_field.clone(), false); let schema = Schema::new(vec![ Field::new("scalar", DataType::UInt32, false), diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 9aaa5c98037ac..bf8c5ad3d7017 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -26,7 +26,10 @@ use std::{ use crate::expr::Sort; use arrow::datatypes::DataType; -use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; +use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; +use datafusion_common::{ + Constraints, DFSchemaRef, Result, SchemaReference, TableReference, +}; use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation @@ -120,9 +123,9 @@ impl DdlStatement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a DdlStatement); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { DdlStatement::CreateExternalTable(CreateExternalTable { @@ -130,14 +133,22 @@ impl DdlStatement { constraints, .. }) => { - write!(f, "CreateExternalTable: {name:?}{constraints}") + if constraints.is_empty() { + write!(f, "CreateExternalTable: {name:?}") + } else { + write!(f, "CreateExternalTable: {name:?} {constraints}") + } } DdlStatement::CreateMemoryTable(CreateMemoryTable { name, constraints, .. }) => { - write!(f, "CreateMemoryTable: {name:?}{constraints}") + if constraints.is_empty() { + write!(f, "CreateMemoryTable: {name:?}") + } else { + write!(f, "CreateMemoryTable: {name:?} {constraints}") + } } DdlStatement::CreateView(CreateView { name, .. }) => { write!(f, "CreateView: {name:?}") @@ -202,6 +213,8 @@ pub struct CreateExternalTable { pub table_partition_cols: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Whether the table is a temporary table + pub temporary: bool, /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user @@ -298,6 +311,8 @@ pub struct CreateMemoryTable { pub or_replace: bool, /// Default values for columns pub column_defaults: Vec<(String, Expr)>, + /// Whether the table is `TableType::Temporary` + pub temporary: bool, } /// Creates a view. @@ -311,6 +326,8 @@ pub struct CreateView { pub or_replace: bool, /// SQL used to create the view, if available pub definition: Option, + /// Whether the view is ephemeral + pub temporary: bool, } /// Creates a catalog (aka "Database"). @@ -481,6 +498,28 @@ pub struct OperateFunctionArg { pub data_type: DataType, pub default_expr: Option, } + +impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.default_expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.default_expr.map_elements(f)?.map_data(|default_expr| { + Ok(Self { + default_expr, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct CreateFunctionBody { /// LANGUAGE lang_name @@ -491,6 +530,29 @@ pub struct CreateFunctionBody { pub function_body: Option, } +impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.function_body.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.function_body + .map_elements(f)? + .map_data(|function_body| { + Ok(Self { + function_body, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct DropFunction { pub name: String, diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 26d54803d4036..14758b61e859d 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + //! This module provides logic for displaying LogicalPlans in various styles use std::collections::HashMap; @@ -21,7 +22,7 @@ use std::fmt; use crate::{ expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, - Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, + Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, Unnest, Values, Window, }; @@ -58,7 +59,7 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } } -impl<'n, 'a, 'b> TreeNodeVisitor<'n> for IndentVisitor<'a, 'b> { +impl<'n> TreeNodeVisitor<'n> for IndentVisitor<'_, '_> { type Node = LogicalPlan; fn f_down( @@ -112,7 +113,7 @@ impl<'n, 'a, 'b> TreeNodeVisitor<'n> for IndentVisitor<'a, 'b> { pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ { struct Wrapper<'a>(&'a Schema); - impl<'a> fmt::Display for Wrapper<'a> { + impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "[")?; for (idx, field) in self.0.fields().iter().enumerate() { @@ -180,7 +181,7 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } } -impl<'n, 'a, 'b> TreeNodeVisitor<'n> for GraphvizVisitor<'a, 'b> { +impl<'n> TreeNodeVisitor<'n> for GraphvizVisitor<'_, '_> { type Node = LogicalPlan; fn f_down( @@ -504,11 +505,6 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Filter": format!("{}", filter_expr) }) } - LogicalPlan::CrossJoin(_) => { - json!({ - "Node Type": "Cross Join" - }) - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -549,11 +545,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let mut object = serde_json::json!( { "Node Type": "Limit", - "Skip": skip, } ); + if let Some(s) = skip { + object["Skip"] = s.to_string().into() + }; if let Some(f) = fetch { - object["Fetch"] = serde_json::Value::Number((*f).into()); + object["Fetch"] = f.to_string().into() }; object } @@ -620,15 +618,6 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Detail": format!("{:?}", e.node) }) } - LogicalPlan::Prepare(Prepare { - name, data_types, .. - }) => { - json!({ - "Node Type": "Prepare", - "Name": name, - "Data Types": format!("{:?}", data_types) - }) - } LogicalPlan::DescribeTable(DescribeTable { .. }) => { json!({ "Node Type": "DescribeTable" @@ -665,7 +654,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { } } -impl<'n, 'a, 'b> TreeNodeVisitor<'n> for PgJsonVisitor<'a, 'b> { +impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { type Node = LogicalPlan; fn f_down( diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 68b3ac41fa083..669bc8e8a7d34 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -165,7 +165,7 @@ impl WriteOp { } impl Display for WriteOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.name()) } } @@ -196,7 +196,7 @@ impl InsertOp { } impl Display for InsertOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.name()) } } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/expr/src/logical_plan/invariants.rs similarity index 69% rename from datafusion/optimizer/src/analyzer/subquery.rs rename to datafusion/expr/src/logical_plan/invariants.rs index aabc549de5837..bde4acaae5629 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -15,20 +15,99 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; - -use crate::analyzer::check_plan; -use crate::utils::collect_subquery_cols; - -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{plan_err, Result}; -use datafusion_expr::expr_rewriter::strip_outer_reference; -use datafusion_expr::utils::split_conjunction; -use datafusion_expr::{ - Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, - Window, +use datafusion_common::{ + internal_err, plan_err, + tree_node::{TreeNode, TreeNodeRecursion}, + DFSchemaRef, Result, }; +use crate::{ + expr::{Exists, InSubquery}, + expr_rewriter::strip_outer_reference, + utils::{collect_subquery_cols, split_conjunction}, + Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, +}; + +pub enum InvariantLevel { + /// Invariants that are always true in DataFusion `LogicalPlan`s + /// such as the number of expected children and no duplicated output fields + Always, + /// Invariants that must hold true for the plan to be "executable" + /// such as the type and number of function arguments are correct and + /// that wildcards have been expanded + /// + /// To ensure a LogicalPlan satisfies the `Executable` invariants, run the + /// `Analyzer` + Executable, +} + +pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { + // Refer to + assert_unique_field_names(plan)?; + + Ok(()) +} + +pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { + assert_always_invariants(plan)?; + assert_valid_semantic_plan(plan)?; + Ok(()) +} + +/// Returns an error if plan, and subplans, do not have unique fields. +/// +/// This invariant is subject to change. +/// refer: +fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> { + plan.schema().check_names() +} + +/// Returns an error if the plan is not sematically valid. +fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> { + assert_subqueries_are_valid(plan)?; + + Ok(()) +} + +/// Returns an error if the plan does not have the expected schema. +/// Ignores metadata and nullability. +pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> { + let equivalent = plan.schema().equivalent_names_and_types(schema); + + if !equivalent { + internal_err!( + "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", + schema, + plan.schema() + ) + } else { + Ok(()) + } +} + +/// Asserts that the subqueries are structured properly with valid node placement. +/// +/// Refer to [`check_subquery_expr`] for more details. +fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { + plan.apply_with_subqueries(|plan: &LogicalPlan| { + plan.apply_expressions(|expr| { + // recursively look for subqueries + expr.apply(|expr| { + match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, expr)?; + } + _ => {} + }; + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| ()) +} + /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, /// the allowed while list: [Projection, Filter, Window, Aggregate, Join]. @@ -41,7 +120,7 @@ pub fn check_subquery_expr( inner_plan: &LogicalPlan, expr: &Expr, ) -> Result<()> { - check_plan(inner_plan)?; + assert_subqueries_are_valid(inner_plan)?; if let Expr::ScalarSubquery(subquery) = expr { // Scalar subquery should only return one column if subquery.subquery.schema().fields().len() > 1 { @@ -83,7 +162,7 @@ pub fn check_subquery_expr( match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => { + LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( @@ -92,13 +171,13 @@ pub fn check_subquery_expr( } else { Ok(()) } - }, + } _ => plan_err!( "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" ) }?; } - check_correlations_in_subquery(inner_plan, true) + check_correlations_in_subquery(inner_plan) } else { if let Expr::InSubquery(subquery) = expr { // InSubquery should only return one column @@ -113,33 +192,29 @@ pub fn check_subquery_expr( match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) + | LogicalPlan::TableScan(_) | LogicalPlan::Window(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ - Projection, Filter, Window functions, Aggregate and Join plan nodes" + Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ + but was used in [{}]", + outer_plan.display() ), }?; - check_correlations_in_subquery(inner_plan, false) + check_correlations_in_subquery(inner_plan) } } // Recursively check the unsupported outer references in the sub query plan. -fn check_correlations_in_subquery( - inner_plan: &LogicalPlan, - is_scalar: bool, -) -> Result<()> { - check_inner_plan(inner_plan, is_scalar, false, true) +fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> { + check_inner_plan(inner_plan, true) } // Recursively check the unsupported outer references in the sub query plan. -fn check_inner_plan( - inner_plan: &LogicalPlan, - is_scalar: bool, - is_aggregate: bool, - can_contain_outer_ref: bool, -) -> Result<()> { +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] +fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> { if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); } @@ -147,32 +222,18 @@ fn check_inner_plan( match inner_plan { LogicalPlan::Aggregate(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { - let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) - .into_iter() - .partition(|e| e.contains_outer()); - let maybe_unsupported = correlated - .into_iter() - .filter(|expr| !can_pullup_over_aggregation(expr)) - .collect::>(); - if is_aggregate && is_scalar && !maybe_unsupported.is_empty() { - return plan_err!( - "Correlated column is not allowed in predicate: {predicate}" - ); - } - check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref) + LogicalPlan::Filter(Filter { input, .. }) => { + check_inner_plan(input, can_contain_outer_ref) } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -180,16 +241,16 @@ fn check_inner_plan( LogicalPlan::Projection(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) | LogicalPlan::EmptyRelation(_) | LogicalPlan::Limit(_) | LogicalPlan::Values(_) | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) => { + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Unnest(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -202,27 +263,25 @@ fn check_inner_plan( }) => match join_type { JoinType::Inner => { inner_plan.apply_children(|plan| { - check_inner_plan( - plan, - is_scalar, - is_aggregate, - can_contain_outer_ref, - )?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - check_inner_plan(left, is_scalar, is_aggregate, can_contain_outer_ref)?; - check_inner_plan(right, is_scalar, is_aggregate, false) + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => { + check_inner_plan(left, can_contain_outer_ref)?; + check_inner_plan(right, false) } JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - check_inner_plan(left, is_scalar, is_aggregate, false)?; - check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) + check_inner_plan(left, false)?; + check_inner_plan(right, can_contain_outer_ref) } JoinType::Full => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, false)?; + check_inner_plan(plan, false)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -291,34 +350,6 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { Ok(exprs) } -/// Check whether the expression can pull up over the aggregation without change the result of the query -fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr - { - match (left.deref(), right.deref()) { - (Expr::Column(_), right) => !right.any_column_refs(), - (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) => - { - !right.any_column_refs() - } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) => - { - !left.any_column_refs() - } - (_, _) => false, - } - } else { - false - } -} - /// Check whether the window expressions contain a mixture of out reference columns and inner columns fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { let mixed = window @@ -339,8 +370,8 @@ mod test { use std::cmp::Ordering; use std::sync::Arc; + use crate::{Extension, UserDefinedLogicalNodeCore}; use datafusion_common::{DFSchema, DFSchemaRef}; - use datafusion_expr::{Extension, UserDefinedLogicalNodeCore}; use super::*; @@ -364,7 +395,7 @@ mod test { vec![] } - fn schema(&self) -> &datafusion_common::DFSchemaRef { + fn schema(&self) -> &DFSchemaRef { &self.empty_schema } @@ -399,6 +430,6 @@ mod test { }), }); - check_inner_plan(&plan, false, false, true).unwrap(); + check_inner_plan(&plan, true).unwrap(); } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a189d4635e001..4049413786636 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -20,6 +20,8 @@ mod ddl; pub mod display; pub mod dml; mod extension; +pub(crate) mod invariants; +pub use invariants::{assert_expected_schema, check_subquery_expr, InvariantLevel}; mod plan; mod statement; pub mod tree_node; @@ -35,15 +37,15 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, ColumnUnnestType, CrossJoin, - DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, + projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, + RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ - SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, - TransactionIsolationLevel, TransactionStart, + Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, + TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, }; pub use display::display_schema; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6cd130ba7703b..100b72a8e55aa 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,9 +21,12 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; use super::dml::CopyTo; +use super::invariants::{ + assert_always_invariants, assert_executable_invariants, InvariantLevel, +}; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; @@ -39,23 +42,26 @@ use crate::utils::{ split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, - ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, - TableSource, WindowFunctionDefinition, + build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Execute, + Expr, ExprSchemable, LogicalPlanBuilder, Operator, Prepare, + TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::cse::{NormalizeEq, Normalizeable}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, +}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, + UnnestOptions, }; use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; -use crate::tree_node::replace_sort_expressions; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -209,10 +215,14 @@ pub enum LogicalPlan { /// Windows input based on a set of window spec and window /// function (e.g. SUM or RANK). This is used to implement SQL /// window functions, and the `OVER` clause. + /// + /// See [`Window`] for more details Window(Window), /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). This is used to implement SQL aggregates /// and `GROUP BY`. + /// + /// See [`Aggregate`] for more details Aggregate(Aggregate), /// Sorts its input according to a list of sort expressions. This /// is used to implement SQL `ORDER BY` @@ -220,9 +230,6 @@ pub enum LogicalPlan { /// Join two logical plans on one or more join columns. /// This is used to implement SQL `JOIN` Join(Join), - /// Apply Cross Join to two logical plans. - /// This is used to implement SQL `CROSS JOIN` - CrossJoin(CrossJoin), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an /// "exchange" operator in other systems @@ -265,9 +272,6 @@ pub enum LogicalPlan { /// Remove duplicate rows from the input. This is used to /// implement SQL `SELECT DISTINCT ...`. Distinct(Distinct), - /// Prepare a statement and find any bind parameters - /// (e.g. `?`). This is used to implement SQL-prepared statements. - Prepare(Prepare), /// Data Manipulation Language (DML): Insert / Update / Delete Dml(DmlStatement), /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS @@ -293,6 +297,22 @@ impl Default for LogicalPlan { } } +impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { @@ -310,13 +330,11 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), LogicalPlan::Join(Join { schema, .. }) => schema, - LogicalPlan::CrossJoin(CrossJoin { schema, .. }) => schema, LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(), LogicalPlan::Limit(Limit { input, .. }) => input.schema(), LogicalPlan::Statement(statement) => statement.schema(), LogicalPlan::Subquery(Subquery { subquery, .. }) => subquery.schema(), LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }) => schema, - LogicalPlan::Prepare(Prepare { input, .. }) => input.schema(), LogicalPlan::Explain(explain) => &explain.schema, LogicalPlan::Analyze(analyze) => &analyze.schema, LogicalPlan::Extension(extension) => extension.node.schema(), @@ -343,8 +361,7 @@ impl LogicalPlan { | LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => self + | LogicalPlan::Join(_) => self .inputs() .iter() .map(|input| input.schema().as_ref()) @@ -422,27 +439,6 @@ impl LogicalPlan { exprs } - #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> - where - F: FnMut(&Expr) -> Result<(), E>, - { - let mut err = Ok(()); - self.apply_expressions(|e| { - if let Err(e) = f(e) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err - } - /// Returns all inputs / children of this `LogicalPlan` node. /// /// Note does not include inputs to inputs, or subqueries. @@ -455,7 +451,6 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], LogicalPlan::Sort(Sort { input, .. }) => vec![input], LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => vec![left, right], LogicalPlan::Limit(Limit { input, .. }) => vec![input], LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], @@ -472,15 +467,14 @@ impl LogicalPlan { LogicalPlan::Copy(copy) => vec![©.input], LogicalPlan::Ddl(ddl) => ddl.inputs(), LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], - LogicalPlan::Prepare(Prepare { input, .. }) => vec![input], LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, recursive_term, .. }) => vec![static_term, recursive_term], + LogicalPlan::Statement(stmt) => stmt.inputs(), // plans without inputs LogicalPlan::TableScan { .. } - | LogicalPlan::Statement { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } | LogicalPlan::DescribeTable(_) => vec![], @@ -558,16 +552,11 @@ impl LogicalPlan { left.head_output_expr() } } - JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left.head_output_expr() + } JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, - LogicalPlan::CrossJoin(cross) => { - if cross.left.schema().fields().is_empty() { - cross.right.head_output_expr() - } else { - cross.left.head_output_expr() - } - } LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() } @@ -590,7 +579,6 @@ impl LogicalPlan { } LogicalPlan::Subquery(_) => Ok(None), LogicalPlan::EmptyRelation(_) - | LogicalPlan::Prepare(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Explain(_) @@ -693,20 +681,6 @@ impl LogicalPlan { null_equals_null, })) } - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema: _, - }) => { - let join_schema = - build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - - Ok(LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema: join_schema.into(), - })) - } LogicalPlan::Subquery(_) => Ok(self), LogicalPlan::SubqueryAlias(SubqueryAlias { input, @@ -756,7 +730,6 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(_) => Ok(self), LogicalPlan::Analyze(_) => Ok(self), LogicalPlan::Explain(_) => Ok(self), - LogicalPlan::Prepare(_) => Ok(self), LogicalPlan::TableScan(_) => Ok(self), LogicalPlan::EmptyRelation(_) => Ok(self), LogicalPlan::Statement(_) => Ok(self), @@ -905,7 +878,11 @@ impl LogicalPlan { }) => { let input = self.only_input(inputs)?; Ok(LogicalPlan::Sort(Sort { - expr: replace_sort_expressions(sort_expr.clone(), expr), + expr: expr + .into_iter() + .zip(sort_expr.iter()) + .map(|(expr, sort)| sort.with_expr(expr)) + .collect(), input: Arc::new(input), fetch: *fetch, })) @@ -957,11 +934,6 @@ impl LogicalPlan { null_equals_null: *null_equals_null, })) } - LogicalPlan::CrossJoin(_) => { - self.assert_no_expressions(expr)?; - let (left, right) = self.only_two_inputs(inputs)?; - LogicalPlanBuilder::from(left).cross_join(right)?.build() - } LogicalPlan::Subquery(Subquery { outer_ref_columns, .. }) => { @@ -980,11 +952,20 @@ impl LogicalPlan { .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { - self.assert_no_expressions(expr)?; + let old_expr_len = skip.iter().chain(fetch.iter()).count(); + if old_expr_len != expr.len() { + return internal_err!( + "Invalid number of new Limit expressions: expected {}, got {}", + old_expr_len, + expr.len() + ); + } + let new_skip = skip.as_ref().and_then(|_| expr.pop()); + let new_fetch = fetch.as_ref().and_then(|_| expr.pop()); let input = self.only_input(inputs)?; Ok(LogicalPlan::Limit(Limit { - skip: *skip, - fetch: *fetch, + skip: new_skip.map(Box::new), + fetch: new_fetch.map(Box::new), input: Arc::new(input), })) } @@ -993,6 +974,7 @@ impl LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, .. })) => { self.assert_no_expressions(expr)?; @@ -1005,6 +987,7 @@ impl LogicalPlan { if_not_exists: *if_not_exists, or_replace: *or_replace, column_defaults: column_defaults.clone(), + temporary: *temporary, }, ))) } @@ -1012,6 +995,7 @@ impl LogicalPlan { name, or_replace, definition, + temporary, .. })) => { self.assert_no_expressions(expr)?; @@ -1020,6 +1004,7 @@ impl LogicalPlan { input: Arc::new(input), name: name.clone(), or_replace: *or_replace, + temporary: *temporary, definition: definition.clone(), }))) } @@ -1098,16 +1083,25 @@ impl LogicalPlan { logical_optimization_succeeded: e.logical_optimization_succeeded, })) } - LogicalPlan::Prepare(Prepare { - name, data_types, .. - }) => { + LogicalPlan::Statement(Statement::Prepare(Prepare { + name, + data_types, + .. + })) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Prepare(Prepare { + Ok(LogicalPlan::Statement(Statement::Prepare(Prepare { name: name.clone(), data_types: data_types.clone(), input: Arc::new(input), - })) + }))) + } + LogicalPlan::Statement(Statement::Execute(Execute { name, .. })) => { + self.assert_no_inputs(inputs)?; + Ok(LogicalPlan::Statement(Statement::Execute(Execute { + name: name.clone(), + parameters: expr, + }))) } LogicalPlan::TableScan(ts) => { self.assert_no_inputs(inputs)?; @@ -1140,6 +1134,14 @@ impl LogicalPlan { } } + /// checks that the plan conforms to the listed invariant level, returning an Error if not + pub fn check_invariants(&self, check: InvariantLevel) -> Result<()> { + match check { + InvariantLevel::Always => assert_always_invariants(self), + InvariantLevel::Executable => assert_executable_invariants(self), + } + } + /// Helper for [Self::with_new_exprs] to use when no expressions are expected. #[inline] #[allow(clippy::needless_pass_by_value)] // expr is moved intentionally to ensure it's not used again @@ -1204,8 +1206,8 @@ impl LogicalPlan { /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`] /// with the specified `param_values`. /// - /// [`LogicalPlan::Prepare`] are - /// converted to their inner logical plan for execution. + /// [`Prepare`] statements are converted to + /// their inner logical plan for execution. /// /// # Example /// ``` @@ -1215,7 +1217,7 @@ impl LogicalPlan { /// # let schema = Schema::new(vec![ /// # Field::new("id", DataType::Int32, false), /// # ]); - /// // Build SELECT * FROM t1 WHRERE id = $1 + /// // Build SELECT * FROM t1 WHERE id = $1 /// let plan = table_scan(Some("t1"), &schema, None).unwrap() /// .filter(col("id").eq(placeholder("$1"))).unwrap() /// .build().unwrap(); @@ -1238,7 +1240,7 @@ impl LogicalPlan { /// ); /// /// // Note you can also used named parameters - /// // Build SELECT * FROM t1 WHRERE id = $my_param + /// // Build SELECT * FROM t1 WHERE id = $my_param /// let plan = table_scan(Some("t1"), &schema, None).unwrap() /// .filter(col("id").eq(placeholder("$my_param"))).unwrap() /// .build().unwrap() @@ -1262,13 +1264,17 @@ impl LogicalPlan { let plan_with_values = self.replace_params_with_values(¶m_values)?; // unwrap Prepare - Ok(if let LogicalPlan::Prepare(prepare_lp) = plan_with_values { - param_values.verify(&prepare_lp.data_types)?; - // try and take ownership of the input if is not shared, clone otherwise - Arc::unwrap_or_clone(prepare_lp.input) - } else { - plan_with_values - }) + Ok( + if let LogicalPlan::Statement(Statement::Prepare(prepare_lp)) = + plan_with_values + { + param_values.verify(&prepare_lp.data_types)?; + // try and take ownership of the input if is not shared, clone otherwise + Arc::unwrap_or_clone(prepare_lp.input) + } else { + plan_with_values + }, + ) } /// Returns the maximum number of rows that this plan can output, if known. @@ -1315,47 +1321,36 @@ impl LogicalPlan { join_type, .. }) => match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - match (left.max_rows(), right.max_rows()) { - (Some(left_max), Some(right_max)) => { - let min_rows = match join_type { - JoinType::Left => left_max, - JoinType::Right => right_max, - JoinType::Full => left_max + right_max, - _ => 0, - }; - Some((left_max * right_max).max(min_rows)) - } - _ => None, + JoinType::Inner => Some(left.max_rows()? * right.max_rows()?), + JoinType::Left | JoinType::Right | JoinType::Full => { + match (left.max_rows()?, right.max_rows()?, join_type) { + (0, 0, _) => Some(0), + (max_rows, 0, JoinType::Left | JoinType::Full) => Some(max_rows), + (0, max_rows, JoinType::Right | JoinType::Full) => Some(max_rows), + (left_max, right_max, _) => Some(left_max * right_max), } } - JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left.max_rows() + } JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - match (left.max_rows(), right.max_rows()) { - (Some(left_max), Some(right_max)) => Some(left_max * right_max), - _ => None, - } - } LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), - LogicalPlan::Union(Union { inputs, .. }) => inputs - .iter() - .map(|plan| plan.max_rows()) - .try_fold(0usize, |mut acc, input_max| { - if let Some(i_max) = input_max { - acc += i_max; - Some(acc) - } else { - None - } - }), + LogicalPlan::Union(Union { inputs, .. }) => { + inputs.iter().try_fold(0usize, |mut acc, plan| { + acc += plan.max_rows()?; + Some(acc) + }) + } LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, LogicalPlan::EmptyRelation(_) => Some(0), LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, + LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + Ok(FetchType::Literal(s)) => s, + _ => None, + }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -1367,7 +1362,6 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) | LogicalPlan::Statement(_) | LogicalPlan::Extension(_) => None, } @@ -1443,9 +1437,15 @@ impl LogicalPlan { let schema = Arc::clone(plan.schema()); let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|e| { - let original_name = name_preserver.save(&e); - let transformed_expr = - e.infer_placeholder_types(&schema)?.transform_up(|e| { + let (e, has_placeholder) = e.infer_placeholder_types(&schema)?; + if !has_placeholder { + // Performance optimization: + // avoid NamePreserver copy and second pass over expression + // if no placeholders. + Ok(Transformed::no(e)) + } else { + let original_name = name_preserver.save(&e); + let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { let value = param_values.get_placeholders_with_values(&id)?; Ok(Transformed::yes(Expr::from(value))) @@ -1453,13 +1453,30 @@ impl LogicalPlan { Ok(Transformed::no(e)) } })?; - // Preserve name to avoid breaking column references to this expression - Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) + // Preserve name to avoid breaking column references to this expression + Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) + } }) }) .map(|res| res.data) } + /// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names. + pub fn get_parameter_names(&self) -> Result> { + let mut param_names = HashSet::new(); + self.apply_with_subqueries(|plan| { + plan.apply_expressions(|expr| { + expr.apply(|expr| { + if let Expr::Placeholder(Placeholder { id, .. }) = expr { + param_names.insert(id.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| param_names) + } + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, @@ -1525,7 +1542,7 @@ impl LogicalPlan { // Boilerplate structure to wrap LogicalPlan with something // that that can be formatted struct Wrapper<'a>(&'a LogicalPlan); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = false; let mut visitor = IndentVisitor::new(f, with_schema); @@ -1568,7 +1585,7 @@ impl LogicalPlan { // Boilerplate structure to wrap LogicalPlan with something // that that can be formatted struct Wrapper<'a>(&'a LogicalPlan); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = true; let mut visitor = IndentVisitor::new(f, with_schema); @@ -1588,7 +1605,7 @@ impl LogicalPlan { // Boilerplate structure to wrap LogicalPlan with something // that that can be formatted struct Wrapper<'a>(&'a LogicalPlan); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = PgJsonVisitor::new(f); visitor.with_schema(true); @@ -1634,7 +1651,7 @@ impl LogicalPlan { // Boilerplate structure to wrap LogicalPlan with something // that that can be formatted struct Wrapper<'a>(&'a LogicalPlan); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = GraphvizVisitor::new(f); @@ -1685,7 +1702,7 @@ impl LogicalPlan { // Boilerplate structure to wrap LogicalPlan with something // that that can be formatted struct Wrapper<'a>(&'a LogicalPlan); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), @@ -1869,6 +1886,11 @@ impl LogicalPlan { .as_ref() .map(|expr| format!(" Filter: {expr}")) .unwrap_or_else(|| "".to_string()); + let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) { + "Cross".to_string() + } else { + join_type.to_string() + }; match join_constraint { JoinConstraint::On => { write!( @@ -1890,9 +1912,6 @@ impl LogicalPlan { } } } - LogicalPlan::CrossJoin(_) => { - write!(f, "CrossJoin:") - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -1920,16 +1939,20 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit(limit) => { + // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. + let skip_str = match limit.get_skip_type() { + Ok(SkipType::Literal(n)) => n.to_string(), + _ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()), + }; + let fetch_str = match limit.get_fetch_type() { + Ok(FetchType::Literal(Some(n))) => n.to_string(), + Ok(FetchType::Literal(None)) => "None".to_string(), + _ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()) + }; write!( f, - "Limit: skip={}, fetch={}", - skip, - fetch.map_or_else(|| "None".to_string(), |x| x.to_string()) + "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -1960,11 +1983,6 @@ impl LogicalPlan { LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), LogicalPlan::Extension(e) => e.node.fmt_for_explain(f), - LogicalPlan::Prepare(Prepare { - name, data_types, .. - }) => { - write!(f, "Prepare: {name:?} {data_types:?} ") - } LogicalPlan::DescribeTable(DescribeTable { .. }) => { write!(f, "DescribeTable") } @@ -2351,6 +2369,19 @@ impl Filter { } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) +/// +/// # Output Schema +/// +/// The output schema is the input schema followed by the window function +/// expressions, in order. +/// +/// For example, given the input schema `"A", "B", "C"` and the window function +/// `SUM(A) OVER (PARTITION BY B+1 ORDER BY C)`, the output schema will be `"A", +/// "B", "C", "SUM(A) OVER ..."` where `"SUM(A) OVER ..."` is the name of the +/// output column. +/// +/// Note that the `PARTITION BY` expression "B+1" is not produced in the output +/// schema. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Window { /// The incoming logical plan @@ -2594,28 +2625,7 @@ impl TableScan { } } -/// Apply Cross Join to two logical plans -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct CrossJoin { - /// Left input - pub left: Arc, - /// Right input - pub right: Arc, - /// The output schema, containing fields from the left and right inputs - pub schema: DFSchemaRef, -} - -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for CrossJoin { - fn partial_cmp(&self, other: &Self) -> Option { - match self.left.partial_cmp(&other.left) { - Some(Ordering::Equal) => self.right.partial_cmp(&other.right), - cmp => cmp, - } - } -} - -/// Repartition the plan based on a partitioning scheme. +// Repartition the plan based on a partitioning scheme. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Repartition { /// The incoming logical plan @@ -2640,18 +2650,6 @@ impl PartialOrd for Union { } } -/// Prepare a statement but do not execute it. Prepare statements can have 0 or more -/// `Expr::Placeholder` expressions that are filled in during execution -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub struct Prepare { - /// The name of the statement - pub name: String, - /// Data types of the parameters ([`Expr::Placeholder`]) - pub data_types: Vec, - /// The logical plan of the statements - pub input: Arc, -} - /// Describe the schema of table /// /// # Example output: @@ -2789,14 +2787,77 @@ impl PartialOrd for Extension { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { /// Number of rows to skip before fetch - pub skip: usize, + pub skip: Option>, /// Maximum number of rows to fetch, /// None means fetching all rows - pub fetch: Option, + pub fetch: Option>, /// The logical plan pub input: Arc, } +/// Different types of skip expression in Limit plan. +pub enum SkipType { + /// The skip expression is a literal value. + Literal(usize), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +/// Different types of fetch expression in Limit plan. +pub enum FetchType { + /// The fetch expression is a literal value. + /// `Literal(None)` means the fetch expression is not provided. + Literal(Option), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +impl Limit { + /// Get the skip type from the limit plan. + pub fn get_skip_type(&self) -> Result { + match self.skip.as_deref() { + Some(expr) => match expr { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Int64(s) => { + // `skip = NULL` is equivalent to `skip = 0` + let s = s.unwrap_or(0); + if s >= 0 { + Ok(SkipType::Literal(s as usize)) + } else { + plan_err!("OFFSET must be >=0, '{}' was provided", s) + } + } + _ => Ok(SkipType::UnsupportedExpr), + }, + _ => Ok(SkipType::UnsupportedExpr), + }, + // `skip = None` is equivalent to `skip = 0` + None => Ok(SkipType::Literal(0)), + } + } + + /// Get the fetch type from the limit plan. + pub fn get_fetch_type(&self) -> Result { + match self.fetch.as_deref() { + Some(expr) => match expr { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Int64(Some(s)) => { + if *s >= 0 { + Ok(FetchType::Literal(Some(*s as usize))) + } else { + plan_err!("LIMIT must be >= 0, '{}' was provided", s) + } + } + ScalarValue::Int64(None) => Ok(FetchType::Literal(None)), + _ => Ok(FetchType::UnsupportedExpr), + }, + _ => Ok(FetchType::UnsupportedExpr), + }, + None => Ok(FetchType::Literal(None)), + } + } +} + /// Removes duplicate rows from the input #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Distinct { @@ -2930,6 +2991,16 @@ impl PartialOrd for DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). +/// +/// # Output Schema +/// +/// The output schema is the group expressions followed by the aggregate +/// expressions in order. +/// +/// For example, given the input schema `"A", "B", "C"` and the aggregate +/// `SUM(A) GROUP BY C+B`, the output schema will be `"C+B", "SUM(A)"` where +/// "C+B" and "SUM(A)" are the names of the output columns. Note that "C+B" is a +/// single new column #[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] @@ -3031,12 +3102,12 @@ impl Aggregate { /// Get the output expressions. fn output_expressions(&self) -> Result> { - static INTERNAL_ID_EXPR: OnceLock = OnceLock::new(); + static INTERNAL_ID_EXPR: LazyLock = LazyLock::new(|| { + Expr::Column(Column::from_name(Aggregate::INTERNAL_GROUPING_ID)) + }); let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; if self.is_grouping_set() { - exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { - Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) - })); + exprs.push(&INTERNAL_ID_EXPR); } exprs.extend(self.aggr_expr.iter()); debug_assert!(exprs.len() == self.schema.fields().len()); @@ -3318,6 +3389,25 @@ pub struct Subquery { pub outer_ref_columns: Vec, } +impl Normalizeable for Subquery { + fn can_normalize(&self) -> bool { + false + } +} + +impl NormalizeEq for Subquery { + fn normalize_eq(&self, other: &Self) -> bool { + // TODO: may be implement NormalizeEq for LogicalPlan? + *self.subquery == *other.subquery + && self.outer_ref_columns.len() == other.outer_ref_columns.len() + && self + .outer_ref_columns + .iter() + .zip(other.outer_ref_columns.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } +} + impl Subquery { pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> { match plan { @@ -3357,39 +3447,6 @@ pub enum Partitioning { DistributeBy(Vec), } -/// Represents the unnesting operation on a column based on the context (a known struct -/// column, a list column, or let the planner infer the unnesting type). -/// -/// The inferred unnesting type works for both struct and list column, but the unnesting -/// will only be done once (depth = 1). In case recursion is needed on a multi-dimensional -/// list type, use [`ColumnUnnestList`] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] -pub enum ColumnUnnestType { - // Unnesting a list column, a vector of ColumnUnnestList is used because - // a column can be unnested at different levels, resulting different output columns - List(Vec), - // for struct, there can only be one unnest performed on one column at a time - Struct, - // Infer the unnest type based on column schema - // If column is a list column, the unnest depth will be 1 - // This value is to support sugar syntax of old api in Dataframe (unnest(either_list_or_struct_column)) - Inferred, -} - -impl fmt::Display for ColumnUnnestType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ColumnUnnestType::List(lists) => { - let list_strs: Vec = - lists.iter().map(|list| list.to_string()).collect(); - write!(f, "List([{}])", list_strs.join(", ")) - } - ColumnUnnestType::Struct => write!(f, "Struct"), - ColumnUnnestType::Inferred => write!(f, "Inferred"), - } - } -} - /// Represent the unnesting operation on a list column, such as the recursion depth and /// the output column name after unnesting /// @@ -3399,15 +3456,15 @@ impl fmt::Display for ColumnUnnestType { /// input output_name /// ┌─────────┐ ┌─────────┐ /// │{{1,2}} │ │ 1 │ -/// ├─────────┼─────►├─────────┤ -/// │{{3}} │ │ 2 │ -/// ├─────────┤ ├─────────┤ -/// │{{4},{5}}│ │ 3 │ -/// └─────────┘ ├─────────┤ -/// │ 4 │ -/// ├─────────┤ -/// │ 5 │ -/// └─────────┘ +/// ├─────────┼─────►├─────────┤ +/// │{{3}} │ │ 2 │ +/// ├─────────┤ ├─────────┤ +/// │{{4},{5}}│ │ 3 │ +/// └─────────┘ ├─────────┤ +/// │ 4 │ +/// ├─────────┤ +/// │ 5 │ +/// └─────────┘ /// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] pub struct ColumnUnnestList { @@ -3415,8 +3472,8 @@ pub struct ColumnUnnestList { pub depth: usize, } -impl fmt::Display for ColumnUnnestList { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Display for ColumnUnnestList { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}|depth={}", self.output_column, self.depth) } } @@ -3428,7 +3485,7 @@ pub struct Unnest { /// The incoming logical plan pub input: Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: Vec<(Column, ColumnUnnestType)>, + pub exec_columns: Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: Vec<(usize, ColumnUnnestList)>, @@ -3452,7 +3509,7 @@ impl PartialOrd for Unnest { /// The incoming logical plan pub input: &'a Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: &'a Vec<(Column, ColumnUnnestType)>, + pub exec_columns: &'a Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: &'a Vec<(usize, ColumnUnnestList)>, @@ -3491,9 +3548,13 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{ + col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, + }; - use datafusion_common::tree_node::{TransformedResult, TreeNodeVisitor}; + use datafusion_common::tree_node::{ + TransformedResult, TreeNodeRewriter, TreeNodeVisitor, + }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; use crate::test::function_stub::count; @@ -3606,7 +3667,7 @@ digraph { "#; // just test for a few key lines in the output rather than the - // whole thing to make test mainteance easier. + // whole thing to make test maintenance easier. let graphviz = format!("{}", plan.display_graphviz()); assert_eq!(expected_graphviz, graphviz); @@ -4133,4 +4194,163 @@ digraph { ); assert_eq!(describe_table.partial_cmp(&describe_table_clone), None); } + + #[test] + fn test_limit_with_new_children() { + let limit = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(Expr::from( + ScalarValue::new_ten(&DataType::UInt32).unwrap(), + ))), + input: Arc::new(LogicalPlan::Values(Values { + schema: Arc::new(DFSchema::empty()), + values: vec![vec![]], + })), + }); + let new_limit = limit + .with_new_exprs( + limit.expressions(), + limit.inputs().into_iter().cloned().collect(), + ) + .unwrap(); + assert_eq!(limit, new_limit); + } + + #[test] + fn test_with_subqueries_jump() { + // The test plan contains a `Project` node above a `Filter` node, and the + // `Project` node contains a subquery plan with a `Filter` root node, so returning + // `TreeNodeRecursion::Jump` on `Project` should cause not visiting any of the + // `Filter`s. + let subquery_schema = + Schema::new(vec![Field::new("sub_id", DataType::Int32, false)]); + + let subquery_plan = + table_scan(TableReference::none(), &subquery_schema, Some(vec![0])) + .unwrap() + .filter(col("sub_id").eq(lit(0))) + .unwrap() + .build() + .unwrap(); + + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = table_scan(TableReference::none(), &schema, Some(vec![0])) + .unwrap() + .filter(col("id").eq(lit(0))) + .unwrap() + .project(vec![col("id"), scalar_subquery(Arc::new(subquery_plan))]) + .unwrap() + .build() + .unwrap(); + + let mut filter_found = false; + plan.apply_with_subqueries(|plan| { + match plan { + LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump), + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + assert!(!filter_found); + + struct ProjectJumpVisitor { + filter_found: bool, + } + + impl ProjectJumpVisitor { + fn new() -> Self { + Self { + filter_found: false, + } + } + } + + impl<'n> TreeNodeVisitor<'n> for ProjectJumpVisitor { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + match node { + LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump), + LogicalPlan::Filter(..) => self.filter_found = true, + _ => {} + } + Ok(TreeNodeRecursion::Continue) + } + } + + let mut visitor = ProjectJumpVisitor::new(); + plan.visit_with_subqueries(&mut visitor).unwrap(); + assert!(!visitor.filter_found); + + let mut filter_found = false; + plan.clone() + .transform_down_with_subqueries(|plan| { + match plan { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(Transformed::no(plan)) + }) + .unwrap(); + assert!(!filter_found); + + let mut filter_found = false; + plan.clone() + .transform_down_up_with_subqueries( + |plan| { + match plan { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new( + plan, + false, + TreeNodeRecursion::Jump, + )) + } + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(Transformed::no(plan)) + }, + |plan| Ok(Transformed::no(plan)), + ) + .unwrap(); + assert!(!filter_found); + + struct ProjectJumpRewriter { + filter_found: bool, + } + + impl ProjectJumpRewriter { + fn new() -> Self { + Self { + filter_found: false, + } + } + } + + impl TreeNodeRewriter for ProjectJumpRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, node: Self::Node) -> Result> { + match node { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + } + LogicalPlan::Filter(..) => self.filter_found = true, + _ => {} + } + Ok(Transformed::no(node)) + } + } + + let mut rewriter = ProjectJumpRewriter::new(); + plan.rewrite_with_subqueries(&mut rewriter).unwrap(); + assert!(!rewriter.filter_found); + } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index ed06375157c94..93be04c275647 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -15,17 +15,20 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::DFSchemaRef; -use std::cmp::Ordering; +use arrow::datatypes::DataType; +use datafusion_common::{DFSchema, DFSchemaRef}; use std::fmt::{self, Display}; +use std::sync::{Arc, LazyLock}; + +use crate::{expr_vec_fmt, Expr, LogicalPlan}; /// Various types of Statements. /// /// # Transactions: /// /// While DataFusion does not offer support transactions, it provides -/// [`LogicalPlan`](crate::LogicalPlan) support to assist building -/// database systems using DataFusion +/// [`LogicalPlan`] support to assist building database systems +/// using DataFusion #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Statement { // Begin a transaction @@ -34,16 +37,24 @@ pub enum Statement { TransactionEnd(TransactionEnd), /// Set a Variable SetVariable(SetVariable), + /// Prepare a statement and find any bind parameters + /// (e.g. `?`). This is used to implement SQL-prepared statements. + Prepare(Prepare), + /// Execute a prepared statement. This is used to implement SQL 'EXECUTE'. + Execute(Execute), + /// Deallocate a prepared statement. + /// This is used to implement SQL 'DEALLOCATE'. + Deallocate(Deallocate), } impl Statement { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { - match self { - Statement::TransactionStart(TransactionStart { schema, .. }) => schema, - Statement::TransactionEnd(TransactionEnd { schema, .. }) => schema, - Statement::SetVariable(SetVariable { schema, .. }) => schema, - } + // Statements have an unchanging empty schema. + static STATEMENT_EMPTY_SCHEMA: LazyLock = + LazyLock::new(|| Arc::new(DFSchema::empty())); + + &STATEMENT_EMPTY_SCHEMA } /// Return a descriptive string describing the type of this @@ -53,6 +64,17 @@ impl Statement { Statement::TransactionStart(_) => "TransactionStart", Statement::TransactionEnd(_) => "TransactionEnd", Statement::SetVariable(_) => "SetVariable", + Statement::Prepare(_) => "Prepare", + Statement::Execute(_) => "Execute", + Statement::Deallocate(_) => "Deallocate", + } + } + + /// Returns input LogicalPlans in the current `Statement`. + pub(super) fn inputs(&self) -> Vec<&LogicalPlan> { + match self { + Statement::Prepare(Prepare { input, .. }) => vec![input.as_ref()], + _ => vec![], } } @@ -61,9 +83,9 @@ impl Statement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a Statement); - impl<'a> Display for Wrapper<'a> { + impl Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { Statement::TransactionStart(TransactionStart { @@ -85,6 +107,24 @@ impl Statement { }) => { write!(f, "SetVariable: set {variable:?} to {value:?}") } + Statement::Prepare(Prepare { + name, data_types, .. + }) => { + write!(f, "Prepare: {name:?} {data_types:?} ") + } + Statement::Execute(Execute { + name, parameters, .. + }) => { + write!( + f, + "Execute: {} params=[{}]", + name, + expr_vec_fmt!(parameters) + ) + } + Statement::Deallocate(Deallocate { name }) => { + write!(f, "Deallocate: {}", name) + } } } } @@ -116,67 +156,57 @@ pub enum TransactionIsolationLevel { } /// Indicator that the following statements should be committed or rolled back atomically -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct TransactionStart { /// indicates if transaction is allowed to write pub access_mode: TransactionAccessMode, // indicates ANSI isolation level pub isolation_level: TransactionIsolationLevel, - /// Empty schema - pub schema: DFSchemaRef, -} - -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for TransactionStart { - fn partial_cmp(&self, other: &Self) -> Option { - match self.access_mode.partial_cmp(&other.access_mode) { - Some(Ordering::Equal) => { - self.isolation_level.partial_cmp(&other.isolation_level) - } - cmp => cmp, - } - } } /// Indicator that any current transaction should be terminated -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct TransactionEnd { /// whether the transaction committed or aborted pub conclusion: TransactionConclusion, /// if specified a new transaction is immediately started with same characteristics pub chain: bool, - /// Empty schema - pub schema: DFSchemaRef, -} - -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for TransactionEnd { - fn partial_cmp(&self, other: &Self) -> Option { - match self.conclusion.partial_cmp(&other.conclusion) { - Some(Ordering::Equal) => self.chain.partial_cmp(&other.chain), - cmp => cmp, - } - } } /// Set a Variable's value -- value in /// [`ConfigOptions`](datafusion_common::config::ConfigOptions) -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct SetVariable { /// The variable name pub variable: String, /// The value to set pub value: String, - /// Dummy schema - pub schema: DFSchemaRef, } -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for SetVariable { - fn partial_cmp(&self, other: &Self) -> Option { - match self.variable.partial_cmp(&other.value) { - Some(Ordering::Equal) => self.value.partial_cmp(&other.value), - cmp => cmp, - } - } +/// Prepare a statement but do not execute it. Prepare statements can have 0 or more +/// `Expr::Placeholder` expressions that are filled in during execution +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct Prepare { + /// The name of the statement + pub name: String, + /// Data types of the parameters ([`Expr::Placeholder`]) + pub data_types: Vec, + /// The logical plan of the statements + pub input: Arc, +} + +/// Execute a prepared statement. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +pub struct Execute { + /// The name of the prepared statement to execute + pub name: String, + /// The execute parameters + pub parameters: Vec, +} + +/// Deallocate a prepared statement. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +pub struct Deallocate { + /// The name of the prepared statement to deallocate + pub name: String, } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 8ba68697bd4d7..9a6103afd4b41 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -36,31 +36,29 @@ //! (Re)creation APIs (these require substantial cloning and thus are slow): //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions + use crate::{ - dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, - DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, - Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, + dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, + Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, + Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, + Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; -use std::sync::Arc; +use datafusion_common::tree_node::TreeNodeRefContainer; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; -use datafusion_common::{ - internal_err, map_until_stop_and_collect, DataFusionError, Result, + Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; +use datafusion_common::{internal_err, Result}; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, ) -> Result { - self.inputs().into_iter().apply_until_stop(f) + self.inputs().apply_ref_elements(f) } /// Applies `f` to each child (input) of this plan node, rewriting them *in place.* @@ -73,14 +71,14 @@ impl TreeNode for LogicalPlan { /// [`Expr::Exists`]: crate::Expr::Exists fn map_children Result>>( self, - mut f: F, + f: F, ) -> Result> { Ok(match self { LogicalPlan::Projection(Projection { expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Projection(Projection { expr, input, @@ -91,7 +89,7 @@ impl TreeNode for LogicalPlan { predicate, input, having, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Filter(Filter { predicate, input, @@ -101,7 +99,7 @@ impl TreeNode for LogicalPlan { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -111,7 +109,7 @@ impl TreeNode for LogicalPlan { input, window_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Window(Window { input, window_expr, @@ -123,7 +121,7 @@ impl TreeNode for LogicalPlan { group_expr, aggr_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Aggregate(Aggregate { input, group_expr, @@ -131,7 +129,8 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => rewrite_arc(input, f)? + LogicalPlan::Sort(Sort { expr, input, fetch }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Join(Join { left, @@ -142,12 +141,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { + }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, right, @@ -159,28 +153,13 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) - }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? + LogicalPlan::Limit(Limit { skip, fetch, input }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, - }) => rewrite_arc(subquery, f)?.update_data(|subquery| { + }) => subquery.map_elements(f)?.update_data(|subquery| { LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, @@ -190,7 +169,7 @@ impl TreeNode for LogicalPlan { input, alias, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, @@ -199,17 +178,18 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? .update_data(LogicalPlan::Extension), - LogicalPlan::Union(Union { inputs, schema }) => rewrite_arcs(inputs, f)? + LogicalPlan::Union(Union { inputs, schema }) => inputs + .map_elements(f)? .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), LogicalPlan::Distinct(distinct) => match distinct { - Distinct::All(input) => rewrite_arc(input, f)?.update_data(Distinct::All), + Distinct::All(input) => input.map_elements(f)?.update_data(Distinct::All), Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { Distinct::On(DistinctOn { on_expr, select_expr, @@ -226,7 +206,7 @@ impl TreeNode for LogicalPlan { stringified_plans, schema, logical_optimization_succeeded, - }) => rewrite_arc(plan, f)?.update_data(|plan| { + }) => plan.map_elements(f)?.update_data(|plan| { LogicalPlan::Explain(Explain { verbose, plan, @@ -239,7 +219,7 @@ impl TreeNode for LogicalPlan { verbose, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Analyze(Analyze { verbose, input, @@ -252,7 +232,7 @@ impl TreeNode for LogicalPlan { op, input, output_schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, table_schema, @@ -267,7 +247,7 @@ impl TreeNode for LogicalPlan { partition_by, file_type, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, @@ -285,7 +265,8 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, - }) => rewrite_arc(input, f)?.update_data(|input| { + temporary, + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, constraints, @@ -293,6 +274,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) }), DdlStatement::CreateView(CreateView { @@ -300,12 +282,14 @@ impl TreeNode for LogicalPlan { input, or_replace, definition, - }) => rewrite_arc(input, f)?.update_data(|input| { + temporary, + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, or_replace, definition, + temporary, }) }), // no inputs in these statements @@ -329,7 +313,7 @@ impl TreeNode for LogicalPlan { dependency_indices, schema, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Unnest(Unnest { input, exec_columns: input_columns, @@ -340,38 +324,31 @@ impl TreeNode for LogicalPlan { options, }) }), - LogicalPlan::Prepare(Prepare { - name, - data_types, - input, - }) => rewrite_arc(input, f)?.update_data(|input| { - LogicalPlan::Prepare(Prepare { - name, - data_types, - input, - }) - }), LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct, - }) => map_until_stop_and_collect!( - rewrite_arc(static_term, &mut f), - recursive_term, - rewrite_arc(recursive_term, &mut f) - )? - .update_data(|(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) - }), + }) => (static_term, recursive_term).map_elements(f)?.update_data( + |(static_term, recursive_term)| { + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }) + }, + ), + LogicalPlan::Statement(stmt) => match stmt { + Statement::Prepare(p) => p + .input + .map_elements(f)? + .update_data(|input| Statement::Prepare(Prepare { input, ..p })), + _ => Transformed::no(stmt), + } + .update_data(LogicalPlan::Statement), // plans without inputs LogicalPlan::TableScan { .. } - | LogicalPlan::Statement { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } | LogicalPlan::DescribeTable(_) => Transformed::no(self), @@ -379,24 +356,6 @@ impl TreeNode for LogicalPlan { } } -/// Applies `f` to rewrite a `Arc` without copying, if possible -fn rewrite_arc Result>>( - plan: Arc, - mut f: F, -) -> Result>> { - f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) -} - -/// rewrite a `Vec` of `Arc` without copying, if possible -fn rewrite_arcs Result>>( - input_plans: Vec>, - mut f: F, -) -> Result>>> { - input_plans - .into_iter() - .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f)) -} - /// Rewrites all inputs for an Extension node "in place" /// (it currently has to copy values because there are no APIs for in place modification) /// @@ -425,8 +384,10 @@ fn rewrite_extension_inputs Result {{ $F_DOWN? - .transform_children(|n| n.map_subqueries($F_CHILD))? - .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_children(|n| { + n.map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD)) + })? .transform_parent($F_UP) }}; } @@ -443,82 +404,70 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().apply_until_stop(f) - } - LogicalPlan::Values(Values { values, .. }) => values - .iter() - .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), + LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), + LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { - expr.iter().apply_until_stop(f) + expr.apply_elements(f) } Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().apply_until_stop(f) + window_expr.apply_elements(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr - .iter() - .chain(aggr_expr.iter()) - .apply_until_stop(f), + }) => (group_expr, aggr_expr).apply_ref_elements(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { - on.iter() - // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... - // it not ideal to create an expr here to analyze them, but could cache it on the Join itself - .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .apply_until_stop(|e| f(&e))? - .visit_sibling(|| filter.iter().apply_until_stop(f)) - } - LogicalPlan::Sort(Sort { expr, .. }) => { - expr.iter().apply_until_stop(|sort| f(&sort.expr)) + (on, filter).apply_ref_elements(f) } + LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().apply_until_stop(f) + extension.node.expressions().apply_elements(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().apply_until_stop(f) + filters.apply_elements(f) } LogicalPlan::Unnest(unnest) => { let columns = unnest.exec_columns.clone(); let exprs = columns .iter() - .map(|(c, _)| Expr::Column(c.clone())) + .map(|c| Expr::Column(c.clone())) .collect::>(); - exprs.iter().apply_until_stop(f) + exprs.apply_elements(f) } LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, .. - })) => on_expr - .iter() - .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) - .apply_until_stop(f), + })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + (skip, fetch).apply_ref_elements(f) + } + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(Execute { parameters, .. }) => { + parameters.apply_elements(f) + } + _ => Ok(TreeNodeRecursion::Continue), + }, // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) @@ -526,8 +475,7 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), + | LogicalPlan::DescribeTable(_) => Ok(TreeNodeRecursion::Continue), } } @@ -547,21 +495,15 @@ impl LogicalPlan { expr, input, schema, - }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) - }), + }) => expr.map_elements(f)?.update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), LogicalPlan::Values(Values { schema, values }) => values - .into_iter() - .map_until_stop_and_collect(|value| { - value.into_iter().map_until_stop_and_collect(&mut f) - })? + .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), LogicalPlan::Filter(Filter { predicate, @@ -579,12 +521,10 @@ impl LogicalPlan { partitioning_scheme, }) => match partitioning_scheme { Partitioning::Hash(expr, usize) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|expr| Partitioning::Hash(expr, usize)), Partitioning::DistributeBy(expr) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(Partitioning::DistributeBy), Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } @@ -598,34 +538,28 @@ impl LogicalPlan { input, window_expr, schema, - }) => window_expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|window_expr| { - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) - }), + }) => window_expr.map_elements(f)?.update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), LogicalPlan::Aggregate(Aggregate { input, group_expr, aggr_expr, schema, - }) => map_until_stop_and_collect!( - group_expr.into_iter().map_until_stop_and_collect(&mut f), - aggr_expr, - aggr_expr.into_iter().map_until_stop_and_collect(&mut f) - )? - .update_data(|(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) - }), + }) => (group_expr, aggr_expr).map_elements(f)?.update_data( + |(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }, + ), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. @@ -639,16 +573,7 @@ impl LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - on.into_iter().map_until_stop_and_collect( - |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) - ), - filter, - filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(on, filter)| { + }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, right, @@ -660,17 +585,13 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - transform_sort_vec(expr, &mut f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) - } + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .map_elements(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - let exprs = node - .expressions() - .into_iter() - .map_until_stop_and_collect(f)?; + let exprs = node.expressions().map_elements(f)?; let plan = LogicalPlan::Extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), @@ -687,50 +608,53 @@ impl LogicalPlan { projected_schema, filters, fetch, - }) => filters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|filters| { - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) - }), + }) => filters.map_elements(f)?.update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - })) => map_until_stop_and_collect!( - on_expr.into_iter().map_until_stop_and_collect(&mut f), - select_expr, - select_expr.into_iter().map_until_stop_and_collect(&mut f), - sort_expr, - transform_sort_option_vec(sort_expr, &mut f) - )? - .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) - }), + })) => (on_expr, select_expr, sort_expr) + .map_elements(f)? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) + }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { skip, fetch, input }) + }) + } + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(e) => { + e.parameters.map_elements(f)?.update_data(|parameters| { + Statement::Execute(Execute { parameters, ..e }) + }) + } + _ => Transformed::no(stmt), + } + .update_data(LogicalPlan::Statement), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) @@ -738,13 +662,13 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Transformed::no(self), + | LogicalPlan::DescribeTable(_) => Transformed::no(self), }) } /// Visits a plan similarly to [`Self::visit`], including subqueries that /// may appear in expressions such as `IN (SELECT ...)`. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn visit_with_subqueries TreeNodeVisitor<'n, Node = Self>>( &self, visitor: &mut V, @@ -752,15 +676,18 @@ impl LogicalPlan { visitor .f_down(self)? .visit_children(|| { - self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + self.apply_subqueries(|c| c.visit_with_subqueries(visitor))? + .visit_sibling(|| { + self.apply_children(|c| c.visit_with_subqueries(visitor)) + }) })? - .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? .visit_parent(|| visitor.f_up(self)) } /// Similarly to [`Self::rewrite`], rewrites this node and its inputs using `f`, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn rewrite_with_subqueries>( self, rewriter: &mut R, @@ -779,19 +706,19 @@ impl LogicalPlan { &self, mut f: F, ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn apply_with_subqueries_impl< F: FnMut(&LogicalPlan) -> Result, >( node: &LogicalPlan, f: &mut F, ) -> Result { - f(node)? - .visit_children(|| { - node.apply_subqueries(|c| apply_with_subqueries_impl(c, f)) - })? - .visit_sibling(|| { - node.apply_children(|c| apply_with_subqueries_impl(c, f)) - }) + f(node)?.visit_children(|| { + node.apply_subqueries(|c| apply_with_subqueries_impl(c, f))? + .visit_sibling(|| { + node.apply_children(|c| apply_with_subqueries_impl(c, f)) + }) + }) } apply_with_subqueries_impl(self, &mut f) @@ -814,19 +741,19 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn transform_down_with_subqueries_impl< F: FnMut(LogicalPlan) -> Result>, >( node: LogicalPlan, f: &mut F, ) -> Result> { - f(node)? - .transform_children(|n| { - n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f)) - })? - .transform_sibling(|n| { - n.map_children(|c| transform_down_with_subqueries_impl(c, f)) - }) + f(node)?.transform_children(|n| { + n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f))? + .transform_sibling(|n| { + n.map_children(|c| transform_down_with_subqueries_impl(c, f)) + }) + }) } transform_down_with_subqueries_impl(self, &mut f) @@ -839,6 +766,7 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn transform_up_with_subqueries_impl< F: FnMut(LogicalPlan) -> Result>, >( @@ -866,6 +794,7 @@ impl LogicalPlan { mut f_down: FD, mut f_up: FU, ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn transform_down_up_with_subqueries_impl< FD: FnMut(LogicalPlan) -> Result>, FU: FnMut(LogicalPlan) -> Result>, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 7dd7360e478f2..42047e8e6caa2 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -25,6 +25,7 @@ use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, }; +use sqlparser::ast; use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; @@ -66,6 +67,11 @@ pub trait ContextProvider { &[] } + /// Getter for the data type planner + fn get_type_planner(&self) -> Option> { + None + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -216,7 +222,7 @@ pub trait ExprPlanner: Debug + Send + Sync { /// custom expressions. #[derive(Debug, Clone)] pub struct RawBinaryExpr { - pub op: sqlparser::ast::BinaryOperator, + pub op: ast::BinaryOperator, pub left: Expr, pub right: Expr, } @@ -249,3 +255,13 @@ pub enum PlannerResult { /// The raw expression could not be planned, and is returned unmodified Original(T), } + +/// This trait allows users to customize the behavior of the data type planning +pub trait TypePlanner: Debug + Send + Sync { + /// Plan SQL type to DataFusion data type + /// + /// Returns None if not possible + fn plan_type(&self, _sql_type: &ast::DataType) -> Result> { + Ok(None) + } +} diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 6d3457f70d4c7..4eb49710bcf85 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -20,8 +20,8 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; -use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; -use std::collections::{HashMap, HashSet}; +use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; +use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index a55cb49b1f402..467ce8bf53e2d 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -29,10 +29,10 @@ use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; /// information in without having to create `DFSchema` objects. If you /// have a [`DFSchemaRef`] you can use [`SimplifyContext`] pub trait SimplifyInfo { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result; - /// returns true of this expr is nullable (could possibly be NULL) + /// Returns true of this expr is nullable (could possibly be NULL) fn nullable(&self, expr: &Expr) -> Result; /// Returns details needed for partial expression evaluation @@ -71,8 +71,8 @@ impl<'a> SimplifyContext<'a> { } } -impl<'a> SimplifyInfo for SimplifyContext<'a> { - /// returns true if this Expr has boolean type +impl SimplifyInfo for SimplifyContext<'_> { + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result { if let Some(schema) = &self.schema { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -113,7 +113,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), - /// the function call could not be simplified, and the arguments + /// The function call could not be simplified, and the arguments /// are return unmodified. Original(Vec), } diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index bdb602d48dee5..d62484153f530 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -55,7 +55,7 @@ pub enum TableProviderFilterPushDown { pub enum TableType { /// An ordinary physical table. Base, - /// A non-materialised table that itself uses a query internally to provide data. + /// A non-materialized table that itself uses a query internally to provide data. View, /// A transient table. Temporary, @@ -99,7 +99,7 @@ pub trait TableSource: Sync + Send { } /// Tests whether the table provider can make use of any or all filter expressions - /// to optimise data retrieval. + /// to optimize data retrieval. Only non-volatile expressions are passed to this function. fn supports_filters_pushdown( &self, filters: &[&Expr], diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b4f768085fcc3..71ab1ad6ef9b2 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -34,25 +34,19 @@ use crate::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::AggregateOrderSensitivity, Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, - Volatility, }; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { paste::paste! { - /// Singleton instance of [$UDAF], ensures the UDAF is only created once - /// named STATIC_$(UDAF). For example `STATIC_FirstValue` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDAF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDAF >] - .get_or_init(|| { + // Singleton instance of [$UDAF], ensures the UDAF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default())) - }) - .clone() + }); + std::sync::Arc::clone(&INSTANCE) } } } @@ -106,7 +100,7 @@ pub struct Sum { impl Sum { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::user_defined(Immutable), } } } @@ -236,13 +230,13 @@ impl Count { pub fn new() -> Self { Self { aliases: vec!["count".to_string()], - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Count { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -258,6 +252,10 @@ impl AggregateUDFImpl for Count { Ok(DataType::Int64) } + fn is_nullable(&self) -> bool { + false + } + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -318,13 +316,13 @@ impl Default for Min { impl Min { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Min { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -403,13 +401,13 @@ impl Default for Max { impl Max { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Max { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index c7c498dd3f017..eacace5ed0461 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -15,27 +15,36 @@ // specific language governing permissions and limitations // under the License. -//! Tree node implementation for logical expr +//! Tree node implementation for Logical Expressions use crate::expr::{ AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }; -use datafusion_common::{map_until_stop_and_collect, Result}; +use datafusion_common::Result; +/// Implementation of the [`TreeNode`] trait +/// +/// This allows logical expressions (`Expr`) to be traversed and transformed +/// Facilitates tasks such as optimization and rewriting during query +/// planning. impl TreeNode for Expr { + /// Applies a function `f` to each child expression of `self`. + /// + /// The function `f` determines whether to continue traversing the tree or to stop. + /// This method collects all child expressions and applies `f` to each. fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, ) -> Result { - let children = match self { - Expr::Alias(Alias{expr,..}) - | Expr::Unnest(Unnest{expr}) + match self { + Expr::Alias(Alias { expr, .. }) + | Expr::Unnest(Unnest { expr }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -48,80 +57,56 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], + | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.iter().collect() + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), + Expr::ScalarFunction(ScalarFunction { args, .. }) => { + args.apply_elements(f) } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.iter().flatten().collect() + lists_of_exprs.apply_elements(f) } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) - | Expr::Exists {..} + | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref(), right.as_ref()] + (left, right).apply_ref_elements(f) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref(), pattern.as_ref()] + (expr, pattern).apply_ref_elements(f) } Expr::Between(Between { - expr, low, high, .. - }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()], - Expr::Case(case) => { - let mut expr_vec = vec![]; - if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref()); - }; - for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref()); - expr_vec.push(then.as_ref()); - } - if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref()); - } - expr_vec - } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.iter().collect::>(); - if let Some(f) = filter { - expr_vec.push(f.as_ref()); - } - if let Some(order_by) = order_by { - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - } - expr_vec - } + expr, low, high, .. + }) => (expr, low, high).apply_ref_elements(f), + Expr::Case(Case { expr, when_then_expr, else_expr }) => + (expr, when_then_expr, else_expr).apply_ref_elements(f), + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.iter().collect::>(); - expr_vec.extend(partition_by); - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - expr_vec + args, + partition_by, + order_by, + .. + }) => { + (args, partition_by, order_by).apply_ref_elements(f) } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![expr.as_ref()]; - expr_vec.extend(list); - expr_vec + (expr, list).apply_ref_elements(f) } - }; - - children.into_iter().apply_until_stop(f) + } } + /// Maps each child of `self` using the provided closure `f`. + /// + /// The closure `f` takes ownership of an expression and returns a `Transformed` result, + /// indicating whether the expression was transformed or left unchanged. fn map_children Result>>( self, mut f: F, @@ -135,137 +120,103 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_) => Transformed::no(self), - Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))), + Expr::Unnest(Unnest { expr, .. }) => expr + .map_elements(f)? + .update_data(|expr| Expr::Unnest(Unnest { expr })), Expr::Alias(Alias { expr, relation, name, - }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), + }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => transform_box(expr, &mut f)?.update_data(|be| { + }) => expr.map_elements(f)?.update_data(|be| { Expr::InSubquery(InSubquery::new(be, subquery, negated)) }), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - map_until_stop_and_collect!( - transform_box(left, &mut f), - right, - transform_box(right, &mut f) - )? + Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right) + .map_elements(f)? .update_data(|(new_left, new_right)| { Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) - }) - } + }), Expr::Like(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) + } Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), - Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), - Expr::IsNotNull(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) - } - Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), - Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), - Expr::IsFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsFalse) - } - Expr::IsUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) - } - Expr::IsNotTrue(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) - } - Expr::IsNotFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) } + Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not), + Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull), + Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse), + Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown), + Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue), + Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse), Expr::IsNotUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) - } - Expr::Negative(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::Negative) + expr.map_elements(f)?.update_data(Expr::IsNotUnknown) } + Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative), Expr::Between(Between { expr, negated, low, high, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - low, - transform_box(low, &mut f), - high, - transform_box(high, &mut f) - )? - .update_data(|(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) - }), + }) => (expr, low, high).map_elements(f)?.update_data( + |(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }, + ), Expr::Case(Case { expr, when_then_expr, else_expr, - }) => map_until_stop_and_collect!( - transform_option_box(expr, &mut f), - when_then_expr, - when_then_expr - .into_iter() - .map_until_stop_and_collect(|(when, then)| { - map_until_stop_and_collect!( - transform_box(when, &mut f), - then, - transform_box(then, &mut f) - ) - }), - else_expr, - transform_option_box(else_expr, &mut f) - )? - .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) - }), - Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + }) => (expr, when_then_expr, else_expr) + .map_elements(f)? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + Expr::TryCast(TryCast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::ScalarFunction(ScalarFunction { func, args }) => { - transform_vec(args, &mut f)?.map_data(|new_args| { + args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( func, new_args, ))) @@ -278,22 +229,17 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - partition_by, - transform_vec(partition_by, &mut f), - order_by, - transform_sort_vec(order_by, &mut f) - )? - .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }), + }) => (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }, + ), Expr::AggregateFunction(AggregateFunction { args, func, @@ -301,31 +247,27 @@ impl TreeNode for Expr { filter, order_by, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - filter, - transform_option_box(filter, &mut f), - order_by, - transform_sort_option_vec(order_by, &mut f) - )? - .map_data(|(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - func, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - })?, + }) => (args, filter, order_by).map_elements(f)?.map_data( + |(new_args, new_filter, new_order_by)| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + new_args, + distinct, + new_filter, + new_order_by, + null_treatment, + ))) + }, + )?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Rollup(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), - GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Cube(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs - .into_iter() - .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .map_elements(f)? .update_data(|new_lists_of_exprs| { Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) }), @@ -334,86 +276,11 @@ impl TreeNode for Expr { expr, list, negated, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - list, - transform_vec(list, &mut f) - )? - .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) - }), + }) => (expr, list) + .map_elements(f)? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), }) } } - -fn transform_box Result>>( - be: Box, - f: &mut F, -) -> Result>> { - Ok(f(*be)?.update_data(Box::new)) -} - -fn transform_option_box Result>>( - obe: Option>, - f: &mut F, -) -> Result>>> { - obe.map_or(Ok(Transformed::no(None)), |be| { - Ok(transform_box(be, f)?.update_data(Some)) - }) -} - -/// &mut transform a Option<`Vec` of `Expr`s> -pub fn transform_option_vec Result>>( - ove: Option>, - f: &mut F, -) -> Result>>> { - ove.map_or(Ok(Transformed::no(None)), |ve| { - Ok(transform_vec(ve, f)?.update_data(Some)) - }) -} - -/// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>>( - ve: Vec, - f: &mut F, -) -> Result>> { - ve.into_iter().map_until_stop_and_collect(f) -} - -pub fn transform_sort_option_vec Result>>( - sorts_option: Option>, - f: &mut F, -) -> Result>>> { - sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { - Ok(transform_sort_vec(sorts, f)?.update_data(Some)) - }) -} - -pub fn transform_sort_vec Result>>( - sorts: Vec, - mut f: &mut F, -) -> Result>> { - Ok(sorts - .iter() - .map(|sort| sort.expr.clone()) - .map_until_stop_and_collect(&mut f)? - .update_data(|transformed_exprs| { - replace_sort_expressions(sorts, transformed_exprs) - })) -} - -pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { - assert_eq!(sorts.len(), new_expr.len()); - sorts - .into_iter() - .zip(new_expr) - .map(|(sort, expr)| replace_sort_expression(sort, expr)) - .collect() -} - -pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { - Sort { - expr: new_expr, - ..sort - } -} diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 143e00fa409ea..650619e6de4c1 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,13 +21,19 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; +use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, - utils::{coerced_fixed_size_list_to_list, list_ndims}, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, + types::{LogicalType, NativeType}, + utils::list_ndims, Result, }; use datafusion_expr_common::{ - signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, + signature::{ + ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, + TIMEZONE_WILDCARD, + }, + type_coercion::binary::comparison_coercion_numeric, type_coercion::binary::string_coercion, }; use std::sync::Arc; @@ -44,17 +50,21 @@ pub fn data_types_with_scalar_udf( func: &ScalarUDF, ) -> Result> { let signature = func.signature(); + let type_signature = &signature.type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); + } else if type_signature.used_to_support_zero_arguments() { + // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 + return plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name()); } else { return plan_err!("{} does not support zero arguments.", func.name()); } } let valid_types = - get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?; + get_valid_types_with_scalar_udf(type_signature, current_types, func)?; if valid_types .iter() @@ -63,7 +73,7 @@ pub fn data_types_with_scalar_udf( return Ok(current_types.to_vec()); } - try_coerce_types(valid_types, current_types, &signature.type_signature) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for aggregate function arguments. @@ -78,20 +88,21 @@ pub fn data_types_with_aggregate_udf( func: &AggregateUDF, ) -> Result> { let signature = func.signature(); + let type_signature = &signature.type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); + } else if type_signature.used_to_support_zero_arguments() { + // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 + return plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name()); } else { return plan_err!("{} does not support zero arguments.", func.name()); } } - let valid_types = get_valid_types_with_aggregate_udf( - &signature.type_signature, - current_types, - func, - )?; + let valid_types = + get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; if valid_types .iter() .any(|data_type| data_type == current_types) @@ -99,7 +110,7 @@ pub fn data_types_with_aggregate_udf( return Ok(current_types.to_vec()); } - try_coerce_types(valid_types, current_types, &signature.type_signature) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for window function arguments. @@ -114,17 +125,21 @@ pub fn data_types_with_window_udf( func: &WindowUDF, ) -> Result> { let signature = func.signature(); + let type_signature = &signature.type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); + } else if type_signature.used_to_support_zero_arguments() { + // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 + return plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name()); } else { return plan_err!("{} does not support zero arguments.", func.name()); } } let valid_types = - get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, current_types, func)?; if valid_types .iter() .any(|data_type| data_type == current_types) @@ -132,7 +147,7 @@ pub fn data_types_with_window_udf( return Ok(current_types.to_vec()); } - try_coerce_types(valid_types, current_types, &signature.type_signature) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for function arguments. @@ -143,21 +158,30 @@ pub fn data_types_with_window_udf( /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. pub fn data_types( + function_name: impl AsRef, current_types: &[DataType], signature: &Signature, ) -> Result> { + let type_signature = &signature.type_signature; + if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); + } else if type_signature.used_to_support_zero_arguments() { + // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 + return plan_err!( + "signature {:?} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", + type_signature + ); } else { return plan_err!( "signature {:?} does not support zero arguments.", - &signature.type_signature + type_signature ); } } - let valid_types = get_valid_types(&signature.type_signature, current_types)?; + let valid_types = get_valid_types(type_signature, current_types)?; if valid_types .iter() .any(|data_type| data_type == current_types) @@ -165,7 +189,7 @@ pub fn data_types( return Ok(current_types.to_vec()); } - try_coerce_types(valid_types, current_types, &signature.type_signature) + try_coerce_types(function_name, valid_types, current_types, type_signature) } fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { @@ -180,10 +204,13 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { | TypeSignature::String(_) | TypeSignature::Coercible(_) | TypeSignature::Any(_) + | TypeSignature::Nullary + | TypeSignature::Comparable(_) ) } fn try_coerce_types( + function_name: impl AsRef, valid_types: Vec>, current_types: &[DataType], type_signature: &TypeSignature, @@ -192,13 +219,18 @@ fn try_coerce_types( // Well-supported signature that returns exact valid types. if !valid_types.is_empty() && is_well_supported_signature(type_signature) { - // exact valid types - assert_eq!(valid_types.len(), 1); + // There may be many valid types if valid signature is OneOf + // Otherwise, there should be only one valid type + if !type_signature.is_one_of() { + assert_eq!(valid_types.len(), 1); + } + let valid_types = valid_types.swap_remove(0); if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) { return Ok(t); } } else { + // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already) // Try and coerce the argument types to match the signature, returning the // coerced types from the first matching signature. for valid_types in valid_types { @@ -210,7 +242,8 @@ fn try_coerce_types( // none possible -> Error plan_err!( - "Coercion from {:?} to the signature {:?} failed.", + "Failed to coerce arguments to satisfy a call to {} function: coercion from {:?} to the signature {:?} failed.", + function_name.as_ref(), current_types, type_signature ) @@ -221,20 +254,37 @@ fn get_valid_types_with_scalar_udf( current_types: &[DataType], func: &ScalarUDF, ) -> Result>> { - let valid_types = match signature { + match signature { TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => vec![coerced_types], - Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + Ok(coerced_types) => Ok(vec![coerced_types]), + Err(e) => exec_err!("User-defined coercion failed with {:?}", e), }, - TypeSignature::OneOf(signatures) => signatures - .iter() - .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) - .flatten() - .collect::>(), - _ => get_valid_types(signature, current_types)?, - }; + TypeSignature::OneOf(signatures) => { + let mut res = vec![]; + let mut errors = vec![]; + for sig in signatures { + match get_valid_types_with_scalar_udf(sig, current_types, func) { + Ok(valid_types) => { + res.extend(valid_types); + } + Err(e) => { + errors.push(e.to_string()); + } + } + } - Ok(valid_types) + // Every signature failed, return the joined error + if res.is_empty() { + internal_err!( + "Failed to match any signature, errors: {}", + errors.join(",") + ) + } else { + Ok(res) + } + } + _ => get_valid_types(signature, current_types), + } } fn get_valid_types_with_aggregate_udf( @@ -366,7 +416,16 @@ fn get_valid_types( _ => Ok(vec![vec![]]), } } + fn array(array_type: &DataType) -> Option { + match array_type { + DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), + DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), + _ => None, + } + } + + fn recursive_array(array_type: &DataType) -> Option { match array_type { DataType::List(_) | DataType::LargeList(_) @@ -378,40 +437,49 @@ fn get_valid_types( } } + fn function_length_check(length: usize, expected_length: usize) -> Result<()> { + if length != expected_length { + return plan_err!( + "The signature expected {expected_length} arguments but received {length}" + ); + } + Ok(()) + } + let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::String(number) => { - if *number < 1 { - return plan_err!( - "The signature expected at least one argument but received {}", - current_types.len() - ); - } - if *number != current_types.len() { - return plan_err!( - "The signature expected {} arguments but received {}", - number, - current_types.len() - ); + function_length_check(current_types.len(), *number)?; + + let mut new_types = Vec::with_capacity(current_types.len()); + for data_type in current_types.iter() { + let logical_data_type: NativeType = data_type.into(); + if logical_data_type == NativeType::String { + new_types.push(data_type.to_owned()); + } else if logical_data_type == NativeType::Null { + // TODO: Switch to Utf8View if all the string functions supports Utf8View + new_types.push(DataType::Utf8); + } else { + return plan_err!( + "The signature expected NativeType::String but received {logical_data_type}" + ); + } } - fn coercion_rule( + // Find the common string type for the given types + fn find_common_type( lhs_type: &DataType, rhs_type: &DataType, ) -> Result { match (lhs_type, rhs_type) { - (DataType::Null, DataType::Null) => Ok(DataType::Utf8), - (DataType::Null, data_type) | (data_type, DataType::Null) => { - coercion_rule(data_type, &DataType::Utf8) - } (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => { - coercion_rule(lhs, rhs) + find_common_type(lhs, rhs) } (DataType::Dictionary(_, v), other) - | (other, DataType::Dictionary(_, v)) => coercion_rule(v, other), + | (other, DataType::Dictionary(_, v)) => find_common_type(v, other), _ => { if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) { Ok(coerced_type) @@ -427,15 +495,13 @@ fn get_valid_types( } // Length checked above, safe to unwrap - let mut coerced_type = current_types.first().unwrap().to_owned(); - for t in current_types.iter().skip(1) { - coerced_type = coercion_rule(&coerced_type, t)?; + let mut coerced_type = new_types.first().unwrap().to_owned(); + for t in new_types.iter().skip(1) { + coerced_type = find_common_type(&coerced_type, t)?; } fn base_type_or_default_type(data_type: &DataType) -> DataType { - if data_type.is_null() { - DataType::Utf8 - } else if let DataType::Dictionary(_, v) = data_type { + if let DataType::Dictionary(_, v) = data_type { base_type_or_default_type(v) } else { data_type.to_owned() @@ -445,22 +511,22 @@ fn get_valid_types( vec![vec![base_type_or_default_type(&coerced_type); *number]] } TypeSignature::Numeric(number) => { - if *number < 1 { - return plan_err!( - "The signature expected at least one argument but received {}", - current_types.len() - ); - } - if *number != current_types.len() { - return plan_err!( - "The signature expected {} arguments but received {}", - number, - current_types.len() - ); - } + function_length_check(current_types.len(), *number)?; - let mut valid_type = current_types.first().unwrap().clone(); + // Find common numeric type among given types except string + let mut valid_type = current_types.first().unwrap().to_owned(); for t in current_types.iter().skip(1) { + let logical_data_type: NativeType = t.into(); + if logical_data_type == NativeType::Null { + continue; + } + + if !logical_data_type.is_numeric() { + return plan_err!( + "The signature expected NativeType::Numeric but received {logical_data_type}" + ); + } + if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { valid_type = coerced_type; } else { @@ -472,42 +538,128 @@ fn get_valid_types( } } - vec![vec![valid_type; *number]] - } - TypeSignature::Coercible(target_types) => { - if target_types.is_empty() { + let logical_data_type: NativeType = valid_type.clone().into(); + // Fallback to default type if we don't know which type to coerced to + // f64 is chosen since most of the math functions utilize Signature::numeric, + // and their default type is double precision + if logical_data_type == NativeType::Null { + valid_type = DataType::Float64; + } else if !logical_data_type.is_numeric() { return plan_err!( - "The signature expected at least one argument but received {}", - current_types.len() + "The signature expected NativeType::Numeric but received {logical_data_type}" ); } - if target_types.len() != current_types.len() { - return plan_err!( - "The signature expected {} arguments but received {}", - target_types.len(), - current_types.len() - ); + + vec![vec![valid_type; *number]] + } + TypeSignature::Comparable(num) => { + function_length_check(current_types.len(), *num)?; + let mut target_type = current_types[0].to_owned(); + for data_type in current_types.iter().skip(1) { + if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) { + target_type = dt; + } else { + return plan_err!("{target_type} and {data_type} is not comparable"); + } } + // Convert null to String type. + if target_type.is_null() { + vec![vec![DataType::Utf8View; *num]] + } else { + vec![vec![target_type; *num]] + } + } + TypeSignature::Coercible(target_types) => { + function_length_check(current_types.len(), target_types.len())?; + + // Aim to keep this logic as SIMPLE as possible! + // Make sure the corresponding test is covered + // If this function becomes COMPLEX, create another new signature! + fn can_coerce_to( + current_type: &DataType, + target_type_class: &TypeSignatureClass, + ) -> Result { + let logical_type: NativeType = current_type.into(); - for (data_type, target_type) in current_types.iter().zip(target_types.iter()) - { - if !can_cast_types(data_type, target_type) { - return plan_err!("{data_type} is not coercible to {target_type}"); + match target_type_class { + TypeSignatureClass::Native(native_type) => { + let target_type = native_type.native(); + if &logical_type == target_type { + return target_type.default_cast_for(current_type); + } + + if logical_type == NativeType::Null { + return target_type.default_cast_for(current_type); + } + + if target_type.is_integer() && logical_type.is_integer() { + return target_type.default_cast_for(current_type); + } + + internal_err!( + "Expect {} but received {}", + target_type_class, + current_type + ) + } + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + TypeSignatureClass::Timestamp + if logical_type == NativeType::String => + { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Date if logical_type.is_date() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Time if logical_type.is_time() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Interval if logical_type.is_interval() => { + Ok(current_type.to_owned()) + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + Ok(current_type.to_owned()) + } + _ => { + not_impl_err!("Got logical_type: {logical_type} with target_type_class: {target_type_class}") + } } } - vec![target_types.to_owned()] + let mut new_types = Vec::with_capacity(current_types.len()); + for (current_type, target_type_class) in + current_types.iter().zip(target_types.iter()) + { + let target_type = can_coerce_to(current_type, target_type_class)?; + new_types.push(target_type); + } + + vec![new_types] + } + TypeSignature::Uniform(number, valid_types) => { + if *number == 0 { + return plan_err!("The function expected at least one argument"); + } + + valid_types + .iter() + .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .collect() } - TypeSignature::Uniform(number, valid_types) => valid_types - .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) - .collect(), TypeSignature::UserDefined => { return internal_err!( "User-defined signature should be handled by function-specific coerce_types." ) } TypeSignature::VariadicAny => { + if current_types.is_empty() { + return plan_err!( + "The function expected at least one argument but received 0" + ); + } vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], @@ -539,6 +691,13 @@ fn get_valid_types( array(¤t_types[0]) .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) } + ArrayFunctionSignature::RecursiveArray => { + if current_types.len() != 1 { + return Ok(vec![vec![]]); + } + recursive_array(¤t_types[0]) + .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + } ArrayFunctionSignature::MapArray => { if current_types.len() != 1 { return Ok(vec![vec![]]); @@ -550,7 +709,22 @@ fn get_valid_types( } } }, + TypeSignature::Nullary => { + if !current_types.is_empty() { + return plan_err!( + "The function expected zero argument but received {}", + current_types.len() + ); + } + vec![vec![]] + } TypeSignature::Any(number) => { + if current_types.is_empty() { + return plan_err!( + "The function expected at least one argument but received 0" + ); + } + if current_types.len() != *number { return plan_err!( "The function expected {} arguments but received {}", @@ -762,6 +936,7 @@ mod tests { use super::*; use arrow::datatypes::Field; + use datafusion_common::assert_contains; #[test] fn test_string_conversion() { @@ -826,6 +1001,67 @@ mod tests { } } + #[test] + fn test_get_valid_types_numeric() -> Result<()> { + let get_valid_types_flatten = + |signature: &TypeSignature, current_types: &[DataType]| { + get_valid_types(signature, current_types) + .unwrap() + .into_iter() + .flatten() + .collect::>() + }; + + // Trivial case. + let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Int32]); + assert_eq!(got, [DataType::Int32]); + + // Args are coerced into a common numeric type. + let got = get_valid_types_flatten( + &TypeSignature::Numeric(2), + &[DataType::Int32, DataType::Int64], + ); + assert_eq!(got, [DataType::Int64, DataType::Int64]); + + // Args are coerced into a common numeric type, specifically, int would be coerced to float. + let got = get_valid_types_flatten( + &TypeSignature::Numeric(3), + &[DataType::Int32, DataType::Int64, DataType::Float64], + ); + assert_eq!( + got, + [DataType::Float64, DataType::Float64, DataType::Float64] + ); + + // Cannot coerce args to a common numeric type. + let got = get_valid_types( + &TypeSignature::Numeric(2), + &[DataType::Int32, DataType::Utf8], + ) + .unwrap_err(); + assert_contains!( + got.to_string(), + "The signature expected NativeType::Numeric but received NativeType::String" + ); + + // Fallbacks to float64 if the arg is of type null. + let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Null]); + assert_eq!(got, [DataType::Float64]); + + // Rejects non-numeric arg. + let got = get_valid_types( + &TypeSignature::Numeric(1), + &[DataType::Timestamp(TimeUnit::Second, None)], + ) + .unwrap_err(); + assert_contains!( + got.to_string(), + "The signature expected NativeType::Numeric but received NativeType::Timestamp(Second, None)" + ); + + Ok(()) + } + #[test] fn test_get_valid_types_one_of() -> Result<()> { let signature = @@ -850,9 +1086,32 @@ mod tests { Ok(()) } + #[test] + fn test_get_valid_types_length_check() -> Result<()> { + let signature = TypeSignature::Numeric(1); + + let err = get_valid_types(&signature, &[]).unwrap_err(); + assert_contains!( + err.to_string(), + "The signature expected 1 arguments but received 0" + ); + + let err = get_valid_types( + &signature, + &[DataType::Int32, DataType::Int32, DataType::Int32], + ) + .unwrap_err(); + assert_contains!( + err.to_string(), + "The signature expected 1 arguments but received 3" + ); + + Ok(()) + } + #[test] fn test_fixed_list_wildcard_coerce() -> Result<()> { - let inner = Arc::new(Field::new("item", DataType::Int32, false)); + let inner = Arc::new(Field::new_list_field(DataType::Int32, false)); let current_types = vec![ DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size ]; @@ -865,7 +1124,7 @@ mod tests { Volatility::Stable, ); - let coerced_data_types = data_types(¤t_types, &signature).unwrap(); + let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap(); assert_eq!(coerced_data_types, current_types); // make sure it can't coerce to a different size @@ -873,7 +1132,7 @@ mod tests { vec![DataType::FixedSizeList(Arc::clone(&inner), 3)], Volatility::Stable, ); - let coerced_data_types = data_types(¤t_types, &signature); + let coerced_data_types = data_types("test", ¤t_types, &signature); assert!(coerced_data_types.is_err()); // make sure it works with the same type. @@ -881,7 +1140,7 @@ mod tests { vec![DataType::FixedSizeList(Arc::clone(&inner), 2)], Volatility::Stable, ); - let coerced_data_types = data_types(¤t_types, &signature).unwrap(); + let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap(); assert_eq!(coerced_data_types, current_types); Ok(()) @@ -890,10 +1149,9 @@ mod tests { #[test] fn test_nested_wildcard_fixed_size_lists() -> Result<()> { let type_into = DataType::FixedSizeList( - Arc::new(Field::new( - "item", + Arc::new(Field::new_list_field( DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, false)), + Arc::new(Field::new_list_field(DataType::Int32, false)), FIXED_SIZE_LIST_WILDCARD, ), false, @@ -902,10 +1160,9 @@ mod tests { ); let type_from = DataType::FixedSizeList( - Arc::new(Field::new( - "item", + Arc::new(Field::new_list_field( DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int8, false)), + Arc::new(Field::new_list_field(DataType::Int8, false)), 4, ), false, @@ -916,10 +1173,9 @@ mod tests { assert_eq!( coerced_from(&type_into, &type_from), Some(DataType::FixedSizeList( - Arc::new(Field::new( - "item", + Arc::new(Field::new_list_field( DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, false)), + Arc::new(Field::new_list_field(DataType::Int32, false)), 4, ), false, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 6e48054bcf3d6..56c9822495f84 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -140,7 +140,7 @@ impl AggregateUDF { )) } - /// creates an [`Expr`] that calls the aggregate function. + /// Creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. @@ -333,13 +333,9 @@ where /// /// fn get_doc() -> &'static Documentation { /// DOCUMENTATION.get_or_init(|| { -/// Documentation::builder() -/// .with_doc_section(DOC_SECTION_AGGREGATE) -/// .with_description("calculates a geometric mean") -/// .with_syntax_example("geo_mean(2.0)") +/// Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)") /// .with_argument("arg1", "The Float64 number for the geometric mean") /// .build() -/// .unwrap() /// }) /// } /// @@ -603,8 +599,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { } /// If this function is max, return true - /// if the function is min, return false - /// otherwise return None (the default) + /// If the function is min, return false + /// Otherwise return None (the default) /// /// /// Note: this is used to use special aggregate implementations in certain conditions @@ -647,7 +643,7 @@ impl PartialEq for dyn AggregateUDFImpl { } } -// manual implementation of `PartialOrd` +// Manual implementation of `PartialOrd` // There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl // https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 impl PartialOrd for dyn AggregateUDFImpl { diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3759fb18f56df..ffac82265a008 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,7 +17,7 @@ //! [`ScalarUDF`]: Scalar User Defined Functions -use crate::expr::schema_name_from_exprs_comma_seperated_without_space; +use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ @@ -164,6 +164,19 @@ impl ScalarUDF { self.inner.signature() } + /// The datatype this function returns given the input argument types. + /// This function is used when the input arguments are [`DataType`]s. + /// + /// # Notes + /// + /// If a function implement [`ScalarUDFImpl::return_type_from_exprs`], + /// its [`ScalarUDFImpl::return_type`] should raise an error. + /// + /// See [`ScalarUDFImpl::return_type`] for more details. + pub fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + /// The datatype this function returns given the input argument input types. /// This function is used when the input arguments are [`Expr`]s. /// @@ -190,10 +203,9 @@ impl ScalarUDF { self.inner.simplify(args, info) } - /// Invoke the function on `args`, returning the appropriate result. - /// - /// See [`ScalarUDFImpl::invoke`] for more details. + #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")] pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + #[allow(deprecated)] self.inner.invoke(args) } @@ -201,18 +213,37 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + pub fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + self.inner.invoke_batch(args, number_rows) + } + + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_with_args`] for details. + pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.inner.invoke_with_args(args) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// - /// See [`ScalarUDFImpl::invoke_no_args`] for more details. + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke_no_args(&self, number_rows: usize) -> Result { + #[allow(deprecated)] self.inner.invoke_no_args(number_rows) } /// Returns a `ScalarFunctionImplementation` that can invoke the function /// during execution - #[deprecated(since = "42.0.0", note = "Use `invoke` or `invoke_no_args` instead")] + #[deprecated(since = "42.0.0", note = "Use `invoke_batch` instead")] pub fn fun(&self) -> ScalarFunctionImplementation { let captured = Arc::clone(&self.inner); + #[allow(deprecated)] Arc::new(move |args| captured.invoke(args)) } @@ -272,6 +303,10 @@ impl ScalarUDF { self.inner.output_ordering(inputs) } + pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result { + self.inner.preserves_lex_ordering(inputs) + } + /// See [`ScalarUDFImpl::coerce_types`] for more details. pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) @@ -288,14 +323,26 @@ impl ScalarUDF { impl From for ScalarUDF where - F: ScalarUDFImpl + Send + Sync + 'static, + F: ScalarUDFImpl + 'static, { fn from(fun: F) -> Self { Self::new_from_impl(fun) } } -/// Trait for implementing [`ScalarUDF`]. +/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a +/// scalar function. +pub struct ScalarFunctionArgs<'a> { + /// The evaluated arguments to the function + pub args: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) + /// when creating the physical expression from the logical expression + pub return_type: &'a DataType, +} + +/// Trait for implementing user defined scalar functions. /// /// This trait exposes the full API for implementing user defined functions and /// can be used to implement any function. @@ -303,18 +350,19 @@ where /// See [`advanced_udf.rs`] for a full example with complete implementation and /// [`ScalarUDF`] for other available options. /// -/// /// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// /// # Basic Example /// ``` /// # use std::any::Any; /// # use std::sync::OnceLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; /// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; /// +/// /// This struct for a simple UDF that adds one to an int32 /// #[derive(Debug)] /// struct AddOne { /// signature: Signature, @@ -327,18 +375,14 @@ where /// } /// } /// } -/// +/// /// static DOCUMENTATION: OnceLock = OnceLock::new(); /// /// fn get_doc() -> &'static Documentation { /// DOCUMENTATION.get_or_init(|| { -/// Documentation::builder() -/// .with_doc_section(DOC_SECTION_MATH) -/// .with_description("Add one to an int32") -/// .with_syntax_example("add_one(2)") +/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") /// .with_argument("arg1", "The int32 number to add one to") /// .build() -/// .unwrap() /// }) /// } /// @@ -354,7 +398,9 @@ where /// Ok(DataType::Int32) /// } /// // The actual implementation would add one to the argument -/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { +/// unimplemented!() +/// } /// fn documentation(&self) -> Option<&Documentation> { /// Some(get_doc()) /// } @@ -390,7 +436,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { Ok(format!( "{}({})", self.name(), - schema_name_from_exprs_comma_seperated_without_space(args)? + schema_name_from_exprs_comma_separated_without_space(args)? )) } @@ -450,27 +496,64 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Invoke the function on `args`, returning the appropriate result /// - /// The function will be invoked passed with the slice of [`ColumnarValue`] - /// (either scalar or array). + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. + #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")] + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!( + "Function {} does not implement invoke but called", + self.name() + ) + } + + /// Invoke the function with `args` and the number of rows, + /// returning the appropriate result. + /// + /// Note: See notes on [`Self::invoke_with_args`] /// - /// If the function does not take any arguments, please use [invoke_no_args] - /// instead and return [not_impl_err] for this function. + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. /// + /// See for more details. + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + match args.is_empty() { + true => + { + #[allow(deprecated)] + self.invoke_no_args(number_rows) + } + false => + { + #[allow(deprecated)] + self.invoke(args) + } + } + } + + /// Invoke the function returning the appropriate result. /// /// # Performance /// - /// For the best performance, the implementations of `invoke` should handle - /// the common case when one or more of their arguments are constant values - /// (aka [`ColumnarValue::Scalar`]). + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). /// /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. - /// - /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args - fn invoke(&self, _args: &[ColumnarValue]) -> Result; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.invoke_batch(&args.args, args.number_rows) + } /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. + /// + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. + #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")] fn invoke_no_args(&self, _number_rows: usize) -> Result { not_impl_err!( "Function {} does not implement invoke_no_args but called", @@ -571,10 +654,30 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { Ok(Some(vec![])) } - /// Calculates the [`SortProperties`] of this function based on its - /// children's properties. - fn output_ordering(&self, _inputs: &[ExprProperties]) -> Result { - Ok(SortProperties::Unordered) + /// Calculates the [`SortProperties`] of this function based on its children's properties. + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + if !self.preserves_lex_ordering(inputs)? { + return Ok(SortProperties::Unordered); + } + + let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else { + return Ok(SortProperties::Singleton); + }; + + if inputs + .iter() + .skip(1) + .all(|input| &input.sort_properties == first_order) + { + Ok(*first_order) + } else { + Ok(SortProperties::Unordered) + } + } + + /// Whether the function preserves lexicographical ordering based on the input ordering + fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { + Ok(false) } /// Coerce arguments of a function call to types that the function can evaluate. @@ -693,12 +796,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_exprs(args, schema, arg_types) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - self.inner.invoke(args) - } - - fn invoke_no_args(&self, number_rows: usize) -> Result { - self.inner.invoke_no_args(number_rows) + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + #[allow(deprecated)] + self.inner.invoke_batch(args, number_rows) } fn simplify( @@ -729,6 +833,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.output_ordering(inputs) } + fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result { + self.inner.preserves_lex_ordering(inputs) + } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) } @@ -773,6 +881,22 @@ pub mod scalar_doc_sections { ] } + pub const fn doc_sections_const() -> &'static [DocSection] { + &[ + DOC_SECTION_MATH, + DOC_SECTION_CONDITIONAL, + DOC_SECTION_STRING, + DOC_SECTION_BINARY_STRING, + DOC_SECTION_REGEX, + DOC_SECTION_DATETIME, + DOC_SECTION_ARRAY, + DOC_SECTION_STRUCT, + DOC_SECTION_MAP, + DOC_SECTION_HASHING, + DOC_SECTION_OTHER, + ] + } + pub const DOC_SECTION_MATH: DocSection = DocSection { include: true, label: "Math Functions", diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 69f357d48f8c1..39e1e8f261a21 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -28,17 +28,27 @@ use std::{ use arrow::datatypes::{DataType, Field}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_functions_window_common::field::WindowUDFFieldArgs; - use crate::expr::WindowFunction; use crate::{ - function::WindowFunctionSimplification, Documentation, Expr, PartitionEvaluator, - Signature, + function::WindowFunctionSimplification, Expr, PartitionEvaluator, Signature, }; +use datafusion_common::{not_impl_err, Result}; +use datafusion_doc::Documentation; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -/// Logical representation of a user-defined window function (UDWF) -/// A UDWF is different from a UDF in that it is stateful across batches. +/// Logical representation of a user-defined window function (UDWF). +/// +/// A Window Function is called via the SQL `OVER` clause: +/// +/// ```sql +/// SELECT first_value(col) OVER (PARTITION BY a, b ORDER BY c) FROM foo; +/// ``` +/// +/// A UDWF is different from a user defined function (UDF) in that it is +/// stateful across batches. /// /// See the documentation on [`PartitionEvaluator`] for more details /// @@ -149,9 +159,18 @@ impl WindowUDF { self.inner.simplify() } + /// Expressions that are passed to the [`PartitionEvaluator`]. + /// + /// See [`WindowUDFImpl::expressions`] for more details. + pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + self.inner.expressions(expr_args) + } /// Return a `PartitionEvaluator` for evaluating this window function - pub fn partition_evaluator_factory(&self) -> Result> { - self.inner.partition_evaluator() + pub fn partition_evaluator_factory( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } /// Returns the field of the final result of evaluating this window function. @@ -218,8 +237,9 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; -/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; /// /// #[derive(Debug, Clone)] /// struct SmoothIt { @@ -238,13 +258,9 @@ where /// /// fn get_doc() -> &'static Documentation { /// DOCUMENTATION.get_or_init(|| { -/// Documentation::builder() -/// .with_doc_section(DOC_SECTION_ANALYTICAL) -/// .with_description("smooths the windows") -/// .with_syntax_example("smooth_it(2)") +/// Documentation::builder(DOC_SECTION_ANALYTICAL, "smooths the windows", "smooth_it(2)") /// .with_argument("arg1", "The int32 number to smooth by") /// .build() -/// .unwrap() /// }) /// } /// @@ -254,7 +270,12 @@ where /// fn name(&self) -> &str { "smooth_it" } /// fn signature(&self) -> &Signature { &self.signature } /// // The actual implementation would smooth the window -/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// fn partition_evaluator( +/// &self, +/// _partition_evaluator_args: PartitionEvaluatorArgs, +/// ) -> Result> { +/// unimplemented!() +/// } /// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { /// if let Some(DataType::Int32) = field_args.get_input_type(0) { /// Ok(Field::new(field_args.name(), DataType::Int32, false)) @@ -293,8 +314,16 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; + /// Returns the expressions that are passed to the [`PartitionEvaluator`]. + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args.input_exprs().into() + } + /// Invoke the function, returning the [`PartitionEvaluator`] instance - fn partition_evaluator(&self) -> Result>; + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result>; /// Returns any aliases (alternate names) for this function. /// @@ -315,7 +344,7 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// optimizations manually for specific UDFs. /// /// Example: - /// [`simplify_udwf_expression.rs`]: + /// [`advanced_udwf.rs`]: /// /// # Returns /// [None] if simplify is not defined or, @@ -468,8 +497,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } - fn partition_evaluator(&self) -> Result> { - self.inner.partition_evaluator() + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } fn aliases(&self) -> &[String] { @@ -514,7 +553,7 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { // Window UDF doc sections for use in public documentation pub mod window_doc_sections { - use crate::DocSection; + use datafusion_doc::DocSection; pub fn doc_sections() -> Vec { vec![ @@ -550,6 +589,7 @@ mod test { use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::cmp::Ordering; @@ -581,7 +621,10 @@ mod test { fn signature(&self) -> &Signature { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!() } fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { @@ -617,7 +660,10 @@ mod test { fn signature(&self) -> &Signature { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!() } fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 02b36d0feab94..049926fb0bcd6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -18,7 +18,7 @@ //! Expression utilities use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashSet}; use std::ops::Deref; use std::sync::Arc; @@ -29,14 +29,14 @@ use crate::{ }; use datafusion_expr_common::signature::{Signature, TypeSignature}; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::utils::get_at_indices; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, - DataFusionError, Result, TableReference, + internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap, + Result, TableReference, }; use indexmap::IndexSet; @@ -67,7 +67,7 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { "Invalid group by expressions, GroupingSet must be the only expression" ); } - // Groupings sets have an additional interal column for the grouping id + // Groupings sets have an additional integral column for the grouping id Ok(grouping_set.distinct_expr().len() + 1) } else { grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) @@ -205,7 +205,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { if !has_grouping_set || group_expr.len() == 1 { return Ok(group_expr); } - // only process mix grouping sets + // Only process mix grouping sets let partial_sets = group_expr .iter() .map(|expr| { @@ -234,7 +234,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { }) .collect::>>()?; - // cross join + // Cross Join let grouping_sets = partial_sets .into_iter() .map(Ok) @@ -342,7 +342,7 @@ fn get_excluded_columns( // Excluded columns should be unique let n_elem = idents.len(); let unique_idents = idents.into_iter().collect::>(); - // if HashSet size, and vector length are different, this means that some of the excluded columns + // If HashSet size, and vector length are different, this means that some of the excluded columns // are not unique. In this case return error. if n_elem != unique_idents.len() { return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); @@ -379,14 +379,12 @@ fn get_exprs_except_skipped( } } -/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s. -pub fn expand_wildcard( - schema: &DFSchema, - plan: &LogicalPlan, - wildcard_options: Option<&WildcardOptions>, -) -> Result> { +/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice +/// (once for each join side), but an unqualified wildcard should include it only once. +/// This function returns the columns that should be excluded. +fn exclude_using_columns(plan: &LogicalPlan) -> Result> { let using_columns = plan.using_columns()?; - let mut columns_to_skip = using_columns + let excluded = using_columns .into_iter() // For each USING JOIN condition, only expand to one of each join column in projection .flat_map(|cols| { @@ -395,18 +393,26 @@ pub fn expand_wildcard( // qualified column cols.sort(); let mut out_column_names: HashSet = HashSet::new(); - cols.into_iter() - .filter_map(|c| { - if out_column_names.contains(&c.name) { - Some(c) - } else { - out_column_names.insert(c.name); - None - } - }) - .collect::>() + cols.into_iter().filter_map(move |c| { + if out_column_names.contains(&c.name) { + Some(c) + } else { + out_column_names.insert(c.name); + None + } + }) }) .collect::>(); + Ok(excluded) +} + +/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s. +pub fn expand_wildcard( + schema: &DFSchema, + plan: &LogicalPlan, + wildcard_options: Option<&WildcardOptions>, +) -> Result> { + let mut columns_to_skip = exclude_using_columns(plan)?; let excluded_columns = if let Some(WildcardOptions { exclude: opt_exclude, except: opt_except, @@ -437,7 +443,10 @@ pub fn expand_qualified_wildcard( return plan_err!("Invalid qualifier {qualifier}"); } - let qualified_schema = Arc::new(Schema::new(fields_with_qualified)); + let qualified_schema = Arc::new(Schema::new_with_metadata( + fields_with_qualified, + schema.metadata().clone(), + )); let qualified_dfschema = DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)? .with_functional_dependencies(projected_func_dependencies)?; @@ -466,7 +475,7 @@ pub fn expand_qualified_wildcard( } /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") -/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column +/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column type WindowSortKey = Vec<(Sort, bool)>; /// Generate a sort key for a given window expr's partition_by and order_by expr @@ -573,7 +582,7 @@ pub fn compare_sort_expr( Ordering::Equal } -/// group a slice of window expression expr by their order by expressions +/// Group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( window_expr: Vec, ) -> Result)>> { @@ -600,7 +609,7 @@ pub fn group_window_expr_by_sort_keys( /// Collect all deeply nested `Expr::AggregateFunction`. /// They are returned in order of occurrence (depth /// first), with duplicates omitted. -pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { +pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { matches!(nested_expr, Expr::AggregateFunction { .. }) }) @@ -625,12 +634,15 @@ pub fn find_out_reference_exprs(expr: &Expr) -> Vec { /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). -fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec +fn find_exprs_in_exprs<'a, F>( + exprs: impl IntoIterator, + test_fn: &F, +) -> Vec where F: Fn(&Expr) -> bool, { exprs - .iter() + .into_iter() .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) .fold(vec![], |mut acc, expr| { if !acc.contains(&expr) { @@ -653,7 +665,7 @@ where if !(exprs.contains(expr)) { exprs.push(expr.clone()) } - // stop recursing down this expr once we find a match + // Stop recursing down this expr once we find a match return Ok(TreeNodeRecursion::Jump); } @@ -672,7 +684,7 @@ where let mut err = Ok(()); expr.apply(|expr| { if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError + // Save the error for later (it may not be a DataFusionError) err = Err(e); Ok(TreeNodeRecursion::Stop) } else { @@ -691,7 +703,7 @@ pub fn exprlist_to_fields<'a>( exprs: impl IntoIterator, plan: &LogicalPlan, ) -> Result, Arc)>> { - // look for exact match in plan's output schema + // Look for exact match in plan's output schema let wildcard_schema = find_base_plan(plan).schema(); let input_schema = plan.schema(); let result = exprs @@ -699,27 +711,20 @@ pub fn exprlist_to_fields<'a>( .map(|e| match e { Expr::Wildcard { qualifier, options } => match qualifier { None => { - let excluded: Vec = get_excluded_columns( + let mut excluded = exclude_using_columns(plan)?; + excluded.extend(get_excluded_columns( options.exclude.as_ref(), options.except.as_ref(), wildcard_schema, None, - )? - .into_iter() - .map(|c| c.flat_name()) - .collect(); - Ok::<_, DataFusionError>( - wildcard_schema - .field_names() - .iter() - .enumerate() - .filter(|(_, s)| !excluded.contains(s)) - .map(|(i, _)| wildcard_schema.qualified_field(i)) - .map(|(qualifier, f)| { - (qualifier.cloned(), Arc::new(f.to_owned())) - }) - .collect::>(), - ) + )?); + Ok(wildcard_schema + .iter() + .filter(|(q, f)| { + !excluded.contains(&Column::new(q.cloned(), f.name())) + }) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect::>()) } Some(qualifier) => { let excluded: Vec = get_excluded_columns( @@ -950,9 +955,9 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes } -/// can this data type be used in hash join equal conditions?? -/// data types here come from function 'equal_rows', if more data types are supported -/// in equal_rows(hash join), add those data types here to generate join logical plan. +/// Can this data type be used in hash join equal conditions?? +/// Data types here come from function 'equal_rows', if more data types are supported +/// in create_hashes, add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { DataType::Null => true, @@ -965,30 +970,38 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt16 => true, DataType::UInt32 => true, DataType::UInt64 => true, + DataType::Float16 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => true, - TimeUnit::Millisecond => true, - TimeUnit::Microsecond => true, - TimeUnit::Nanosecond => true, - }, + DataType::Decimal128(_, _) => true, + DataType::Decimal256(_, _) => true, + DataType::Timestamp(_, _) => true, DataType::Utf8 => true, DataType::LargeUtf8 => true, - DataType::Decimal128(_, _) => true, + DataType::Utf8View => true, + DataType::Binary => true, + DataType::LargeBinary => true, + DataType::BinaryView => true, DataType::Date32 => true, DataType::Date64 => true, + DataType::Time32(_) => true, + DataType::Time64(_) => true, + DataType::Duration(_) => true, + DataType::Interval(_) => true, DataType::FixedSizeBinary(_) => true, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - DataType::is_dictionary_key_type(key_type) + DataType::Dictionary(key_type, value_type) => { + DataType::is_dictionary_key_type(key_type) && can_hash(value_type) } - DataType::List(_) => true, - DataType::LargeList(_) => true, - DataType::FixedSizeList(_, _) => true, + DataType::List(value_type) => can_hash(value_type.data_type()), + DataType::LargeList(value_type) => can_hash(value_type.data_type()), + DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()), + DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()), DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), - _ => false, + + DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Union(_, _) + | DataType::RunEndEncoded(_, _) => false, } } @@ -1098,6 +1111,54 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& } } +/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction(expr: &Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(right); + stack.push(left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(expr), + other => return Some(other), + } + } + None + }) +} + +/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(*right); + stack.push(*left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(*expr), + other => return Some(other), + } + } + None + }) +} + /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// This is often used to "split" filter expressions such as `col1 = 5 @@ -1239,7 +1300,7 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// col("b").eq(lit(2)), /// ]; /// -/// // use disjuncton to join them together with `OR` +/// // use disjunction to join them together with `OR` /// assert_eq!(disjunction(split), Some(expr)); /// ``` pub fn disjunction(filters: impl IntoIterator) -> Option { @@ -1340,14 +1401,33 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } +/// Determine the set of [`Column`]s produced by the subquery. +pub fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: &DFSchema, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.column_refs().into_iter() { + if subquery_schema.has_column(col) { + using_cols.push(col.clone()); + } + } + + cols.extend(using_cols); + Result::<_>::Ok(cols) + }) +} + #[cfg(test)] mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, + col, cube, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::max_udaf, test::function_stub::min_udaf, test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; + use arrow::datatypes::{UnionFields, UnionMode}; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1359,19 +1439,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( + let min3 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( + let sum4 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1386,28 +1466,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = expr::Sort::new(col("age"), true, true); - let name_desc = expr::Sort::new(col("name"), false, true); - let created_at_desc = expr::Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let age_asc = Sort::new(col("age"), true, true); + let name_desc = Sort::new(col("name"), false, true); + let created_at_desc = Sort::new(col("created_at"), false, true); + let max1 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( + let min3 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( + let sum4 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) @@ -1750,4 +1830,21 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_can_hash() { + let union_fields: UnionFields = [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect(); + + let union_type = DataType::Union(union_fields, UnionMode::Sparse); + assert!(!can_hash(&union_type)); + + let list_union_type = + DataType::List(Arc::new(Field::new("my_union", union_type, true))); + assert!(!can_hash(&list_union_type)); + } } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b2e8268aa332c..815d5742afd23 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,11 +23,11 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::{expr::Sort, lit}; +use arrow::datatypes::DataType; use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{expr::Sort, lit}; - use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -94,7 +94,7 @@ pub struct WindowFrame { } impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!( f, "{} BETWEEN {} AND {}", @@ -119,9 +119,9 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.try_into()?; + let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?; let end_bound = match value.end_bound { - Some(value) => value.try_into()?, + Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?, None => WindowFrameBound::CurrentRow, }; @@ -138,6 +138,7 @@ impl TryFrom for WindowFrame { )? } }; + let units = value.units.into(); Ok(Self::new_bounds(units, start_bound, end_bound)) } @@ -273,7 +274,7 @@ impl WindowFrame { Ok(()) } - /// Returns whether the window frame can accept multiple ORDER BY expressons. + /// Returns whether the window frame can accept multiple ORDER BY expressions. pub fn can_accept_multi_orderby(&self) -> bool { match self.units { WindowFrameUnits::Rows => true, @@ -334,17 +335,18 @@ impl WindowFrameBound { } } -impl TryFrom for WindowFrameBound { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrameBound) -> Result { +impl WindowFrameBound { + fn try_parse( + value: ast::WindowFrameBound, + units: &ast::WindowFrameUnits, + ) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) + Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_frame_bound_to_scalar_value(*v)?) + Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), ast::WindowFrameBound::CurrentRow => Self::CurrentRow, @@ -352,37 +354,69 @@ impl TryFrom for WindowFrameBound { } } -pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { - Ok(ScalarValue::Utf8(Some(match v { - ast::Expr::Value(ast::Value::Number(value, false)) - | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, - ast::Expr::Interval(ast::Interval { - value, - leading_field, - .. - }) => { - let result = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, - e => { - return sql_err!(ParserError(format!( - "INTERVAL expression cannot be {e:?}" - ))); +fn convert_frame_bound_to_scalar_value( + v: ast::Expr, + units: &ast::WindowFrameUnits, +) -> Result { + match units { + // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... + ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { + ast::Expr::Value(ast::Value::Number(value, false)) => { + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + }, + ast::Expr::Interval(ast::Interval { + value, + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }) => { + let value = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + } + _ => plan_err!( + "Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ), + }, + // ... instead for RANGE it could be anything depending on the type of the ORDER BY clause, + // so we use a ScalarValue::Utf8. + ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) => value, + ast::Expr::Interval(ast::Interval { + value, + leading_field, + .. + }) => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + if let Some(leading_field) = leading_field { + format!("{result} {leading_field}") + } else { + result } - }; - if let Some(leading_field) = leading_field { - format!("{result} {leading_field}") - } else { - result } - } - _ => plan_err!( - "Invalid window frame: frame offsets must be non negative integers" - )?, - }))) + _ => plan_err!( + "Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval" + )?, + }))), + } } impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { WindowFrameBound::Preceding(n) => { if n.is_null() { @@ -423,7 +457,7 @@ pub enum WindowFrameUnits { } impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.write_str(match self { WindowFrameUnits::Rows => "ROWS", WindowFrameUnits::Range => "RANGE", @@ -479,8 +513,91 @@ mod tests { ast::Expr::Value(ast::Value::Number("1".to_string(), false)), )))), }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); + + let window_frame = WindowFrame::try_from(window_frame)?; + assert_eq!(window_frame.units, WindowFrameUnits::Rows); + assert_eq!( + window_frame.start_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) + ); + assert_eq!( + window_frame.end_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + ); + + Ok(()) + } + + macro_rules! test_bound { + ($unit:ident, $value:expr, $expected:expr) => { + let preceding = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(preceding, WindowFrameBound::Preceding($expected)); + let following = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(following, WindowFrameBound::Following($expected)); + }; + } + + macro_rules! test_bound_err { + ($unit:ident, $value:expr, $expected:expr) => { + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + }; + } + + #[test] + fn test_window_frame_bound_creation() -> Result<()> { + // Unbounded + test_bound!(Rows, None, ScalarValue::Null); + test_bound!(Groups, None, ScalarValue::Null); + test_bound!(Range, None, ScalarValue::Null); + + // Number + let number = Some(Box::new(ast::Expr::Value(ast::Value::Number( + "42".to_string(), + false, + )))); + test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("42".to_string())) + ); + + // Interval + let number = Some(Box::new(ast::Expr::Interval(ast::Interval { + value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + "1".to_string(), + ))), + leading_field: Some(ast::DateTimeField::Day), + fractional_seconds_precision: None, + last_field: None, + leading_precision: None, + }))); + test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("1 DAY".to_string())) + ); + Ok(()) } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index a80718147c3a4..0000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_common::ScalarValue; - -use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; - -/// Create an expression to represent the `rank` window function -pub fn rank() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) -} - -/// Create an expression to represent the `dense_rank` window function -pub fn dense_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::DenseRank, - vec![], - )) -} - -/// Create an expression to represent the `percent_rank` window function -pub fn percent_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::PercentRank, - vec![], - )) -} - -/// Create an expression to represent the `cume_dist` window function -pub fn cume_dist() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) -} - -/// Create an expression to represent the `ntile` window function -pub fn ntile(arg: Expr) -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) -} - -/// Create an expression to represent the `lag` window function -pub fn lag( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lag, - vec![arg, shift_offset_lit, default_lit], - )) -} - -/// Create an expression to represent the `lead` window function -pub fn lead( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lead, - vec![arg, shift_offset_lit, default_lit], - )) -} - -/// Create an expression to represent the `nth_value` window function -pub fn nth_value(arg: Expr, n: i64) -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::NthValue, - vec![arg, n.lit()], - )) -} diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index e7f31bbfbf2bd..f1d0ead23ab19 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -48,7 +48,7 @@ pub struct WindowAggState { /// Keeps track of how many rows should be generated to be in sync with input record_batch. // (For each row in the input record batch we need to generate a window result). pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition + /// Flag indicating whether we have received all data for this partition pub is_end: bool, } diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml new file mode 100644 index 0000000000000..a0179ec44d7f0 --- /dev/null +++ b/datafusion/ffi/Cargo.toml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-ffi" +description = "Foreign Function Interface implementation for DataFusion" +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_ffi" +path = "src/lib.rs" + +[dependencies] +abi_stable = "0.11.3" +arrow = { workspace = true, features = ["ffi"] } +async-ffi = { version = "0.5.0", features = ["abi_stable"] } +async-trait = { workspace = true } +datafusion = { workspace = true, default-features = false } +datafusion-proto = { workspace = true } +futures = { workspace = true } +log = { workspace = true } +prost = { workspace = true } + +[dev-dependencies] +doc-comment = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/ffi/LICENSE.txt b/datafusion/ffi/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/ffi/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/ffi/NOTICE.txt b/datafusion/ffi/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/ffi/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/ffi/README.md b/datafusion/ffi/README.md new file mode 100644 index 0000000000000..48283f4cfdc14 --- /dev/null +++ b/datafusion/ffi/README.md @@ -0,0 +1,112 @@ + + +# `datafusion-ffi`: Apache DataFusion Foreign Function Interface + +This crate contains code to allow interoperability of Apache [DataFusion] with +functions from other libraries and/or [DataFusion] versions using a stable +interface. + +One of the limitations of the Rust programming language is that there is no +stable [Rust ABI] (Application Binary Interface). If a library is compiled with +one version of the Rust compiler and you attempt to use that library with a +program compiled by a different Rust compiler, there is no guarantee that you +can access the data structures. In order to share code between libraries loaded +at runtime, you need to use Rust's [FFI](Foreign Function Interface (FFI)). + +The purpose of this crate is to define interfaces between [DataFusion] libraries +that will remain stable across different versions of [DataFusion]. This allows +users to write libraries that can interface between each other at runtime rather +than require compiling all of the code into a single executable. + +In general, it is recommended to run the same version of DataFusion by both the +producer and consumer of the data and functions shared across the [FFI], but +this is not strictly required. + +See [API Docs] for details and examples. + +## Use Cases + +Two use cases have been identified for this crate, but they are not intended to +be all inclusive. + +1. `datafusion-python` which will use the FFI to provide external services such + as a `TableProvider` without needing to re-export the entire `datafusion-python` + code base. With `datafusion-ffi` these packages do not need `datafusion-python` + as a dependency at all. +2. Users may want to create a modular interface that allows runtime loading of + libraries. For example, you may wish to design a program that only uses the + built in table sources, but also allows for extension from the community led + [datafusion-contrib] repositories. You could enable module loading so that + users could at runtime load a library to access additional data sources. + Alternatively, you could use this approach so that customers could interface + with their own proprietary data sources. + +## Limitations + +One limitation of the approach in this crate is that it is designed specifically +to work across Rust libraries. In general, you can use Rust's [FFI] to +operate across different programming languages, but that is not the design +intent of this crate. Instead, we are using external crates that provide +stable interfaces that closely mirror the Rust native approach. To learn more +about this approach see the [abi_stable] and [async-ffi] crates. + +If you have a library in another language that you wish to interface to +[DataFusion] the recommendation is to create a Rust wrapper crate to interface +with your library and then to connect it to [DataFusion] using this crate. +Alternatively, you could use [bindgen] to interface directly to the [FFI] provided +by this crate, but that is currently not supported. + +## FFI Boundary + +We expect this crate to be used by both sides of the FFI Boundary. This should +provide ergonamic ways to both produce and consume structs and functions across +this layer. + +For example, if you have a library that provides a custom `TableProvider`, you +can expose it by using `FFI_TableProvider::new()`. When you need to consume a +`FFI_TableProvider`, you can access it by converting using +`ForeignTableProvider::from()` which will create a struct that implements +`TableProvider`. + +There is a complete end to end demonstration in the +[examples](https://github.com/apache/datafusion/tree/main/datafusion-examples/examples/ffi). + +## Asynchronous Calls + +Some of the functions with this crate require asynchronous operation. These +will perform similar to their pure rust counterparts by using the [async-ffi] +crate. In general, any call to an asynchronous function in this interface will +not block the rest of the program's execution. + +## Struct Layout + +In this crate we have a variety of structs which closely mimic the behavior of +their internal counterparts. To see detailed notes about how to use them, see +the example in `FFI_TableProvider`. + +[datafusion]: https://datafusion.apache.org +[api docs]: http://docs.rs/datafusion-ffi/latest +[rust abi]: https://doc.rust-lang.org/reference/abi.html +[ffi]: https://doc.rust-lang.org/nomicon/ffi.html +[abi_stable]: https://crates.io/crates/abi_stable +[async-ffi]: https://crates.io/crates/async-ffi +[bindgen]: https://crates.io/crates/bindgen +[datafusion-python]: https://datafusion.apache.org/python/ +[datafusion-contrib]: https://github.com/datafusion-contrib diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs new file mode 100644 index 0000000000000..c5add8782c51b --- /dev/null +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use abi_stable::StableAbi; +use arrow::{ + datatypes::{Schema, SchemaRef}, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, +}; +use log::error; + +/// This is a wrapper struct around FFI_ArrowSchema simply to indicate +/// to the StableAbi macros that the underlying struct is FFI safe. +#[repr(C)] +#[derive(Debug, StableAbi)] +pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); + +impl From for WrappedSchema { + fn from(value: SchemaRef) -> Self { + let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { + Ok(s) => s, + Err(e) => { + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + FFI_ArrowSchema::empty() + } + }; + + WrappedSchema(ffi_schema) + } +} + +impl From for SchemaRef { + fn from(value: WrappedSchema) -> Self { + let schema = match Schema::try_from(&value.0) { + Ok(s) => s, + Err(e) => { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); + Schema::empty() + } + }; + Arc::new(schema) + } +} + +/// This is a wrapper struct for FFI_ArrowArray to indicate to StableAbi +/// that the struct is FFI Safe. For convenience, we also include the +/// schema needed to create a record batch from the array. +#[repr(C)] +#[derive(Debug, StableAbi)] +pub struct WrappedArray { + #[sabi(unsafe_opaque_field)] + pub array: FFI_ArrowArray, + + pub schema: WrappedSchema, +} diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs new file mode 100644 index 0000000000000..5ab321cc0114f --- /dev/null +++ b/datafusion/ffi/src/execution_plan.rs @@ -0,0 +1,368 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, pin::Pin, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use datafusion::error::Result; +use datafusion::{ + error::DataFusionError, + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{DisplayAs, ExecutionPlan, PlanProperties}, +}; + +use crate::{ + plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream, +}; + +/// A stable struct for sharing a [`ExecutionPlan`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_ExecutionPlan { + /// Return the plan properties + pub properties: unsafe extern "C" fn(plan: &Self) -> FFI_PlanProperties, + + /// Return a vector of children plans + pub children: unsafe extern "C" fn(plan: &Self) -> RVec, + + /// Return the plan name. + pub name: unsafe extern "C" fn(plan: &Self) -> RString, + + /// Execute the plan and return a record batch stream. Errors + /// will be returned as a string. + pub execute: unsafe extern "C" fn( + plan: &Self, + partition: usize, + ) -> RResult, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignExecutionPlan`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_ExecutionPlan {} +unsafe impl Sync for FFI_ExecutionPlan {} + +pub struct ExecutionPlanPrivateData { + pub plan: Arc, + pub context: Arc, +} + +unsafe extern "C" fn properties_fn_wrapper( + plan: &FFI_ExecutionPlan, +) -> FFI_PlanProperties { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + + plan.properties().into() +} + +unsafe extern "C" fn children_fn_wrapper( + plan: &FFI_ExecutionPlan, +) -> RVec { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + let ctx = &(*private_data).context; + + let children: Vec<_> = plan + .children() + .into_iter() + .map(|child| FFI_ExecutionPlan::new(Arc::clone(child), Arc::clone(ctx))) + .collect(); + + children.into() +} + +unsafe extern "C" fn execute_fn_wrapper( + plan: &FFI_ExecutionPlan, + partition: usize, +) -> RResult { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + let ctx = &(*private_data).context; + + match plan.execute(partition, Arc::clone(ctx)) { + Ok(rbs) => RResult::ROk(rbs.into()), + Err(e) => RResult::RErr( + format!("Error occurred during FFI_ExecutionPlan execute: {}", e).into(), + ), + } +} +unsafe extern "C" fn name_fn_wrapper(plan: &FFI_ExecutionPlan) -> RString { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + + plan.name().into() +} + +unsafe extern "C" fn release_fn_wrapper(plan: &mut FFI_ExecutionPlan) { + let private_data = Box::from_raw(plan.private_data as *mut ExecutionPlanPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(plan: &FFI_ExecutionPlan) -> FFI_ExecutionPlan { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan_data = &(*private_data); + + FFI_ExecutionPlan::new(Arc::clone(&plan_data.plan), Arc::clone(&plan_data.context)) +} + +impl Clone for FFI_ExecutionPlan { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl FFI_ExecutionPlan { + /// This function is called on the provider's side. + pub fn new(plan: Arc, context: Arc) -> Self { + let private_data = Box::new(ExecutionPlanPrivateData { plan, context }); + + Self { + properties: properties_fn_wrapper, + children: children_fn_wrapper, + name: name_fn_wrapper, + execute: execute_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_ExecutionPlan { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an execution plan provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignExecutionPlan is to be used by the caller of the plan, so it has +/// no knowledge or access to the private data. All interaction with the plan +/// must occur through the functions defined in FFI_ExecutionPlan. +#[derive(Debug)] +pub struct ForeignExecutionPlan { + name: String, + plan: FFI_ExecutionPlan, + properties: PlanProperties, + children: Vec>, +} + +unsafe impl Send for ForeignExecutionPlan {} +unsafe impl Sync for ForeignExecutionPlan {} + +impl DisplayAs for ForeignExecutionPlan { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!( + f, + "FFI_ExecutionPlan(number_of_children={})", + self.children.len(), + ) + } +} + +impl TryFrom<&FFI_ExecutionPlan> for ForeignExecutionPlan { + type Error = DataFusionError; + + fn try_from(plan: &FFI_ExecutionPlan) -> Result { + unsafe { + let name = (plan.name)(plan).into(); + + let properties: PlanProperties = (plan.properties)(plan).try_into()?; + + let children_rvec = (plan.children)(plan); + let children: Result> = children_rvec + .iter() + .map(ForeignExecutionPlan::try_from) + .map(|child| child.map(|c| Arc::new(c) as Arc)) + .collect(); + + Ok(Self { + name, + plan: plan.clone(), + properties, + children: children?, + }) + } + } +} + +impl ExecutionPlan for ForeignExecutionPlan { + fn name(&self) -> &str { + &self.name + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + self.children + .iter() + .map(|p| p as &Arc) + .collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(ForeignExecutionPlan { + plan: self.plan.clone(), + name: self.name.clone(), + children, + properties: self.properties.clone(), + })) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + unsafe { + match (self.plan.execute)(&self.plan, partition) { + RResult::ROk(stream) => { + let stream = Pin::new(Box::new(stream)) as SendableRecordBatchStream; + Ok(stream) + } + RResult::RErr(e) => Err(DataFusionError::Execution(format!( + "Error occurred during FFI call to FFI_ExecutionPlan execute. {}", + e + ))), + } + } + } +} + +#[cfg(test)] +mod tests { + use datafusion::{ + physical_plan::{ + execution_plan::{Boundedness, EmissionType}, + Partitioning, + }, + prelude::SessionContext, + }; + + use super::*; + + #[derive(Debug)] + pub struct EmptyExec { + props: PlanProperties, + } + + impl EmptyExec { + pub fn new(schema: arrow::datatypes::SchemaRef) -> Self { + Self { + props: PlanProperties::new( + datafusion::physical_expr::EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(3), + EmissionType::Incremental, + Boundedness::Bounded, + ), + } + } + } + + impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + "empty-exec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.props + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[test] + fn test_round_trip_ffi_execution_plan() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let ctx = SessionContext::new(); + + let original_plan = Arc::new(EmptyExec::new(schema)); + let original_name = original_plan.name().to_string(); + + let local_plan = FFI_ExecutionPlan::new(original_plan, ctx.task_ctx()); + + let foreign_plan: ForeignExecutionPlan = (&local_plan).try_into()?; + + assert!(original_name == foreign_plan.name()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs new file mode 100644 index 0000000000000..8e09780edf03d --- /dev/null +++ b/datafusion/ffi/src/lib.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +pub mod arrow_wrappers; +pub mod execution_plan; +pub mod plan_properties; +pub mod record_batch_stream; +pub mod session_config; +pub mod table_provider; +pub mod table_source; + +#[cfg(doctest)] +doc_comment::doctest!("../README.md", readme_example_test); diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs new file mode 100644 index 0000000000000..3c7bc886aede2 --- /dev/null +++ b/datafusion/ffi/src/plan_properties.rs @@ -0,0 +1,352 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ + RResult::{self, RErr, ROk}, + RStr, RVec, + }, + StableAbi, +}; +use arrow::datatypes::SchemaRef; +use datafusion::{ + error::{DataFusionError, Result}, + physical_expr::EquivalenceProperties, + physical_plan::{ + execution_plan::{Boundedness, EmissionType}, + PlanProperties, + }, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_sort_exprs, parse_protobuf_partitioning}, + to_proto::{serialize_partitioning, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::{Partitioning, PhysicalSortExprNodeCollection}, +}; +use prost::Message; + +use crate::arrow_wrappers::WrappedSchema; + +/// A stable struct for sharing [`PlanProperties`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PlanProperties { + /// The output partitioning is a [`Partitioning`] protobuf message serialized + /// into bytes to pass across the FFI boundary. + pub output_partitioning: + unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + + /// Return the emission type of the plan. + pub emission_type: unsafe extern "C" fn(plan: &Self) -> FFI_EmissionType, + + /// Indicate boundedness of the plan and its memory requirements. + pub boundedness: unsafe extern "C" fn(plan: &Self) -> FFI_Boundedness, + + /// The output ordering is a [`PhysicalSortExprNodeCollection`] protobuf message + /// serialized into bytes to pass across the FFI boundary. + pub output_ordering: + unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + + /// Return the schema of the plan. + pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, +} + +struct PlanPropertiesPrivateData { + props: PlanProperties, +} + +unsafe extern "C" fn output_partitioning_fn_wrapper( + properties: &FFI_PlanProperties, +) -> RResult, RStr<'static>> { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + + let codec = DefaultPhysicalExtensionCodec {}; + let partitioning_data = + match serialize_partitioning(props.output_partitioning(), &codec) { + Ok(p) => p, + Err(_) => { + return RErr( + "unable to serialize output_partitioning in FFI_PlanProperties" + .into(), + ) + } + }; + let output_partitioning = partitioning_data.encode_to_vec(); + + ROk(output_partitioning.into()) +} + +unsafe extern "C" fn emission_type_fn_wrapper( + properties: &FFI_PlanProperties, +) -> FFI_EmissionType { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + props.emission_type.into() +} + +unsafe extern "C" fn boundedness_fn_wrapper( + properties: &FFI_PlanProperties, +) -> FFI_Boundedness { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + props.boundedness.into() +} + +unsafe extern "C" fn output_ordering_fn_wrapper( + properties: &FFI_PlanProperties, +) -> RResult, RStr<'static>> { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + + let codec = DefaultPhysicalExtensionCodec {}; + let output_ordering = + match props.output_ordering() { + Some(ordering) => { + let physical_sort_expr_nodes = + match serialize_physical_sort_exprs(ordering.to_owned(), &codec) { + Ok(v) => v, + Err(_) => return RErr( + "unable to serialize output_ordering in FFI_PlanProperties" + .into(), + ), + }; + + let ordering_data = PhysicalSortExprNodeCollection { + physical_sort_expr_nodes, + }; + + ordering_data.encode_to_vec() + } + None => Vec::default(), + }; + ROk(output_ordering.into()) +} + +unsafe extern "C" fn schema_fn_wrapper(properties: &FFI_PlanProperties) -> WrappedSchema { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + + let schema: SchemaRef = Arc::clone(props.eq_properties.schema()); + schema.into() +} + +unsafe extern "C" fn release_fn_wrapper(props: &mut FFI_PlanProperties) { + let private_data = + Box::from_raw(props.private_data as *mut PlanPropertiesPrivateData); + drop(private_data); +} + +impl Drop for FFI_PlanProperties { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl From<&PlanProperties> for FFI_PlanProperties { + fn from(props: &PlanProperties) -> Self { + let private_data = Box::new(PlanPropertiesPrivateData { + props: props.clone(), + }); + + FFI_PlanProperties { + output_partitioning: output_partitioning_fn_wrapper, + emission_type: emission_type_fn_wrapper, + boundedness: boundedness_fn_wrapper, + output_ordering: output_ordering_fn_wrapper, + schema: schema_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl TryFrom for PlanProperties { + type Error = DataFusionError; + + fn try_from(ffi_props: FFI_PlanProperties) -> Result { + let ffi_schema = unsafe { (ffi_props.schema)(&ffi_props) }; + let schema = (&ffi_schema.0).try_into()?; + + // TODO Extend FFI to get the registry and codex + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; + let orderings = match ffi_orderings { + ROk(ordering_vec) => { + let proto_output_ordering = + PhysicalSortExprNodeCollection::decode(ordering_vec.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + Some(parse_physical_sort_exprs( + &proto_output_ordering.physical_sort_expr_nodes, + &default_ctx, + &schema, + &codex, + )?) + } + RErr(e) => return Err(DataFusionError::Plan(e.to_string())), + }; + + let ffi_partitioning = unsafe { (ffi_props.output_partitioning)(&ffi_props) }; + let partitioning = match ffi_partitioning { + ROk(partitioning_vec) => { + let proto_output_partitioning = + Partitioning::decode(partitioning_vec.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + parse_protobuf_partitioning( + Some(&proto_output_partitioning), + &default_ctx, + &schema, + &codex, + )? + .ok_or(DataFusionError::Plan( + "Unable to deserialize partitioning protobuf in FFI_PlanProperties" + .to_string(), + )) + } + RErr(e) => Err(DataFusionError::Plan(e.to_string())), + }?; + + let eq_properties = match orderings { + Some(ordering) => { + EquivalenceProperties::new_with_orderings(Arc::new(schema), &[ordering]) + } + None => EquivalenceProperties::new(Arc::new(schema)), + }; + + let emission_type: EmissionType = + unsafe { (ffi_props.emission_type)(&ffi_props).into() }; + + let boundedness: Boundedness = + unsafe { (ffi_props.boundedness)(&ffi_props).into() }; + + Ok(PlanProperties::new( + eq_properties, + partitioning, + emission_type, + boundedness, + )) + } +} + +/// FFI safe version of [`Boundedness`]. +#[repr(C)] +#[allow(non_camel_case_types)] +#[derive(Clone, StableAbi)] +pub enum FFI_Boundedness { + Bounded, + Unbounded { requires_infinite_memory: bool }, +} + +impl From for FFI_Boundedness { + fn from(value: Boundedness) -> Self { + match value { + Boundedness::Bounded => FFI_Boundedness::Bounded, + Boundedness::Unbounded { + requires_infinite_memory, + } => FFI_Boundedness::Unbounded { + requires_infinite_memory, + }, + } + } +} + +impl From for Boundedness { + fn from(value: FFI_Boundedness) -> Self { + match value { + FFI_Boundedness::Bounded => Boundedness::Bounded, + FFI_Boundedness::Unbounded { + requires_infinite_memory, + } => Boundedness::Unbounded { + requires_infinite_memory, + }, + } + } +} + +/// FFI safe version of [`EmissionType`]. +#[repr(C)] +#[allow(non_camel_case_types)] +#[derive(Clone, StableAbi)] +pub enum FFI_EmissionType { + Incremental, + Final, + Both, +} + +impl From for FFI_EmissionType { + fn from(value: EmissionType) -> Self { + match value { + EmissionType::Incremental => FFI_EmissionType::Incremental, + EmissionType::Final => FFI_EmissionType::Final, + EmissionType::Both => FFI_EmissionType::Both, + } + } +} + +impl From for EmissionType { + fn from(value: FFI_EmissionType) -> Self { + match value { + FFI_EmissionType::Incremental => EmissionType::Incremental, + FFI_EmissionType::Final => EmissionType::Final, + FFI_EmissionType::Both => EmissionType::Both, + } + } +} + +#[cfg(test)] +mod tests { + use datafusion::physical_plan::Partitioning; + + use super::*; + + #[test] + fn test_round_trip_ffi_plan_properties() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + let original_props = PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(3), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + let local_props_ptr = FFI_PlanProperties::from(&original_props); + + let foreign_props: PlanProperties = local_props_ptr.try_into()?; + + assert!(format!("{:?}", foreign_props) == format!("{:?}", original_props)); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs new file mode 100644 index 0000000000000..c944e56c5cde9 --- /dev/null +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, task::Poll}; + +use abi_stable::{ + std_types::{ROption, RResult, RString}, + StableAbi, +}; +use arrow::array::{Array, RecordBatch}; +use arrow::{ + array::{make_array, StructArray}, + ffi::{from_ffi, to_ffi}, +}; +use async_ffi::{ContextExt, FfiContext, FfiPoll}; +use datafusion::error::Result; +use datafusion::{ + error::DataFusionError, + execution::{RecordBatchStream, SendableRecordBatchStream}, +}; +use futures::{Stream, TryStreamExt}; + +use crate::arrow_wrappers::{WrappedArray, WrappedSchema}; + +/// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries. +/// We use the async-ffi crate for handling async calls across libraries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_RecordBatchStream { + /// This mirrors the `poll_next` of [`RecordBatchStream`] but does so + /// in a FFI safe manner. + pub poll_next: + unsafe extern "C" fn( + stream: &Self, + cx: &mut FfiContext, + ) -> FfiPoll>>, + + /// Return the schema of the record batch + pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema, + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, +} + +impl From for FFI_RecordBatchStream { + fn from(stream: SendableRecordBatchStream) -> Self { + FFI_RecordBatchStream { + poll_next: poll_next_fn_wrapper, + schema: schema_fn_wrapper, + private_data: Box::into_raw(Box::new(stream)) as *mut c_void, + } + } +} + +unsafe impl Send for FFI_RecordBatchStream {} + +unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema { + let stream = stream.private_data as *const SendableRecordBatchStream; + + (*stream).schema().into() +} + +fn record_batch_to_wrapped_array( + record_batch: RecordBatch, +) -> RResult { + let struct_array = StructArray::from(record_batch); + match to_ffi(&struct_array.to_data()) { + Ok((array, schema)) => RResult::ROk(WrappedArray { + array, + schema: WrappedSchema(schema), + }), + Err(e) => RResult::RErr(e.to_string().into()), + } +} + +// probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { +fn maybe_record_batch_to_wrapped_stream( + record_batch: Option>, +) -> ROption> { + match record_batch { + Some(Ok(record_batch)) => { + ROption::RSome(record_batch_to_wrapped_array(record_batch)) + } + Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())), + None => ROption::RNone, + } +} + +unsafe extern "C" fn poll_next_fn_wrapper( + stream: &FFI_RecordBatchStream, + cx: &mut FfiContext, +) -> FfiPoll>> { + let stream = stream.private_data as *mut SendableRecordBatchStream; + + let poll_result = cx.with_context(|std_cx| { + (*stream) + .try_poll_next_unpin(std_cx) + .map(maybe_record_batch_to_wrapped_stream) + }); + + poll_result.into() +} + +impl RecordBatchStream for FFI_RecordBatchStream { + fn schema(&self) -> arrow::datatypes::SchemaRef { + let wrapped_schema = unsafe { (self.schema)(self) }; + wrapped_schema.into() + } +} + +fn wrapped_array_to_record_batch(array: WrappedArray) -> Result { + let array_data = + unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? }; + let array = make_array(array_data); + let struct_array = array + .as_any() + .downcast_ref::() + .ok_or(DataFusionError::Execution( + "Unexpected array type during record batch collection in FFI_RecordBatchStream" + .to_string(), + ))?; + + Ok(struct_array.into()) +} + +fn maybe_wrapped_array_to_record_batch( + array: ROption>, +) -> Option> { + match array { + ROption::RSome(RResult::ROk(wrapped_array)) => { + Some(wrapped_array_to_record_batch(wrapped_array)) + } + ROption::RSome(RResult::RErr(e)) => { + Some(Err(DataFusionError::Execution(e.to_string()))) + } + ROption::RNone => None, + } +} + +impl Stream for FFI_RecordBatchStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let poll_result = + unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) }; + + match poll_result { + FfiPoll::Ready(array) => { + Poll::Ready(maybe_wrapped_array_to_record_batch(array)) + } + FfiPoll::Pending => Poll::Pending, + FfiPoll::Panicked => Poll::Ready(Some(Err(DataFusionError::Execution( + "Error occurred during poll_next on FFI_RecordBatchStream".to_string(), + )))), + } + } +} diff --git a/datafusion/ffi/src/session_config.rs b/datafusion/ffi/src/session_config.rs new file mode 100644 index 0000000000000..aea03cf94e0af --- /dev/null +++ b/datafusion/ffi/src/session_config.rs @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::HashMap, + ffi::{c_char, c_void, CString}, +}; + +use abi_stable::{ + std_types::{RHashMap, RString}, + StableAbi, +}; +use datafusion::{config::ConfigOptions, error::Result}; +use datafusion::{error::DataFusionError, prelude::SessionConfig}; + +/// A stable struct for sharing [`SessionConfig`] across FFI boundaries. +/// Instead of attempting to expose the entire SessionConfig interface, we +/// convert the config options into a map from a string to string and pass +/// those values across the FFI boundary. On the receiver side, we +/// reconstruct a SessionConfig from those values. +/// +/// It is possible that using different versions of DataFusion across the +/// FFI boundary could have differing expectations of the config options. +/// This is a limitation of this approach, but exposing the entire +/// SessionConfig via a FFI interface would be extensive and provide limited +/// value over this version. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_SessionConfig { + /// Return a hash map from key to value of the config options represented + /// by string values. + pub config_options: unsafe extern "C" fn(config: &Self) -> RHashMap, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignSessionConfig`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_SessionConfig {} +unsafe impl Sync for FFI_SessionConfig {} + +unsafe extern "C" fn config_options_fn_wrapper( + config: &FFI_SessionConfig, +) -> RHashMap { + let private_data = config.private_data as *mut SessionConfigPrivateData; + let config_options = &(*private_data).config; + + let mut options = RHashMap::default(); + for config_entry in config_options.entries() { + if let Some(value) = config_entry.value { + options.insert(config_entry.key.into(), value.into()); + } + } + + options +} + +unsafe extern "C" fn release_fn_wrapper(config: &mut FFI_SessionConfig) { + let private_data = + Box::from_raw(config.private_data as *mut SessionConfigPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_SessionConfig { + let old_private_data = config.private_data as *mut SessionConfigPrivateData; + let old_config = &(*old_private_data).config; + + let private_data = Box::new(SessionConfigPrivateData { + config: old_config.clone(), + }); + + FFI_SessionConfig { + config_options: config_options_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + } +} + +struct SessionConfigPrivateData { + pub config: ConfigOptions, +} + +impl From<&SessionConfig> for FFI_SessionConfig { + fn from(session: &SessionConfig) -> Self { + let mut config_keys = Vec::new(); + let mut config_values = Vec::new(); + for config_entry in session.options().entries() { + if let Some(value) = config_entry.value { + let key_cstr = CString::new(config_entry.key).unwrap_or_default(); + let key_ptr = key_cstr.into_raw() as *const c_char; + config_keys.push(key_ptr); + + config_values + .push(CString::new(value).unwrap_or_default().into_raw() + as *const c_char); + } + } + + let private_data = Box::new(SessionConfigPrivateData { + config: session.options().clone(), + }); + + Self { + config_options: config_options_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + } + } +} + +impl Clone for FFI_SessionConfig { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl Drop for FFI_SessionConfig { + fn drop(&mut self) { + unsafe { (self.release)(self) }; + } +} + +/// A wrapper struct for accessing [`SessionConfig`] across a FFI boundary. +/// The [`SessionConfig`] will be generated from a hash map of the config +/// options in the provider and will be reconstructed on this side of the +/// interface.s +pub struct ForeignSessionConfig(pub SessionConfig); + +impl TryFrom<&FFI_SessionConfig> for ForeignSessionConfig { + type Error = DataFusionError; + + fn try_from(config: &FFI_SessionConfig) -> Result { + let config_options = unsafe { (config.config_options)(config) }; + + let mut options_map = HashMap::new(); + config_options.iter().for_each(|kv_pair| { + options_map.insert(kv_pair.0.to_string(), kv_pair.1.to_string()); + }); + + Ok(Self(SessionConfig::from_string_hash_map(&options_map)?)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_ffi_session_config() -> Result<()> { + let session_config = SessionConfig::new(); + let original_options = session_config.options().entries(); + + let ffi_config: FFI_SessionConfig = (&session_config).into(); + + let foreign_config: ForeignSessionConfig = (&ffi_config).try_into()?; + + let returned_options = foreign_config.0.options().entries(); + + assert!(original_options.len() == returned_options.len()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs new file mode 100644 index 0000000000000..b229d908d10dd --- /dev/null +++ b/datafusion/ffi/src/table_provider.rs @@ -0,0 +1,481 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::SchemaRef; +use async_ffi::{FfiFuture, FutureExt}; +use async_trait::async_trait; +use datafusion::{ + catalog::{Session, TableProvider}, + datasource::TableType, + error::DataFusionError, + execution::session_state::SessionStateBuilder, + logical_expr::TableProviderFilterPushDown, + physical_plan::ExecutionPlan, + prelude::{Expr, SessionContext}, +}; +use datafusion_proto::{ + logical_plan::{ + from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec, + }, + protobuf::LogicalExprList, +}; +use prost::Message; + +use crate::{ + arrow_wrappers::WrappedSchema, + session_config::ForeignSessionConfig, + table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, +}; + +use super::{ + execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan}, + session_config::FFI_SessionConfig, +}; +use datafusion::error::Result; + +/// A stable struct for sharing [`TableProvider`] across FFI boundaries. +/// +/// # Struct Layout +/// +/// The following description applies to all structs provided in this crate. +/// +/// Each of the exposed structs in this crate is provided with a variant prefixed +/// with `Foreign`. This variant is designed to be used by the consumer of the +/// foreign code. The `Foreign` structs should _never_ access the `private_data` +/// fields. Instead they should only access the data returned through the function +/// calls defined on the `FFI_` structs. The second purpose of the `Foreign` +/// structs is to contain additional data that may be needed by the traits that +/// are implemented on them. Some of these traits require borrowing data which +/// can be far more convenient to be locally stored. +/// +/// For example, we have a struct `FFI_TableProvider` to give access to the +/// `TableProvider` functions like `table_type()` and `scan()`. If we write a +/// library that wishes to expose it's `TableProvider`, then we can access the +/// private data that contains the Arc reference to the `TableProvider` via +/// `FFI_TableProvider`. This data is local to the library. +/// +/// If we have a program that accesses a `TableProvider` via FFI, then it +/// will use `ForeignTableProvider`. When using `ForeignTableProvider` we **must** +/// not attempt to access the `private_data` field in `FFI_TableProvider`. If a +/// user is testing locally, you may be able to successfully access this field, but +/// it will only work if you are building against the exact same version of +/// `DataFusion` for both libraries **and** the same compiler. It will not work +/// in general. +/// +/// It is worth noting that which library is the `local` and which is `foreign` +/// depends on which interface we are considering. For example, suppose we have a +/// Python library called `my_provider` that exposes a `TableProvider` called +/// `MyProvider` via `FFI_TableProvider`. Within the library `my_provider` we can +/// access the `private_data` via `FFI_TableProvider`. We connect this to +/// `datafusion-python`, where we access it as a `ForeignTableProvider`. Now when +/// we call `scan()` on this interface, we have to pass it a `FFI_SessionConfig`. +/// The `SessionConfig` is local to `datafusion-python` and **not** `my_provider`. +/// It is important to be careful when expanding these functions to be certain which +/// side of the interface each object refers to. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TableProvider { + /// Return the table schema + pub schema: unsafe extern "C" fn(provider: &Self) -> WrappedSchema, + + /// Perform a scan on the table. See [`TableProvider`] for detailed usage information. + /// + /// # Arguments + /// + /// * `provider` - the table provider + /// * `session_config` - session configuration + /// * `projections` - if specified, only a subset of the columns are returned + /// * `filters_serialized` - filters to apply to the scan, which are a + /// [`LogicalExprList`] protobuf message serialized into bytes to pass + /// across the FFI boundary. + /// * `limit` - if specified, limit the number of rows returned + pub scan: unsafe extern "C" fn( + provider: &Self, + session_config: &FFI_SessionConfig, + projections: RVec, + filters_serialized: RVec, + limit: ROption, + ) -> FfiFuture>, + + /// Return the type of table. See [`TableType`] for options. + pub table_type: unsafe extern "C" fn(provider: &Self) -> FFI_TableType, + + /// Based upon the input filters, identify which are supported. The filters + /// are a [`LogicalExprList`] protobuf message serialized into bytes to pass + /// across the FFI boundary. + pub supports_filters_pushdown: Option< + unsafe extern "C" fn( + provider: &FFI_TableProvider, + filters_serialized: RVec, + ) + -> RResult, RString>, + >, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignExecutionPlan`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_TableProvider {} +unsafe impl Sync for FFI_TableProvider {} + +struct ProviderPrivateData { + provider: Arc, +} + +unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { + let private_data = provider.private_data as *const ProviderPrivateData; + let provider = &(*private_data).provider; + + provider.schema().into() +} + +unsafe extern "C" fn table_type_fn_wrapper( + provider: &FFI_TableProvider, +) -> FFI_TableType { + let private_data = provider.private_data as *const ProviderPrivateData; + let provider = &(*private_data).provider; + + provider.table_type().into() +} + +fn supports_filters_pushdown_internal( + provider: &Arc, + filters_serialized: &[u8], +) -> Result> { + let default_ctx = SessionContext::new(); + let codec = DefaultLogicalExtensionCodec {}; + + let filters = match filters_serialized.is_empty() { + true => vec![], + false => { + let proto_filters = LogicalExprList::decode(filters_serialized) + .map_err(|e| DataFusionError::Plan(e.to_string()))?; + + parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)? + } + }; + let filters_borrowed: Vec<&Expr> = filters.iter().collect(); + + let results: RVec<_> = provider + .supports_filters_pushdown(&filters_borrowed)? + .iter() + .map(|v| v.into()) + .collect(); + + Ok(results) +} + +unsafe extern "C" fn supports_filters_pushdown_fn_wrapper( + provider: &FFI_TableProvider, + filters_serialized: RVec, +) -> RResult, RString> { + let private_data = provider.private_data as *const ProviderPrivateData; + let provider = &(*private_data).provider; + + supports_filters_pushdown_internal(provider, &filters_serialized) + .map_err(|e| e.to_string().into()) + .into() +} + +unsafe extern "C" fn scan_fn_wrapper( + provider: &FFI_TableProvider, + session_config: &FFI_SessionConfig, + projections: RVec, + filters_serialized: RVec, + limit: ROption, +) -> FfiFuture> { + let private_data = provider.private_data as *mut ProviderPrivateData; + let internal_provider = &(*private_data).provider; + let session_config = session_config.clone(); + + async move { + let config = match ForeignSessionConfig::try_from(&session_config) { + Ok(c) => c, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + let session = SessionStateBuilder::new() + .with_default_features() + .with_config(config.0) + .build(); + let ctx = SessionContext::new_with_state(session); + + let filters = match filters_serialized.is_empty() { + true => vec![], + false => { + let default_ctx = SessionContext::new(); + let codec = DefaultLogicalExtensionCodec {}; + + let proto_filters = + match LogicalExprList::decode(filters_serialized.as_ref()) { + Ok(f) => f, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + + match parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec) { + Ok(f) => f, + Err(e) => return RResult::RErr(e.to_string().into()), + } + } + }; + + let projections: Vec<_> = projections.into_iter().collect(); + let maybe_projections = match projections.is_empty() { + true => None, + false => Some(&projections), + }; + + let plan = match internal_provider + .scan(&ctx.state(), maybe_projections, &filters, limit.into()) + .await + { + Ok(p) => p, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + + RResult::ROk(FFI_ExecutionPlan::new(plan, ctx.task_ctx())) + } + .into_ffi() +} + +unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) { + let private_data = Box::from_raw(provider.private_data as *mut ProviderPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_TableProvider { + let old_private_data = provider.private_data as *const ProviderPrivateData; + + let private_data = Box::into_raw(Box::new(ProviderPrivateData { + provider: Arc::clone(&(*old_private_data).provider), + })) as *mut c_void; + + FFI_TableProvider { + schema: schema_fn_wrapper, + scan: scan_fn_wrapper, + table_type: table_type_fn_wrapper, + supports_filters_pushdown: provider.supports_filters_pushdown, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data, + } +} + +impl Drop for FFI_TableProvider { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_TableProvider { + /// Creates a new [`FFI_TableProvider`]. + pub fn new( + provider: Arc, + can_support_pushdown_filters: bool, + ) -> Self { + let private_data = Box::new(ProviderPrivateData { provider }); + + Self { + schema: schema_fn_wrapper, + scan: scan_fn_wrapper, + table_type: table_type_fn_wrapper, + supports_filters_pushdown: match can_support_pushdown_filters { + true => Some(supports_filters_pushdown_fn_wrapper), + false => None, + }, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +/// This wrapper struct exists on the receiver side of the FFI interface, so it has +/// no guarantees about being able to access the data in `private_data`. Any functions +/// defined on this struct must only use the stable functions provided in +/// FFI_TableProvider to interact with the foreign table provider. +#[derive(Debug)] +pub struct ForeignTableProvider(FFI_TableProvider); + +unsafe impl Send for ForeignTableProvider {} +unsafe impl Sync for ForeignTableProvider {} + +impl From<&FFI_TableProvider> for ForeignTableProvider { + fn from(provider: &FFI_TableProvider) -> Self { + Self(provider.clone()) + } +} + +impl Clone for FFI_TableProvider { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +#[async_trait] +impl TableProvider for ForeignTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + let wrapped_schema = unsafe { (self.0.schema)(&self.0) }; + wrapped_schema.into() + } + + fn table_type(&self) -> TableType { + unsafe { (self.0.table_type)(&self.0).into() } + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let session_config: FFI_SessionConfig = session.config().into(); + + let projections: Option> = + projection.map(|p| p.iter().map(|v| v.to_owned()).collect()); + + let codec = DefaultLogicalExtensionCodec {}; + let filter_list = LogicalExprList { + expr: serialize_exprs(filters, &codec)?, + }; + let filters_serialized = filter_list.encode_to_vec().into(); + + let plan = unsafe { + let maybe_plan = (self.0.scan)( + &self.0, + &session_config, + projections.unwrap_or_default(), + filters_serialized, + limit.into(), + ) + .await; + + match maybe_plan { + RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, + RResult::RErr(_) => { + return Err(DataFusionError::Internal( + "Unable to perform scan via FFI".to_string(), + )) + } + } + }; + + Ok(Arc::new(plan)) + } + + /// Tests whether the table provider can make use of a filter expression + /// to optimize data retrieval. + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + unsafe { + let pushdown_fn = match self.0.supports_filters_pushdown { + Some(func) => func, + None => { + return Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } + }; + + let codec = DefaultLogicalExtensionCodec {}; + + let expr_list = LogicalExprList { + expr: serialize_exprs(filters.iter().map(|f| f.to_owned()), &codec)?, + }; + let serialized_filters = expr_list.encode_to_vec(); + + let pushdowns = pushdown_fn(&self.0, serialized_filters.into()); + + match pushdowns { + RResult::ROk(p) => Ok(p.iter().map(|v| v.into()).collect()), + RResult::RErr(e) => Err(DataFusionError::Plan(e.to_string())), + } + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::prelude::{col, lit}; + + use super::*; + + #[tokio::test] + async fn test_round_trip_ffi_table_provider() -> Result<()> { + use arrow::datatypes::Field; + use datafusion::arrow::{ + array::Float32Array, datatypes::DataType, record_batch::RecordBatch, + }; + use datafusion::datasource::MemTable; + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + let ctx = SessionContext::new(); + + let provider = + Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); + + let ffi_provider = FFI_TableProvider::new(provider, true); + + let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + + ctx.register_table("t", Arc::new(foreign_table_provider))?; + + let df = ctx.table("t").await?; + + df.select(vec![col("a")])? + .filter(col("a").gt(lit(3.0)))? + .show() + .await?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/table_source.rs b/datafusion/ffi/src/table_source.rs new file mode 100644 index 0000000000000..a59836622ee65 --- /dev/null +++ b/datafusion/ffi/src/table_source.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::StableAbi; +use datafusion::{datasource::TableType, logical_expr::TableProviderFilterPushDown}; + +/// FFI safe version of [`TableProviderFilterPushDown`]. +#[repr(C)] +#[derive(StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_TableProviderFilterPushDown { + Unsupported, + Inexact, + Exact, +} + +impl From<&FFI_TableProviderFilterPushDown> for TableProviderFilterPushDown { + fn from(value: &FFI_TableProviderFilterPushDown) -> Self { + match value { + FFI_TableProviderFilterPushDown::Unsupported => { + TableProviderFilterPushDown::Unsupported + } + FFI_TableProviderFilterPushDown::Inexact => { + TableProviderFilterPushDown::Inexact + } + FFI_TableProviderFilterPushDown::Exact => TableProviderFilterPushDown::Exact, + } + } +} + +impl From<&TableProviderFilterPushDown> for FFI_TableProviderFilterPushDown { + fn from(value: &TableProviderFilterPushDown) -> Self { + match value { + TableProviderFilterPushDown::Unsupported => { + FFI_TableProviderFilterPushDown::Unsupported + } + TableProviderFilterPushDown::Inexact => { + FFI_TableProviderFilterPushDown::Inexact + } + TableProviderFilterPushDown::Exact => FFI_TableProviderFilterPushDown::Exact, + } + } +} + +/// FFI safe version of [`TableType`]. +#[repr(C)] +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, StableAbi)] +pub enum FFI_TableType { + Base, + View, + Temporary, +} + +impl From for TableType { + fn from(value: FFI_TableType) -> Self { + match value { + FFI_TableType::Base => TableType::Base, + FFI_TableType::View => TableType::View, + FFI_TableType::Temporary => TableType::Temporary, + } + } +} + +impl From for FFI_TableType { + fn from(value: TableType) -> Self { + match value { + TableType::Base => FFI_TableType::Base, + TableType::View => FFI_TableType::View, + TableType::Temporary => FFI_TableType::Temporary, + } + } +} diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml index a8296ce11f30d..cf6eb99e60c62 100644 --- a/datafusion/functions-aggregate-common/Cargo.toml +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -19,7 +19,6 @@ name = "datafusion-functions-aggregate-common" description = "Utility functions for implementing aggregate functions for the DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] -readme = "README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } @@ -43,4 +42,11 @@ arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } + +[dev-dependencies] +criterion = "0.5" rand = { workspace = true } + +[[bench]] +harness = false +name = "accumulate" diff --git a/datafusion/functions-aggregate-common/LICENSE.txt b/datafusion/functions-aggregate-common/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions-aggregate-common/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions-aggregate-common/NOTICE.txt b/datafusion/functions-aggregate-common/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions-aggregate-common/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/functions-aggregate-common/benches/accumulate.rs b/datafusion/functions-aggregate-common/benches/accumulate.rs new file mode 100644 index 0000000000000..f422f8a2a7bfd --- /dev/null +++ b/datafusion/functions-aggregate-common/benches/accumulate.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, Int64Array}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; + +fn generate_group_indices(len: usize) -> Vec { + (0..len).collect() +} + +fn generate_values(len: usize, has_null: bool) -> ArrayRef { + if has_null { + let values = (0..len) + .map(|i| if i % 7 == 0 { None } else { Some(i as i64) }) + .collect::>(); + Arc::new(Int64Array::from(values)) + } else { + let values = (0..len).map(|i| Some(i as i64)).collect::>(); + Arc::new(Int64Array::from(values)) + } +} + +fn generate_filter(len: usize) -> Option { + let values = (0..len) + .map(|i| { + if i % 7 == 0 { + None + } else if i % 5 == 0 { + Some(false) + } else { + Some(true) + } + }) + .collect::>(); + Some(BooleanArray::from(values)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let len = 500_000; + let group_indices = generate_group_indices(len); + let rows_count = group_indices.len(); + let values = generate_values(len, true); + let opt_filter = generate_filter(len); + let mut counts: Vec = vec![0; rows_count]; + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + + c.bench_function("Handle both nulls and filter", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); + + c.bench_function("Handle nulls only", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + None, + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); + + let values = generate_values(len, false); + c.bench_function("Handle filter only", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index ddf0085b9de4c..a230bb0289091 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -18,9 +18,8 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; -use datafusion_physical_expr_common::{ - physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr, -}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; /// [`AccumulatorArgs`] contains information about how an aggregate @@ -53,7 +52,7 @@ pub struct AccumulatorArgs<'a> { /// ``` /// /// If no `ORDER BY` is specified, `ordering_req` will be empty. - pub ordering_req: &'a [PhysicalSortExpr], + pub ordering_req: &'a LexOrdering, /// Whether the aggregation is running in reverse order pub is_reversed: bool, diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs index ee61128979e10..e321df61ddc6a 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs @@ -19,13 +19,13 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array_nullable; +use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::ScalarValue; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet; use std::fmt::Debug; -use std::sync::Arc; +use std::mem::size_of_val; /// Specialized implementation of /// `COUNT DISTINCT` for [`StringArray`] [`LargeStringArray`], @@ -48,8 +48,7 @@ impl Accumulator for BytesDistinctCountAccumulator { fn state(&mut self) -> datafusion_common::Result> { let set = self.0.take(); let arr = set.into_state(); - let list = Arc::new(array_into_list_array_nullable(arr)); - Ok(vec![ScalarValue::List(list)]) + Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()]) } fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { @@ -86,7 +85,7 @@ impl Accumulator for BytesDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() } } @@ -108,8 +107,7 @@ impl Accumulator for BytesViewDistinctCountAccumulator { fn state(&mut self) -> datafusion_common::Result> { let set = self.0.take(); let arr = set.into_state(); - let list = Arc::new(array_into_list_array_nullable(arr)); - Ok(vec![ScalarValue::List(list)]) + Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()]) } fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { @@ -146,6 +144,6 @@ impl Accumulator for BytesViewDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs index d128a8af58eef..e8b6588dc0913 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs @@ -23,6 +23,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; +use std::mem::size_of_val; use std::sync::Arc; use ahash::RandomState; @@ -32,8 +33,8 @@ use arrow::array::PrimitiveArray; use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::utils::memory::estimate_memory_size; +use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::ScalarValue; use datafusion_expr_common::accumulator::Accumulator; @@ -72,8 +73,7 @@ where PrimitiveArray::::from_iter_values(self.values.iter().cloned()) .with_data_type(self.data_type.clone()), ); - let list = Arc::new(array_into_list_array_nullable(arr)); - Ok(vec![ScalarValue::List(list)]) + Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()]) } fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { @@ -117,8 +117,7 @@ where fn size(&self) -> usize { let num_elements = self.values.len(); - let fixed_size = - std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + let fixed_size = size_of_val(self) + size_of_val(&self.values); estimate_memory_size::(num_elements, fixed_size).unwrap() } @@ -160,8 +159,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; - let list = Arc::new(array_into_list_array_nullable(arr)); - Ok(vec![ScalarValue::List(list)]) + Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()]) } fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { @@ -206,8 +204,7 @@ where fn size(&self) -> usize { let num_elements = self.values.len(); - let fixed_size = - std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + let fixed_size = size_of_val(self) + size_of_val(&self.values); estimate_memory_size::(num_elements, fixed_size).unwrap() } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index fbbf4d303515e..aa2f5a586e877 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -23,14 +23,16 @@ pub mod bool_op; pub mod nulls; pub mod prim_op; +use std::mem::{size_of, size_of_val}; + +use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, + compute::take_arrays, datatypes::UInt32Type, }; -use datafusion_common::{ - arrow_datafusion_err, utils::take_arrays, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; @@ -77,7 +79,7 @@ use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; /// /// Logical group Current Min/Max value for that group stored /// number as a ScalarValue which points to an -/// indivdually allocated String +/// individually allocated String /// ///``` /// @@ -122,9 +124,7 @@ impl AccumulatorState { /// Returns the amount of memory taken by this structure and its accumulator fn size(&self) -> usize { - self.accumulator.size() - + std::mem::size_of_val(self) - + self.indices.allocated_size() + self.accumulator.size() + size_of_val(self) + self.indices.allocated_size() } } @@ -238,7 +238,7 @@ impl GroupsAccumulatorAdapter { // reorder the values and opt_filter by batch_indices so that // all values for each group are contiguous, then invoke the // accumulator once per group with values - let values = take_arrays(values, &batch_indices)?; + let values = take_arrays(values, &batch_indices, None)?; let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; // invoke each accumulator with the appropriate rows, first @@ -281,7 +281,7 @@ impl GroupsAccumulatorAdapter { /// See [`Self::allocation_bytes`] for rationale. fn free_allocation(&mut self, size: usize) { // use saturating sub to avoid errors if the accumulators - // report erronious sizes + // report erroneous sizes self.allocation_bytes = self.allocation_bytes.saturating_sub(size) } @@ -405,6 +405,18 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { ) -> Result> { let num_rows = values[0].len(); + // If there are no rows, return empty arrays + if num_rows == 0 { + // create empty accumulator to get the state types + let empty_state = (self.factory)()?.state()?; + let empty_arrays = empty_state + .into_iter() + .map(|state_val| new_empty_array(&state_val.data_type())) + .collect::>(); + + return Ok(empty_arrays); + } + // Each row has its respective group let mut results = vec![]; for row_idx in 0..num_rows { @@ -452,7 +464,7 @@ pub trait VecAllocExt { impl VecAllocExt for Vec { type T = T; fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index a0475fe8e4464..e629e99e1657a 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -95,7 +95,7 @@ impl NullState { /// /// When value_fn is called it also sets /// - /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value pub fn accumulate( &mut self, group_indices: &[usize], @@ -371,6 +371,75 @@ pub fn accumulate( } } +/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`) +/// +/// This method assumes that for any input record index, if any of the value column +/// is null, or it's filtered out by `opt_filter`, then the record would be ignored. +/// (won't be accumulated by `value_fn`) +/// +/// # Arguments +/// +/// * `group_indices` - To which groups do the rows in `value_columns` belong +/// * `value_columns` - The input arrays to accumulate +/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included +/// * `value_fn` - Callback function for each valid row, with parameters: +/// * `group_idx`: The group index for the current row +/// * `batch_idx`: The index of the current row in the input arrays +/// * `columns`: Reference to all input arrays for accessing values +pub fn accumulate_multiple( + group_indices: &[usize], + value_columns: &[&PrimitiveArray], + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, usize, &[&PrimitiveArray]) + Send, +{ + // Calculate `valid_indices` to accumulate, non-valid indices are ignored. + // `valid_indices` is a bit mask corresponding to the `group_indices`. An index + // is considered valid if: + // 1. All columns are non-null at this index. + // 2. Not filtered out by `opt_filter` + + // Take AND from all null buffers of `value_columns`. + let combined_nulls = value_columns + .iter() + .map(|arr| arr.logical_nulls()) + .fold(None, |acc, nulls| { + NullBuffer::union(acc.as_ref(), nulls.as_ref()) + }); + + // Take AND from previous combined nulls and `opt_filter`. + let valid_indices = match (combined_nulls, opt_filter) { + (None, None) => None, + (None, Some(filter)) => Some(filter.clone()), + (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)), + (Some(nulls), Some(filter)) => { + let combined = nulls.inner() & filter.values(); + Some(BooleanArray::new(combined, None)) + } + }; + + for col in value_columns.iter() { + debug_assert_eq!(col.len(), group_indices.len()); + } + + match valid_indices { + None => { + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + value_fn(group_idx, batch_idx, value_columns); + } + } + Some(valid_indices) => { + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + if valid_indices.value(batch_idx) { + value_fn(group_idx, batch_idx, value_columns); + } + } + } + } +} + /// This function is called to update the accumulator state per row /// when the value is not needed (e.g. COUNT) /// @@ -395,19 +464,41 @@ pub fn accumulate_indices( } } (None, Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - // The performance with a filter could be improved by - // iterating over the filter in chunks, rather than a single - // iterator. TODO file a ticket - let iter = group_indices.iter().zip(filter.iter()); - for (&group_index, filter_value) in iter { - if let Some(true) = filter_value { - index_fn(group_index) - } - } + debug_assert_eq!(filter.len(), group_indices.len()); + let group_indices_chunks = group_indices.chunks_exact(64); + let bit_chunks = filter.values().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }, + ); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + index_fn(group_index) + } + }); } (Some(valids), None) => { - assert_eq!(valids.len(), group_indices.len()); + debug_assert_eq!(valids.len(), group_indices.len()); // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum // iterate over in chunks of 64 bits for more efficient null checking let group_indices_chunks = group_indices.chunks_exact(64); @@ -444,20 +535,44 @@ pub fn accumulate_indices( } (Some(valids), Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - assert_eq!(valids.len(), group_indices.len()); - // The performance with a filter could likely be improved by - // iterating over the filter in chunks, rather than using - // iterators. TODO file a ticket - filter + debug_assert_eq!(filter.len(), group_indices.len()); + debug_assert_eq!(valids.len(), group_indices.len()); + + let group_indices_chunks = group_indices.chunks_exact(64); + let valid_bit_chunks = valids.inner().bit_chunks(); + let filter_bit_chunks = filter.values().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks + .zip(valid_bit_chunks.iter()) + .zip(filter_bit_chunks.iter()) + .for_each(|((group_index_chunk, valid_mask), filter_mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (valid_mask & filter_mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }); + + // handle any remaining bits (after the initial 64) + let remainder_valid_bits = valid_bit_chunks.remainder_bits(); + let remainder_filter_bits = filter_bit_chunks.remainder_bits(); + group_indices_remainder .iter() - .zip(group_indices.iter()) - .zip(valids.iter()) - .for_each(|((filter_value, &group_index), is_valid)| { - if let (Some(true), true) = (filter_value, is_valid) { + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = + remainder_valid_bits & remainder_filter_bits & (1 << i) != 0; + if is_valid { index_fn(group_index) } - }) + }); } } } @@ -482,7 +597,7 @@ fn initialize_builder( mod test { use super::*; - use arrow::array::UInt32Array; + use arrow::array::{Int32Array, UInt32Array}; use rand::{rngs::ThreadRng, Rng}; use std::collections::HashSet; @@ -894,4 +1009,107 @@ mod test { .collect() } } + + #[test] + fn test_accumulate_multiple_no_nulls_no_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + let expected = vec![ + (0, vec![1, 10]), + (1, vec![2, 20]), + (0, vec![3, 30]), + (1, vec![4, 40]), + ]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = [values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where both columns are non-null should be accumulated + let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::from(vec![true, false, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where filter is true should be accumulated + let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls_and_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::from(vec![true, true, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where both: + // 1. Filter is true + // 2. Both columns are non-null + // should be accumulated + let expected = [(0, vec![1, 10])]; + assert_eq!(accumulated, expected); + } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 25212f7f0f5ff..6a8946034cbc3 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -15,13 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls +//! [`set_nulls`], other utilities for working with nulls -use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::array::{ + Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, + BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, + StringViewArray, +}; use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, Result}; +use std::sync::Arc; /// Sets the validity mask for a `PrimitiveArray` to `nulls` /// replacing any existing null mask +/// +/// See [`set_nulls_dyn`] for a version that works with `Array` pub fn set_nulls( array: PrimitiveArray, nulls: Option, @@ -91,3 +100,105 @@ pub fn filtered_null_mask( let opt_filter = opt_filter.and_then(filter_to_nulls); NullBuffer::union(opt_filter.as_ref(), input.nulls()) } + +/// Applies optional filter to input, returning a new array of the same type +/// with the same data, but with any values that were filtered out set to null +pub fn apply_filter_as_nulls( + input: &dyn Array, + opt_filter: Option<&BooleanArray>, +) -> Result { + let nulls = filtered_null_mask(opt_filter, input); + set_nulls_dyn(input, nulls) +} + +/// Replaces the nulls in the input array with the given `NullBuffer` +/// +/// TODO: replace when upstreamed in arrow-rs: +pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), input.len()); + } + + let output: ArrayRef = match input.data_type() { + DataType::Utf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeUtf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeStringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::Utf8View => { + let input = input.as_string_view(); + // safety: values / views came from a valid string view array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + + DataType::Binary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeBinary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid large binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeBinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::BinaryView => { + let input = input.as_binary_view(); + // safety: values / views came from a valid binary view array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + _ => { + return not_impl_err!("Applying nulls {:?}", input.data_type()); + } + }; + assert_eq!(input.len(), output.len()); + assert_eq!(input.data_type(), output.data_type()); + + Ok(output) +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 8bbcf756c37c1..078982c983fc7 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}; @@ -195,6 +196,6 @@ where } fn size(&self) -> usize { - self.values.capacity() * std::mem::size_of::() + self.null_state.size() + self.values.capacity() * size_of::() + self.null_state.size() } } diff --git a/datafusion/functions-aggregate-common/src/merge_arrays.rs b/datafusion/functions-aggregate-common/src/merge_arrays.rs index 544bdc182829d..9b9a1240c1a19 100644 --- a/datafusion/functions-aggregate-common/src/merge_arrays.rs +++ b/datafusion/functions-aggregate-common/src/merge_arrays.rs @@ -65,7 +65,7 @@ impl<'a> CustomElement<'a> { // Overwrite ordering implementation such that // - `self.ordering` values are used for comparison, // - When used inside `BinaryHeap` it is a min-heap. -impl<'a> Ord for CustomElement<'a> { +impl Ord for CustomElement<'_> { fn cmp(&self, other: &Self) -> Ordering { // Compares according to custom ordering self.ordering(&self.ordering, &other.ordering) @@ -78,7 +78,7 @@ impl<'a> Ord for CustomElement<'a> { } } -impl<'a> PartialOrd for CustomElement<'a> { +impl PartialOrd for CustomElement<'_> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 620a68e83ecdc..378fc8c42bc66 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -1,17 +1,19 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with this -// work for additional information regarding copyright ownership. The ASF -// licenses this file to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance with the License. -// You may obtain a copy of the License at +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. //! An implementation of the [TDigest sketch algorithm] providing approximate //! quantile calculations. @@ -21,7 +23,7 @@ //! [Facebook's Folly TDigest] implementation. //! //! Alterations include reduction of runtime heap allocations, broader type -//! support, (de-)serialisation support, reduced type conversions and null value +//! support, (de-)serialization support, reduced type conversions and null value //! tolerance. //! //! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 @@ -33,6 +35,7 @@ use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; +use std::mem::{size_of, size_of_val}; pub const DEFAULT_MAX_SIZE: usize = 100; @@ -203,8 +206,7 @@ impl TDigest { /// Size in bytes including `Self`. pub fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.centroids.capacity()) + size_of_val(self) + (size_of::() * self.centroids.capacity()) } } @@ -610,7 +612,7 @@ impl TDigest { ] } - /// Unpack the serialised state of a [`TDigest`] produced by + /// Unpack the serialized state of a [`TDigest`] produced by /// [`Self::to_scalar_state()`]. /// /// # Correctness @@ -644,7 +646,9 @@ impl TDigest { let max = cast_scalar_f64!(&state[3]); let min = cast_scalar_f64!(&state[4]); - assert!(max.total_cmp(&min).is_ge()); + if min.is_finite() && max.is_finite() { + assert!(max.total_cmp(&min).is_ge()); + } Self { max_size, diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index 4fba772d8ddc3..083dac615b5d1 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -30,7 +30,7 @@ use arrow::{ }; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr_common::accumulator::Accumulator; -use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -48,6 +48,7 @@ pub fn get_accum_scalar_values_as_arrays( /// Since `Decimal128Arrays` created from `Vec` have /// default precision and scale, this function adjusts the output to /// match `data_type`, if necessary +#[deprecated(since = "44.0.0", note = "use PrimitiveArray::with_datatype")] pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result { let array = match data_type { DataType::Decimal128(p, s) => Arc::new( @@ -88,7 +89,7 @@ pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result Vec { @@ -107,7 +108,7 @@ pub fn ordering_fields( } /// Selects the sort option attribute from all the given `PhysicalSortExpr`s. -pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { +pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec { ordering_req.iter().map(|item| item.options).collect() } diff --git a/datafusion/functions-aggregate/COMMENTS.md b/datafusion/functions-aggregate/COMMENTS.md index e669e13557115..1cb4cdd7d5a45 100644 --- a/datafusion/functions-aggregate/COMMENTS.md +++ b/datafusion/functions-aggregate/COMMENTS.md @@ -54,7 +54,7 @@ first argument and the definition looks like this: // `input_type` : data type of the first argument let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), - Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ), + Field::new_list_field(args.input_types[0].clone(), true /* nullable of list item */ ), false, // nullable of list itself )]; ``` diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 37e4c7f4a5ad8..bd65490c5a78e 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -42,13 +42,14 @@ ahash = { workspace = true } arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } +datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } +datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } -indexmap = { workspace = true } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-aggregate/LICENSE.txt b/datafusion/functions-aggregate/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions-aggregate/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions-aggregate/NOTICE.txt b/datafusion/functions-aggregate/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions-aggregate/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 65956cb8a1dea..e6b62e6e1856a 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -23,6 +23,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; fn prepare_accumulator() -> Box { @@ -31,7 +32,7 @@ fn prepare_accumulator() -> Box { return_type: &DataType::Int64, schema: &schema, ignore_nulls: false, - ordering_req: &[], + ordering_req: &LexOrdering::default(), is_reversed: false, name: "COUNT(f)", is_distinct: false, diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index 652d447129dc1..1c180126a3136 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -23,6 +23,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::sum::Sum; use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; fn prepare_accumulator(data_type: &DataType) -> Box { @@ -31,7 +32,7 @@ fn prepare_accumulator(data_type: &DataType) -> Box { return_type: data_type, schema: &schema, ignore_nulls: false, - ordering_req: &[], + ordering_req: &LexOrdering::default(), is_reversed: false, name: "SUM(f)", is_distinct: false, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index cf8217fe981de..1d378fff176fa 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -33,11 +33,15 @@ use datafusion_common::{ }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::marker::PhantomData; + make_udaf_expr_and_func!( ApproxDistinct, approx_distinct, @@ -239,6 +243,20 @@ impl Default for ApproxDistinct { } } +#[user_doc( + doc_section(label = "Approximate Functions"), + description = "Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm.", + syntax_example = "approx_distinct(expression)", + sql_example = r#"```sql +> SELECT approx_distinct(column_name) FROM table_name; ++-----------------------------------+ +| approx_distinct(column_name) | ++-----------------------------------+ +| 42 | ++-----------------------------------+ +```"#, + standard_argument(name = "expression",) +)] pub struct ApproxDistinct { signature: Signature, } @@ -303,4 +321,8 @@ impl AggregateUDFImpl for ApproxDistinct { }; Ok(accumulator) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 7a7b12432544a..5d174a7522966 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -27,7 +27,10 @@ use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_macros::user_doc; use crate::approx_percentile_cont::ApproxPercentileAccumulator; @@ -40,6 +43,20 @@ make_udaf_expr_and_func!( ); /// APPROX_MEDIAN aggregate expression +#[user_doc( + doc_section(label = "Approximate Functions"), + description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.", + syntax_example = "approx_median(expression)", + sql_example = r#"```sql +> SELECT approx_median(column_name) FROM table_name; ++-----------------------------------+ +| approx_median(column_name) | ++-----------------------------------+ +| 23.5 | ++-----------------------------------+ +```"#, + standard_argument(name = "expression",) +)] pub struct ApproxMedian { signature: Signature, } @@ -83,7 +100,7 @@ impl AggregateUDFImpl for ApproxMedian { Field::new(format_state_name(args.name, "min"), Float64, false), Field::new_list( format_state_name(args.name, "centroids"), - Field::new("item", Float64, true), + Field::new_list_field(Float64, true), false, ), ]) @@ -116,4 +133,8 @@ impl AggregateUDFImpl for ApproxMedian { acc_args.exprs[0].data_type(acc_args.schema)?, ))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 51d9ac764c409..44a521ff2ddb0 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; use std::sync::Arc; use arrow::array::{Array, RecordBatch}; @@ -38,12 +39,13 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature, + TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; +use datafusion_macros::user_doc; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); @@ -62,6 +64,28 @@ pub fn approx_percentile_cont( approx_percentile_cont_udaf().call(args) } +#[user_doc( + doc_section(label = "Approximate Functions"), + description = "Returns the approximate percentile of input values using the t-digest algorithm.", + syntax_example = "approx_percentile_cont(expression, percentile, centroids)", + sql_example = r#"```sql +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++-------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++-------------------------------------------------+ +| 65.0 | ++-------------------------------------------------+ +```"#, + standard_argument(name = "expression",), + argument( + name = "percentile", + description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ), + argument( + name = "centroids", + description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." + ) +)] pub struct ApproxPercentileCont { signature: Signature, } @@ -207,7 +231,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ @@ -238,7 +262,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new_list( format_state_name(args.name, "centroids"), - Field::new("item", DataType::Float64, true), + Field::new_list_field(DataType::Float64, true), false, ), ]) @@ -268,6 +292,10 @@ impl AggregateUDFImpl for ApproxPercentileCont { } Ok(arg_types[0].clone()) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -455,10 +483,9 @@ impl Accumulator for ApproxPercentileAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.digest.size() - - std::mem::size_of_val(&self.digest) + size_of_val(self) + self.digest.size() - size_of_val(&self.digest) + self.return_type.size() - - std::mem::size_of_val(&self.return_type) + - size_of_val(&self.return_type) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index fee67ba1623db..16dac2c1b8f04 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; use std::sync::Arc; use arrow::{ @@ -29,10 +30,13 @@ use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, +}; use datafusion_functions_aggregate_common::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; +use datafusion_macros::user_doc; use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; @@ -45,6 +49,28 @@ make_udaf_expr_and_func!( ); /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression +#[user_doc( + doc_section(label = "Approximate Functions"), + description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", + syntax_example = "approx_percentile_cont_with_weight(expression, weight, percentile)", + sql_example = r#"```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++----------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++----------------------------------------------------------------------+ +| 78.5 | ++----------------------------------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "The"), + argument( + name = "weight", + description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "percentile", + description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ) +)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, @@ -146,11 +172,15 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { } #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -205,8 +235,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - - std::mem::size_of_val(&self.approx_percentile_cont_accumulator) + size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator) + self.approx_percentile_cont_accumulator.size() } } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 15146fc4a2d89..b75de83f6aceb 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -22,17 +22,19 @@ use arrow::datatypes::DataType; use arrow_schema::{Field, Fields}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::AggregateUDFImpl; use datafusion_expr::{Accumulator, Signature, Volatility}; +use datafusion_expr::{AggregateUDFImpl, Documentation}; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; make_udaf_expr_and_func!( @@ -43,6 +45,20 @@ make_udaf_expr_and_func!( array_agg_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.", + syntax_example = "array_agg(expression [ORDER BY expression])", + sql_example = r#"```sql +> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| array_agg(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| [element1, element2, element3] | ++-----------------------------------------------+ +```"#, + standard_argument(name = "expression",) +)] #[derive(Debug)] /// ARRAY_AGG aggregate expression pub struct ArrayAgg { @@ -75,8 +91,7 @@ impl AggregateUDFImpl for ArrayAgg { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(DataType::List(Arc::new(Field::new( - "item", + Ok(DataType::List(Arc::new(Field::new_list_field( arg_types[0].clone(), true, )))) @@ -87,7 +102,7 @@ impl AggregateUDFImpl for ArrayAgg { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_types[0].clone(), true), + Field::new_list_field(args.input_types[0].clone(), true), true, )]); } @@ -95,7 +110,7 @@ impl AggregateUDFImpl for ArrayAgg { let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_types[0].clone(), true), + Field::new_list_field(args.input_types[0].clone(), true), true, )]; @@ -106,7 +121,7 @@ impl AggregateUDFImpl for ArrayAgg { let orderings = args.ordering_fields.to_vec(); fields.push(Field::new_list( format_state_name(args.name, "array_agg_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), + Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), false, )); @@ -133,7 +148,7 @@ impl AggregateUDFImpl for ArrayAgg { OrderSensitiveArrayAggAccumulator::try_new( &data_type, &ordering_dtypes, - acc_args.ordering_req.to_vec(), + acc_args.ordering_req.clone(), acc_args.is_reversed, ) .map(|acc| Box::new(acc) as _) @@ -142,6 +157,10 @@ impl AggregateUDFImpl for ArrayAgg { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -209,21 +228,20 @@ impl Accumulator for ArrayAggAccumulator { } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array_nullable(concated_array); - Ok(ScalarValue::List(Arc::new(list_array))) + Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar()) } fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() .map(|arr| arr.get_array_memory_size()) .sum::() + self.datatype.size() - - std::mem::size_of_val(&self.datatype) + - size_of_val(&self.datatype) } } @@ -288,10 +306,10 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) + size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - size_of_val(&self.values) + self.datatype.size() - - std::mem::size_of_val(&self.datatype) + - size_of_val(&self.datatype) } } @@ -456,25 +474,23 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } @@ -482,7 +498,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); let num_columns = fields.len(); let struct_field = Fields::from(fields.clone()); @@ -503,9 +519,7 @@ impl OrderSensitiveArrayAggAccumulator { let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( - Arc::new(ordering_array), - )))) + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) } } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddad76a8734b0..18874f831e9d0 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -18,8 +18,8 @@ //! Defines `Avg` & `Mean` aggregate & accumulators use arrow::array::{ - self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray, + BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; use arrow::compute::sum; @@ -33,7 +33,8 @@ use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_typ use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, + ReversedUDAF, Signature, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; @@ -42,9 +43,11 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: }; use datafusion_functions_aggregate_common::utils::DecimalAverager; +use datafusion_macros::user_doc; use log::debug; use std::any::Any; use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; make_udaf_expr_and_func!( @@ -55,6 +58,20 @@ make_udaf_expr_and_func!( avg_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the average of numeric values in the specified column.", + syntax_example = "avg(expression)", + sql_example = r#"```sql +> SELECT avg(column_name) FROM table_name; ++---------------------------+ +| avg(column_name) | ++---------------------------+ +| 42.75 | ++---------------------------+ +```"#, + standard_argument(name = "expression",) +)] #[derive(Debug)] pub struct Avg { signature: Signature, @@ -235,6 +252,10 @@ impl AggregateUDFImpl for Avg { } coerce_avg_type(self.name(), arg_types) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// An accumulator to compute the average @@ -262,7 +283,7 @@ impl Accumulator for AvgAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -340,7 +361,7 @@ impl Accumulator for DecimalAvgAccumu } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -439,7 +460,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -522,7 +543,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); @@ -582,7 +603,6 @@ where } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.sums.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index c5382c168f17a..6298071a223b2 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; @@ -138,13 +139,13 @@ static BIT_AND_DOC: OnceLock = OnceLock::new(); fn get_bit_and_doc() -> &'static Documentation { BIT_AND_DOC.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_GENERAL) - .with_description("Computes the bitwise AND of all non-null input values.") - .with_syntax_example("bit_and(expression)") - .with_standard_argument("expression", "Integer") - .build() - .unwrap() + Documentation::builder( + DOC_SECTION_GENERAL, + "Computes the bitwise AND of all non-null input values.", + "bit_and(expression)", + ) + .with_standard_argument("expression", Some("Integer")) + .build() }) } @@ -152,13 +153,13 @@ static BIT_OR_DOC: OnceLock = OnceLock::new(); fn get_bit_or_doc() -> &'static Documentation { BIT_OR_DOC.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_GENERAL) - .with_description("Computes the bitwise OR of all non-null input values.") - .with_syntax_example("bit_or(expression)") - .with_standard_argument("expression", "Integer") - .build() - .unwrap() + Documentation::builder( + DOC_SECTION_GENERAL, + "Computes the bitwise OR of all non-null input values.", + "bit_or(expression)", + ) + .with_standard_argument("expression", Some("Integer")) + .build() }) } @@ -166,15 +167,13 @@ static BIT_XOR_DOC: OnceLock = OnceLock::new(); fn get_bit_xor_doc() -> &'static Documentation { BIT_XOR_DOC.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_GENERAL) - .with_description( - "Computes the bitwise exclusive OR of all non-null input values.", - ) - .with_syntax_example("bit_xor(expression)") - .with_standard_argument("expression", "Integer") - .build() - .unwrap() + Documentation::builder( + DOC_SECTION_GENERAL, + "Computes the bitwise exclusive OR of all non-null input values.", + "bit_xor(expression)", + ) + .with_standard_argument("expression", Some("Integer")) + .build() }) } @@ -273,7 +272,7 @@ impl AggregateUDFImpl for BitwiseOperation { format!("{} distinct", self.name()).as_str(), ), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.return_type.clone(), true), + Field::new_list_field(args.return_type.clone(), true), false, )]) } else { @@ -347,7 +346,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -392,7 +391,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -446,7 +445,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -509,8 +508,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() + size_of_val(self) + self.values.capacity() * size_of::() } fn state(&mut self) -> Result> { diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 7cc7d9ff7fec3..29dfc68e05762 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -18,6 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; +use std::mem::size_of_val; use arrow::array::ArrayRef; use arrow::array::BooleanArray; @@ -32,10 +33,12 @@ use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; +use datafusion_macros::user_doc; // returns the new value after bool_and/bool_or with the new values, taking nullability into account macro_rules! typed_bool_and_or_batch { @@ -88,6 +91,20 @@ make_udaf_expr_and_func!( bool_or_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns true if all non-null input values are true, otherwise false.", + syntax_example = "bool_and(expression)", + sql_example = r#"```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +```"#, + standard_argument(name = "expression", prefix = "The") +)] /// BOOL_AND aggregate expression #[derive(Debug)] pub struct BoolAnd { @@ -172,6 +189,10 @@ impl AggregateUDFImpl for BoolAnd { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug, Default)] @@ -196,7 +217,7 @@ impl Accumulator for BoolAndAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -208,6 +229,20 @@ impl Accumulator for BoolAndAccumulator { } } +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns true if all non-null input values are true, otherwise false.", + syntax_example = "bool_and(expression)", + sql_example = r#"```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +```"#, + standard_argument(name = "expression", prefix = "The") +)] /// BOOL_OR aggregate expression #[derive(Debug, Clone)] pub struct BoolOr { @@ -293,6 +328,10 @@ impl AggregateUDFImpl for BoolOr { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug, Default)] @@ -317,7 +356,7 @@ impl Accumulator for BoolOrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 88f01b06d2d9b..72c1f6dbaed2b 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -19,13 +19,22 @@ use std::any::Any; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::Arc; -use arrow::compute::{and, filter, is_not_null}; +use arrow::array::{ + downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, Float64Array, + UInt64Array, +}; +use arrow::compute::{and, filter, is_not_null, kernels::cast}; +use arrow::datatypes::{Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple; +use log::debug; use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; @@ -34,9 +43,10 @@ use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, utils::format_state_name, - Accumulator, AggregateUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; +use datafusion_macros::user_doc; make_udaf_expr_and_func!( Correlation, @@ -46,6 +56,21 @@ make_udaf_expr_and_func!( corr_udaf ); +#[user_doc( + doc_section(label = "Statistical Functions"), + description = "Returns the coefficient of correlation between two numeric values.", + syntax_example = "corr(expression1, expression2)", + sql_example = r#"```sql +> SELECT corr(column1, column2) FROM table_name; ++--------------------------------+ +| corr(column1, column2) | ++--------------------------------+ +| 0.85 | ++--------------------------------+ +```"#, + standard_argument(name = "expression1", prefix = "First"), + standard_argument(name = "expression2", prefix = "Second") +)] #[derive(Debug)] pub struct Correlation { signature: Signature, @@ -107,6 +132,22 @@ impl AggregateUDFImpl for Correlation { ), ]) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`"); + Ok(Box::new(CorrelationGroupsAccumulator::new())) + } } /// An accumulator to compute correlation @@ -172,11 +213,10 @@ impl Accumulator for CorrelationAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) - + self.covar.size() - - std::mem::size_of_val(&self.stddev1) + size_of_val(self) - size_of_val(&self.covar) + self.covar.size() + - size_of_val(&self.stddev1) + self.stddev1.size() - - std::mem::size_of_val(&self.stddev2) + - size_of_val(&self.stddev2) + self.stddev2.size() } @@ -232,3 +272,308 @@ impl Accumulator for CorrelationAccumulator { Ok(()) } } + +#[derive(Default)] +pub struct CorrelationGroupsAccumulator { + // Number of elements for each group + // This is also used to track nulls: if a group has 0 valid values accumulated, + // final aggregation result will be null. + count: Vec, + // Sum of x values for each group + sum_x: Vec, + // Sum of y + sum_y: Vec, + // Sum of x*y + sum_xy: Vec, + // Sum of x^2 + sum_xx: Vec, + // Sum of y^2 + sum_yy: Vec, +} + +impl CorrelationGroupsAccumulator { + pub fn new() -> Self { + Default::default() + } +} + +/// Specialized version of `accumulate_multiple` for correlation's merge_batch +/// +/// Note: Arrays in `state_arrays` should not have null values, because they are all +/// intermediate states created within the accumulator, instead of inputs from +/// outside. +fn accumulate_correlation_states( + group_indices: &[usize], + state_arrays: ( + &UInt64Array, // count + &Float64Array, // sum_x + &Float64Array, // sum_y + &Float64Array, // sum_xy + &Float64Array, // sum_xx + &Float64Array, // sum_yy + ), + mut value_fn: impl FnMut(usize, u64, &[f64]), +) { + let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays; + + assert_eq!(counts.null_count(), 0); + assert_eq!(sum_x.null_count(), 0); + assert_eq!(sum_y.null_count(), 0); + assert_eq!(sum_xy.null_count(), 0); + assert_eq!(sum_xx.null_count(), 0); + assert_eq!(sum_yy.null_count(), 0); + + let counts_values = counts.values().as_ref(); + let sum_x_values = sum_x.values().as_ref(); + let sum_y_values = sum_y.values().as_ref(); + let sum_xy_values = sum_xy.values().as_ref(); + let sum_xx_values = sum_xx.values().as_ref(); + let sum_yy_values = sum_yy.values().as_ref(); + + for (idx, &group_idx) in group_indices.iter().enumerate() { + let row = [ + sum_x_values[idx], + sum_y_values[idx], + sum_xy_values[idx], + sum_xx_values[idx], + sum_yy_values[idx], + ]; + value_fn(group_idx, counts_values[idx], &row); + } +} + +/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient +/// between two numeric columns. +/// +/// Online algorithm for correlation: +/// +/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2)) +/// where: +/// n = number of observations +/// sum_x = sum of x values +/// sum_y = sum of y values +/// sum_xy = sum of (x * y) +/// sum_xx = sum of x^2 values +/// sum_yy = sum of y^2 values +/// +/// Reference: +impl GroupsAccumulator for CorrelationGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + let array_x = &cast(&values[0], &DataType::Float64)?; + let array_x = downcast_array::(array_x); + let array_y = &cast(&values[1], &DataType::Float64)?; + let array_y = downcast_array::(array_y); + + accumulate_multiple( + group_indices, + &[&array_x, &array_y], + opt_filter, + |group_index, batch_index, columns| { + let x = columns[0].value(batch_index); + let y = columns[1].value(batch_index); + self.count[group_index] += 1; + self.sum_x[group_index] += x; + self.sum_y[group_index] += y; + self.sum_xy[group_index] += x * y; + self.sum_xx[group_index] += x * x; + self.sum_yy[group_index] += y * y; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Resize vectors to accommodate total number of groups + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + // Extract arrays from input values + let partial_counts = values[0].as_primitive::(); + let partial_sum_x = values[1].as_primitive::(); + let partial_sum_y = values[2].as_primitive::(); + let partial_sum_xy = values[3].as_primitive::(); + let partial_sum_xx = values[4].as_primitive::(); + let partial_sum_yy = values[5].as_primitive::(); + + assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage"); + + accumulate_correlation_states( + group_indices, + ( + partial_counts, + partial_sum_x, + partial_sum_y, + partial_sum_xy, + partial_sum_xx, + partial_sum_yy, + ), + |group_index, count, values| { + self.count[group_index] += count; + self.sum_x[group_index] += values[0]; + self.sum_y[group_index] += values[1]; + self.sum_xy[group_index] += values[2]; + self.sum_xx[group_index] += values[3]; + self.sum_yy[group_index] += values[4]; + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + let mut values = Vec::with_capacity(n); + let mut nulls = BooleanBufferBuilder::new(n); + + // Notes for `Null` handling: + // - If the `count` state of a group is 0, no valid records are accumulated + // for this group, so the aggregation result is `Null`. + // - Correlation can't be calculated when a group only has 1 record, or when + // the `denominator` state is 0. In these cases, the final aggregation + // result should be `Null` (according to PostgreSQL's behavior). + // + // TODO: Old datafusion implementation returns 0.0 for these invalid cases. + // Update this to match PostgreSQL's behavior. + for i in 0..n { + if self.count[i] < 2 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + continue; + } + + let count = self.count[i]; + let sum_x = self.sum_x[i]; + let sum_y = self.sum_y[i]; + let sum_xy = self.sum_xy[i]; + let sum_xx = self.sum_xx[i]; + let sum_yy = self.sum_yy[i]; + + let mean_x = sum_x / count as f64; + let mean_y = sum_y / count as f64; + + let numerator = sum_xy - sum_x * mean_y; + let denominator = + ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt(); + + if denominator == 0.0 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + } else { + values.push(numerator / denominator); + nulls.append(true); + } + } + + Ok(Arc::new(Float64Array::new( + values.into(), + Some(nulls.finish().into()), + ))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + Ok(vec![ + Arc::new(UInt64Array::from(self.count[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())), + ]) + } + + fn size(&self) -> usize { + size_of_val(&self.count) + + size_of_val(&self.sum_x) + + size_of_val(&self.sum_y) + + size_of_val(&self.sum_xy) + + size_of_val(&self.sum_xx) + + size_of_val(&self.sum_yy) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, UInt64Array}; + + #[test] + fn test_accumulate_correlation_states() { + // Test data + let group_indices = vec![0, 1, 0, 1]; + let counts = UInt64Array::from(vec![1, 2, 3, 4]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let mut accumulated = vec![]; + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |group_idx, count, values| { + accumulated.push((group_idx, count, values.to_vec())); + }, + ); + + let expected = vec![ + (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]), + (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]), + (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]), + (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]), + ]; + assert_eq!(accumulated, expected); + + // Test that function panics with null values + let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let result = std::panic::catch_unwind(|| { + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |_, _, _| {}, + ) + }); + assert!(result.is_err()); + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 282fa2b95be5b..ea8faa761bb23 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -18,10 +18,13 @@ use ahash::RandomState; use datafusion_common::stats::Precision; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; use std::collections::HashSet; +use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; -use std::{fmt::Debug, sync::Arc}; +use std::sync::Arc; use arrow::{ array::{ArrayRef, AsArray}, @@ -46,7 +49,7 @@ use datafusion_common::{ use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - EmitTo, GroupsAccumulator, Signature, Volatility, + Documentation, EmitTo, GroupsAccumulator, Signature, Volatility, }; use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ @@ -76,6 +79,27 @@ pub fn count_distinct(expr: Expr) -> Expr { )) } +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.", + syntax_example = "count(expression)", + sql_example = r#"```sql +> SELECT count(column_name) FROM table_name; ++-----------------------+ +| count(column_name) | ++-----------------------+ +| 100 | ++-----------------------+ + +> SELECT count(*) FROM table_name; ++------------------+ +| count(*) | ++------------------+ +| 120 | ++------------------+ +```"#, + standard_argument(name = "expression",) +)] pub struct Count { signature: Signature, } @@ -99,8 +123,7 @@ impl Count { pub fn new() -> Self { Self { signature: Signature::one_of( - // TypeSignature::Any(0) is required to handle `Count()` with no args - vec![TypeSignature::VariadicAny, TypeSignature::Any(0)], + vec![TypeSignature::VariadicAny, TypeSignature::Nullary], Volatility::Immutable, ), } @@ -133,7 +156,7 @@ impl AggregateUDFImpl for Count { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_types[0].clone(), true), + Field::new_list_field(args.input_types[0].clone(), true), false, )]) } else { @@ -324,6 +347,10 @@ impl AggregateUDFImpl for Count { } None } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -357,7 +384,7 @@ impl Accumulator for CountAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = downcast_value!(states[0], Int64Array); - let delta = &arrow::compute::sum(counts); + let delta = &compute::sum(counts); if let Some(d) = delta { self.count += *d; } @@ -373,7 +400,7 @@ impl Accumulator for CountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -430,7 +457,8 @@ impl GroupsAccumulator for CountGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&BooleanArray>, + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "one argument to merge_batch"); @@ -443,22 +471,11 @@ impl GroupsAccumulator for CountGroupsAccumulator { // Adds the counts with the partial counts self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } + group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ); Ok(()) } @@ -546,7 +563,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() } } @@ -590,28 +607,28 @@ impl DistinctCountAccumulator { // number of batches This method is faster than .full_size(), however it is // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) .unwrap_or(0) - + std::mem::size_of::() + + size_of::() } // calculates the size as accurately as possible. Note that calling this // method is expensive fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) .sum::() - + std::mem::size_of::() + + size_of::() } } @@ -678,3 +695,17 @@ impl Accumulator for DistinctCountAccumulator { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + + #[test] + fn count_accumulator_nulls() -> Result<()> { + let mut accumulator = CountAccumulator::new(); + accumulator.update_batch(&[Arc::new(NullArray::new(10))])?; + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index d0abb079ef155..d4ae27533c6db 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -18,6 +18,7 @@ //! [`CovarianceSample`]: covariance sample aggregations. use std::fmt::Debug; +use std::mem::size_of_val; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, @@ -33,9 +34,10 @@ use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, utils::format_state_name, - Accumulator, AggregateUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; +use datafusion_macros::user_doc; make_udaf_expr_and_func!( CovarianceSample, @@ -53,6 +55,21 @@ make_udaf_expr_and_func!( covar_pop_udaf ); +#[user_doc( + doc_section(label = "Statistical Functions"), + description = "Returns the sample covariance of a set of number pairs.", + syntax_example = "covar_samp(expression1, expression2)", + sql_example = r#"```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +```"#, + standard_argument(name = "expression1", prefix = "First"), + standard_argument(name = "expression2", prefix = "Second") +)] pub struct CovarianceSample { signature: Signature, aliases: Vec, @@ -124,8 +141,27 @@ impl AggregateUDFImpl for CovarianceSample { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } +#[user_doc( + doc_section(label = "Statistical Functions"), + description = "Returns the sample covariance of a set of number pairs.", + syntax_example = "covar_samp(expression1, expression2)", + sql_example = r#"```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +```"#, + standard_argument(name = "expression1", prefix = "First"), + standard_argument(name = "expression2", prefix = "Second") +)] pub struct CovariancePopulation { signature: Signature, } @@ -193,6 +229,10 @@ impl AggregateUDFImpl for CovariancePopulation { StatsType::Population, )?)) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// An accumulator to compute covariance @@ -206,7 +246,7 @@ impl AggregateUDFImpl for CovariancePopulation { /// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. /// /// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, -/// parallelizable and numerically stable. +/// parallelize and numerically stable. #[derive(Debug)] pub struct CovarianceAccumulator { @@ -388,6 +428,6 @@ impl Accumulator for CovarianceAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 41ac7875795d1..8ef139ae61230 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -19,23 +19,25 @@ use std::any::Any; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn}; +use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::{compare_rows, get_row_at_idx, take_arrays}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, - Signature, SortExpr, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature, + SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_macros::user_doc; +use datafusion_physical_expr_common::sort_expr::LexOrdering; create_func!(FirstValue, first_value_udaf); @@ -53,6 +55,20 @@ pub fn first_value(expression: Expr, order_by: Option>) -> Expr { } } +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", + syntax_example = "first_value(expression [ORDER BY expression])", + sql_example = r#"```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +```"#, + standard_argument(name = "expression",) +)] pub struct FirstValue { signature: Signature, requirement_satisfied: bool, @@ -77,15 +93,7 @@ impl Default for FirstValue { impl FirstValue { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // TODO: we can introduce more strict signature that only numeric of array types are allowed - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), - TypeSignature::Numeric(1), - TypeSignature::Uniform(1, vec![DataType::Utf8]), - ], - Volatility::Immutable, - ), + signature: Signature::any(1, Volatility::Immutable), requirement_satisfied: false, } } @@ -128,7 +136,7 @@ impl AggregateUDFImpl for FirstValue { FirstValueAccumulator::try_new( acc_args.return_type, &ordering_dtypes, - acc_args.ordering_req.to_vec(), + acc_args.ordering_req.clone(), acc_args.ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) @@ -165,6 +173,10 @@ impl AggregateUDFImpl for FirstValue { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -284,7 +296,7 @@ impl Accumulator for FirstValueAccumulator { if compare_rows( &self.orderings, orderings, - &get_sort_options(&self.ordering_req), + &get_sort_options(self.ordering_req.as_ref()), )? .is_gt() { @@ -302,21 +314,23 @@ impl Accumulator for FirstValueAccumulator { let flags = states[is_set_idx].as_boolean(); let filtered_states = filter_states_according_to_is_set(states, flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_cols = - convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + let sort_cols = convert_to_sort_cols( + &filtered_states[1..is_set_idx], + self.ordering_req.as_ref(), + ); let ordered_states = if sort_cols.is_empty() { // When no ordering is given, use the existing state as is: filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - take_arrays(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { let first_row = get_row_at_idx(&ordered_states, 0)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; - let sort_options = get_sort_options(&self.ordering_req); + let sort_options = get_sort_options(self.ordering_req.as_ref()); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() @@ -335,10 +349,10 @@ impl Accumulator for FirstValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.first) + size_of_val(self) - size_of_val(&self.first) + self.first.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -349,6 +363,20 @@ make_udaf_expr_and_func!( last_value_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", + syntax_example = "last_value(expression [ORDER BY expression])", + sql_example = r#"```sql +> SELECT last_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| last_value(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| last_element | ++-----------------------------------------------+ +```"#, + standard_argument(name = "expression",) +)] pub struct LastValue { signature: Signature, requirement_satisfied: bool, @@ -373,15 +401,7 @@ impl Default for LastValue { impl LastValue { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // TODO: we can introduce more strict signature that only numeric of array types are allowed - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), - TypeSignature::Numeric(1), - TypeSignature::Uniform(1, vec![DataType::Utf8]), - ], - Volatility::Immutable, - ), + signature: Signature::any(1, Volatility::Immutable), requirement_satisfied: false, } } @@ -422,7 +442,7 @@ impl AggregateUDFImpl for LastValue { LastValueAccumulator::try_new( acc_args.return_type, &ordering_dtypes, - acc_args.ordering_req.to_vec(), + acc_args.ordering_req.clone(), acc_args.ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) @@ -466,6 +486,10 @@ impl AggregateUDFImpl for LastValue { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -587,7 +611,7 @@ impl Accumulator for LastValueAccumulator { if compare_rows( &self.orderings, orderings, - &get_sort_options(&self.ordering_req), + &get_sort_options(self.ordering_req.as_ref()), )? .is_lt() { @@ -605,15 +629,17 @@ impl Accumulator for LastValueAccumulator { let flags = states[is_set_idx].as_boolean(); let filtered_states = filter_states_according_to_is_set(states, flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_cols = - convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + let sort_cols = convert_to_sort_cols( + &filtered_states[1..is_set_idx], + self.ordering_req.as_ref(), + ); let ordered_states = if sort_cols.is_empty() { // When no ordering is given, use existing state as is: filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - take_arrays(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { @@ -621,10 +647,11 @@ impl Accumulator for LastValueAccumulator { let last_row = get_row_at_idx(&ordered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; - let sort_options = get_sort_options(&self.ordering_req); + let sort_options = get_sort_options(self.ordering_req.as_ref()); // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set + || self.requirement_satisfied || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() { // Update with last value in the state. Note that we should exclude the @@ -641,10 +668,10 @@ impl Accumulator for LastValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + size_of_val(self) - size_of_val(&self.last) + self.last.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -661,10 +688,7 @@ fn filter_states_according_to_is_set( } /// Combines array refs and their corresponding orderings to construct `SortColumn`s. -fn convert_to_sort_cols( - arrs: &[ArrayRef], - sort_exprs: &[PhysicalSortExpr], -) -> Vec { +fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec { arrs.iter() .zip(sort_exprs.iter()) .map(|(item, sort_expr)| SortColumn { @@ -682,10 +706,18 @@ mod tests { #[test] fn test_first_last_value_value() -> Result<()> { - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -722,14 +754,22 @@ mod tests { .collect::>(); // FirstValueAccumulator - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = first_accumulator.state()?; - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = first_accumulator.state()?; @@ -738,28 +778,40 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); } - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; first_accumulator.merge_batch(&states)?; let merged_state = first_accumulator.state()?; assert_eq!(merged_state.len(), state1.len()); // LastValueAccumulator - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = last_accumulator.state()?; - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = last_accumulator.state()?; @@ -768,14 +820,18 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); } - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 6fb7c3800f4ed..445774ff11e7d 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -26,7 +26,10 @@ use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_macros::user_doc; make_udaf_expr_and_func!( Grouping, @@ -36,12 +39,33 @@ make_udaf_expr_and_func!( grouping_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + syntax_example = "grouping(expression)", + sql_example = r#"```sql +> SELECT column_name, GROUPING(column_name) AS group_column + FROM table_name + GROUP BY GROUPING SETS ((column_name), ()); ++-------------+-------------+ +| column_name | group_column | ++-------------+-------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+-------------+ +```"#, + argument( + name = "expression", + description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function." + ) +)] pub struct Grouping { signature: Signature, } impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Grouping") .field("name", &self.name()) .field("signature", &self.signature) @@ -59,7 +83,7 @@ impl Grouping { /// Create a new GROUPING aggregate function. pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::variadic_any(Volatility::Immutable), } } } @@ -94,4 +118,8 @@ impl AggregateUDFImpl for Grouping { "physical plan is not yet implemented for GROUPING aggregate function" ) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index ca0276d326a49..746873442d9aa 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index ffb5183278e67..b464dde6ccab5 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -65,19 +65,14 @@ macro_rules! create_func { }; ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { paste::paste! { - /// Singleton instance of [$UDAF], ensures the UDAF is only created once - /// named STATIC_$(UDAF). For example `STATIC_FirstValue` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDAF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - #[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDAF >] - .get_or_init(|| { + // Singleton instance of [$UDAF], ensures the UDAF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) - }) - .clone() + }); + std::sync::Arc::clone(&INSTANCE) } } } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 7dd0de14c3c0c..70f192c32ae1b 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::fmt::Formatter; -use std::{fmt::Debug, sync::Arc}; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::mem::{size_of, size_of_val}; +use std::sync::Arc; use arrow::array::{downcast_integer, ArrowNumericType}; use arrow::{ @@ -30,15 +31,16 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - Signature, Volatility, + Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::utils::Hashable; +use datafusion_macros::user_doc; make_udaf_expr_and_func!( Median, @@ -48,6 +50,20 @@ make_udaf_expr_and_func!( median_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the median value in the specified column.", + syntax_example = "median(expression)", + sql_example = r#"```sql +> SELECT median(column_name) FROM table_name; ++----------------------+ +| median(column_name) | ++----------------------+ +| 45.5 | ++----------------------+ +```"#, + standard_argument(name = "expression", prefix = "The") +)] /// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a /// lot of memory because all values need to be stored in memory before a result can be /// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more @@ -61,7 +77,7 @@ pub struct Median { } impl Debug for Median { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { f.debug_struct("Median") .field("name", &self.name()) .field("signature", &self.signature) @@ -102,7 +118,7 @@ impl AggregateUDFImpl for Median { fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", args.input_types[0].clone(), true); + let field = Field::new_list_field(args.input_types[0].clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -152,6 +168,10 @@ impl AggregateUDFImpl for Median { fn aliases(&self) -> &[String] { &[] } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// The median accumulator accumulates the raw input values @@ -166,7 +186,7 @@ struct MedianAccumulator { all_values: Vec, } -impl std::fmt::Debug for MedianAccumulator { +impl Debug for MedianAccumulator { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MedianAccumulator({})", self.data_type) } @@ -206,8 +226,7 @@ impl Accumulator for MedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.all_values.capacity() * std::mem::size_of::() + size_of_val(self) + self.all_values.capacity() * size_of::() } } @@ -223,7 +242,7 @@ struct DistinctMedianAccumulator { distinct_values: HashSet>, } -impl std::fmt::Debug for DistinctMedianAccumulator { +impl Debug for DistinctMedianAccumulator { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DistinctMedianAccumulator({})", self.data_type) } @@ -278,11 +297,25 @@ impl Accumulator for DistinctMedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.distinct_values.capacity() * std::mem::size_of::() + size_of_val(self) + self.distinct_values.capacity() * size_of::() } } +/// Get maximum entry in the slice, +fn slice_max(array: &[T::Native]) -> T::Native +where + T: ArrowPrimitiveType, + T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison +{ + // Make sure that, array is not empty. + debug_assert!(!array.is_empty()); + // `.unwrap()` is safe here as the array is supposed to be non-empty + *array + .iter() + .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less)) + .unwrap() +} + fn calculate_median( mut values: Vec, ) -> Option { @@ -293,8 +326,11 @@ fn calculate_median( None } else if len % 2 == 0 { let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); - let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); - let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + // Get the maximum of the low (left side after bi-partitioning) + let left_max = slice_max::(low); + let median = left_max + .add_wrapping(*high) + .div_wrapping(T::Native::usize_as(2)); Some(median) } else { let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index e0b029f0909db..c4e05bd57de6c 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -2,6 +2,7 @@ // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // @@ -17,20 +18,7 @@ //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +mod min_max_bytes; use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, @@ -55,6 +43,7 @@ use datafusion_common::{ }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_physical_expr::expressions; +use std::cmp::Ordering; use std::fmt::Debug; use arrow::datatypes::i256; @@ -64,12 +53,16 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; +use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, + Volatility, }; use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; +use datafusion_macros::user_doc; use half::f16; +use std::mem::size_of_val; use std::ops::Deref; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -93,6 +86,20 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { } } +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the maximum value in the specified column.", + syntax_example = "max(expression)", + sql_example = r#"```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +```"#, + standard_argument(name = "expression",) +)] // MAX aggregate UDF #[derive(Debug)] pub struct Max { @@ -116,12 +123,16 @@ impl Default for Max { /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_max_accumulator { +macro_rules! primitive_max_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { - if *cur < new { - *cur = new + match (new).partial_cmp(cur) { + Some(Ordering::Greater) | None => { + // new is Greater or None + *cur = new + } + _ => {} } }) // Initialize each accumulator to $NATIVE::MIN @@ -135,12 +146,16 @@ macro_rules! instantiate_max_accumulator { /// /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_min_accumulator { +macro_rules! primitive_min_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { - if *cur > new { - *cur = new + match (new).partial_cmp(cur) { + Some(Ordering::Less) | None => { + // new is Less or NaN + *cur = new + } + _ => {} } }) // Initialize each accumulator to $NATIVE::MAX @@ -243,6 +258,12 @@ impl AggregateUDFImpl for Max { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -254,58 +275,58 @@ impl AggregateUDFImpl for Max { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_max_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_max_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_max_accumulator!(data_type, f16, Float16Type) + primitive_max_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_max_accumulator!(data_type, f32, Float32Type) + primitive_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(data_type, f64, Float64Type) + primitive_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_max_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(data_type, i32, Time32SecondType) + primitive_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) + primitive_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) + primitive_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) + primitive_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampSecondType) + primitive_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) + primitive_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) + primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) + primitive_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_max_accumulator!(data_type, i128, Decimal128Type) + primitive_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(data_type, i256, Decimal256Type) + primitive_max_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), @@ -336,6 +357,10 @@ impl AggregateUDFImpl for Max { fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { self.value_from_statistics(statistics_args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } // Statically-typed version of min/max(array) -> ScalarValue for string types @@ -347,7 +372,7 @@ macro_rules! typed_min_max_batch_string { ScalarValue::$SCALAR(value) }}; } -// Statically-typed version of min/max(array) -> ScalarValue for binay types. +// Statically-typed version of min/max(array) -> ScalarValue for binary types. macro_rules! typed_min_max_batch_binary { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); @@ -898,7 +923,7 @@ impl Accumulator for MaxAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + size_of_val(self) - size_of_val(&self.max) + self.max.size() } } @@ -957,10 +982,24 @@ impl Accumulator for SlidingMaxAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + size_of_val(self) - size_of_val(&self.max) + self.max.size() } } +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the minimum value in the specified column.", + syntax_example = "min(expression)", + sql_example = r#"```sql +> SELECT min(column_name) FROM table_name; ++----------------------+ +| min(column_name) | ++----------------------+ +| 12 | ++----------------------+ +```"#, + standard_argument(name = "expression",) +)] #[derive(Debug)] pub struct Min { signature: Signature, @@ -1040,6 +1079,12 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -1051,58 +1096,58 @@ impl AggregateUDFImpl for Min { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_min_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_min_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_min_accumulator!(data_type, f16, Float16Type) + primitive_min_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_min_accumulator!(data_type, f32, Float32Type) + primitive_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_min_accumulator!(data_type, f64, Float64Type) + primitive_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_min_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_min_accumulator!(data_type, i32, Time32SecondType) + primitive_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) + primitive_min_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) + primitive_min_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) + primitive_min_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampSecondType) + primitive_min_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) + primitive_min_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) + primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) + primitive_min_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_min_accumulator!(data_type, i128, Decimal128Type) + primitive_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_min_accumulator!(data_type, i256, Decimal256Type) + primitive_min_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), @@ -1134,7 +1179,12 @@ impl AggregateUDFImpl for Min { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } + /// An accumulator to compute the minimum value #[derive(Debug)] pub struct MinAccumulator { @@ -1173,7 +1223,7 @@ impl Accumulator for MinAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + size_of_val(self) - size_of_val(&self.min) + self.min.size() } } @@ -1236,7 +1286,7 @@ impl Accumulator for SlidingMinAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + size_of_val(self) - size_of_val(&self.min) + self.min.size() } } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs new file mode 100644 index 0000000000000..725b7a29bd479 --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -0,0 +1,520 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, + LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, +}; +use arrow_schema::DataType; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; +use std::mem::size_of; +use std::sync::Arc; + +/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], +/// [`BinaryArray`], [`StringViewArray`], etc) +/// +/// This implementation dispatches to the appropriate specialized code in +/// [`MinMaxBytesState`] based on data type and comparison function +/// +/// [`StringArray`]: arrow::array::StringArray +/// [`BinaryArray`]: arrow::array::BinaryArray +/// [`StringViewArray`]: arrow::array::StringViewArray +#[derive(Debug)] +pub(crate) struct MinMaxBytesAccumulator { + /// Inner data storage. + inner: MinMaxBytesState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxBytesAccumulator { + /// Create a new accumulator for computing `min(val)` + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: true, + } + } + + /// Create a new accumulator fo computing `max(val)` + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxBytesAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + // dispatch to appropriate kernel / specialized implementation + fn string_min(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a < b + } + } + fn string_max(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a > b + } + } + fn binary_min(a: &[u8], b: &[u8]) -> bool { + a < b + } + + fn binary_max(a: &[u8], b: &[u8]) -> bool { + a > b + } + + fn str_to_bytes<'a>( + it: impl Iterator>, + ) -> impl Iterator> { + it.map(|s| s.map(|s| s.as_bytes())) + } + + match (self.is_min, &self.inner.data_type) { + // Utf8/LargeUtf8/Utf8View Min + (true, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_min, + ), + + // Utf8/LargeUtf8/Utf8View Max + (false, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_max, + ), + + // Binary/LargeBinary/BinaryView Min + (true, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_min, + ), + + // Binary/LargeBinary/BinaryView Max + (false, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_max, + ), + + _ => internal_err!( + "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})", + self.is_min, + self.inner.data_type + ), + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); + + // Convert the Vec of bytes to a vec of Strings (at no cost) + fn bytes_to_str( + min_maxes: Vec>>, + ) -> impl Iterator> { + min_maxes.into_iter().map(|opt| { + opt.map(|bytes| { + // Safety: only called on data added from update_batch which ensures + // the input type matched the output type + unsafe { String::from_utf8_unchecked(bytes) } + }) + }) + } + + let result: ArrayRef = match self.inner.data_type { + DataType::Utf8 => { + let mut builder = + StringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Utf8View => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = StringViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Binary => { + let mut builder = + BinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::BinaryView => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = BinaryViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + _ => { + return internal_err!( + "Unexpected data type for MinMaxBytesAccumulator: {:?}", + self.inner.data_type + ); + } + }; + + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(result) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// Returns the block size in (contiguous buffer size) to use +/// for a given data capacity (total string length) +/// +/// This is a heuristic to avoid allocating too many small buffers +fn capacity_to_view_block_size(data_capacity: usize) -> u32 { + let max_block_size = 2 * 1024 * 1024; + // Avoid block size equal to zero when calling `with_fixed_block_size()`. + if data_capacity == 0 { + return 1; + } + if let Ok(block_size) = u32::try_from(data_capacity) { + block_size.min(max_block_size) + } else { + max_block_size + } +} + +/// Stores internal Min/Max state for "bytes" types. +/// +/// This implementation is general and stores the minimum/maximum for each +/// groups in an individual byte array, which balances allocations and memory +/// fragmentation (aka garbage). +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// ┌─────┐ ┌────▶│Option> (["A"]) │───────────▶ "A" +/// │ 0 │────┘ └─────────────────────────────────┘ +/// ├─────┤ ┌─────────────────────────────────┐ +/// │ 1 │─────────▶│Option> (["Z"]) │───────────▶ "Z" +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │─────────▶│Option> (["A"]) │────────────▶ "A" +/// ├─────┤ └────────────────────────────────┘ +/// │ N-1 │────┐ ┌────────────────────────────────┐ +/// └─────┘ └────▶│Option> (["Q"]) │────────────▶ "Q" +/// └────────────────────────────────┘ +/// +/// min_max: Vec> +/// ``` +/// +/// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially +/// more efficient implementations (e.g. by managing a string data buffer +/// directly), but then garbage collection, memory management, and final array +/// construction becomes more complex. +/// +/// See discussion on +#[derive(Debug)] +struct MinMaxBytesState { + /// The minimum/maximum value for each group + min_max: Vec>>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone, Copy)] +enum MinMaxLocation<'a> { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(&'a [u8]), +} + +/// Implement the MinMaxBytesAccumulator with a comparison function +/// for comparing strings +impl MinMaxBytesState { + /// Create a new MinMaxBytesAccumulator + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &[u8]) { + match self.min_max[group_index].as_mut() { + None => { + self.min_max[group_index] = Some(new_val.to_vec()); + self.total_data_bytes += new_val.len(); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.len(); + self.total_data_bytes += new_val.len(); + existing_val.clear(); + existing_val.extend_from_slice(new_val); + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch<'a, F, I>( + &mut self, + iter: I, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&[u8], &[u8]) -> bool + Send + Sync, + I: IntoIterator>, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owned values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { + let group_index = *group_index; + let Some(new_val) = new_val else { + continue; // skip nulls + }; + + let existing_val = match locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(existing_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + existing_val.as_ref() + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>>() + } +} diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 399303e7e77b2..4fd46cb61c7ee 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -20,21 +20,23 @@ use std::any::Any; use std::collections::VecDeque; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; -use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - lit, Accumulator, AggregateUDFImpl, ExprFunctionExt, ReversedUDAF, Signature, - SortExpr, Volatility, + lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF, + Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -58,6 +60,32 @@ pub fn nth_value( } } +#[user_doc( + doc_section(label = "Statistical Functions"), + description = "Returns the nth value in a group of values.", + syntax_example = "nth_value(expression, n ORDER BY expression)", + sql_example = r#"```sql +> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept + FROM employee; ++---------+--------+-------------------------+ +| dept_id | salary | second_salary_by_dept | ++---------+--------+-------------------------+ +| 1 | 30000 | NULL | +| 1 | 40000 | 40000 | +| 1 | 50000 | 40000 | +| 2 | 35000 | NULL | +| 2 | 45000 | 45000 | ++---------+--------+-------------------------+ +```"#, + argument( + name = "expression", + description = "The column or expression to retrieve the nth value from." + ), + argument( + name = "n", + description = "The position (nth) of the value to retrieve, based on the ordering." + ) +)] /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. @@ -131,7 +159,7 @@ impl AggregateUDFImpl for NthValueAgg { n, &data_type, &ordering_dtypes, - acc_args.ordering_req.to_vec(), + acc_args.ordering_req.clone(), ) .map(|acc| Box::new(acc) as _) } @@ -140,14 +168,14 @@ impl AggregateUDFImpl for NthValueAgg { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_types[0].clone(), true), + Field::new_list_field(args.input_types[0].clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); if !orderings.is_empty() { fields.push(Field::new_list( format_state_name(self.name(), "nth_value_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), + Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), false, )); } @@ -161,6 +189,10 @@ impl AggregateUDFImpl for NthValueAgg { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Reversed(nth_value_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -343,25 +375,23 @@ impl Accumulator for NthValueAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec_deque(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } @@ -369,7 +399,7 @@ impl Accumulator for NthValueAccumulator { impl NthValueAccumulator { fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); let struct_field = Fields::from(fields.clone()); let mut column_wise_ordering_values = vec![]; @@ -391,9 +421,7 @@ impl NthValueAccumulator { let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( - Arc::new(ordering_array), - )))) + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) } fn evaluate_values(&self) -> ScalarValue { diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 390a769aca7f8..f302b72f94b5b 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,9 +17,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::fmt::Debug; - use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -27,12 +24,21 @@ use arrow::{ datatypes::DataType, datatypes::Field, }; -use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{ + downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result, + ScalarValue, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use std::any::Any; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::OnceLock; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -76,23 +82,7 @@ impl Regr { } } -/* -#[derive(Debug)] -pub struct Regr { - name: String, - regr_type: RegrType, - expr_y: Arc, - expr_x: Arc, -} - -impl Regr { - pub fn get_regr_type(&self) -> RegrType { - self.regr_type.clone() - } -} -*/ - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[allow(clippy::upper_case_acronyms)] pub enum RegrType { /// Variant for `regr_slope` aggregate expression @@ -135,6 +125,130 @@ pub enum RegrType { SXY, } +impl RegrType { + /// return the documentation for the `RegrType` + fn documentation(&self) -> Option<&Documentation> { + get_regr_docs().get(self) + } +} + +static DOCUMENTATION: OnceLock> = OnceLock::new(); +fn get_regr_docs() -> &'static HashMap { + DOCUMENTATION.get_or_init(|| { + let mut hash_map = HashMap::new(); + hash_map.insert( + RegrType::Slope, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \ + Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.", + + "regr_slope(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::Intercept, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \ + this function returns b.", + + "regr_intercept(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::Count, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Counts the number of non-null paired data points.", + + "regr_count(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::R2, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the square of the correlation coefficient between the independent and dependent variables.", + + "regr_r2(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::AvgX, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the average of the independent variable (input) expression_x for the non-null paired data points.", + + "regr_avgx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::AvgY, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.", + + "regr_avgy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::SXX, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the sum of squares of the independent variable.", + + "regr_sxx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::SYY, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the sum of squares of the dependent variable.", + + "regr_syy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + + hash_map.insert( + RegrType::SXY, + Documentation::builder( + DOC_SECTION_STATISTICAL, + "Computes the sum of products of paired data points.", + + "regr_sxy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + ); + hash_map + }) +} + impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self @@ -198,22 +312,11 @@ impl AggregateUDFImpl for Regr { ), ]) } -} -/* -impl PartialEq for Regr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.expr_y.eq(&x.expr_y) - && self.expr_x.eq(&x.expr_x) - }) - .unwrap_or(false) + fn documentation(&self) -> Option<&Documentation> { + self.regr_type.documentation() } } -*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. @@ -495,6 +598,6 @@ impl Accumulator for RegrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index a25ab5e319915..adf86a128cfb1 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::mem::align_of_val; use std::sync::Arc; use arrow::array::Float64Array; @@ -29,9 +30,11 @@ use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; +use datafusion_macros::user_doc; use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator}; @@ -43,6 +46,20 @@ make_udaf_expr_and_func!( stddev_udaf ); +#[user_doc( + doc_section(label = "Statistical Functions"), + description = "Returns the standard deviation of a set of numbers.", + syntax_example = "stddev(expression)", + sql_example = r#"```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +```"#, + standard_argument(name = "expression",) +)] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression pub struct Stddev { signature: Signature, @@ -68,10 +85,7 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::coercible( - vec![DataType::Float64], - Volatility::Immutable, - ), + signature: Signature::numeric(1, Volatility::Immutable), alias: vec!["stddev_samp".to_string()], } } @@ -132,6 +146,10 @@ impl AggregateUDFImpl for Stddev { ) -> Result> { Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udaf_expr_and_func!( @@ -142,6 +160,20 @@ make_udaf_expr_and_func!( stddev_pop_udaf ); +#[user_doc( + doc_section(label = "Statistical Functions"), + description = "Returns the population standard deviation of a set of numbers.", + syntax_example = "stddev_pop(expression)", + sql_example = r#"```sql +> SELECT stddev_pop(column_name) FROM table_name; ++--------------------------+ +| stddev_pop(column_name) | ++--------------------------+ +| 10.56 | ++--------------------------+ +```"#, + standard_argument(name = "expression",) +)] /// STDDEV_POP population aggregate expression pub struct StddevPop { signature: Signature, @@ -228,6 +260,10 @@ impl AggregateUDFImpl for StddevPop { StatsType::Population, ))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// An accumulator to compute the average @@ -285,8 +321,7 @@ impl Accumulator for StddevAccumulator { } fn size(&self) -> usize { - std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) - + self.variance.size() + align_of_val(self) - align_of_val(&self.variance) + self.variance.size() } fn supports_retract_batch(&self) -> bool { @@ -352,6 +387,7 @@ mod tests { use datafusion_expr::AggregateUDF; use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; #[test] @@ -403,7 +439,7 @@ mod tests { return_type: &DataType::Float64, schema, ignore_nulls: false, - ordering_req: &[], + ordering_req: &LexOrdering::default(), name: "a", is_distinct: false, is_reversed: false, @@ -414,7 +450,7 @@ mod tests { return_type: &DataType::Float64, schema, ignore_nulls: false, - ordering_req: &[], + ordering_req: &LexOrdering::default(), name: "a", is_distinct: false, is_reversed: false, diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 1b3654547ef6e..6a77232d4211c 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -24,10 +24,12 @@ use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, }; +use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; use std::any::Any; +use std::mem::size_of_val; make_udaf_expr_and_func!( StringAgg, @@ -37,6 +39,28 @@ make_udaf_expr_and_func!( string_agg_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Concatenates the values of string expressions and places separator values between them.", + syntax_example = "string_agg(expression, delimiter)", + sql_example = r#"```sql +> SELECT string_agg(name, ', ') AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Alice, Bob, Charlie | ++--------------------------+ +```"#, + argument( + name = "expression", + description = "The string expression to concatenate. Can be a column or any valid string expression." + ), + argument( + name = "delimiter", + description = "A literal string used as a separator between the concatenated values." + ) +)] /// STRING_AGG aggregate expression #[derive(Debug)] pub struct StringAgg { @@ -84,20 +108,26 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { - return match lit.scalar().value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { - Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) + return match lit.scalar().value().try_as_str() { + Some(Some(delimiter)) => { + Ok(Box::new(StringAggAccumulator::new(delimiter))) + } + Some(None) => Ok(Box::new(StringAggAccumulator::new(""))), + None => { + not_impl_err!( + "StringAgg not supported for delimiter {}", + lit.scalar().value() + ) } - ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), - e => not_impl_err!("StringAgg not supported for delimiter {}", e), }; } not_impl_err!("expect literal") } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[derive(Debug)] @@ -146,7 +176,7 @@ impl Accumulator for StringAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + self.delimiter.capacity() } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 7e40c1bd17a8d..6c2854f6bc248 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -21,6 +21,7 @@ use ahash::RandomState; use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::collections::HashSet; +use std::mem::{size_of, size_of_val}; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; @@ -37,10 +38,12 @@ use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_functions_aggregate_common::utils::Hashable; +use datafusion_macros::user_doc; make_udaf_expr_and_func!( Sum, @@ -75,6 +78,20 @@ macro_rules! downcast_sum { }; } +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the sum of all values in the specified column.", + syntax_example = "sum(expression)", + sql_example = r#"```sql +> SELECT sum(column_name) FROM table_name; ++-----------------------+ +| sum(column_name) | ++-----------------------+ +| 12345 | ++-----------------------+ +```"#, + standard_argument(name = "expression",) +)] #[derive(Debug)] pub struct Sum { signature: Signature, @@ -179,7 +196,7 @@ impl AggregateUDFImpl for Sum { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.return_type.clone(), true), + Field::new_list_field(args.return_type.clone(), true), false, )]) } else { @@ -233,6 +250,10 @@ impl AggregateUDFImpl for Sum { fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// This accumulator computes SUM incrementally @@ -279,7 +300,7 @@ impl Accumulator for SumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -339,7 +360,7 @@ impl Accumulator for SlidingSumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -433,7 +454,6 @@ impl Accumulator for DistinctSumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() + size_of_val(self) + self.values.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 49a30344c2123..8aa7a40ce3207 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -24,13 +24,12 @@ use arrow::{ compute::kernels::cast, datatypes::{DataType, Field}, }; -use std::sync::OnceLock; +use std::mem::{size_of, size_of_val}; use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, @@ -40,6 +39,7 @@ use datafusion_expr::{ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; +use datafusion_macros::user_doc; make_udaf_expr_and_func!( VarianceSample, @@ -57,6 +57,12 @@ make_udaf_expr_and_func!( var_pop_udaf ); +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the statistical sample variance of a set of numbers.", + syntax_example = "var(expression)", + standard_argument(name = "expression", prefix = "Numeric") +)] pub struct VarianceSample { signature: Signature, aliases: Vec, @@ -81,10 +87,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::coercible( - vec![DataType::Float64], - Volatility::Immutable, - ), + signature: Signature::numeric(1, Volatility::Immutable), } } } @@ -139,26 +142,16 @@ impl AggregateUDFImpl for VarianceSample { } fn documentation(&self) -> Option<&Documentation> { - Some(get_variance_sample_doc()) + self.doc() } } -static VARIANCE_SAMPLE_DOC: OnceLock = OnceLock::new(); - -fn get_variance_sample_doc() -> &'static Documentation { - VARIANCE_SAMPLE_DOC.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_GENERAL) - .with_description( - "Returns the statistical sample variance of a set of numbers.", - ) - .with_syntax_example("var(expression)") - .with_standard_argument("expression", "Numeric") - .build() - .unwrap() - }) -} - +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the statistical population variance of a set of numbers.", + syntax_example = "var_pop(expression)", + standard_argument(name = "expression", prefix = "Numeric") +)] pub struct VariancePopulation { signature: Signature, aliases: Vec, @@ -245,26 +238,10 @@ impl AggregateUDFImpl for VariancePopulation { ))) } fn documentation(&self) -> Option<&Documentation> { - Some(get_variance_population_doc()) + self.doc() } } -static VARIANCE_POPULATION_DOC: OnceLock = OnceLock::new(); - -fn get_variance_population_doc() -> &'static Documentation { - VARIANCE_POPULATION_DOC.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_GENERAL) - .with_description( - "Returns the statistical population variance of a set of numbers.", - ) - .with_syntax_example("var_pop(expression)") - .with_standard_argument("expression", "Numeric") - .build() - .unwrap() - }) -} - /// An accumulator to compute variance /// The algorithm used is an online implementation and numerically stable. It is based on this paper: /// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". @@ -315,6 +292,7 @@ fn merge( mean2: f64, m22: f64, ) -> (u64, f64, f64) { + debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states"); let new_count = count + count2; let new_mean = mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64; @@ -424,7 +402,7 @@ impl Accumulator for VarianceAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn supports_retract_batch(&self) -> bool { @@ -461,7 +439,7 @@ impl VarianceGroupsAccumulator { counts: &UInt64Array, means: &Float64Array, m2s: &Float64Array, - opt_filter: Option<&BooleanArray>, + _opt_filter: Option<&BooleanArray>, mut value_fn: F, ) where F: FnMut(usize, u64, f64, f64) + Send, @@ -470,33 +448,14 @@ impl VarianceGroupsAccumulator { assert_eq!(means.null_count(), 0); assert_eq!(m2s.null_count(), 0); - match opt_filter { - None => { - group_indices - .iter() - .zip(counts.values().iter()) - .zip(means.values().iter()) - .zip(m2s.values().iter()) - .for_each(|(((&group_index, &count), &mean), &m2)| { - value_fn(group_index, count, mean, m2); - }); - } - Some(filter) => { - group_indices - .iter() - .zip(counts.values().iter()) - .zip(means.values().iter()) - .zip(m2s.values().iter()) - .zip(filter.iter()) - .for_each( - |((((&group_index, &count), &mean), &m2), filter_value)| { - if let Some(true) = filter_value { - value_fn(group_index, count, mean, m2); - } - }, - ); - } - } + group_indices + .iter() + .zip(counts.values().iter()) + .zip(means.values().iter()) + .zip(m2s.values().iter()) + .for_each(|(((&group_index, &count), &mean), &m2)| { + value_fn(group_index, count, mean, m2); + }); } pub fn variance( @@ -529,7 +488,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow::array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -555,7 +514,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow::array::BooleanArray>, + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); @@ -570,8 +530,11 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { partial_counts, partial_means, partial_m2s, - opt_filter, + None, |group_index, partial_count, partial_mean, partial_m2| { + if partial_count == 0 { + return; + } let (new_count, new_mean, new_m2) = merge( self.counts[group_index], self.means[group_index], @@ -606,8 +569,37 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { } fn size(&self) -> usize { - self.m2s.capacity() * std::mem::size_of::() - + self.means.capacity() * std::mem::size_of::() - + self.counts.capacity() * std::mem::size_of::() + self.m2s.capacity() * size_of::() + + self.means.capacity() * size_of::() + + self.counts.capacity() * size_of::() + } +} + +#[cfg(test)] +mod tests { + use datafusion_expr::EmitTo; + + use super::*; + + #[test] + fn test_groups_accumulator_merge_empty_states() -> Result<()> { + let state_1 = vec![ + Arc::new(UInt64Array::from(vec![0])) as ArrayRef, + Arc::new(Float64Array::from(vec![0.0])), + Arc::new(Float64Array::from(vec![0.0])), + ]; + let state_2 = vec![ + Arc::new(UInt64Array::from(vec![2])) as ArrayRef, + Arc::new(Float64Array::from(vec![1.0])), + Arc::new(Float64Array::from(vec![1.0])), + ]; + let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample); + acc.merge_batch(&state_1, &[0], None, 1)?; + acc.merge_batch(&state_2, &[0], None, 1)?; + let result = acc.evaluate(EmitTo::All)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result.value(0), 1.0); + Ok(()) } } diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index bdfb07031b8c1..e7254e4125cb0 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -46,18 +46,20 @@ arrow-buffer = { workspace = true } arrow-ord = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } +datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-macros = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "1.0.14" -rand = "0.8.5" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } +rand = "0.8.5" [[bench]] harness = false diff --git a/datafusion/functions-nested/LICENSE.txt b/datafusion/functions-nested/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions-nested/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions-nested/NOTICE.txt b/datafusion/functions-nested/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions-nested/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index e978591000584..9926bd630654e 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -75,7 +75,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("map_1000", |b| { let mut rng = rand::thread_rng(); - let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let field = Arc::new(Field::new_list_field(DataType::Utf8, true)); let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); let key_list = ListArray::new( field, @@ -83,7 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(StringArray::from(keys(&mut rng))), None, ); - let field = Arc::new(Field::new("item", DataType::Int32, true)); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); let value_list = ListArray::new( field, @@ -96,8 +96,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( + // TODO use invoke_with_args map_udf() - .invoke(&[keys.clone(), values.clone()]) + .invoke_batch(&[keys.clone(), values.clone()], 1) .expect("map should work on valid values"), ); }); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index df1a336426d71..a83d3af410cfe 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -25,7 +25,10 @@ use arrow_buffer::BooleanBuffer; use datafusion_common::cast::as_generic_list_array; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; @@ -54,6 +57,27 @@ make_udf_expr_and_func!(ArrayHasAny, array_has_any_udf // internal function name ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns true if the array contains the element.", + syntax_example = "array_has(array, element)", + sql_example = r#"```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "element", + description = "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub struct ArrayHas { signature: Signature, @@ -95,11 +119,15 @@ impl ScalarUDFImpl for ArrayHas { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match &args[1] { ColumnarValue::Array(array_needle) => { // the needle is already an array, convert the haystack to an array of the same length - let haystack = args[0].to_owned().into_array(array_needle.len())?; + let haystack = args[0].to_array(array_needle.len())?; let array = array_has_inner_for_array(&haystack, array_needle)?; Ok(ColumnarValue::Array(array)) } @@ -111,7 +139,7 @@ impl ScalarUDFImpl for ArrayHas { } // since the needle is a scalar, convert it to an array of size 1 - let haystack = args[0].to_owned().into_array(1)?; + let haystack = args[0].to_array(1)?; let needle = scalar_needle.to_array_of_size(1)?; let needle = Scalar::new(needle); let array = array_has_inner_for_scalar(&haystack, &needle)?; @@ -129,6 +157,10 @@ impl ScalarUDFImpl for ArrayHas { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn array_has_inner_for_scalar( @@ -245,6 +277,27 @@ fn array_has_any_inner(args: &[ArrayRef]) -> Result { } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns true if all elements of sub-array exist in array.", + syntax_example = "array_has_all(array, sub-array)", + sql_example = r#"```sql +> select array_has_all([1, 2, 3, 4], [2, 3]); ++--------------------------------------------+ +| array_has_all(List([1,2,3,4]), List([2,3])) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "sub-array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub struct ArrayHasAll { signature: Signature, @@ -282,15 +335,44 @@ impl ScalarUDFImpl for ArrayHasAll { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_has_all_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns true if any elements exist in both arrays.", + syntax_example = "array_has_any(array, sub-array)", + sql_example = r#"```sql +> select array_has_any([1, 2, 3], [3, 4]); ++------------------------------------------+ +| array_has_any(List([1,2,3]), List([3,4])) | ++------------------------------------------+ +| true | ++------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "sub-array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub struct ArrayHasAny { signature: Signature, @@ -328,13 +410,21 @@ impl ScalarUDFImpl for ArrayHasAny { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_has_any_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Represents the type of comparison for array_has. diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index ea07ac381affd..a46c5348d1230 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -27,9 +27,10 @@ use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -56,11 +57,34 @@ impl Cardinality { } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the total number of elements in the array.", + syntax_example = "cardinality(array)", + sql_example = r#"```sql +> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); ++--------------------------------------+ +| cardinality(List([1,2,3,4,5,6,7,8])) | ++--------------------------------------+ +| 8 | ++--------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct Cardinality { +pub struct Cardinality { signature: Signature, aliases: Vec, } + +impl Default for Cardinality { + fn default() -> Self { + Self::new() + } +} impl ScalarUDFImpl for Cardinality { fn as_any(&self) -> &dyn Any { self @@ -82,13 +106,21 @@ impl ScalarUDFImpl for Cardinality { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(cardinality_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Cardinality SQL function diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index c52118d0a5e2b..934c5a5fec733 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,7 +17,8 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. -use std::{any::Any, cmp::Ordering, sync::Arc}; +use std::sync::Arc; +use std::{any::Any, cmp::Ordering}; use arrow::array::{Capacities, MutableArrayData}; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; @@ -28,9 +29,10 @@ use datafusion_common::{ cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, }; use datafusion_expr::{ - type_coercion::binary::get_wider_type, ColumnarValue, ScalarUDFImpl, Signature, - Volatility, + type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl, + Signature, Volatility, }; +use datafusion_macros::user_doc; use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; @@ -42,6 +44,24 @@ make_udf_expr_and_func!( array_append_udf // internal function name ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Appends an element to the end of an array.", + syntax_example = "array_append(array, element)", + sql_example = r#"```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "element", description = "Element to append to the array.") +)] #[derive(Debug)] pub struct ArrayAppend { signature: Signature, @@ -84,13 +104,21 @@ impl ScalarUDFImpl for ArrayAppend { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_append_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -101,6 +129,24 @@ make_udf_expr_and_func!( array_prepend_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Prepends an element to the beginning of an array.", + syntax_example = "array_prepend(element, array)", + sql_example = r#"```sql +> select array_prepend(1, [2, 3, 4]); ++---------------------------------------+ +| array_prepend(Int64(1),List([2,3,4])) | ++---------------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "element", description = "Element to prepend to the array.") +)] #[derive(Debug)] pub struct ArrayPrepend { signature: Signature, @@ -143,13 +189,21 @@ impl ScalarUDFImpl for ArrayPrepend { Ok(arg_types[1].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_prepend_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -159,6 +213,27 @@ make_udf_expr_and_func!( array_concat_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Concatenates arrays.", + syntax_example = "array_concat(array[, ..., array_n])", + sql_example = r#"```sql +> select array_concat([1, 2], [3, 4], [5, 6]); ++---------------------------------------------------+ +| array_concat(List([1,2]),List([3,4]),List([5,6])) | ++---------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++---------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array_n", + description = "Subsequent array column or literal array to concatenate." + ) +)] #[derive(Debug)] pub struct ArrayConcat { signature: Signature, @@ -226,13 +301,21 @@ impl ScalarUDFImpl for ArrayConcat { Ok(expr_type) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_concat_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_concat/Array_cat SQL function @@ -312,7 +395,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .collect::>(); let list_arr = GenericListArray::::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(array_lengths), Arc::new(arrow::compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), @@ -321,7 +404,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { Ok(Arc::new(list_arr)) } -/// Kernal functions +// Kernel functions /// Array_append SQL function pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { @@ -441,7 +524,7 @@ where let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), + Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), None, diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index d84fa0c19ee9a..702d0fc3a77dd 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -29,7 +29,10 @@ use datafusion_common::{exec_err, plan_err, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use arrow_schema::Field; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::sync::Arc; make_udf_expr_and_func!( @@ -40,12 +43,35 @@ make_udf_expr_and_func!( array_dims_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array of the array's dimensions.", + syntax_example = "array_dims(array)", + sql_example = r#"```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct ArrayDims { +pub struct ArrayDims { signature: Signature, aliases: Vec, } +impl Default for ArrayDims { + fn default() -> Self { + Self::new() + } +} + impl ArrayDims { pub fn new() -> Self { Self { @@ -70,7 +96,7 @@ impl ScalarUDFImpl for ArrayDims { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new("item", UInt64, true))) + List(Arc::new(Field::new_list_field(UInt64, true))) } _ => { return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); @@ -78,13 +104,21 @@ impl ScalarUDFImpl for ArrayDims { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_dims_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -95,6 +129,24 @@ make_udf_expr_and_func!( array_ndims_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the number of dimensions of the array.", + syntax_example = "array_ndims(array, element)", + sql_example = r#"```sql +> select array_ndims([[1, 2, 3], [4, 5, 6]]); ++----------------------------------+ +| array_ndims(List([1,2,3,4,5,6])) | ++----------------------------------+ +| 2 | ++----------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "element", description = "Array element.") +)] #[derive(Debug)] pub(super) struct ArrayNdims { signature: Signature, @@ -130,13 +182,21 @@ impl ScalarUDFImpl for ArrayNdims { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_ndims_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_dims SQL function diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index fa9394c73bcb0..8559b1096bc5d 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -17,21 +17,23 @@ //! [ScalarUDFImpl] definitions for array_distance function. -use crate::utils::{downcast_arg, make_scalar_function}; +use crate::utils::make_scalar_function; use arrow_array::{ Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, }; use arrow_schema::DataType; use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List}; -use core::any::type_name; use datafusion_common::cast::{ as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, as_int64_array, }; use datafusion_common::utils::coerced_fixed_size_list_to_list; -use datafusion_common::DataFusionError; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, internal_datafusion_err, Result}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::{downcast_arg, downcast_named_arg}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -43,12 +45,39 @@ make_udf_expr_and_func!( array_distance_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the Euclidean distance between two input arrays of equal length.", + syntax_example = "array_distance(array1, array2)", + sql_example = r#"```sql +> select array_distance([1, 2], [1, 4]); ++------------------------------------+ +| array_distance(List([1,2], [1,4])) | ++------------------------------------+ +| 2.0 | ++------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct ArrayDistance { +pub struct ArrayDistance { signature: Signature, aliases: Vec, } +impl Default for ArrayDistance { + fn default() -> Self { + Self::new() + } +} + impl ArrayDistance { pub fn new() -> Self { Self { @@ -93,13 +122,21 @@ impl ScalarUDFImpl for ArrayDistance { Ok(result) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_distance_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } pub fn array_distance_inner(args: &[ArrayRef]) -> Result { @@ -207,7 +244,7 @@ fn compute_array_distance( /// Converts an array of any numeric type to a Float64Array. fn convert_to_f64_array(array: &ArrayRef) -> Result { match array.data_type() { - DataType::Float64 => Ok(as_float64_array(array)?.clone()), + Float64 => Ok(as_float64_array(array)?.clone()), DataType::Float32 => { let array = as_float32_array(array)?; let converted: Float64Array = diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index 36c82e92081d2..9739ffb15f6b7 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -23,7 +23,10 @@ use arrow_schema::DataType; use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; use datafusion_common::cast::as_generic_list_array; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -35,11 +38,34 @@ make_udf_expr_and_func!( array_empty_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns 1 for an empty array or 0 for a non-empty array.", + syntax_example = "empty(array)", + sql_example = r#"```sql +> select empty([1]); ++------------------+ +| empty(List([1])) | ++------------------+ +| 0 | ++------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct ArrayEmpty { +pub struct ArrayEmpty { signature: Signature, aliases: Vec, } + +impl Default for ArrayEmpty { + fn default() -> Self { + Self::new() + } +} impl ArrayEmpty { pub fn new() -> Self { Self { @@ -70,13 +96,21 @@ impl ScalarUDFImpl for ArrayEmpty { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_empty_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_empty SQL function diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index 50ef20a7d4162..356c92983ae2e 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -23,10 +23,12 @@ use arrow_array::cast::AsArray; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, FieldRef}; -use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, internal_err, HashSet, Result}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; -use std::collections::HashSet; use std::sync::Arc; make_udf_expr_and_func!( @@ -37,12 +39,45 @@ make_udf_expr_and_func!( array_except_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array of the elements that appear in the first array but not in the second.", + syntax_example = "array_except(array1, array2)", + sql_example = r#"```sql +> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct ArrayExcept { +pub struct ArrayExcept { signature: Signature, aliases: Vec, } +impl Default for ArrayExcept { + fn default() -> Self { + Self::new() + } +} + impl ArrayExcept { pub fn new() -> Self { Self { @@ -71,13 +106,21 @@ impl ScalarUDFImpl for ArrayExcept { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_except_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_except SQL function diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 7dfc736b76d3e..3406b83e03667 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -36,7 +36,10 @@ use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, }; use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -77,12 +80,39 @@ make_udf_expr_and_func!( array_any_value_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Extracts the element with the index n from the array.", + syntax_example = "array_element(array, index)", + sql_example = r#"```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "index", + description = "Index to extract the element from the array." + ) +)] #[derive(Debug)] -pub(super) struct ArrayElement { +pub struct ArrayElement { signature: Signature, aliases: Vec, } +impl Default for ArrayElement { + fn default() -> Self { + Self::new() + } +} + impl ArrayElement { pub fn new() -> Self { Self { @@ -140,13 +170,21 @@ impl ScalarUDFImpl for ArrayElement { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_element_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_element SQL function @@ -254,6 +292,35 @@ pub fn array_slice(array: Expr, begin: Expr, end: Expr, stride: Option) -> array_slice_udf().call(args) } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns a slice of the array based on 1-indexed start and end positions.", + syntax_example = "array_slice(array, begin, end)", + sql_example = r#"```sql +> select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); ++--------------------------------------------------------+ +| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | ++--------------------------------------------------------+ +| [3, 4, 5, 6] | ++--------------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "begin", + description = "Index of the first element. If negative, it counts backward from the end of the array." + ), + argument( + name = "end", + description = "Index of the last element. If negative, it counts backward from the end of the array." + ), + argument( + name = "stride", + description = "Stride of the array slice. The default is 1." + ) +)] #[derive(Debug)] pub(super) struct ArraySlice { signature: Signature, @@ -307,13 +374,21 @@ impl ScalarUDFImpl for ArraySlice { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_slice_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_slice SQL function @@ -535,13 +610,30 @@ where let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", array.value_type(), true)), + Arc::new(Field::new_list_field(array.value_type(), true)), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), None, )?)) } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the array without the first element.", + syntax_example = "array_pop_front(array)", + sql_example = r#"```sql +> select array_pop_front([1, 2, 3]); ++-------------------------------+ +| array_pop_front(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub(super) struct ArrayPopFront { signature: Signature, @@ -573,13 +665,21 @@ impl ScalarUDFImpl for ArrayPopFront { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_pop_front_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_pop_front SQL function @@ -617,6 +717,23 @@ where general_array_slice::(array, &from_array, &to_array, None) } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the array without the last element.", + syntax_example = "array_pop_back(array)", + sql_example = r#"```sql +> select array_pop_back([1, 2, 3]); ++-------------------------------+ +| array_pop_back(List([1,2,3])) | ++-------------------------------+ +| [1, 2] | ++-------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub(super) struct ArrayPopBack { signature: Signature, @@ -648,13 +765,21 @@ impl ScalarUDFImpl for ArrayPopBack { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_pop_back_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_pop_back SQL function @@ -696,6 +821,23 @@ where general_array_slice::(array, &from_array, &to_array, None) } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the first non-null element in the array.", + syntax_example = "array_any_value(array)", + sql_example = r#"```sql +> select array_any_value([NULL, 1, 2, 3]); ++-------------------------------+ +| array_any_value(List([NULL,1,2,3])) | ++-------------------------------------+ +| 1 | ++-------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub(super) struct ArrayAnyValue { signature: Signature, @@ -732,12 +874,20 @@ impl ScalarUDFImpl for ArrayAnyValue { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_any_value_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn array_any_value_inner(args: &[ArrayRef]) -> Result { @@ -807,3 +957,86 @@ where let data = mutable.freeze(); Ok(arrow::array::make_array(data)) } + +#[cfg(test)] +mod tests { + use super::array_element_udf; + use arrow_schema::{DataType, Field}; + use datafusion_common::{Column, DFSchema, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{cast, Expr, ExprSchemable}; + use std::collections::HashMap; + + // Regression test for https://github.com/apache/datafusion/issues/13755 + #[test] + fn test_array_element_return_type_fixed_size_list() { + let fixed_size_list_type = DataType::FixedSizeList( + Field::new("some_arbitrary_test_field", DataType::Int32, false).into(), + 13, + ); + let array_type = DataType::List( + Field::new_list_field(fixed_size_list_type.clone(), true).into(), + ); + let index_type = DataType::Int64; + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("my_array", array_type.clone(), false), + Field::new("my_index", index_type.clone(), false), + ] + .into(), + HashMap::default(), + ) + .unwrap(); + + let udf = array_element_udf(); + + // ScalarUDFImpl::return_type + assert_eq!( + udf.return_type(&[array_type.clone(), index_type.clone()]) + .unwrap(), + fixed_size_list_type + ); + + // ScalarUDFImpl::return_type_from_exprs with typed exprs + assert_eq!( + udf.return_type_from_exprs( + &[ + cast(Expr::from(ScalarValue::Null), array_type.clone()), + cast(Expr::from(ScalarValue::Null), index_type.clone()), + ], + &schema, + &[array_type.clone(), index_type.clone()] + ) + .unwrap(), + fixed_size_list_type + ); + + // ScalarUDFImpl::return_type_from_exprs with exprs not carrying type + assert_eq!( + udf.return_type_from_exprs( + &[ + Expr::Column(Column::new_unqualified("my_array")), + Expr::Column(Column::new_unqualified("my_index")), + ], + &schema, + &[array_type.clone(), index_type.clone()] + ) + .unwrap(), + fixed_size_list_type + ); + + // Via ExprSchemable::get_type (e.g. SimplifyInfo) + let udf_expr = Expr::ScalarFunction(ScalarFunction { + func: array_element_udf(), + args: vec![ + Expr::Column(Column::new_unqualified("my_array")), + Expr::Column(Column::new_unqualified("my_index")), + ], + }); + assert_eq!( + ExprSchemable::get_type(&udf_expr, &schema).unwrap(), + fixed_size_list_type + ); + } +} diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index b04c35667226c..30bf2fcbf6244 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -26,7 +26,11 @@ use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -38,15 +42,45 @@ make_udf_expr_and_func!( flatten_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Converts an array of arrays to a flat array.\n\n- Applies to any depth of nested arrays\n- Does not change arrays that are already flat\n\nThe flattened array contains all the elements from all source arrays.", + syntax_example = "flatten(array)", + sql_example = r#"```sql +> select flatten([[1, 2], [3, 4]]); ++------------------------------+ +| flatten(List([1,2], [3,4])) | ++------------------------------+ +| [1, 2, 3, 4] | ++------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct Flatten { +pub struct Flatten { signature: Signature, aliases: Vec, } + +impl Default for Flatten { + fn default() -> Self { + Self::new() + } +} + impl Flatten { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::RecursiveArray, + ), + volatility: Volatility::Immutable, + }, aliases: vec![], } } @@ -88,13 +122,21 @@ impl ScalarUDFImpl for Flatten { Ok(data_type) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(flatten_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Flatten SQL function diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index 5d9ccd2901cfa..70a9188a2c3d0 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -17,17 +17,19 @@ //! [`ScalarUDFImpl`] definitions for array_length function. -use crate::utils::{downcast_arg, make_scalar_function}; +use crate::utils::make_scalar_function; use arrow_array::{ Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, }; use arrow_schema::DataType; use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; -use core::any::type_name; use datafusion_common::cast::{as_generic_list_array, as_int64_array}; -use datafusion_common::DataFusionError; -use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::{downcast_arg, downcast_named_arg}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -39,11 +41,36 @@ make_udf_expr_and_func!( array_length_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the length of the array dimension.", + syntax_example = "array_length(array, dimension)", + sql_example = r#"```sql +> select array_length([1, 2, 3, 4, 5], 1); ++-------------------------------------------+ +| array_length(List([1,2,3,4,5]), 1) | ++-------------------------------------------+ +| 5 | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "dimension", description = "Array dimension.") +)] #[derive(Debug)] -pub(super) struct ArrayLength { +pub struct ArrayLength { signature: Signature, aliases: Vec, } + +impl Default for ArrayLength { + fn default() -> Self { + Self::new() + } +} + impl ArrayLength { pub fn new() -> Self { Self { @@ -74,13 +101,21 @@ impl ScalarUDFImpl for ArrayLength { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_length_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_length SQL function diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 301ddb36fc560..c47e4a696a1d8 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/functions-nested/src/macros.rs b/datafusion/functions-nested/src/macros.rs index 00247f39ac10f..cec7f2fd562d6 100644 --- a/datafusion/functions-nested/src/macros.rs +++ b/datafusion/functions-nested/src/macros.rs @@ -85,22 +85,17 @@ macro_rules! make_udf_expr_and_func { macro_rules! create_func { ($UDF:ty, $SCALAR_UDF_FN:ident) => { paste::paste! { - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - #[doc = concat!("ScalarFunction that returns a [`ScalarUDF`](datafusion_expr::ScalarUDF) for ")] #[doc = stringify!($UDF)] pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { + // Singleton instance of [`$UDF`], ensures the UDF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( <$UDF>::new(), )) - }) - .clone() + }); + std::sync::Arc::clone(&INSTANCE) } } }; diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 51fc71e6b09dd..0283cdd402757 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -17,22 +17,28 @@ //! [`ScalarUDFImpl`] definitions for `make_array` function. +use std::any::Any; +use std::sync::Arc; use std::vec; -use std::{any::Any, sync::Arc}; +use crate::utils::make_scalar_function; use arrow::array::{ArrayData, Capacities, MutableArrayData}; use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, NullArray, OffsetSizeTrait, }; use arrow_buffer::OffsetBuffer; -use arrow_schema::DataType::{LargeList, List, Null}; +use arrow_schema::DataType::{List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; -use datafusion_expr::binary::type_union_resolution; +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::binary::{ + try_type_union_resolution_with_struct, type_union_resolution, +}; use datafusion_expr::TypeSignature; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::make_scalar_function; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; make_udf_expr_and_func!( MakeArray, @@ -41,6 +47,23 @@ make_udf_expr_and_func!( make_array_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array using the specified input expressions.", + syntax_example = "make_array(expression1[, ..., expression_n])", + sql_example = r#"```sql +> select make_array(1, 2, 3, 4, 5); ++----------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | ++----------------------------------------------------------+ +| [1, 2, 3, 4, 5] | ++----------------------------------------------------------+ +```"#, + argument( + name = "expression_n", + description = "Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators." + ) +)] #[derive(Debug)] pub struct MakeArray { signature: Signature, @@ -57,7 +80,7 @@ impl MakeArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Any(0)], + vec![TypeSignature::Nullary, TypeSignature::UserDefined], Volatility::Immutable, ), aliases: vec![String::from("make_list")], @@ -83,8 +106,7 @@ impl ScalarUDFImpl for MakeArray { 0 => Ok(empty_array_type()), _ => { // At this point, all the type in array should be coerced to the same one - Ok(List(Arc::new(Field::new( - "item", + Ok(List(Arc::new(Field::new_list_field( arg_types[0].to_owned(), true, )))) @@ -92,22 +114,31 @@ impl ScalarUDFImpl for MakeArray { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(make_array_inner)(args) } - fn invoke_no_args(&self, _number_rows: usize) -> Result { - make_scalar_function(make_array_inner)(&[]) - } - fn aliases(&self) -> &[String] { &self.aliases } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let mut errors = vec![]; + match try_type_union_resolution_with_struct(arg_types) { + Ok(r) => return Ok(r), + Err(e) => { + errors.push(e); + } + } + if let Some(new_type) = type_union_resolution(arg_types) { + // TODO: Move FixedSizeList to List in type_union_resolution if let DataType::FixedSizeList(field, _) = new_type { - Ok(vec![DataType::List(field); arg_types.len()]) + Ok(vec![List(field); arg_types.len()]) } else if new_type.is_null() { Ok(vec![DataType::Int64; arg_types.len()]) } else { @@ -115,17 +146,22 @@ impl ScalarUDFImpl for MakeArray { } } else { plan_err!( - "Fail to find the valid type between {:?} for {}", + "Fail to find the valid type between {:?} for {}, errors are {:?}", arg_types, - self.name() + self.name(), + errors ) } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } // Empty array is a special case that is useful for many other array functions pub(super) fn empty_array_type() -> DataType { - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))) + List(Arc::new(Field::new_list_field(DataType::Int64, true))) } /// `make_array_inner` is the implementation of the `make_array` function. @@ -147,9 +183,10 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let length = arrays.iter().map(|a| a.len()).sum(); // By default Int64 let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new(array_into_list_array_nullable(array))) + Ok(Arc::new( + SingleRowListArrayBuilder::new(array).build_list_array(), + )) } - LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } @@ -239,7 +276,7 @@ fn array_array( let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), None, diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 1a5eba2d70645..a8c08d30bd095 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::collections::{HashSet, VecDeque}; +use std::collections::VecDeque; use std::sync::Arc; use arrow::array::ArrayData; @@ -25,9 +25,12 @@ use arrow_buffer::{Buffer, ToByteSlice}; use arrow_schema::{DataType, Field, SchemaBuilder}; use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, HashSet, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use crate::make_array::make_array; @@ -178,6 +181,50 @@ fn make_map_batch_internal( }) } +#[user_doc( + doc_section(label = "Map Functions"), + description = "Returns an Arrow map with the specified key-value pairs.\n\n\ + The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null.", + syntax_example = "map(key, value)\nmap(key: value)\nmake_map(['key1', 'key2'], ['value1', 'value2'])", + sql_example = r#" +```sql +-- Using map function +SELECT MAP('type', 'test'); +---- +{type: test} + +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); +---- +{POST: 41, HEAD: 33, PATCH: NULL} + +SELECT MAP([[1,2], [3,4]], ['a', 'b']); +---- +{[1, 2]: a, [3, 4]: b} + +SELECT MAP { 'a': 1, 'b': 2 }; +---- +{a: 1, b: 2} + +-- Using make_map function +SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); +---- +{POST: 41, HEAD: 33} + +SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); +---- +{key1: value1, key2: } +```"#, + argument( + name = "key", + description = "For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null." + ), + argument( + name = "value", + description = "For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of values to be mapped to the corresponding keys." + ) +)] #[derive(Debug)] pub struct MapFunc { signature: Signature, @@ -211,9 +258,9 @@ impl ScalarUDFImpl for MapFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() % 2 != 0 { + if arg_types.len() != 2 { return exec_err!( - "map requires an even number of arguments, got {} instead", + "map requires exactly 2 arguments, got {} instead", arg_types.len() ); } @@ -235,10 +282,19 @@ impl ScalarUDFImpl for MapFunc { )) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_map_batch(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } + fn get_element_type(data_type: &DataType) -> Result<&DataType> { match data_type { DataType::List(element) => Ok(element.data_type()), @@ -306,7 +362,6 @@ fn get_element_type(data_type: &DataType) -> Result<&DataType> { /// | +-------+ | | +-------+ | /// +-----------+ +-----------+ /// ```text - fn make_map_array_internal( keys: ArrayRef, values: ArrayRef, diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 9f0c4ad29c60e..8ccfae0ff93e3 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -26,7 +26,10 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::{cast::as_map_array, exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; use std::vec; @@ -42,12 +45,44 @@ make_udf_expr_and_func!( map_extract_udf ); +#[user_doc( + doc_section(label = "Map Functions"), + description = "Returns a list containing the value for the given key or an empty list if the key is not present in the map.", + syntax_example = "map_extract(map, key)", + sql_example = r#"```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```"#, + argument( + name = "map", + description = "Map expression. Can be a constant, column, or function, and any combination of map operators." + ), + argument( + name = "key", + description = "Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed." + ) +)] #[derive(Debug)] -pub(super) struct MapExtract { +pub struct MapExtract { signature: Signature, aliases: Vec, } +impl Default for MapExtract { + fn default() -> Self { + Self::new() + } +} + impl MapExtract { pub fn new() -> Self { Self { @@ -75,14 +110,17 @@ impl ScalarUDFImpl for MapExtract { } let map_type = &arg_types[0]; let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new( - "item", + Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), true, )))) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(map_extract_inner)(args) } @@ -101,6 +139,10 @@ impl ScalarUDFImpl for MapExtract { field.first().unwrap().data_type().clone(), ]) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn general_map_extract_inner( @@ -141,7 +183,7 @@ fn general_map_extract_inner( let data = mutable.freeze(); Ok(Arc::new(ListArray::new( - Arc::new(Field::new("item", map_array.value_type().clone(), true)), + Arc::new(Field::new_list_field(map_array.value_type().clone(), true)), OffsetBuffer::::new(offsets.into()), Arc::new(make_array(data)), None, diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 0b1cebb27c866..74a106bb0c169 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -22,9 +22,10 @@ use arrow_array::{Array, ArrayRef, ListArray}; use arrow_schema::{DataType, Field}; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -36,11 +37,35 @@ make_udf_expr_and_func!( map_keys_udf ); +#[user_doc( + doc_section(label = "Map Functions"), + description = "Returns a list of all keys in the map.", + syntax_example = "map_keys(map)", + sql_example = r#"```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +```"#, + argument( + name = "map", + description = "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) +)] #[derive(Debug)] -pub(crate) struct MapKeysFunc { +pub struct MapKeysFunc { signature: Signature, } +impl Default for MapKeysFunc { + fn default() -> Self { + Self::new() + } +} + impl MapKeysFunc { pub fn new() -> Self { Self { @@ -65,22 +90,29 @@ impl ScalarUDFImpl for MapKeysFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if arg_types.len() != 1 { return exec_err!("map_keys expects single argument"); } let map_type = &arg_types[0]; let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new( - "item", + Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.first().unwrap().data_type().clone(), false, )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(map_keys_inner)(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn map_keys_inner(args: &[ArrayRef]) -> Result { @@ -94,7 +126,7 @@ fn map_keys_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new("item", map_array.key_type().clone(), false)), + Arc::new(Field::new_list_field(map_array.key_type().clone(), false)), map_array.offsets().clone(), Arc::clone(map_array.keys()), None, diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 58c0d74eed5ff..d2fc0a12f9cf2 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -22,9 +22,10 @@ use arrow_array::{Array, ArrayRef, ListArray}; use arrow_schema::{DataType, Field}; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -36,11 +37,35 @@ make_udf_expr_and_func!( map_values_udf ); +#[user_doc( + doc_section(label = "Map Functions"), + description = "Returns a list of all values in the map.", + syntax_example = "map_values(map)", + sql_example = r#"```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +```"#, + argument( + name = "map", + description = "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) +)] #[derive(Debug)] pub(crate) struct MapValuesFunc { signature: Signature, } +impl Default for MapValuesFunc { + fn default() -> Self { + Self::new() + } +} + impl MapValuesFunc { pub fn new() -> Self { Self { @@ -65,22 +90,29 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if arg_types.len() != 1 { return exec_err!("map_values expects single argument"); } let map_type = &arg_types[0]; let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new( - "item", + Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), true, )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(map_values_inner)(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn map_values_inner(args: &[ArrayRef]) -> Result { @@ -94,7 +126,7 @@ fn map_values_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new("item", map_array.value_type().clone(), true)), + Arc::new(Field::new_list_field(map_array.value_type().clone(), true)), map_array.offsets().clone(), Arc::clone(map_array.values()), None, diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 9ae2fa781d87e..5ca51ac20f1e5 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -133,7 +133,6 @@ impl ExprPlanner for NestedFunctionPlanner { #[derive(Debug)] pub struct FieldAccessPlanner; - impl ExprPlanner for FieldAccessPlanner { fn plan_field_access( &self, @@ -186,5 +185,5 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - return agg_func.func.name() == "array_agg"; + agg_func.func.name() == "array_agg" } diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index a48332ceb0b30..eec2a32fa2a2f 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -19,7 +19,11 @@ use arrow_schema::DataType::{LargeList, List, UInt64}; use arrow_schema::{DataType, Field}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + use std::any::Any; use std::sync::Arc; @@ -43,11 +47,45 @@ make_udf_expr_and_func!( array_position_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the position of the first occurrence of the specified element in the array.", + syntax_example = "array_position(array, element)\narray_position(array, element, index)", + sql_example = r#"```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "element", + description = "Element to search for position in the array." + ), + argument(name = "index", description = "Index at which to start searching.") +)] #[derive(Debug)] -pub(super) struct ArrayPosition { +pub struct ArrayPosition { signature: Signature, aliases: Vec, } + +impl Default for ArrayPosition { + fn default() -> Self { + Self::new() + } +} impl ArrayPosition { pub fn new() -> Self { Self { @@ -79,13 +117,21 @@ impl ScalarUDFImpl for ArrayPosition { Ok(UInt64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_position_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_position SQL function @@ -172,6 +218,28 @@ make_udf_expr_and_func!( "searches for an element in the array, returns all occurrences.", // doc array_positions_udf // internal function name ); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Searches for an element in the array, returns all occurrences.", + syntax_example = "array_positions(array, element)", + sql_example = r#"```sql +> select array_positions([1, 2, 2, 3, 1, 4], 2); ++-----------------------------------------------+ +| array_positions(List([1,2,2,3,1,4]),Int64(2)) | ++-----------------------------------------------+ +| [2, 3] | ++-----------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "element", + description = "Element to search for position in the array." + ) +)] #[derive(Debug)] pub(super) struct ArrayPositions { signature: Signature, @@ -200,16 +268,24 @@ impl ScalarUDFImpl for ArrayPositions { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new("item", UInt64, true)))) + Ok(List(Arc::new(Field::new_list_field(UInt64, true)))) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_positions_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_positions SQL function diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index b3d8010cb6683..4f8132f59e85f 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -37,7 +37,10 @@ use datafusion_common::cast::{ use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, Result, }; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use itertools::Itertools; use std::any::Any; use std::cmp::Ordering; @@ -52,11 +55,50 @@ make_udf_expr_and_func!( "create a list of values in the range between start and stop", range_udf ); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0.", + syntax_example = "range(start, stop, step)", + sql_example = r#"```sql +> select range(2, 10, 3); ++-----------------------------------+ +| range(Int64(2),Int64(10),Int64(3))| ++-----------------------------------+ +| [2, 5, 8] | ++-----------------------------------+ + +> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); ++--------------------------------------------------------------+ +| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | ++--------------------------------------------------------------+ +| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | ++--------------------------------------------------------------+ +```"#, + argument( + name = "start", + description = "Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported." + ), + argument( + name = "end", + description = "End of the range (not included). Type must be the same as start." + ), + argument( + name = "step", + description = "Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges." + ) +)] #[derive(Debug)] -pub(super) struct Range { +pub struct Range { signature: Signature, aliases: Vec, } + +impl Default for Range { + fn default() -> Self { + Self::new() + } +} impl Range { pub fn new() -> Self { Self { @@ -106,15 +148,18 @@ impl ScalarUDFImpl for Range { if arg_types.iter().any(|t| t.is_null()) { Ok(Null) } else { - Ok(List(Arc::new(Field::new( - "item", + Ok(List(Arc::new(Field::new_list_field( arg_types[0].clone(), true, )))) } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.iter().any(|arg| arg.data_type().is_null()) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } @@ -133,6 +178,10 @@ impl ScalarUDFImpl for Range { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -142,6 +191,32 @@ make_udf_expr_and_func!( "create a list of values in the range between start and stop, include upper bound", gen_series_udf ); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Similar to the range function, but it includes the upper bound.", + syntax_example = "generate_series(start, stop, step)", + sql_example = r#"```sql +> select generate_series(1,3); ++------------------------------------+ +| generate_series(Int64(1),Int64(3)) | ++------------------------------------+ +| [1, 2, 3] | ++------------------------------------+ +```"#, + argument( + name = "start", + description = "Start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported." + ), + argument( + name = "end", + description = "End of the series (included). Type must be the same as start." + ), + argument( + name = "step", + description = "Increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges." + ) +)] #[derive(Debug)] pub(super) struct GenSeries { signature: Signature, @@ -196,15 +271,18 @@ impl ScalarUDFImpl for GenSeries { if arg_types.iter().any(|t| t.is_null()) { Ok(Null) } else { - Ok(List(Arc::new(Field::new( - "item", + Ok(List(Arc::new(Field::new_list_field( arg_types[0].clone(), true, )))) } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.iter().any(|arg| arg.data_type().is_null()) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } @@ -226,6 +304,10 @@ impl ScalarUDFImpl for GenSeries { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Generates an array of integers from start to stop with a given step. @@ -297,7 +379,7 @@ pub(super) fn gen_range_inner( }; } let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", Int64, true)), + Arc::new(Field::new_list_field(Int64, true)), OffsetBuffer::new(offsets.into()), Arc::new(Int64Array::from(values)), Some(NullBuffer::new(valid.finish())), @@ -366,10 +448,18 @@ fn gen_range_date(args: &[ArrayRef], include_upper_bound: bool) -> Result Result select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "element", + description = "Element to be removed from the array." + ) +)] #[derive(Debug)] -pub(super) struct ArrayRemove { +pub struct ArrayRemove { signature: Signature, aliases: Vec, } +impl Default for ArrayRemove { + fn default() -> Self { + Self::new() + } +} + impl ArrayRemove { pub fn new() -> Self { Self { @@ -71,13 +101,21 @@ impl ScalarUDFImpl for ArrayRemove { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_remove_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -88,6 +126,28 @@ make_udf_expr_and_func!( array_remove_n_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Removes the first `max` elements from the array equal to the given value.", + syntax_example = "array_remove_n(array, element, max))", + sql_example = r#"```sql +> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); ++---------------------------------------------------------+ +| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) | ++---------------------------------------------------------+ +| [1, 3, 2, 1, 4] | ++---------------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "element", + description = "Element to be removed from the array." + ), + argument(name = "max", description = "Number of first occurrences to remove.") +)] #[derive(Debug)] pub(super) struct ArrayRemoveN { signature: Signature, @@ -120,13 +180,21 @@ impl ScalarUDFImpl for ArrayRemoveN { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_remove_n_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -137,6 +205,27 @@ make_udf_expr_and_func!( array_remove_all_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Removes all elements from the array equal to the given value.", + syntax_example = "array_remove_all(array, element)", + sql_example = r#"```sql +> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); ++--------------------------------------------------+ +| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) | ++--------------------------------------------------+ +| [1, 3, 1, 4] | ++--------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "element", + description = "Element to be removed from the array." + ) +)] #[derive(Debug)] pub(super) struct ArrayRemoveAll { signature: Signature, @@ -169,13 +258,21 @@ impl ScalarUDFImpl for ArrayRemoveAll { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_remove_all_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_remove SQL function @@ -313,7 +410,7 @@ fn general_remove( }; Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::new(offsets.into()), values, list_array.nulls().cloned(), diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index 7ed913da3f2a0..da0aa5f12fde2 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -29,7 +29,10 @@ use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -40,12 +43,46 @@ make_udf_expr_and_func!( "returns an array containing element `count` times.", // doc array_repeat_udf // internal function name ); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array containing element `count` times.", + syntax_example = "array_repeat(element, count)", + sql_example = r#"```sql +> select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +> select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +```"#, + argument( + name = "element", + description = "Element expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "count", + description = "Value of how many times to repeat the element." + ) +)] #[derive(Debug)] -pub(super) struct ArrayRepeat { +pub struct ArrayRepeat { signature: Signature, aliases: Vec, } +impl Default for ArrayRepeat { + fn default() -> Self { + Self::new() + } +} + impl ArrayRepeat { pub fn new() -> Self { Self { @@ -69,20 +106,27 @@ impl ScalarUDFImpl for ArrayRepeat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", + Ok(List(Arc::new(Field::new_list_field( arg_types[0].clone(), true, )))) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_repeat_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_repeat SQL function @@ -156,7 +200,7 @@ fn general_repeat( let values = compute::concat(&new_values)?; Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), + Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::from_lengths(count_vec), values, None, @@ -207,7 +251,7 @@ fn general_list_repeat( let repeated_array = arrow_array::make_array(data); let list_arr = GenericListArray::::try_new( - Arc::new(Field::new("item", value_type.clone(), true)), + Arc::new(Field::new_list_field(value_type.clone(), true)), OffsetBuffer::::from_lengths(vec![original_data.len(); count]), repeated_array, None, @@ -224,7 +268,7 @@ fn general_list_repeat( let values = compute::concat(&new_values)?; Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), + Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::::from_lengths(lengths), values, None, diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 46a2e078aa4cd..0d3db07c647f6 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -27,7 +27,10 @@ use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use crate::utils::compare_element_to_list; use crate::utils::make_scalar_function; @@ -55,12 +58,37 @@ make_udf_expr_and_func!(ArrayReplaceAll, array_replace_all_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Replaces the first occurrence of the specified element with another specified element.", + syntax_example = "array_replace(array, from, to)", + sql_example = r#"```sql +> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5); ++--------------------------------------------------------+ +| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++--------------------------------------------------------+ +| [1, 5, 2, 3, 2, 1, 4] | ++--------------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "from", description = "Initial element."), + argument(name = "to", description = "Final element.") +)] #[derive(Debug)] -pub(super) struct ArrayReplace { +pub struct ArrayReplace { signature: Signature, aliases: Vec, } +impl Default for ArrayReplace { + fn default() -> Self { + Self::new() + } +} + impl ArrayReplace { pub fn new() -> Self { Self { @@ -87,15 +115,43 @@ impl ScalarUDFImpl for ArrayReplace { Ok(args[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_replace_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Replaces the first `max` occurrences of the specified element with another specified element.", + syntax_example = "array_replace_n(array, from, to, max)", + sql_example = r#"```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "from", description = "Initial element."), + argument(name = "to", description = "Final element."), + argument(name = "max", description = "Number of first occurrences to replace.") +)] #[derive(Debug)] pub(super) struct ArrayReplaceN { signature: Signature, @@ -128,15 +184,42 @@ impl ScalarUDFImpl for ArrayReplaceN { Ok(args[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_replace_n_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Replaces all occurrences of the specified element with another specified element.", + syntax_example = "array_replace_all(array, from, to)", + sql_example = r#"```sql +> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5); ++------------------------------------------------------------+ +| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++------------------------------------------------------------+ +| [1, 5, 5, 3, 5, 1, 4] | ++------------------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "from", description = "Initial element."), + argument(name = "to", description = "Final element.") +)] #[derive(Debug)] pub(super) struct ArrayReplaceAll { signature: Signature, @@ -169,13 +252,21 @@ impl ScalarUDFImpl for ArrayReplaceAll { Ok(args[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_replace_all_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences @@ -282,7 +373,7 @@ fn general_replace( let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), + Arc::new(Field::new_list_field(list_array.value_type(), true)), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), Some(NullBuffer::new(valid.finish())), diff --git a/datafusion/functions-nested/src/resize.rs b/datafusion/functions-nested/src/resize.rs index 83c545a26eb24..a2b95debd2063 100644 --- a/datafusion/functions-nested/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -19,13 +19,18 @@ use crate::utils::make_scalar_function; use arrow::array::{Capacities, MutableArrayData}; -use arrow_array::{ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; -use arrow_buffer::{ArrowNativeType, OffsetBuffer}; +use arrow_array::{ + new_null_array, Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait, +}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -37,12 +42,40 @@ make_udf_expr_and_func!( array_resize_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.", + syntax_example = "array_resize(array, size, value)", + sql_example = r#"```sql +> select array_resize([1, 2, 3], 5, 0); ++-------------------------------------+ +| array_resize(List([1,2,3],5,0)) | ++-------------------------------------+ +| [1, 2, 3, 0, 0] | ++-------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "size", description = "New size of given array."), + argument( + name = "value", + description = "Defines new elements' value or empty if value is not set." + ) +)] #[derive(Debug)] -pub(super) struct ArrayResize { +pub struct ArrayResize { signature: Signature, aliases: Vec, } +impl Default for ArrayResize { + fn default() -> Self { + Self::new() + } +} + impl ArrayResize { pub fn new() -> Self { Self { @@ -75,13 +108,21 @@ impl ScalarUDFImpl for ArrayResize { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_resize_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_resize SQL function @@ -90,6 +131,23 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { return exec_err!("array_resize needs two or three arguments"); } + let array = &arg[0]; + + // Checks if entire array is null + if array.null_count() == array.len() { + let return_type = match array.data_type() { + List(field) => List(Arc::clone(field)), + LargeList(field) => LargeList(Arc::clone(field)), + _ => { + return exec_err!( + "array_resize does not support type '{:?}'.", + array.data_type() + ) + } + }; + return Ok(new_null_array(&return_type, array.len())); + } + let new_len = as_int64_array(&arg[1])?; let new_element = if arg.len() == 3 { Some(Arc::clone(&arg[2])) @@ -140,7 +198,16 @@ fn general_list_resize>( capacity, ); + let mut null_builder = BooleanBufferBuilder::new(array.len()); + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + if array.is_null(row_index) { + null_builder.append(false); + offsets.push(offsets[row_index]); + continue; + } + null_builder.append(true); + let count = count_array.value(row_index).to_usize().ok_or_else(|| { internal_datafusion_err!("array_resize: failed to convert size to usize") })?; @@ -167,10 +234,12 @@ fn general_list_resize>( } let data = mutable.freeze(); + let null_bit_buffer: NullBuffer = null_builder.finish().into(); + Ok(Arc::new(GenericListArray::::try_new( Arc::clone(field), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), - None, + Some(null_bit_buffer), )?)) } diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index 581caf5daf2b8..8538ba5cac12b 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -25,7 +25,10 @@ use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -37,12 +40,35 @@ make_udf_expr_and_func!( array_reverse_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the array with the order of the elements reversed.", + syntax_example = "array_reverse(array)", + sql_example = r#"```sql +> select array_reverse([1, 2, 3, 4]); ++------------------------------------------------------------+ +| array_reverse(List([1, 2, 3, 4])) | ++------------------------------------------------------------+ +| [4, 3, 2, 1] | ++------------------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct ArrayReverse { +pub struct ArrayReverse { signature: Signature, aliases: Vec, } +impl Default for ArrayReverse { + fn default() -> Self { + Self::new() + } +} + impl ArrayReverse { pub fn new() -> Self { Self { @@ -69,13 +95,21 @@ impl ScalarUDFImpl for ArrayReverse { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_reverse_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_reverse SQL function diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 1de9c264ddc2c..079e0e3ed214a 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -27,7 +27,10 @@ use arrow::row::{RowConverter, SortField}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use itertools::Itertools; use std::any::Any; use std::collections::HashSet; @@ -59,12 +62,45 @@ make_udf_expr_and_func!( array_distinct_udf ); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + syntax_example = "array_union(array1, array2)", + sql_example = r#"```sql +> select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +> select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6, 7, 8] | ++----------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] -pub(super) struct ArrayUnion { +pub struct ArrayUnion { signature: Signature, aliases: Vec, } +impl Default for ArrayUnion { + fn default() -> Self { + Self::new() + } +} + impl ArrayUnion { pub fn new() -> Self { Self { @@ -95,15 +131,50 @@ impl ScalarUDFImpl for ArrayUnion { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_union_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array of elements in the intersection of array1 and array2.", + syntax_example = "array_intersect(array1, array2)", + sql_example = r#"```sql +> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [3, 4] | ++----------------------------------------------------+ +> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [] | ++----------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub(super) struct ArrayIntersect { signature: Signature, @@ -140,15 +211,40 @@ impl ScalarUDFImpl for ArrayIntersect { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_intersect_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns distinct values from the array after removing duplicates.", + syntax_example = "array_distinct(array)", + sql_example = r#"```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] #[derive(Debug)] pub(super) struct ArrayDistinct { signature: Signature, @@ -179,13 +275,10 @@ impl ScalarUDFImpl for ArrayDistinct { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", + List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( + Field::new_list_field(field.data_type().clone(), true), + ))), + LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( field.data_type().clone(), true, )))), @@ -195,13 +288,21 @@ impl ScalarUDFImpl for ArrayDistinct { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_distinct_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// array_distinct SQL function @@ -252,10 +353,10 @@ fn generic_set_lists( set_op: SetOp, ) -> Result { if matches!(l.value_type(), Null) { - let field = Arc::new(Field::new("item", r.value_type(), true)); + let field = Arc::new(Field::new_list_field(r.value_type(), true)); return general_array_distinct::(r, &field); } else if matches!(r.value_type(), Null) { - let field = Arc::new(Field::new("item", l.value_type(), true)); + let field = Arc::new(Field::new_list_field(l.value_type(), true)); return general_array_distinct::(l, &field); } @@ -412,17 +513,25 @@ fn general_array_distinct( array: &GenericListArray, field: &FieldRef, ) -> Result { + if array.is_empty() { + return Ok(Arc::new(array.clone()) as ArrayRef); + } let dt = array.value_type(); let mut offsets = Vec::with_capacity(array.len()); offsets.push(OffsetSize::usize_as(0)); let mut new_arrays = Vec::with_capacity(array.len()); let converter = RowConverter::new(vec![SortField::new(dt)])?; // distinct for each list in ListArray - for arr in array.iter().flatten() { + for arr in array.iter() { + let last_offset: OffsetSize = offsets.last().copied().unwrap(); + let Some(arr) = arr else { + // Add same offset for null + offsets.push(last_offset); + continue; + }; let values = converter.convert_columns(&[arr])?; // sort elements in list and remove duplicates let rows = values.iter().sorted().dedup().collect::>(); - let last_offset: OffsetSize = offsets.last().copied().unwrap(); offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; let array = match arrays.first() { @@ -433,6 +542,9 @@ fn general_array_distinct( }; new_arrays.push(array); } + if new_arrays.is_empty() { + return Ok(Arc::new(array.clone()) as ArrayRef); + } let offsets = OffsetBuffer::new(offsets.into()); let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); let values = compute::concat(&new_arrays_ref)?; @@ -440,6 +552,7 @@ fn general_array_distinct( Arc::clone(field), offsets, values, - None, + // Keep the list nulls + array.nulls().cloned(), )?)) } diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 9c1ae507636c9..3f2ad57cbe860 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -25,7 +25,10 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -37,12 +40,50 @@ make_udf_expr_and_func!( array_sort_udf ); +/// Implementation of `array_sort` function +/// +/// `array_sort` sorts the elements of an array +/// +/// # Example +/// +/// `array_sort([3, 1, 2])` returns `[1, 2, 3]` +#[user_doc( + doc_section(label = "Array Functions"), + description = "Sort array.", + syntax_example = "array_sort(array, desc, nulls_first)", + sql_example = r#"```sql +> select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "desc", + description = "Whether to sort in descending order(`ASC` or `DESC`)." + ), + argument( + name = "nulls_first", + description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)." + ) +)] #[derive(Debug)] -pub(super) struct ArraySort { +pub struct ArraySort { signature: Signature, aliases: Vec, } +impl Default for ArraySort { + fn default() -> Self { + Self::new() + } +} + impl ArraySort { pub fn new() -> Self { Self { @@ -67,13 +108,10 @@ impl ScalarUDFImpl for ArraySort { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", + List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( + Field::new_list_field(field.data_type().clone(), true), + ))), + LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( field.data_type().clone(), true, )))), @@ -83,13 +121,21 @@ impl ScalarUDFImpl for ArraySort { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_sort_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_sort SQL function @@ -152,7 +198,7 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { .collect::>(); let list_arr = ListArray::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(array_lengths), Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index 2dc0a55e69519..bbe1dc2a01e9a 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -26,47 +26,34 @@ use arrow::array::{ use arrow::datatypes::{DataType, Field}; use datafusion_expr::TypeSignature; -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{ + internal_datafusion_err, not_impl_err, plan_err, DataFusionError, Result, +}; -use std::any::{type_name, Any}; +use std::any::Any; -use crate::utils::{downcast_arg, make_scalar_function}; +use crate::utils::make_scalar_function; use arrow::compute::cast; +use arrow_array::builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder}; +use arrow_array::cast::AsArray; +use arrow_array::{GenericStringArray, StringArrayType, StringViewArray}; use arrow_schema::DataType::{ - Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, -}; -use datafusion_common::cast::{ - as_generic_string_array, as_large_list_array, as_list_array, as_string_array, + Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; +use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::exec_err; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::{downcast_arg, downcast_named_arg}; +use datafusion_macros::user_doc; use std::sync::Arc; -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - macro_rules! call_array_function { ($DATATYPE:expr, false) => { match $DATATYPE { DataType::Utf8 => array_function!(StringArray), + DataType::Utf8View => array_function!(StringViewArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), DataType::Float32 => array_function!(Float32Array), @@ -86,6 +73,7 @@ macro_rules! call_array_function { match $DATATYPE { DataType::List(_) => array_function!(ListArray), DataType::Utf8 => array_function!(StringArray), + DataType::Utf8View => array_function!(StringViewArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), DataType::Float32 => array_function!(Float32Array), @@ -103,6 +91,27 @@ macro_rules! call_array_function { }}; } +macro_rules! to_string { + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + for x in arr { + match x { + Some(x) => { + $ARG.push_str(&x.to_string()); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } + } + } + } + Ok($ARG) + }}; +} + // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( ArrayToString, @@ -111,12 +120,41 @@ make_udf_expr_and_func!( "converts each element to its text representation.", // doc array_to_string_udf // internal function name ); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Converts each element to its text representation.", + syntax_example = "array_to_string(array, delimiter[, null_string])", + sql_example = r#"```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "delimiter", description = "Array element separator."), + argument( + name = "null_string", + description = "Optional. String to replace null values in the array. If not provided, nulls will be handled by default behavior." + ) +)] #[derive(Debug)] -pub(super) struct ArrayToString { +pub struct ArrayToString { signature: Signature, aliases: Vec, } +impl Default for ArrayToString { + fn default() -> Self { + Self::new() + } +} + impl ArrayToString { pub fn new() -> Self { Self { @@ -152,13 +190,21 @@ impl ScalarUDFImpl for ArrayToString { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(array_to_string_inner)(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } make_udf_expr_and_func!( @@ -168,6 +214,32 @@ make_udf_expr_and_func!( "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc string_to_array_udf // internal function name ); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Splits a string into an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL.", + syntax_example = "string_to_array(str, delimiter[, null_str])", + sql_example = r#"```sql +> select string_to_array('abc##def', '##'); ++-----------------------------------+ +| string_to_array(Utf8('abc##def')) | ++-----------------------------------+ +| ['abc', 'def'] | ++-----------------------------------+ +> select string_to_array('abc def', ' ', 'def'); ++---------------------------------------------+ +| string_to_array(Utf8('abc def'), Utf8(' '), Utf8('def')) | ++---------------------------------------------+ +| ['abc', NULL] | ++---------------------------------------------+ +```"#, + argument(name = "str", description = "String expression to split."), + argument(name = "delimiter", description = "Delimiter string to split on."), + argument( + name = "null_str", + description = "Substring values to be replaced with `NULL`." + ) +)] #[derive(Debug)] pub(super) struct StringToArray { signature: Signature, @@ -178,10 +250,7 @@ impl StringToArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), - TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), aliases: vec![String::from("string_to_list")], @@ -204,23 +273,27 @@ impl ScalarUDFImpl for StringToArray { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { - List(Arc::new(Field::new("item", arg_types[0].clone(), true))) + Utf8 | Utf8View | LargeUtf8 => { + List(Arc::new(Field::new_list_field(arg_types[0].clone(), true))) } _ => { return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." + "The string_to_array function can only accept Utf8, Utf8View or LargeUtf8." ); } }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { - Utf8 => make_scalar_function(string_to_array_inner::)(args), + Utf8 | Utf8View => make_scalar_function(string_to_array_inner::)(args), LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), other => { - exec_err!("unsupported type for string_to_array function as {other}") + exec_err!("unsupported type for string_to_array function as {other:?}") } } } @@ -228,6 +301,10 @@ impl ScalarUDFImpl for StringToArray { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Array_to_string SQL function @@ -238,13 +315,22 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { let arr = &args[0]; - let delimiters = as_string_array(&args[1])?; - let delimiters: Vec> = delimiters.iter().collect(); + let delimiters: Vec> = match args[1].data_type() { + Utf8 => args[1].as_string::().iter().collect(), + Utf8View => args[1].as_string_view().iter().collect(), + LargeUtf8 => args[1].as_string::().iter().collect(), + other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") + }; let mut null_string = String::from(""); let mut with_null_string = false; if args.len() == 3 { - null_string = as_string_array(&args[2])?.value(0).to_string(); + null_string = match args[2].data_type() { + Utf8 => args[2].as_string::().value(0).to_string(), + Utf8View => args[2].as_string_view().value(0).to_string(), + LargeUtf8 => args[2].as_string::().value(0).to_string(), + other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") + }; with_null_string = true; } @@ -404,20 +490,173 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { /// String_to_array SQL function /// Splits string at occurrences of delimiter and returns an array of parts /// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { +fn string_to_array_inner(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("string_to_array expects two or three arguments"); } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - )); + match args[0].data_type() { + Utf8 => { + let string_array = args[0].as_string::(); + let builder = StringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); + string_to_array_inner_2::<&GenericStringArray, StringBuilder>(args, string_array, builder) + } + Utf8View => { + let string_array = args[0].as_string_view(); + let builder = StringViewBuilder::with_capacity(string_array.len()); + string_to_array_inner_2::<&StringViewArray, StringViewBuilder>(args, string_array, builder) + } + LargeUtf8 => { + let string_array = args[0].as_string::(); + let builder = LargeStringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); + string_to_array_inner_2::<&GenericStringArray, LargeStringBuilder>(args, string_array, builder) + } + other => exec_err!("unsupported type for first argument to string_to_array function as {other:?}") + } +} + +fn string_to_array_inner_2<'a, StringArrType, StringBuilderType>( + args: &'a [ArrayRef], + string_array: StringArrType, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + match args[1].data_type() { + Utf8 => { + let delimiter_array = args[1].as_string::(); + if args.len() == 2 { + string_to_array_impl::< + StringArrType, + &GenericStringArray, + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::, + StringBuilderType>(args, string_array, delimiter_array, string_builder) + } + } + Utf8View => { + let delimiter_array = args[1].as_string_view(); + + if args.len() == 2 { + string_to_array_impl::< + StringArrType, + &StringViewArray, + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::(args, string_array, delimiter_array, string_builder) + } + } + LargeUtf8 => { + let delimiter_array = args[1].as_string::(); + if args.len() == 2 { + string_to_array_impl::< + StringArrType, + &GenericStringArray, + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::, + StringBuilderType>(args, string_array, delimiter_array, string_builder) + } + } + other => exec_err!("unsupported type for second argument to string_to_array function as {other:?}") + } +} + +fn string_to_array_inner_3<'a, StringArrType, DelimiterArrType, StringBuilderType>( + args: &'a [ArrayRef], + string_array: StringArrType, + delimiter_array: DelimiterArrType, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + match args[2].data_type() { + Utf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &GenericStringArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &StringViewArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &GenericStringArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + other => { + exec_err!("unsupported type for string_to_array function as {other:?}") + } + } +} - match args.len() { - 2 => { +fn string_to_array_impl< + 'a, + StringArrType, + DelimiterArrType, + NullValueArrType, + StringBuilderType, +>( + string_array: StringArrType, + delimiter_array: DelimiterArrType, + null_value_array: Option, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + NullValueArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + let mut list_builder = ListBuilder::new(string_builder); + + match null_value_array { + None => { string_array.iter().zip(delimiter_array.iter()).for_each( |(string, delimiter)| { match (string, delimiter) { @@ -433,63 +672,90 @@ pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c); + list_builder.values().append_value(c.as_str()); }); list_builder.append(true); } _ => list_builder.append(false), // null value } }, - ); + ) } - - 3 => { - let null_value_array = as_generic_string_array::(&args[2])?; - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { + Some(null_value_array) => string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { list_builder.values().append_null(); } else { - list_builder.values().append_value(string); + list_builder.values().append_value(s); } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(c); - } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value + }); + list_builder.append(true); } - }); - } - _ => { - return exec_err!( - "Expect string_to_array function to take two or three parameters" - ) - } - } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c.as_str()); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }), + }; let list_array = list_builder.finish(); Ok(Arc::new(list_array) as ArrayRef) } + +trait StringArrayBuilderType: ArrayBuilder { + fn append_value(&mut self, val: &str); + + fn append_null(&mut self); +} + +impl StringArrayBuilderType for StringBuilder { + fn append_value(&mut self, val: &str) { + StringBuilder::append_value(self, val); + } + + fn append_null(&mut self) { + StringBuilder::append_null(self); + } +} + +impl StringArrayBuilderType for StringViewBuilder { + fn append_value(&mut self, val: &str) { + StringViewBuilder::append_value(self, val) + } + + fn append_null(&mut self) { + StringViewBuilder::append_null(self) + } +} + +impl StringArrayBuilderType for LargeStringBuilder { + fn append_value(&mut self, val: &str) { + LargeStringBuilder::append_value(self, val); + } + + fn append_null(&mut self) { + LargeStringBuilder::append_null(self); + } +} diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 5382346a0a811..621e296f6308c 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -28,23 +28,12 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::{Field, Fields}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, +}; -use core::any::type_name; -use datafusion_common::DataFusionError; use datafusion_expr::ColumnarValue; - -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} -pub(crate) use downcast_arg; +use datafusion_functions::{downcast_arg, downcast_named_arg}; pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); @@ -114,7 +103,7 @@ pub(crate) fn align_array_dimensions( let offsets = OffsetBuffer::::from_lengths(array_lengths); aligned_array = Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), offsets, aligned_array, None, @@ -274,27 +263,27 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::utils::array_into_list_array_nullable; + use datafusion_common::utils::SingleRowListArrayBuilder; /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] fn test_align_array_dimensions() { - let array1d_1 = + let array1d_1: ArrayRef = Arc::new(ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), Some(vec![Some(4), Some(5)]), ])); - let array1d_2 = + let array1d_2: ArrayRef = Arc::new(ListArray::from_iter_primitive::(vec![ Some(vec![Some(6), Some(7), Some(8)]), ])); - let array2d_1 = Arc::new(array_into_list_array_nullable( - Arc::clone(&array1d_1) as ArrayRef - )) as ArrayRef; - let array2d_2 = Arc::new(array_into_list_array_nullable( - Arc::clone(&array1d_2) as ArrayRef - )) as ArrayRef; + let array2d_1: ArrayRef = Arc::new( + SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(), + ); + let array2d_2 = Arc::new( + SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(), + ); let res = align_array_dimensions::(vec![ array1d_1.to_owned(), @@ -310,10 +299,11 @@ mod tests { expected_dim ); - let array3d_1 = Arc::new(array_into_list_array_nullable(array2d_1)) as ArrayRef; - let array3d_2 = array_into_list_array_nullable(array2d_2.to_owned()); - let res = - align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2)]).unwrap(); + let array3d_1: ArrayRef = + Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array()); + let array3d_2: ArrayRef = + Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array()); + let res = align_array_dimensions::(vec![array1d_1, array3d_2]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); diff --git a/datafusion/functions-table/Cargo.toml b/datafusion/functions-table/Cargo.toml new file mode 100644 index 0000000000000..f722d698f3d38 --- /dev/null +++ b/datafusion/functions-table/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-table" +description = "Traits and types for logical plans and expressions for DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_functions_table" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +async-trait = { workspace = true } +datafusion-catalog = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } +parking_lot = { workspace = true } +paste = "1.0.14" + +[dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/functions-table/LICENSE.txt b/datafusion/functions-table/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions-table/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions-table/NOTICE.txt b/datafusion/functions-table/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions-table/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/core/tests/sqllogictests/MOVED.md b/datafusion/functions-table/README.md similarity index 73% rename from datafusion/core/tests/sqllogictests/MOVED.md rename to datafusion/functions-table/README.md index dd70dab9d11f2..c4e7a5aff9993 100644 --- a/datafusion/core/tests/sqllogictests/MOVED.md +++ b/datafusion/functions-table/README.md @@ -17,4 +17,10 @@ under the License. --> -The SQL Logic Test code has moved to `datafusion/sqllogictest` +# DataFusion Table Function Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains table functions that can be used in DataFusion queries. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs new file mode 100644 index 0000000000000..d087df3451e02 --- /dev/null +++ b/datafusion/functions-table/src/generate_series.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion_catalog::Session; +use datafusion_catalog::TableFunctionImpl; +use datafusion_catalog::TableProvider; +use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; +use datafusion_physical_plan::ExecutionPlan; +use parking_lot::RwLock; +use std::fmt; +use std::sync::Arc; + +/// Indicates the arguments used for generating a series. +#[derive(Debug, Clone)] +enum GenSeriesArgs { + /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated. + ContainsNull, + /// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null. + AllNotNullArgs { start: i64, end: i64, step: i64 }, +} + +/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step +#[derive(Debug, Clone)] +struct GenerateSeriesTable { + schema: SchemaRef, + args: GenSeriesArgs, +} + +/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step +#[derive(Debug, Clone)] +struct GenerateSeriesState { + schema: SchemaRef, + start: i64, // Kept for display + end: i64, + step: i64, + batch_size: usize, + + /// Tracks current position when generating table + current: i64, +} + +impl GenerateSeriesState { + fn reach_end(&self, val: i64) -> bool { + if self.step > 0 { + return val > self.end; + } + + val < self.end + } +} + +/// Detail to display for 'Explain' plan +impl fmt::Display for GenerateSeriesState { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "generate_series: start={}, end={}, batch_size={}", + self.start, self.end, self.batch_size + ) + } +} + +impl LazyBatchGenerator for GenerateSeriesState { + fn generate_next_batch(&mut self) -> Result> { + let mut buf = Vec::with_capacity(self.batch_size); + while buf.len() < self.batch_size && !self.reach_end(self.current) { + buf.push(self.current); + self.current += self.step; + } + let array = Int64Array::from(buf); + + if array.is_empty() { + return Ok(None); + } + + let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(array)])?; + + Ok(Some(batch)) + } +} + +#[async_trait] +impl TableProvider for GenerateSeriesTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batch_size = state.config_options().execution.batch_size; + + let state = match self.args { + // if args have null, then return 0 row + GenSeriesArgs::ContainsNull => GenerateSeriesState { + schema: self.schema.clone(), + start: 0, + end: 0, + step: 1, + current: 1, + batch_size, + }, + GenSeriesArgs::AllNotNullArgs { start, end, step } => GenerateSeriesState { + schema: self.schema.clone(), + start, + end, + step, + current: start, + batch_size, + }, + }; + + Ok(Arc::new(LazyMemoryExec::try_new( + self.schema.clone(), + vec![Arc::new(RwLock::new(state))], + )?)) + } +} + +#[derive(Debug)] +pub struct GenerateSeriesFunc {} + +impl TableFunctionImpl for GenerateSeriesFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if exprs.is_empty() || exprs.len() > 3 { + return plan_err!("generate_series function requires 1 to 3 arguments"); + } + + let mut normalize_args = Vec::new(); + for expr in exprs { + match expr { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Null => {} + ScalarValue::Int64(Some(n)) => normalize_args.push(*n), + _ => return plan_err!("First argument must be an integer literal"), + }, + _ => return plan_err!("First argument must be an integer literal"), + }; + } + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])); + + if normalize_args.len() != exprs.len() { + // contain null + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull, + })); + } + + let (start, end, step) = match &normalize_args[..] { + [end] => (0, *end, 1), + [start, end] => (*start, *end, 1), + [start, end, step] => (*start, *end, *step), + _ => { + return plan_err!("generate_series function requires 1 to 3 arguments"); + } + }; + + if start > end && step > 0 { + return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series"); + } + + if start < end && step < 0 { + return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series"); + } + + if step == 0 { + return plan_err!("step cannot be zero"); + } + + Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::AllNotNullArgs { start, end, step }, + })) + } +} diff --git a/datafusion/functions-table/src/lib.rs b/datafusion/functions-table/src/lib.rs new file mode 100644 index 0000000000000..f5436f7bf8a6d --- /dev/null +++ b/datafusion/functions-table/src/lib.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod generate_series; + +use datafusion_catalog::TableFunction; +use std::sync::Arc; + +/// Returns all default table functions +pub fn all_default_table_functions() -> Vec> { + vec![generate_series()] +} + +/// Creates a singleton instance of a table function +/// - `$module`: A struct implementing `TableFunctionImpl` to create the function from +/// - `$name`: The name to give to the created function +/// +/// This is used to ensure creating the list of `TableFunction` only happens once. +#[macro_export] +macro_rules! create_udtf_function { + ($module:path, $name:expr) => { + paste::paste! { + pub fn [<$name:lower>]() -> Arc { + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { + std::sync::Arc::new(TableFunction::new( + $name.to_string(), + Arc::new($module {}), + )) + }); + std::sync::Arc::clone(&INSTANCE) + } + } + }; +} + +create_udtf_function!(generate_series::GenerateSeriesFunc, "generate_series"); diff --git a/datafusion/functions-window-common/Cargo.toml b/datafusion/functions-window-common/Cargo.toml index 98b6f8c6dba5f..b5df212b7d2ad 100644 --- a/datafusion/functions-window-common/Cargo.toml +++ b/datafusion/functions-window-common/Cargo.toml @@ -39,3 +39,4 @@ path = "src/lib.rs" [dependencies] datafusion-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/functions-window-common/LICENSE.txt b/datafusion/functions-window-common/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions-window-common/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions-window-common/NOTICE.txt b/datafusion/functions-window-common/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions-window-common/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs new file mode 100644 index 0000000000000..1d99fe7acf152 --- /dev/null +++ b/datafusion/functions-window-common/src/expr.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to user-defined window function +#[derive(Debug, Default)] +pub struct ExpressionArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], +} + +impl<'a> ExpressionArgs<'a> { + /// Create an instance of [`ExpressionArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + ) -> Self { + Self { + input_exprs, + input_types, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } +} diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 2e4bcbbc83b9a..da8d096da5621 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -18,4 +18,6 @@ //! Common user-defined window functionality for [DataFusion] //! //! [DataFusion]: +pub mod expr; pub mod field; +pub mod partition; diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs new file mode 100644 index 0000000000000..64786d2fe7c70 --- /dev/null +++ b/datafusion/functions-window-common/src/partition.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to created user-defined window function state +/// during physical execution. +#[derive(Debug, Default)] +pub struct PartitionEvaluatorArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], + /// Set to `true` if the user-defined window function is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is specified. + ignore_nulls: bool, +} + +impl<'a> PartitionEvaluatorArgs<'a> { + /// Create an instance of [`PartitionEvaluatorArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// * `is_reversed` - Set to `true` if and only if the user-defined + /// window function is reversible and is reversed. + /// * `ignore_nulls` - Set to `true` when `IGNORE NULLS` is + /// specified. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + is_reversed: bool, + ignore_nulls: bool, + ) -> Self { + Self { + input_exprs, + input_types, + is_reversed, + ignore_nulls, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } + + /// Returns `true` when the user-defined window function is + /// reversed, otherwise returns `false`. + pub fn is_reversed(&self) -> bool { + self.is_reversed + } + + /// Returns `true` when `IGNORE NULLS` is specified, otherwise + /// returns `false`. + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } +} diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 952e5720c77c1..fc1bc51bcc665 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -39,8 +39,11 @@ path = "src/lib.rs" [dependencies] datafusion-common = { workspace = true } +datafusion-doc = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-macros = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" diff --git a/datafusion/functions-window/LICENSE.txt b/datafusion/functions-window/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions-window/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions-window/NOTICE.txt b/datafusion/functions-window/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions-window/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs similarity index 52% rename from datafusion/physical-expr/src/window/cume_dist.rs rename to datafusion/functions-window/src/cume_dist.rs index 9720187ea83dd..d777f7932b0e6 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -15,65 +15,91 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `cume_dist` that can evaluated -//! at runtime during query execution - -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::array::Float64Array; -use arrow::datatypes::{DataType, Field}; +//! `cume_dist` window function implementation + +use datafusion_common::arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_macros::user_doc; +use field::WindowUDFFieldArgs; use std::any::Any; +use std::fmt::Debug; use std::iter; use std::ops::Range; use std::sync::Arc; +define_udwf_and_expr!( + CumeDist, + cume_dist, + "Calculates the cumulative distribution of a value in a group of values." +); + /// CumeDist calculates the cume_dist in the window function with order by +#[user_doc( + doc_section(label = "Ranking Functions"), + description = "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", + syntax_example = "cume_dist()" +)] #[derive(Debug)] pub struct CumeDist { - name: String, - /// Output data type - data_type: DataType, + signature: Signature, } -/// Create a cume_dist window function -pub fn cume_dist(name: String, data_type: &DataType) -> CumeDist { - CumeDist { - name, - data_type: data_type.clone(), +impl CumeDist { + pub fn new() -> Self { + Self { + signature: Signature::nullary(Volatility::Immutable), + } } } -impl BuiltInWindowFunctionExpr for CumeDist { +impl Default for CumeDist { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for CumeDist { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) + fn name(&self) -> &str { + "cume_dist" } - fn expressions(&self) -> Vec> { - vec![] + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false)) } - fn create_evaluator(&self) -> Result> { - Ok(Box::new(CumeDistEvaluator {})) + fn documentation(&self) -> Option<&Documentation> { + self.doc() } } -#[derive(Debug)] +#[derive(Debug, Default)] pub(crate) struct CumeDistEvaluator; impl PartitionEvaluator for CumeDistEvaluator { + /// Computes the cumulative distribution for all rows in the partition fn evaluate_all_with_rank( &self, num_rows: usize, @@ -105,40 +131,29 @@ mod tests { use super::*; use datafusion_common::cast::as_float64_array; - fn test_i32_result( - expr: &CumeDist, + fn test_f64_result( num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; + let evaluator = CumeDistEvaluator; + let result = evaluator.evaluate_all_with_rank(num_rows, &ranks)?; let result = as_float64_array(&result)?; - let result = result.values(); - assert_eq!(expected, *result); + let result = result.values().to_vec(); + assert_eq!(expected, result); Ok(()) } #[test] #[allow(clippy::single_range_in_vec_init)] fn test_cume_dist() -> Result<()> { - let r = cume_dist("arr".into(), &DataType::Float64); - - let expected = vec![0.0; 0]; - test_i32_result(&r, 0, vec![], expected)?; - - let expected = vec![1.0; 1]; - test_i32_result(&r, 1, vec![0..1], expected)?; + test_f64_result(0, vec![], vec![])?; - let expected = vec![1.0; 2]; - test_i32_result(&r, 2, vec![0..2], expected)?; + test_f64_result(1, vec![0..1], vec![1.0])?; - let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; + test_f64_result(2, vec![0..2], vec![1.0, 1.0])?; - let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; + test_f64_result(4, vec![0..2, 2..4], vec![0.5, 0.5, 1.0, 1.0])?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs similarity index 55% rename from datafusion/physical-expr/src/window/lead_lag.rs rename to datafusion/functions-window/src/lead_lag.rs index 1656b7c3033a4..fc55151a0ba72 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -15,125 +15,317 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `lead` and `lag` that can evaluated -//! at runtime during query execution -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +//! `lead` and `lag` window function implementations + +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +use datafusion_expr::{ + Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, + Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; use std::ops::{Neg, Range}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +get_or_init_udwf!( + Lag, + lag, + "Returns the row value that precedes the current row by a specified \ + offset within partition. If no such row exists, then returns the \ + default value.", + WindowShift::lag +); +get_or_init_udwf!( + Lead, + lead, + "Returns the value from a row that follows the current row by a \ + specified offset within the partition. If no such row exists, then \ + returns the default value.", + WindowShift::lead +); + +/// Create an expression to represent the `lag` window function +/// +/// returns value evaluated at the row that is offset rows before the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lag( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lag_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +/// Create an expression to represent the `lead` window function +/// +/// returns value evaluated at the row that is offset rows after the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lead( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +#[derive(Debug)] +enum WindowShiftKind { + Lag, + Lead, +} + +impl WindowShiftKind { + fn name(&self) -> &'static str { + match self { + WindowShiftKind::Lag => "lag", + WindowShiftKind::Lead => "lead", + } + } + + /// In [`WindowShiftEvaluator`] a positive offset is used to signal + /// computation of `lag()`. So here we negate the input offset + /// value when computing `lead()`. + fn shift_offset(&self, value: Option) -> i64 { + match self { + WindowShiftKind::Lag => value.unwrap_or(1), + WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1), + } + } +} /// window shift expression #[derive(Debug)] pub struct WindowShift { - name: String, - /// Output data type - data_type: DataType, - shift_offset: i64, - expr: Arc, - default_value: ScalarValue, - ignore_nulls: bool, + signature: Signature, + kind: WindowShiftKind, } impl WindowShift { - /// Get shift_offset of window shift expression - pub fn get_shift_offset(&self) -> i64 { - self.shift_offset + fn new(kind: WindowShiftKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + kind, + } } - /// Get the default_value for window shift expression. - pub fn get_default_value(&self) -> ScalarValue { - self.default_value.clone() + pub fn lag() -> Self { + Self::new(WindowShiftKind::Lag) } -} -/// lead() window function -pub fn lead( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), - expr, - default_value, - ignore_nulls, + pub fn lead() -> Self { + Self::new(WindowShiftKind::Lead) } } -/// lag() window function -pub fn lag( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.unwrap_or(1), - expr, - default_value, - ignore_nulls, - } +static LAG_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lag_doc() -> &'static Documentation { + LAG_DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", "lag(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows back \ + the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + }) +} + +static LEAD_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lead_doc() -> &'static Documentation { + LEAD_DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_ANALYTICAL, + "Returns value evaluated at the row that is offset rows after the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + "lead(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows \ + forward the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + }) } -impl BuiltInWindowFunctionExpr for WindowShift { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for WindowShift { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + fn name(&self) -> &str { + self.kind.name() } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name + /// Handles the case where `NULL` expression is passed as an + /// argument to `lead`/`lag`. The type is refined depending + /// on the default value argument. + /// + /// For more details see: + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + parse_expr(expr_args.input_exprs(), expr_args.input_types()) + .into_iter() + .collect::>() } - fn create_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let shift_offset = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some)) + .map(|n| self.kind.shift_offset(n)) + .map(|offset| { + if partition_evaluator_args.is_reversed() { + -offset + } else { + offset + } + })?; + let default_value = parse_default_value( + partition_evaluator_args.input_exprs(), + partition_evaluator_args.input_types(), + )?; + Ok(Box::new(WindowShiftEvaluator { - shift_offset: self.shift_offset, - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, + shift_offset, + default_value, + ignore_nulls: partition_evaluator_args.ignore_nulls(), non_null_offsets: VecDeque::new(), })) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.clone(), - data_type: self.data_type.clone(), - shift_offset: -self.shift_offset, - expr: Arc::clone(&self.expr), - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, - })) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = parse_expr_type(field_args.input_types())?; + + Ok(Field::new(field_args.name(), return_type, true)) + } + + fn reverse_expr(&self) -> ReversedUDWF { + match self.kind { + WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()), + WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + match self.kind { + WindowShiftKind::Lag => Some(get_lag_doc()), + WindowShiftKind::Lead => Some(get_lead_doc()), + } + } +} + +/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to +/// refine it by matching it with the type of the default value. +/// +/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null` +/// is refined into `ScalarValue::Boolean(None)`. Only the type is +/// refined, the expression value remains `NULL`. +/// +/// When the window function is evaluated with `NULL` expression +/// this guarantees that the type matches with that of the default +/// value. +/// +/// For more details see: +fn parse_expr( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result> { + assert!(!input_exprs.is_empty()); + assert!(!input_types.is_empty()); + + let expr = Arc::clone(input_exprs.first().unwrap()); + let expr_type = input_types.first().unwrap(); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr); + } + + let default_value = get_scalar_value_from_args(input_exprs, 2)?; + default_value.map_or(Ok(expr), |value| { + ScalarValue::try_from(&value.data_type()).map(|v| { + Arc::new(datafusion_physical_expr::expressions::Literal::from(v)) + as Arc + }) + }) +} + +/// Returns the data type of the default value(if provided) when the +/// expression is `NULL`. +/// +/// Otherwise, returns the expression type unchanged. +fn parse_expr_type(input_types: &[DataType]) -> Result { + assert!(!input_types.is_empty()); + let expr_type = input_types.first().unwrap_or(&DataType::Null); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr_type.clone()); } + + let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); + Ok(default_value_type.clone()) +} + +/// Handles type coercion and null value refinement for default value +/// argument depending on the data type of the input expression. +fn parse_default_value( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result { + let expr_type = parse_expr_type(input_types)?; + let unparsed = get_scalar_value_from_args(input_exprs, 2)?; + + unparsed + .filter(|v| !v.data_type().is_null()) + .map(|v| v.cast_to(&expr_type)) + .unwrap_or(ScalarValue::try_from(expr_type)) } #[derive(Debug)] -pub(crate) struct WindowShiftEvaluator { +struct WindowShiftEvaluator { shift_offset: i64, default_value: ScalarValue, ignore_nulls: bool, @@ -205,7 +397,7 @@ fn shift_with_default_value( offset: i64, default_value: &ScalarValue, ) -> Result { - use arrow::compute::concat; + use datafusion_common::arrow::compute::concat; let value_len = array.len() as i64; if offset == 0 { @@ -402,19 +594,22 @@ impl PartitionEvaluator for WindowShiftEvaluator { #[cfg(test)] mod tests { use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; + use arrow::array::*; use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + fn test_i32_result( + expr: WindowShift, + partition_evaluator_args: PartitionEvaluatorArgs, + expected: Int32Array, + ) -> Result<()> { let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let values = expr.evaluate_args(&batch)?; + let num_rows = values.len(); let result = expr - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; + .partition_evaluator(partition_evaluator_args)? + .evaluate_all(&values, num_rows)?; let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) @@ -466,16 +661,12 @@ mod tests { } #[test] - fn lead_lag_window_shift() -> Result<()> { + fn test_lead_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_result( - lead( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lead(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ Some(-2), Some(3), @@ -488,17 +679,16 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ None, Some(1), @@ -511,17 +701,24 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_with_default() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let shift_offset = + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))) as Arc; + let default_value = Arc::new(Literal::from(ScalarValue::Int32(Some(100)))) + as Arc; + + let input_exprs = &[expr, shift_offset, default_value]; + let input_types: &[DataType] = + &[DataType::Int32, DataType::Int32, DataType::Int32]; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Int32(Some(100)), - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), [ Some(100), Some(1), @@ -534,7 +731,6 @@ mod tests { ] .iter() .collect::(), - )?; - Ok(()) + ) } } diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 6e98bb0914461..9f8e54a0423b8 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -22,6 +22,7 @@ //! //! [DataFusion]: https://crates.io/crates/datafusion //! + use std::sync::Arc; use log::debug; @@ -31,16 +32,41 @@ use datafusion_expr::WindowUDF; #[macro_use] pub mod macros; + +pub mod cume_dist; +pub mod lead_lag; +pub mod nth_value; +pub mod ntile; +pub mod rank; pub mod row_number; +mod utils; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::cume_dist::cume_dist; + pub use super::lead_lag::lag; + pub use super::lead_lag::lead; + pub use super::nth_value::{first_value, last_value, nth_value}; + pub use super::ntile::ntile; + pub use super::rank::{dense_rank, percent_rank, rank}; pub use super::row_number::row_number; } /// Returns all default window functions pub fn all_default_window_functions() -> Vec> { - vec![row_number::row_number_udwf()] + vec![ + cume_dist::cume_dist_udwf(), + row_number::row_number_udwf(), + lead_lag::lead_udwf(), + lead_lag::lag_udwf(), + rank::rank_udwf(), + rank::dense_rank_udwf(), + rank::percent_rank_udwf(), + ntile::ntile_udwf(), + nth_value::first_value_udwf(), + nth_value::last_value_udwf(), + nth_value::nth_value_udwf(), + ] } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all( diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 843d8ecb38cc8..0a86ba6255330 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -45,6 +45,7 @@ /// # /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; /// # use datafusion_functions_window::get_or_init_udwf; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// Defines the `simple_udwf()` user-defined window function. /// get_or_init_udwf!( @@ -80,6 +81,7 @@ /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -97,21 +99,16 @@ macro_rules! get_or_init_udwf { ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { paste::paste! { - #[doc = concat!(" Singleton instance of [`", stringify!($OUT_FN_NAME), "`], ensures the user-defined")] - #[doc = concat!(" window function is only created once.")] - #[allow(non_upper_case_globals)] - static []: std::sync::OnceLock> = - std::sync::OnceLock::new(); - #[doc = concat!(" Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for [`", stringify!($OUT_FN_NAME), "`].")] #[doc = ""] #[doc = concat!(" ", $DOC)] pub fn [<$OUT_FN_NAME _udwf>]() -> std::sync::Arc { - [] - .get_or_init(|| { + // Singleton instance of UDWF, ensures it is only created once. + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { std::sync::Arc::new(datafusion_expr::WindowUDF::from($CTOR())) - }) - .clone() + }); + std::sync::Arc::clone(&INSTANCE) } } }; @@ -145,6 +142,8 @@ macro_rules! get_or_init_udwf { /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// /// # get_or_init_udwf!( /// # RowNumber, /// # row_number, @@ -193,6 +192,7 @@ macro_rules! get_or_init_udwf { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -216,6 +216,7 @@ macro_rules! get_or_init_udwf { /// # use datafusion_common::arrow::datatypes::Field; /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// # get_or_init_udwf!(Lead, lead, "user-defined window function"); /// # @@ -278,6 +279,7 @@ macro_rules! get_or_init_udwf { /// # } /// # fn partition_evaluator( /// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -296,7 +298,7 @@ macro_rules! create_udwf_expr { ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { paste::paste! { #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] - #[doc = concat!(" [`", stringify!($UDWF), "`] user-defined window function.")] + #[doc = concat!(" `", stringify!($UDWF), "` user-defined window function.")] #[doc = ""] #[doc = concat!(" ", $DOC)] pub fn $OUT_FN_NAME() -> datafusion_expr::Expr { @@ -309,7 +311,7 @@ macro_rules! create_udwf_expr { ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { paste::paste! { #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] - #[doc = concat!(" [`", stringify!($UDWF), "`] user-defined window function.")] + #[doc = concat!(" `", stringify!($UDWF), "` user-defined window function.")] #[doc = ""] #[doc = concat!(" ", $DOC)] pub fn $OUT_FN_NAME( @@ -355,6 +357,7 @@ macro_rules! create_udwf_expr { /// # /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; /// # use datafusion_functions_window::{define_udwf_and_expr, get_or_init_udwf, create_udwf_expr}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `simple_udwf()` user-defined window function. /// /// @@ -397,6 +400,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -415,6 +419,7 @@ macro_rules! create_udwf_expr { /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `row_number_udwf()` user-defined window function. /// /// @@ -459,6 +464,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -484,6 +490,7 @@ macro_rules! create_udwf_expr { /// # use datafusion_common::arrow::datatypes::Field; /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `lead_udwf()` user-defined window function. /// /// @@ -543,6 +550,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -570,6 +578,7 @@ macro_rules! create_udwf_expr { /// # use datafusion_common::arrow::datatypes::Field; /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `lead_udwf()` user-defined window function. /// /// @@ -630,6 +639,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs new file mode 100644 index 0000000000000..d59357e5f25cf --- /dev/null +++ b/datafusion/functions-window/src/nth_value.rs @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `nth_value` window function implementation + +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; + +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::OnceLock; + +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::{DataType, Field}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +use datafusion_expr::window_state::WindowAggState; +use datafusion_expr::{ + Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, + Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +get_or_init_udwf!( + First, + first_value, + "returns the first value in the window frame", + NthValue::first +); +get_or_init_udwf!( + Last, + last_value, + "returns the last value in the window frame", + NthValue::last +); +get_or_init_udwf!( + NthValue, + nth_value, + "returns the nth value in the window frame", + NthValue::nth +); + +/// Create an expression to represent the `first_value` window function +/// +pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr { + first_value_udwf().call(vec![arg]) +} + +/// Create an expression to represent the `last_value` window function +/// +pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr { + last_value_udwf().call(vec![arg]) +} + +/// Create an expression to represent the `nth_value` window function +/// +pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr { + nth_value_udwf().call(vec![arg, n.lit()]) +} + +/// Tag to differentiate special use cases of the NTH_VALUE built-in window function. +#[derive(Debug, Copy, Clone)] +pub enum NthValueKind { + First, + Last, + Nth, +} + +impl NthValueKind { + fn name(&self) -> &'static str { + match self { + NthValueKind::First => "first_value", + NthValueKind::Last => "last_value", + NthValueKind::Nth => "nth_value", + } + } +} + +#[derive(Debug)] +pub struct NthValue { + signature: Signature, + kind: NthValueKind, +} + +impl NthValue { + /// Create a new `nth_value` function + pub fn new(kind: NthValueKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(0), + TypeSignature::Any(1), + TypeSignature::Any(2), + ], + Volatility::Immutable, + ), + kind, + } + } + + pub fn first() -> Self { + Self::new(NthValueKind::First) + } + + pub fn last() -> Self { + Self::new(NthValueKind::Last) + } + pub fn nth() -> Self { + Self::new(NthValueKind::Nth) + } +} + +static FIRST_VALUE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_first_value_doc() -> &'static Documentation { + FIRST_VALUE_DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_ANALYTICAL, + "Returns value evaluated at the row that is the first row of the window \ + frame.", + "first_value(expression)", + ) + .with_argument("expression", "Expression to operate on") + .build() + }) +} + +static LAST_VALUE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_last_value_doc() -> &'static Documentation { + LAST_VALUE_DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_ANALYTICAL, + "Returns value evaluated at the row that is the last row of the window \ + frame.", + "last_value(expression)", + ) + .with_argument("expression", "Expression to operate on") + .build() + }) +} + +static NTH_VALUE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nth_value_doc() -> &'static Documentation { + NTH_VALUE_DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_ANALYTICAL, + "Returns value evaluated at the row that is the nth row of the window \ + frame (counting from 1); null if no such row.", + "nth_value(expression, n)", + ) + .with_argument( + "expression", + "The name the column of which nth \ + value to retrieve", + ) + .with_argument("n", "Integer. Specifies the n in nth") + .build() + }) +} + +impl WindowUDFImpl for NthValue { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.kind.name() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let state = NthValueState { + finalized_result: None, + kind: self.kind, + }; + + if !matches!(self.kind, NthValueKind::Nth) { + return Ok(Box::new(NthValueEvaluator { + state, + ignore_nulls: partition_evaluator_args.ignore_nulls(), + n: 0, + })); + } + + let n = + match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1) + .map_err(|_e| { + exec_datafusion_err!( + "Expected a signed integer literal for the second argument of nth_value") + })? + .map(get_signed_integer) + { + Some(Ok(n)) => { + if partition_evaluator_args.is_reversed() { + -n + } else { + n + } + } + _ => { + return exec_err!( + "Expected a signed integer literal for the second argument of nth_value" + ) + } + }; + + Ok(Box::new(NthValueEvaluator { + state, + ignore_nulls: partition_evaluator_args.ignore_nulls(), + n, + })) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let nullable = true; + let return_type = field_args.input_types().first().unwrap_or(&DataType::Null); + + Ok(Field::new(field_args.name(), return_type.clone(), nullable)) + } + + fn reverse_expr(&self) -> ReversedUDWF { + match self.kind { + NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()), + NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()), + NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + match self.kind { + NthValueKind::First => Some(get_first_value_doc()), + NthValueKind::Last => Some(get_last_value_doc()), + NthValueKind::Nth => Some(get_nth_value_doc()), + } + } +} + +#[derive(Debug, Clone)] +pub struct NthValueState { + // In certain cases, we can finalize the result early. Consider this usage: + // ``` + // FIRST_VALUE(increasing_col) OVER window AS my_first_value + // WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window + // ``` + // The result will always be the first entry in the table. We can store such + // early-finalizing results and then just reuse them as necessary. This opens + // opportunities to prune our datasets. + pub finalized_result: Option, + pub kind: NthValueKind, +} + +#[derive(Debug)] +pub(crate) struct NthValueEvaluator { + state: NthValueState, + ignore_nulls: bool, + n: i64, +} + +impl PartitionEvaluator for NthValueEvaluator { + /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), + /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we + /// can memoize the result. Once result is calculated, it will always stay + /// same. Hence, we do not need to keep past data as we process the entire + /// dataset. + fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { + let out = &state.out_col; + let size = out.len(); + let mut buffer_size = 1; + // Decide if we arrived at a final result yet: + let (is_prunable, is_reverse_direction) = match self.state.kind { + NthValueKind::First => { + let n_range = + state.window_frame_range.end - state.window_frame_range.start; + (n_range > 0 && size > 0, false) + } + NthValueKind::Last => (true, true), + NthValueKind::Nth => { + let n_range = + state.window_frame_range.end - state.window_frame_range.start; + match self.n.cmp(&0) { + Ordering::Greater => ( + n_range >= (self.n as usize) && size > (self.n as usize), + false, + ), + Ordering::Less => { + let reverse_index = (-self.n) as usize; + buffer_size = reverse_index; + // Negative index represents reverse direction. + (n_range >= reverse_index, true) + } + Ordering::Equal => (false, false), + } + } + }; + // Do not memoize results when nulls are ignored. + if is_prunable && !self.ignore_nulls { + if self.state.finalized_result.is_none() && !is_reverse_direction { + let result = ScalarValue::try_from_array(out, size - 1)?; + self.state.finalized_result = Some(result); + } + state.window_frame_range.start = + state.window_frame_range.end.saturating_sub(buffer_size); + } + Ok(()) + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + if let Some(ref result) = self.state.finalized_result { + Ok(result.clone()) + } else { + // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1. + let arr = &values[0]; + let n_range = range.end - range.start; + if n_range == 0 { + // We produce None if the window is empty. + return ScalarValue::try_from(arr.data_type()); + } + + // If null values exist and need to be ignored, extract the valid indices. + let valid_indices = if self.ignore_nulls { + // Calculate valid indices, inside the window frame boundaries. + let slice = arr.slice(range.start, n_range); + match slice.nulls() { + Some(nulls) => { + let valid_indices = nulls + .valid_indices() + .map(|idx| { + // Add offset `range.start` to valid indices, to point correct index in the original arr. + idx + range.start + }) + .collect::>(); + if valid_indices.is_empty() { + // If all values are null, return directly. + return ScalarValue::try_from(arr.data_type()); + } + Some(valid_indices) + } + None => None, + } + } else { + None + }; + match self.state.kind { + NthValueKind::First => { + if let Some(valid_indices) = &valid_indices { + ScalarValue::try_from_array(arr, valid_indices[0]) + } else { + ScalarValue::try_from_array(arr, range.start) + } + } + NthValueKind::Last => { + if let Some(valid_indices) = &valid_indices { + ScalarValue::try_from_array( + arr, + valid_indices[valid_indices.len() - 1], + ) + } else { + ScalarValue::try_from_array(arr, range.end - 1) + } + } + NthValueKind::Nth => { + match self.n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (self.n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } else if let Some(valid_indices) = valid_indices { + if index >= valid_indices.len() { + return ScalarValue::try_from(arr.data_type()); + } + ScalarValue::try_from_array(&arr, valid_indices[index]) + } else { + ScalarValue::try_from_array(arr, range.start + index) + } + } + Ordering::Less => { + let reverse_index = (-self.n) as usize; + if n_range < reverse_index { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } else if let Some(valid_indices) = valid_indices { + if reverse_index > valid_indices.len() { + return ScalarValue::try_from(arr.data_type()); + } + let new_index = + valid_indices[valid_indices.len() - reverse_index]; + ScalarValue::try_from_array(&arr, new_index) + } else { + ScalarValue::try_from_array( + arr, + range.start + n_range - reverse_index, + ) + } + } + Ordering::Equal => ScalarValue::try_from(arr.data_type()), + } + } + } + } + } + + fn supports_bounded_execution(&self) -> bool { + true + } + + fn uses_window_frame(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::*; + use datafusion_common::cast::as_int32_array; + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use std::sync::Arc; + + fn test_i32_result( + expr: NthValue, + partition_evaluator_args: PartitionEvaluatorArgs, + expected: Int32Array, + ) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let values = vec![arr]; + let mut ranges: Vec> = vec![]; + for i in 0..8 { + ranges.push(Range { + start: 0, + end: i + 1, + }) + } + let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?; + let result = ranges + .iter() + .map(|range| evaluator.evaluate(&values, range)) + .collect::>>()?; + let result = ScalarValue::iter_to_array(result.into_iter())?; + let result = as_int32_array(&result)?; + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn first_value() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_result( + NthValue::first(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + Int32Array::from(vec![1; 8]).iter().collect::(), + ) + } + + #[test] + fn last_value() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_result( + NthValue::last(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + Int32Array::from(vec![ + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + Some(8), + ]), + ) + } + + #[test] + fn nth_value_1() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let n_value = + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))) as Arc; + + test_i32_result( + NthValue::nth(), + PartitionEvaluatorArgs::new( + &[expr, n_value], + &[DataType::Int32], + false, + false, + ), + Int32Array::from(vec![1; 8]), + )?; + Ok(()) + } + + #[test] + fn nth_value_2() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let n_value = + Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) as Arc; + + test_i32_result( + NthValue::nth(), + PartitionEvaluatorArgs::new( + &[expr, n_value], + &[DataType::Int32], + false, + false, + ), + Int32Array::from(vec![ + None, + Some(-2), + Some(-2), + Some(-2), + Some(-2), + Some(-2), + Some(-2), + Some(-2), + ]), + )?; + Ok(()) + } +} diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs new file mode 100644 index 0000000000000..180f7ab02c03b --- /dev/null +++ b/datafusion/functions-window/src/ntile.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `ntile` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::utils::{ + get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, +}; +use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; +use datafusion_common::arrow::datatypes::{DataType, Field}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::{ + Documentation, Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_macros::user_doc; +use field::WindowUDFFieldArgs; + +get_or_init_udwf!( + Ntile, + ntile, + "integer ranging from 1 to the argument value, dividing the partition as equally as possible" +); + +pub fn ntile(arg: Expr) -> Expr { + ntile_udwf().call(vec![arg]) +} + +#[user_doc( + doc_section(label = "Ranking Functions"), + description = "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", + syntax_example = "ntile(expression)", + argument( + name = "expression", + description = "An integer describing the number groups the partition should be split into" + ) +)] +#[derive(Debug)] +pub struct Ntile { + signature: Signature, +} + +impl Ntile { + /// Create a new `ntile` function + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for Ntile { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for Ntile { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ntile" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if scalar_n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if scalar_n.is_unsigned() { + let n = get_unsigned_integer(scalar_n)?; + Ok(Box::new(NtileEvaluator { n })) + } else { + let n: i64 = get_signed_integer(scalar_n)?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Ok(Box::new(NtileEvaluator { n: n as u64 })) + } + } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let nullable = false; + + Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[derive(Debug)] +struct NtileEvaluator { + n: u64, +} + +impl PartitionEvaluator for NtileEvaluator { + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { + let num_rows = num_rows as u64; + let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); + for i in 0..num_rows { + let res = i * n / num_rows; + vec.push(res + 1) + } + Ok(Arc::new(UInt64Array::from(vec))) + } +} diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/functions-window/src/rank.rs similarity index 57% rename from datafusion/physical-expr/src/window/rank.rs rename to datafusion/functions-window/src/rank.rs index fa3d4e487f14f..dacee90bfad65 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -15,40 +15,83 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::window_expr::RankState; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::ArrayRef; -use arrow::array::{Float64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::utils::get_row_at_idx; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +//! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, +//! which can be evaluated at runtime during query execution. use std::any::Any; +use std::fmt::Debug; use std::iter; use std::ops::Range; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +use crate::define_udwf_and_expr; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::{Float64Array, UInt64Array}; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +define_udwf_and_expr!( + Rank, + rank, + "Returns rank of the current row with gaps. Same as `row_number` of its first peer", + Rank::basic +); + +define_udwf_and_expr!( + DenseRank, + dense_rank, + "Returns rank of the current row without gaps. This function counts peer groups", + Rank::dense_rank +); + +define_udwf_and_expr!( + PercentRank, + percent_rank, + "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)", + Rank::percent_rank +); /// Rank calculates the rank in the window function with order by #[derive(Debug)] pub struct Rank { name: String, + signature: Signature, rank_type: RankType, - /// Output data type - data_type: DataType, } impl Rank { - /// Get rank_type of the rank in window function with order by - pub fn get_type(&self) -> RankType { - self.rank_type + /// Create a new `rank` function with the specified name and rank type + pub fn new(name: String, rank_type: RankType) -> Self { + Self { + name, + signature: Signature::nullary(Volatility::Immutable), + rank_type, + } + } + + /// Create a `rank` window function + pub fn basic() -> Self { + Rank::new("rank".to_string(), RankType::Basic) + } + + /// Create a `dense_rank` window function + pub fn dense_rank() -> Self { + Rank::new("dense_rank".to_string(), RankType::Dense) + } + + /// Create a `percent_rank` window function + pub fn percent_rank() -> Self { + Rank::new("percent_rank".to_string(), RankType::Percent) } } @@ -59,74 +102,107 @@ pub enum RankType { Percent, } -/// Create a rank window function -pub fn rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Basic, - data_type: data_type.clone(), - } +static RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rank_doc() -> &'static Documentation { + RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_RANKING, + "Returns the rank of the current row within its partition, allowing \ + gaps between ranks. This function provides a ranking similar to `row_number`, but \ + skips ranks for identical values.", + + "rank()") + .build() + }) } -/// Create a dense rank window function -pub fn dense_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Dense, - data_type: data_type.clone(), - } +static DENSE_RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_dense_rank_doc() -> &'static Documentation { + DENSE_RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_RANKING, "Returns the rank of the current row without gaps. This function ranks \ + rows in a dense manner, meaning consecutive ranks are assigned even for identical \ + values.", "dense_rank()") + .build() + }) } -/// Create a percent rank window function -pub fn percent_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Percent, - data_type: data_type.clone(), - } +static PERCENT_RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_percent_rank_doc() -> &'static Documentation { + PERCENT_RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_RANKING, "Returns the percentage rank of the current row within its partition. \ + The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", "percent_rank()") + .build() + }) } -impl BuiltInWindowFunctionExpr for Rank { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for Rank { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - fn name(&self) -> &str { &self.name } - fn create_evaluator(&self) -> Result> { + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(RankEvaluator { state: RankState::default(), rank_type: self.rank_type, })) } - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in RANK window function (in all modes) introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = match self.rank_type { + RankType::Basic | RankType::Dense => DataType::UInt64, + RankType::Percent => DataType::Float64, + }; + + let nullable = false; + Ok(Field::new(field_args.name(), return_type, nullable)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, }) } + + fn documentation(&self) -> Option<&Documentation> { + match self.rank_type { + RankType::Basic => Some(get_rank_doc()), + RankType::Dense => Some(get_dense_rank_doc()), + RankType::Percent => Some(get_percent_rank_doc()), + } + } +} + +/// State for the RANK(rank) built-in window function. +#[derive(Debug, Clone, Default)] +pub struct RankState { + /// The last values for rank as these values change, we increase n_rank + pub last_rank_data: Option>, + /// The index where last_rank_boundary is started + pub last_rank_boundary: usize, + /// Keep the number of entries in current rank + pub current_group_count: usize, + /// Rank number kept from the start + pub n_rank: usize, } +/// State for the `rank` built-in window function. #[derive(Debug)] -pub(crate) struct RankEvaluator { +struct RankEvaluator { state: RankState, rank_type: RankType, } @@ -136,7 +212,6 @@ impl PartitionEvaluator for RankEvaluator { matches!(self.rank_type, RankType::Basic | RankType::Dense) } - /// Evaluates the window function inside the given range. fn evaluate( &mut self, values: &[ArrayRef], @@ -163,6 +238,7 @@ impl PartitionEvaluator for RankEvaluator { // data is still in the same rank self.state.current_group_count += 1; } + match self.rank_type { RankType::Basic => Ok(ScalarValue::UInt64(Some( self.state.last_rank_boundary as u64 + 1, @@ -179,8 +255,19 @@ impl PartitionEvaluator for RankEvaluator { num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - // see https://www.postgresql.org/docs/current/functions-window.html let result: ArrayRef = match self.rank_type { + RankType::Basic => Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(1_u64, |acc, range| { + let len = range.end - range.start; + let result = iter::repeat(*acc).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )), + RankType::Dense => Arc::new(UInt64Array::from_iter_values( ranks_in_partition .iter() @@ -190,9 +277,10 @@ impl PartitionEvaluator for RankEvaluator { iter::repeat(rank).take(len) }), )), + RankType::Percent => { - // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. let denominator = num_rows as f64; + Arc::new(Float64Array::from_iter_values( ranks_in_partition .iter() @@ -206,18 +294,8 @@ impl PartitionEvaluator for RankEvaluator { .flatten(), )) } - RankType::Basic => Arc::new(UInt64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(1_u64, |acc, range| { - let len = range.end - range.start; - let result = iter::repeat(*acc).take(len); - *acc += len as u64; - Some(result) - }) - .flatten(), - )), }; + Ok(result) } @@ -244,53 +322,57 @@ mod tests { test_i32_result(expr, vec![0..8], expected) } - fn test_f64_result( + fn test_i32_result( expr: &Rank, - num_rows: usize, ranks: Vec>, - expected: Vec, + expected: Vec, ) -> Result<()> { + let args = PartitionEvaluatorArgs::default(); let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; + .partition_evaluator(args)? + .evaluate_all_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); Ok(()) } - fn test_i32_result( + fn test_f64_result( expr: &Rank, + num_rows: usize, ranks: Vec>, - expected: Vec, + expected: Vec, ) -> Result<()> { - let result = expr.create_evaluator()?.evaluate_all_with_rank(8, &ranks)?; - let result = as_uint64_array(&result)?; + let args = PartitionEvaluatorArgs::default(); + let result = expr + .partition_evaluator(args)? + .evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); Ok(()) } #[test] - fn test_dense_rank() -> Result<()> { - let r = dense_rank("arr".into(), &DataType::UInt64); + fn test_rank() -> Result<()> { + let r = Rank::basic(); test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; + test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; Ok(()) } #[test] - fn test_rank() -> Result<()> { - let r = rank("arr".into(), &DataType::UInt64); + fn test_dense_rank() -> Result<()> { + let r = Rank::dense_rank(); test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; + test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; Ok(()) } #[test] #[allow(clippy::single_range_in_vec_init)] fn test_percent_rank() -> Result<()> { - let r = percent_rank("arr".into(), &DataType::Float64); + let r = Rank::percent_rank(); // empty case let expected = vec![0.0; 0]; diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index c903f6778ae83..8f462528dbedc 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `row_number` that can evaluated at runtime during query execution +//! `row_number` window function implementation use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; @@ -23,16 +23,16 @@ use datafusion_common::arrow::compute::SortOptions; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; use datafusion_expr::{ Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; use std::any::Any; use std::fmt::Debug; use std::ops::Range; -use std::sync::OnceLock; define_udwf_and_expr!( RowNumber, @@ -41,6 +41,11 @@ define_udwf_and_expr!( ); /// row_number expression +#[user_doc( + doc_section(label = "Ranking Functions"), + description = "Number of the current row within its partition, counting from 1.", + syntax_example = "row_number()" +)] #[derive(Debug)] pub struct RowNumber { signature: Signature, @@ -50,7 +55,7 @@ impl RowNumber { /// Create a new `row_number` function pub fn new() -> Self { Self { - signature: Signature::any(0, Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), } } } @@ -61,21 +66,6 @@ impl Default for RowNumber { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_row_number_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_RANKING) - .with_description( - "Number of the current row within its partition, counting from 1.", - ) - .with_syntax_example("row_number()") - .build() - .unwrap() - }) -} - impl WindowUDFImpl for RowNumber { fn as_any(&self) -> &dyn Any { self @@ -89,7 +79,10 @@ impl WindowUDFImpl for RowNumber { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::::default()) } @@ -105,7 +98,7 @@ impl WindowUDFImpl for RowNumber { } fn documentation(&self) -> Option<&Documentation> { - Some(get_row_number_doc()) + self.doc() } } @@ -162,7 +155,7 @@ mod tests { let num_rows = values.len(); let actual = RowNumber::default() - .partition_evaluator()? + .partition_evaluator(PartitionEvaluatorArgs::default())? .evaluate_all(&[values], num_rows)?; let actual = as_uint64_array(&actual)?; @@ -178,7 +171,7 @@ mod tests { let num_rows = values.len(); let actual = RowNumber::default() - .partition_evaluator()? + .partition_evaluator(PartitionEvaluatorArgs::default())? .evaluate_all(&[values], num_rows)?; let actual = as_uint64_array(&actual)?; diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs new file mode 100644 index 0000000000000..4073680515d27 --- /dev/null +++ b/datafusion/functions-window/src/utils.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +pub(crate) fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::Int64)?.try_into() +} + +pub(crate) fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Result> { + Ok(if let Some(field) = args.get(index) { + let tmp = field + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::NotImplemented( + format!("There is only support Literal types for field at idx: {index} in Window Function"), + ))? + .scalar() + .value() + .clone(); + Some(tmp) + } else { + None + }) +} + +pub(crate) fn get_unsigned_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::UInt64)?.try_into() +} diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index a3d114221d3f3..db3e6838f6a53 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["regex_expressions", "uuid"] +string_expressions = ["uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] @@ -72,8 +72,11 @@ blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true } +datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-macros = { workspace = true } hashbrown = { workspace = true, optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } @@ -89,7 +92,6 @@ uuid = { version = "1.7", features = ["v4"], optional = true } arrow = { workspace = true, features = ["test_utils"] } criterion = "0.5" rand = { workspace = true } -rstest = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "sync"] } [[bench]] @@ -117,6 +119,11 @@ harness = false name = "make_date" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "iszero" +required-features = ["math_expressions"] + [[bench]] harness = false name = "nullif" @@ -132,6 +139,16 @@ harness = false name = "to_char" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "isnan" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "signum" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" @@ -177,7 +194,32 @@ harness = false name = "character_length" required-features = ["unicode_expressions"] +[[bench]] +harness = false +name = "cot" +required-features = ["math_expressions"] + [[bench]] harness = false name = "strpos" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "reverse" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "trunc" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "initcap" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "find_in_set" +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/LICENSE.txt b/datafusion/functions/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/functions/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/functions/NOTICE.txt b/datafusion/functions/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/functions/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 17c4dd1f89125..b3fdb8dc8561b 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -84,28 +84,52 @@ fn criterion_benchmark(c: &mut Criterion) { let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); c.bench_function( &format!("character_length_StringArray_ascii_str_len_{}", str_len), - |b| b.iter(|| black_box(character_length.invoke(&args_string_ascii))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(character_length.invoke_batch(&args_string_ascii, n_rows)) + }) + }, ); // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); c.bench_function( &format!("character_length_StringArray_utf8_str_len_{}", str_len), - |b| b.iter(|| black_box(character_length.invoke(&args_string_utf8))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(character_length.invoke_batch(&args_string_utf8, n_rows)) + }) + }, ); // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); c.bench_function( &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), - |b| b.iter(|| black_box(character_length.invoke(&args_string_view_ascii))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box( + character_length.invoke_batch(&args_string_view_ascii, n_rows), + ) + }) + }, ); // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); c.bench_function( &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), - |b| b.iter(|| black_box(character_length.invoke(&args_string_view_utf8))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box( + character_length.invoke_batch(&args_string_view_utf8, n_rows), + ) + }) + }, ); } } diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index bd3bc31b0c65c..507e5a52afc9c 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -38,7 +38,10 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args(size, 32); let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { - b.iter(|| criterion::black_box(concat().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(concat().invoke_batch(&args, size).unwrap()) + }) }); group.finish(); } diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs new file mode 100644 index 0000000000000..bb0585a2de9b7 --- /dev/null +++ b/datafusion/functions/benches/cot.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::cot; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let cot_fn = cot(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("cot f32 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(cot_fn.invoke_batch(&f32_args, size).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("cot f64 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(cot_fn.invoke_batch(&f64_args, size).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7a92037ccc5da..e3ee3a13c2a67 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -19,7 +19,7 @@ extern crate criterion; use std::sync::Arc; -use arrow::array::{ArrayRef, TimestampSecondArray}; +use arrow::array::{Array, ArrayRef, TimestampSecondArray}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -40,13 +40,16 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_bin_1000", |b| { let mut rng = rand::thread_rng(); + let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; + let batch_len = timestamps_array.len(); let interval = ColumnarValue::from(ScalarValue::new_interval_dt(0, 1_000_000)); - let timestamps = ColumnarValue::Array(Arc::new(timestamps(&mut rng)) as ArrayRef); + let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_bin(); b.iter(|| { + // TODO use invoke_with_args black_box( - udf.invoke(&[interval.clone(), timestamps.clone()]) + udf.invoke_batch(&[interval.clone(), timestamps.clone()], batch_len) .expect("date_bin should work on valid values"), ) }) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index f92e88a464bcf..877b13c19a789 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -30,22 +30,36 @@ fn criterion_benchmark(c: &mut Criterion) { let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); c.bench_function(&format!("base64_decode/{size}"), |b| { let method = ColumnarValue::from(ScalarValue::from("base64")); + // TODO: use invoke_with_args let encoded = encoding::encode() - .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .invoke_batch( + &[ColumnarValue::Array(str_array.clone()), method.clone()], + size, + ) .unwrap(); let args = vec![encoded, method]; - b.iter(|| black_box(decode.invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + black_box(decode.invoke_batch(&args, size).unwrap()) + }) }); c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::from(ScalarValue::from("hex")); + // TODO use invoke_with_args let encoded = encoding::encode() - .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .invoke_batch( + &[ColumnarValue::Array(str_array.clone()), method.clone()], + size, + ) .unwrap(); let args = vec![encoded, method]; - b.iter(|| black_box(decode.invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + black_box(decode.invoke_batch(&args, size).unwrap()) + }) }); } } diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs new file mode 100644 index 0000000000000..eacae0c64ec1e --- /dev/null +++ b/datafusion/functions/benches/find_in_set.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::DataType; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distributions::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; +use std::time::Duration; + +/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with +/// 4096 rows, each row containing a string with 128 random characters. +/// around 10% of the rows are null, around 10% of the rows are non-ASCII. +fn gen_args_array( + n_rows: usize, + str_len_chars: usize, + null_density: f32, + utf8_density: f32, + is_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + let rng_ref = &mut rng; + + let num_elements = 5; // 5 elements separated by comma + let utf8 = "DataFusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes + let corpus_char_count = utf8.chars().count(); + + let mut output_set_vec: Vec> = Vec::with_capacity(n_rows); + let mut output_element_vec: Vec> = Vec::with_capacity(n_rows); + for _ in 0..n_rows { + let rand_num = rng_ref.gen::(); // [0.0, 1.0) + if rand_num < null_density { + output_element_vec.push(None); + output_set_vec.push(None); + } else if rand_num < null_density + utf8_density { + // Generate random UTF-8 string with comma separators + let mut generated_string = String::with_capacity(str_len_chars); + for i in 0..num_elements { + for _ in 0..str_len_chars { + let idx = rng_ref.gen_range(0..corpus_char_count); + let char = utf8.chars().nth(idx).unwrap(); + generated_string.push(char); + } + if i < num_elements - 1 { + generated_string.push(','); + } + } + output_element_vec.push(Some(random_element_in_set(&generated_string))); + output_set_vec.push(Some(generated_string)); + } else { + // Generate random ASCII-only string with comma separators + let mut generated_string = String::with_capacity(str_len_chars); + for i in 0..num_elements { + for _ in 0..str_len_chars { + let c = rng_ref.sample(Alphanumeric); + generated_string.push(c as char); + } + if i < num_elements - 1 { + generated_string.push(','); + } + } + output_element_vec.push(Some(random_element_in_set(&generated_string))); + output_set_vec.push(Some(generated_string)); + } + } + + if is_string_view { + let set_array: StringViewArray = output_set_vec.into_iter().collect(); + let element_array: StringViewArray = output_element_vec.into_iter().collect(); + vec![ + ColumnarValue::Array(Arc::new(element_array)), + ColumnarValue::Array(Arc::new(set_array)), + ] + } else { + let set_array: StringArray = output_set_vec.clone().into_iter().collect(); + let element_array: StringArray = output_element_vec.into_iter().collect(); + vec![ + ColumnarValue::Array(Arc::new(element_array)), + ColumnarValue::Array(Arc::new(set_array)), + ] + } +} + +fn random_element_in_set(string: &str) -> String { + let elements: Vec<&str> = string.split(',').collect(); + + if elements.is_empty() || (elements.len() == 1 && elements[0].is_empty()) { + return String::new(); + } + + let mut rng = StdRng::seed_from_u64(44); + let random_index = rng.gen_range(0..elements.len()); + + elements[random_index].to_string() +} + +fn gen_args_scalar( + n_rows: usize, + str_len_chars: usize, + null_density: f32, + is_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> Vec { + let str_list = "Apache,DataFusion,SQL,Query,Engine".to_string(); + if is_string_view { + let string = + create_string_view_array_with_len(n_rows, null_density, str_len_chars, false); + vec![ + ColumnarValue::Array(Arc::new(string)), + ColumnarValue::from(ScalarValue::Utf8(Some(str_list))), + ] + } else { + let string = + create_string_array_with_len::(n_rows, null_density, str_len_chars); + vec![ + ColumnarValue::Array(Arc::new(string)), + ColumnarValue::from(ScalarValue::Utf8(Some(str_list))), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + // All benches are single batch run with 8192 rows + let find_in_set = datafusion_functions::unicode::find_in_set(); + + let n_rows = 8192; + for str_len in [8, 32, 1024] { + let mut group = c.benchmark_group("find_in_set"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(50); + group.measurement_time(Duration::from_secs(10)); + + let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); + group.bench_function(format!("string_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); + group.bench_function(format!("string_view_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + group.finish(); + + let mut group = c.benchmark_group("find_in_set_scalar"); + + let args = gen_args_scalar(n_rows, str_len, 0.1, false); + group.bench_function(format!("string_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + let args = gen_args_scalar(n_rows, str_len, 0.1, true); + group.bench_function(format!("string_view_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs new file mode 100644 index 0000000000000..97c76831b33c8 --- /dev/null +++ b/datafusion/functions/benches/initcap.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::OffsetSizeTrait; +use arrow::datatypes::DataType; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::unicode; +use std::sync::Arc; + +fn create_args( + size: usize, + str_len: usize, + force_view_types: bool, +) -> Vec { + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.2, str_len, false)); + + vec![ColumnarValue::Array(string_array)] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + + vec![ColumnarValue::Array(string_array)] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let initcap = unicode::initcap(); + for size in [1024, 4096] { + let args = create_args::(size, 8, true); + c.bench_function( + format!("initcap string view shorter than 12 [size={}]", size).as_str(), + |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: size, + return_type: &DataType::Utf8View, + })) + }) + }, + ); + + let args = create_args::(size, 16, true); + c.bench_function( + format!("initcap string view longer than 12 [size={}]", size).as_str(), + |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: size, + return_type: &DataType::Utf8View, + })) + }) + }, + ); + + let args = create_args::(size, 16, false); + c.bench_function(format!("initcap string [size={}]", size).as_str(), |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: size, + return_type: &DataType::Utf8, + })) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs new file mode 100644 index 0000000000000..605a520715f4b --- /dev/null +++ b/datafusion/functions/benches/isnan.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::isnan; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let isnan = isnan(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("isnan f32 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(isnan.invoke_batch(&f32_args, size).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("isnan f64 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(isnan.invoke_batch(&f64_args, size).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs new file mode 100644 index 0000000000000..48fb6fbed9c38 --- /dev/null +++ b/datafusion/functions/benches/iszero.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::iszero; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let iszero = iszero(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = f32_array.len(); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("iszero f32 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(iszero.invoke_batch(&f32_args, batch_len).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = f64_array.len(); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("iszero f64 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(iszero.invoke_batch(&f64_args, batch_len).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 934c1c6bd189e..114ac4a16fe54 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -124,19 +124,32 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args1(size, 32); c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { - b.iter(|| black_box(lower.invoke(&args))) + b.iter(|| { + // TODO use invoke_with_args + black_box(lower.invoke_batch(&args, size)) + }) }); let args = create_args2(size); c.bench_function( &format!("lower_the_first_value_is_nonascii: {}", size), - |b| b.iter(|| black_box(lower.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(lower.invoke_batch(&args, size)) + }) + }, ); let args = create_args3(size); c.bench_function( &format!("lower_the_middle_value_is_nonascii: {}", size), - |b| b.iter(|| black_box(lower.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(lower.invoke_batch(&args, size)) + }) + }, ); } @@ -151,24 +164,33 @@ fn criterion_benchmark(c: &mut Criterion) { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", + &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), - |b| b.iter(|| black_box(lower.invoke(&args))), - ); + |b| b.iter(|| { + // TODO use invoke_with_args + black_box(lower.invoke_batch(&args, size)) + }), + ); let args = create_args4(size, str_len, *null_density, mixed); c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", + &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), - |b| b.iter(|| black_box(lower.invoke(&args))), - ); + |b| b.iter(|| { + // TODO use invoke_with_args + black_box(lower.invoke_batch(&args, size)) + }), + ); let args = create_args5(size, 0.1, *null_density); c.bench_function( - &format!("lower_some_values_are_nonascii_string_views: size: {}, str_len: {}, non_ascii_density: {}, null_density: {}, mixed: {}", + &format!("lower_some_values_are_nonascii_string_views: size: {}, str_len: {}, non_ascii_density: {}, null_density: {}, mixed: {}", size, str_len, 0.1, null_density, mixed), - |b| b.iter(|| black_box(lower.invoke(&args))), - ); + |b| b.iter(|| { + // TODO use invoke_with_args + black_box(lower.invoke_batch(&args, size)) + }), + ); } } } diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 9d3f98fa98226..ef5e6e56a505a 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -139,7 +139,12 @@ fn run_with_string_type( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", ), - |b| b.iter(|| black_box(ltrim.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(ltrim.invoke_batch(&args, size)) + }) + }, ); } diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index a865953897eb3..3edbfdab036b0 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -19,7 +19,7 @@ extern crate criterion; use std::sync::Arc; -use arrow::array::{ArrayRef, Int32Array}; +use arrow::array::{Array, ArrayRef, Int32Array}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::ThreadRng; use rand::Rng; @@ -57,14 +57,20 @@ fn days(rng: &mut ThreadRng) -> Int32Array { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_col_col_col_1000", |b| { let mut rng = rand::thread_rng(); - let years = ColumnarValue::Array(Arc::new(years(&mut rng)) as ArrayRef); + let years_array = Arc::new(years(&mut rng)) as ArrayRef; + let batch_len = years_array.len(); + let years = ColumnarValue::Array(years_array); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { + // TODO use invoke_with_args black_box( make_date() - .invoke(&[years.clone(), months.clone(), days.clone()]) + .invoke_batch( + &[years.clone(), months.clone(), days.clone()], + batch_len, + ) .expect("make_date should work on valid values"), ) }) @@ -73,13 +79,19 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_scalar_col_col_1000", |b| { let mut rng = rand::thread_rng(); let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); - let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); + let months_arr = Arc::new(months(&mut rng)) as ArrayRef; + let batch_len = months_arr.len(); + let months = ColumnarValue::Array(months_arr); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { + // TODO use invoke_with_args black_box( make_date() - .invoke(&[year.clone(), months.clone(), days.clone()]) + .invoke_batch( + &[year.clone(), months.clone(), days.clone()], + batch_len, + ) .expect("make_date should work on valid values"), ) }) @@ -89,12 +101,15 @@ fn criterion_benchmark(c: &mut Criterion) { let mut rng = rand::thread_rng(); let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::from(ScalarValue::Int32(Some(11))); - let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); + let day_arr = Arc::new(days(&mut rng)); + let batch_len = day_arr.len(); + let days = ColumnarValue::Array(day_arr); b.iter(|| { + // TODO use invoke_with_args black_box( make_date() - .invoke(&[year.clone(), month.clone(), days.clone()]) + .invoke_batch(&[year.clone(), month.clone(), days.clone()], batch_len) .expect("make_date should work on valid values"), ) }) @@ -106,9 +121,10 @@ fn criterion_benchmark(c: &mut Criterion) { let day = ColumnarValue::from(ScalarValue::Int32(Some(26))); b.iter(|| { + // TODO use invoke_with_args black_box( make_date() - .invoke(&[year.clone(), month.clone(), day.clone()]) + .invoke_batch(&[year.clone(), month.clone(), day.clone()], 1) .expect("make_date should work on valid values"), ) }) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 31192c1a749f9..03b009e3b451c 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -33,7 +33,10 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(array), ]; c.bench_function(&format!("nullif scalar array: {}", size), |b| { - b.iter(|| black_box(nullif.invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + black_box(nullif.invoke_batch(&args, size).unwrap()) + }) }); } } diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index 71fa68762c1e0..6f267b350a35f 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -101,17 +101,26 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, false); group.bench_function(BenchmarkId::new("utf8 type", size), |b| { - b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(lpad().invoke_batch(&args, size).unwrap()) + }) }); let args = create_args::(size, 32, false); group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { - b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(lpad().invoke_batch(&args, size).unwrap()) + }) }); let args = create_args::(size, 32, true); group.bench_function(BenchmarkId::new("stringview type", size), |b| { - b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(lpad().invoke_batch(&args, size).unwrap()) + }) }); group.finish(); @@ -120,18 +129,27 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, false); group.bench_function(BenchmarkId::new("utf8 type", size), |b| { - b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(rpad().invoke_batch(&args, size).unwrap()) + }) }); let args = create_args::(size, 32, false); group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { - b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(rpad().invoke_batch(&args, size).unwrap()) + }) }); // rpad for stringview type let args = create_args::(size, 32, true); group.bench_function(BenchmarkId::new("stringview type", size), |b| { - b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + b.iter(|| { + // TODO use invoke_with_args + criterion::black_box(rpad().invoke_batch(&args, size).unwrap()) + }) }); group.finish(); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index a721836bb68ce..bc20e0ff11c1f 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -29,7 +29,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_8192", |b| { b.iter(|| { for _ in 0..iterations { - black_box(random_func.invoke_no_args(8192).unwrap()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + black_box(random_func.invoke_batch(&[], 8192).unwrap()); } }) }); @@ -39,7 +40,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_128", |b| { b.iter(|| { for _ in 0..iterations_128 { - black_box(random_func.invoke_no_args(128).unwrap()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + black_box(random_func.invoke_batch(&[], 128).unwrap()); } }) }); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 45bfa23511281..468d3d548bcf0 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,11 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +62,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,20 +87,56 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("regexp_like_1000", |b| { + c.bench_function("regexp_count_1000 string", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; b.iter(|| { black_box( - regexp_like::(&[ + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_count_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_count_func(&[ Arc::clone(&data), Arc::clone(®ex), + Arc::clone(&start), Arc::clone(&flags), ]) - .expect("regexp_like should work on valid values"), + .expect("regexp_count should work on utf8view"), + ) + }) + }); + + c.bench_function("regexp_like_1000", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_like(&[Arc::clone(&data), Arc::clone(®ex), Arc::clone(&flags)]) + .expect("regexp_like should work on valid values"), ) }) }); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 5643ccf071331..e7e3c634ea825 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -71,7 +71,12 @@ fn criterion_benchmark(c: &mut Criterion) { "repeat_string_view [size={}, repeat_times={}]", size, repeat_times ), - |b| b.iter(|| black_box(repeat.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(repeat.invoke_batch(&args, repeat_times as usize)) + }) + }, ); let args = create_args::(size, 32, repeat_times, false); @@ -80,7 +85,12 @@ fn criterion_benchmark(c: &mut Criterion) { "repeat_string [size={}, repeat_times={}]", size, repeat_times ), - |b| b.iter(|| black_box(repeat.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(repeat.invoke_batch(&args, repeat_times as usize)) + }) + }, ); let args = create_args::(size, 32, repeat_times, false); @@ -89,7 +99,12 @@ fn criterion_benchmark(c: &mut Criterion) { "repeat_large_string [size={}, repeat_times={}]", size, repeat_times ), - |b| b.iter(|| black_box(repeat.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(repeat.invoke_batch(&args, repeat_times as usize)) + }) + }, ); group.finish(); @@ -107,7 +122,12 @@ fn criterion_benchmark(c: &mut Criterion) { "repeat_string_view [size={}, repeat_times={}]", size, repeat_times ), - |b| b.iter(|| black_box(repeat.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(repeat.invoke_batch(&args, repeat_times as usize)) + }) + }, ); let args = create_args::(size, 32, repeat_times, false); @@ -116,7 +136,12 @@ fn criterion_benchmark(c: &mut Criterion) { "repeat_string [size={}, repeat_times={}]", size, repeat_times ), - |b| b.iter(|| black_box(repeat.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(repeat.invoke_batch(&args, size)) + }) + }, ); let args = create_args::(size, 32, repeat_times, false); @@ -125,7 +150,12 @@ fn criterion_benchmark(c: &mut Criterion) { "repeat_large_string [size={}, repeat_times={}]", size, repeat_times ), - |b| b.iter(|| black_box(repeat.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(repeat.invoke_batch(&args, repeat_times as usize)) + }) + }, ); group.finish(); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs new file mode 100644 index 0000000000000..c7c1ef8a82209 --- /dev/null +++ b/datafusion/functions/benches/reverse.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::OffsetSizeTrait; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode; +use std::sync::Arc; + +fn create_args( + size: usize, + str_len: usize, + force_view_types: bool, +) -> Vec { + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + + vec![ColumnarValue::Array(string_array)] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ColumnarValue::Array(string_array)] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let reverse = unicode::reverse(); + for size in [1024, 4096] { + let str_len = 8; + + let args = create_args::(size, str_len, true); + c.bench_function( + format!("reverse_string_view [size={}, str_len={}]", size, str_len).as_str(), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(reverse.invoke_batch(&args, str_len)) + }) + }, + ); + + let str_len = 32; + + let args = create_args::(size, str_len, true); + c.bench_function( + format!("reverse_string_view [size={}, str_len={}]", size, str_len).as_str(), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(reverse.invoke_batch(&args, str_len)) + }) + }, + ); + + let args = create_args::(size, str_len, false); + c.bench_function( + format!("reverse_string [size={}, str_len={}]", size, str_len).as_str(), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(reverse.invoke_batch(&args, str_len)) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs new file mode 100644 index 0000000000000..a51b2ebe5ab7b --- /dev/null +++ b/datafusion/functions/benches/signum.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::signum; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let signum = signum(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = f32_array.len(); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("signum f32 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(signum.invoke_batch(&f32_args, batch_len).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = f64_array.len(); + + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("signum f64 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(signum.invoke_batch(&f64_args, batch_len).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index c78e69826836b..f4962380dfbf1 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -112,28 +112,48 @@ fn criterion_benchmark(c: &mut Criterion) { let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); c.bench_function( &format!("strpos_StringArray_ascii_str_len_{}", str_len), - |b| b.iter(|| black_box(strpos.invoke(&args_string_ascii))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(strpos.invoke_batch(&args_string_ascii, n_rows)) + }) + }, ); // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); c.bench_function( &format!("strpos_StringArray_utf8_str_len_{}", str_len), - |b| b.iter(|| black_box(strpos.invoke(&args_string_utf8))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(strpos.invoke_batch(&args_string_utf8, n_rows)) + }) + }, ); // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); c.bench_function( &format!("strpos_StringViewArray_ascii_str_len_{}", str_len), - |b| b.iter(|| black_box(strpos.invoke(&args_string_view_ascii))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(strpos.invoke_batch(&args_string_view_ascii, n_rows)) + }) + }, ); // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); c.bench_function( &format!("strpos_StringViewArray_utf8_str_len_{}", str_len), - |b| b.iter(|| black_box(strpos.invoke(&args_string_view_utf8))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(strpos.invoke_batch(&args_string_view_utf8, n_rows)) + }) + }, ); } } diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 90ba75c1e8a51..8b8e8dbc42790 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -107,19 +107,34 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_without_count::(size, len, true, true); group.bench_function( format!("substr_string_view [size={}, strlen={}]", size, len), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); let args = create_args_without_count::(size, len, false, false); group.bench_function( format!("substr_string [size={}, strlen={}]", size, len), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); let args = create_args_without_count::(size, len, true, false); group.bench_function( format!("substr_large_string [size={}, strlen={}]", size, len), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); group.finish(); @@ -137,7 +152,12 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string_view [size={}, count={}, strlen={}]", size, count, len, ), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); let args = create_args_with_count::(size, len, count, false); @@ -146,7 +166,12 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); let args = create_args_with_count::(size, len, count, false); @@ -155,7 +180,12 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_large_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); group.finish(); @@ -173,7 +203,12 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string_view [size={}, count={}, strlen={}]", size, count, len, ), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); let args = create_args_with_count::(size, len, count, false); @@ -182,7 +217,12 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); let args = create_args_with_count::(size, len, count, false); @@ -191,7 +231,12 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_large_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| b.iter(|| black_box(substr.invoke(&args))), + |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(substr.invoke_batch(&args, size)) + }) + }, ); group.finish(); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index bb9a5b809eee4..1ea8e2606f0d7 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -84,15 +84,17 @@ fn data() -> (StringArray, StringArray, Int64Array) { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("substr_index_array_array_1000", |b| { let (strings, delimiters, counts) = data(); + let batch_len = counts.len(); let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = [strings, delimiters, counts]; b.iter(|| { + #[allow(deprecated)] // TODO: invoke_with_args black_box( substr_index() - .invoke(&args) + .invoke_batch(&args, batch_len) .expect("substr_index should work on valid values"), ) }) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 14819c90a7a5a..c62dd5af20b08 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -82,13 +82,16 @@ fn patterns(rng: &mut ThreadRng) -> StringArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_array_1000", |b| { let mut rng = rand::thread_rng(); - let data = ColumnarValue::Array(Arc::new(data(&mut rng)) as ArrayRef); + let data_arr = data(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Array(Arc::new(patterns(&mut rng)) as ArrayRef); b.iter(|| { + // TODO use invoke_with_args black_box( to_char() - .invoke(&[data.clone(), patterns.clone()]) + .invoke_batch(&[data.clone(), patterns.clone()], batch_len) .expect("to_char should work on valid values"), ) }) @@ -96,14 +99,17 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_scalar_1000", |b| { let mut rng = rand::thread_rng(); - let data = ColumnarValue::Array(Arc::new(data(&mut rng)) as ArrayRef); + let data_arr = data(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::from(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); b.iter(|| { + // TODO use invoke_with_args black_box( to_char() - .invoke(&[data.clone(), patterns.clone()]) + .invoke_batch(&[data.clone(), patterns.clone()], batch_len) .expect("to_char should work on valid values"), ) }) @@ -123,9 +129,10 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::from(ScalarValue::Utf8(Some("%d-%m-%Y %H:%M:%S".to_string()))); b.iter(|| { + // TODO use invoke_with_args black_box( to_char() - .invoke(&[data.clone(), pattern.clone()]) + .invoke_batch(&[data.clone(), pattern.clone()], 1) .expect("to_char should work on valid values"), ) }) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index e734b6832f29c..9f5f6661f998a 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -20,103 +20,212 @@ extern crate criterion; use std::sync::Arc; use arrow::array::builder::StringBuilder; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ColumnarValue; use datafusion_functions::datetime::to_timestamp; +fn data() -> StringArray { + let data: Vec<&str> = vec![ + "1997-01-31T09:26:56.123Z", + "1997-01-31T09:26:56.123-05:00", + "1997-01-31 09:26:56.123-05:00", + "2023-01-01 04:05:06.789 -08", + "1997-01-31T09:26:56.123", + "1997-01-31 09:26:56.123", + "1997-01-31 09:26:56", + "1997-01-31 13:26:56", + "1997-01-31 13:26:56+04:00", + "1997-01-31", + ]; + + StringArray::from(data) +} + +fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { + let mut inputs = StringBuilder::new(); + let mut format1_builder = StringBuilder::with_capacity(2, 10); + let mut format2_builder = StringBuilder::with_capacity(2, 10); + let mut format3_builder = StringBuilder::with_capacity(2, 10); + + inputs.append_value("1997-01-31T09:26:56.123Z"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); + + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); + + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); + + inputs.append_value("2023-01-01 04:05:06.789 -08"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); + + inputs.append_value("1997-01-31T09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S"); + + inputs.append_value("1997-01-31 092656"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S"); + + inputs.append_value("1997-01-31 092656+04:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); + + inputs.append_value("Sun Jul 8 00:34:60 2001"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d 00:00:00"); + + ( + inputs.finish(), + format1_builder.finish(), + format2_builder.finish(), + format3_builder.finish(), + ) +} fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("to_timestamp_no_formats", |b| { - let mut inputs = StringBuilder::new(); - inputs.append_value("1997-01-31T09:26:56.123Z"); - inputs.append_value("1997-01-31T09:26:56.123-05:00"); - inputs.append_value("1997-01-31 09:26:56.123-05:00"); - inputs.append_value("2023-01-01 04:05:06.789 -08"); - inputs.append_value("1997-01-31T09:26:56.123"); - inputs.append_value("1997-01-31 09:26:56.123"); - inputs.append_value("1997-01-31 09:26:56"); - inputs.append_value("1997-01-31 13:26:56"); - inputs.append_value("1997-01-31 13:26:56+04:00"); - inputs.append_value("1997-01-31"); - - let string_array = ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef); + c.bench_function("to_timestamp_no_formats_utf8", |b| { + let arr_data = data(); + let batch_len = arr_data.len(); + let string_array = ColumnarValue::Array(Arc::new(arr_data) as ArrayRef); + + b.iter(|| { + // TODO use invoke_with_args + black_box( + to_timestamp() + .invoke_batch(&[string_array.clone()], batch_len) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_no_formats_largeutf8", |b| { + let data = cast(&data(), &DataType::LargeUtf8).unwrap(); + let batch_len = data.len(); + let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); + + b.iter(|| { + // TODO use invoke_with_args + black_box( + to_timestamp() + .invoke_batch(&[string_array.clone()], batch_len) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_no_formats_utf8view", |b| { + let data = cast(&data(), &DataType::Utf8View).unwrap(); + let batch_len = data.len(); + let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); + + b.iter(|| { + // TODO use invoke_with_args + black_box( + to_timestamp() + .invoke_batch(&[string_array.clone()], batch_len) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_with_formats_utf8", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + let batch_len = inputs.len(); + + let args = [ + ColumnarValue::Array(Arc::new(inputs) as ArrayRef), + ColumnarValue::Array(Arc::new(format1) as ArrayRef), + ColumnarValue::Array(Arc::new(format2) as ArrayRef), + ColumnarValue::Array(Arc::new(format3) as ArrayRef), + ]; + b.iter(|| { + // TODO use invoke_with_args + black_box( + to_timestamp() + .invoke_batch(&args.clone(), batch_len) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_with_formats_largeutf8", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + let batch_len = inputs.len(); + let args = [ + ColumnarValue::Array( + Arc::new(cast(&inputs, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format1, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format2, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ]; b.iter(|| { + // TODO use invoke_with_args black_box( to_timestamp() - .invoke(&[string_array.clone()]) + .invoke_batch(&args.clone(), batch_len) .expect("to_timestamp should work on valid values"), ) }) }); - c.bench_function("to_timestamp_with_formats", |b| { - let mut inputs = StringBuilder::new(); - let mut format1_builder = StringBuilder::with_capacity(2, 10); - let mut format2_builder = StringBuilder::with_capacity(2, 10); - let mut format3_builder = StringBuilder::with_capacity(2, 10); - - inputs.append_value("1997-01-31T09:26:56.123Z"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); - - inputs.append_value("1997-01-31T09:26:56.123-05:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); - - inputs.append_value("1997-01-31 09:26:56.123-05:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); - - inputs.append_value("2023-01-01 04:05:06.789 -08"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); - - inputs.append_value("1997-01-31T09:26:56.123"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); - - inputs.append_value("1997-01-31 09:26:56.123"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); - - inputs.append_value("1997-01-31 09:26:56"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S"); - - inputs.append_value("1997-01-31 092656"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H%M%S"); - - inputs.append_value("1997-01-31 092656+04:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); - - inputs.append_value("Sun Jul 8 00:34:60 2001"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d 00:00:00"); + c.bench_function("to_timestamp_with_formats_utf8view", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + + let batch_len = inputs.len(); let args = [ - ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ColumnarValue::Array( + Arc::new(cast(&inputs, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format1, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format2, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef + ), ]; b.iter(|| { + // TODO use invoke_with_args black_box( to_timestamp() - .invoke(&args.clone()) + .invoke_batch(&args.clone(), batch_len) .expect("to_timestamp should work on valid values"), ) }) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs new file mode 100644 index 0000000000000..83d5b761e8097 --- /dev/null +++ b/datafusion/functions/benches/trunc.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::trunc; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let trunc = trunc(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("trunc f32 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(trunc.invoke_batch(&f32_args, size).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("trunc f64 array: {}", size), |b| { + b.iter(|| { + // TODO use invoke_with_args + black_box(trunc.invoke_batch(&f64_args, size).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index a3e5fbd7a4332..9b41a15b11c70 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -37,7 +37,10 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); c.bench_function("upper_all_values_are_ascii", |b| { - b.iter(|| black_box(upper.invoke(&args))) + b.iter(|| { + // TODO use invoke_with_args + black_box(upper.invoke_batch(&args, size)) + }) }); } } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 0b6f506d447ff..10e5f0fb30564 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,18 +17,19 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use std::any::Any; - use arrow::datatypes::DataType; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, ScalarValue, }; +use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Expr, ExprSchemable, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature, + Volatility, }; +use datafusion_macros::user_doc; /// Implements casting to arbitrary arrow types (rather than SQL types) /// @@ -51,6 +52,31 @@ use datafusion_expr::{ /// ```sql /// select arrow_cast(column_x, 'Float64') /// ``` +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts a value to a specific Arrow data type.", + syntax_example = "arrow_cast(expression, datatype)", + sql_example = r#"```sql +> select arrow_cast(-5, 'Int8') as a, + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, + arrow_cast('bar', 'LargeUtf8') as c, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d + ; ++----+-----+-----+---------------------------+ +| a | b | c | d | ++----+-----+-----+---------------------------+ +| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | ++----+-----+-----+---------------------------+ +```"#, + argument( + name = "expression", + description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "datatype", + description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]" + ) +)] #[derive(Debug)] pub struct ArrowCastFunc { signature: Signature, @@ -102,7 +128,11 @@ impl ScalarUDFImpl for ArrowCastFunc { data_type_from_args(args) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { internal_err!("arrow_cast should have been simplified to cast") } @@ -131,6 +161,10 @@ impl ScalarUDFImpl for ArrowCastFunc { // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Returns the requested type from the arguments diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index dd502fb2686dd..b48192dbfc61c 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -17,10 +17,29 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Other Functions"), + description = "Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.", + syntax_example = "arrow_typeof(expression)", + sql_example = r#"```sql +> select arrow_typeof('foo'), arrow_typeof(1); ++---------------------------+------------------------+ +| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | ++---------------------------+------------------------+ +| Utf8 | Int64 | ++---------------------------+------------------------+ +``` +"#, + argument( + name = "expression", + description = "Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators." + ) +)] #[derive(Debug)] pub struct ArrowTypeOfFunc { signature: Signature, @@ -56,7 +75,11 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { Ok(DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 1 { return exec_err!( "arrow_typeof function requires 1 arguments, got {}", @@ -69,4 +92,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { "{input_data_type}" )))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 8155c04da6263..7b97490311430 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -20,14 +20,30 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, ExprSchema, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; -use datafusion_expr::type_coercion::binary::type_union_resolution; +use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use itertools::Itertools; use std::any::Any; -use std::sync::OnceLock; +#[user_doc( + doc_section(label = "Conditional Functions"), + description = "Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values.", + syntax_example = "coalesce(expression1[, ..., expression_n])", + sql_example = r#"```sql +> select coalesce(null, null, 'datafusion'); ++----------------------------------------+ +| coalesce(NULL,NULL,Utf8("datafusion")) | ++----------------------------------------+ +| datafusion | ++----------------------------------------+ +```"#, + argument( + name = "expression1, expression_n", + description = "Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + ) +)] #[derive(Debug)] pub struct CoalesceFunc { signature: Signature, @@ -47,23 +63,6 @@ impl CoalesceFunc { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_coalesce_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_CONDITIONAL) - .with_description("Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values.") - .with_syntax_example("coalesce(expression1[, ..., expression_n])") - .with_argument( - "expression1, expression_n", - "Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." - ) - .build() - .unwrap() - }) -} - impl ScalarUDFImpl for CoalesceFunc { fn as_any(&self) -> &dyn Any { self @@ -91,7 +90,11 @@ impl ScalarUDFImpl for CoalesceFunc { } /// coalesce evaluates to the first value which is not NULL - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { // do not accept 0 arguments. if args.is_empty() { return exec_err!( @@ -154,13 +157,12 @@ impl ScalarUDFImpl for CoalesceFunc { if arg_types.is_empty() { return exec_err!("coalesce must have at least one argument"); } - let new_type = type_union_resolution(arg_types) - .unwrap_or(arg_types.first().unwrap().clone()); - Ok(vec![new_type; arg_types.len()]) + + try_type_union_resolution(arg_types) } fn documentation(&self) -> Option<&Documentation> { - Some(get_coalesce_doc()) + self.doc() } } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 0ff6067dd37eb..dfd05f448c4e0 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -23,11 +23,52 @@ use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, }; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; +#[user_doc( + doc_section(label = "Other Functions"), + description = r#"Returns a field within a map or a struct with the given key. + Note: most users invoke `get_field` indirectly via field access + syntax such as `my_struct_col['field_name']` which results in a call to + `get_field(my_struct_col, 'field_name')`."#, + syntax_example = "get_field(expression1, expression2)", + sql_example = r#"```sql +> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); +> select struct(idx, v) from t as c; ++-------------------------+ +| struct(c.idx,c.v) | ++-------------------------+ +| {c0: data, c1: fusion} | +| {c0: apache, c1: arrow} | ++-------------------------+ +> select get_field((select struct(idx, v) from t), 'c0'); ++-----------------------+ +| struct(t.idx,t.v)[c0] | ++-----------------------+ +| data | +| apache | ++-----------------------+ +> select get_field((select struct(idx, v) from t), 'c1'); ++-----------------------+ +| struct(t.idx,t.v)[c1] | ++-----------------------+ +| fusion | +| arrow | ++-----------------------+ +```"#, + argument( + name = "expression1", + description = "The map or struct to retrieve a field for." + ), + argument( + name = "expression2", + description = "The field name in the map or struct to retrieve data for. Must evaluate to a string." + ) +)] #[derive(Debug)] pub struct GetFieldFunc { signature: Signature, @@ -133,7 +174,7 @@ impl ScalarUDFImpl for GetFieldFunc { DataType::Struct(fields) if fields.len() == 2 => { // Arrow's MapArray is essentially a ListArray of structs with two columns. They are // often named "key", and "value", but we don't require any specific naming here; - // instead, we assume that the second columnis the "value" column both here and in + // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); Ok(value_field.data_type().clone()) @@ -155,11 +196,15 @@ impl ScalarUDFImpl for GetFieldFunc { "Only UTF8 strings are valid as an indexed field in a struct" ), (DataType::Null, _) => Ok(DataType::Null), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, `Map` or `Null` types, got {other}"), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 2 { return exec_err!( "get_field function requires 2 arguments, got {}", @@ -190,7 +235,7 @@ impl ScalarUDFImpl for GetFieldFunc { let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; // note that this array has more entries than the expected output/input size - // because maparray is flatten + // because map_array is flattened let original_data = map_array.entries().column(1).to_data(); let capacity = Capacities::Array(original_data.len()); let mut mutable = @@ -205,7 +250,7 @@ impl ScalarUDFImpl for GetFieldFunc { keys.slice(start, end-start). iter().enumerate(). find(|(_, t)| t.unwrap()); - if maybe_matched.is_none(){ + if maybe_matched.is_none() { mutable.extend_nulls(1); continue } @@ -224,14 +269,18 @@ impl ScalarUDFImpl for GetFieldFunc { } } (DataType::Struct(_), name) => exec_err!( - "get indexed field is only possible on struct with utf8 indexes. \ - Tried with {name:?} index" + "get_field is only possible on struct with utf8 indexes. \ + Received with {name:?} index" ), (DataType::Null, _) => Ok(ColumnarValue::from(ScalarValue::Null)), (dt, name) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {name:?} index" + "get_field is only possible on maps with utf8 indexes or struct \ + with utf8 indexes. Received {dt:?} with {name:?} index" ), } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs new file mode 100644 index 0000000000000..a20caa34c6f8b --- /dev/null +++ b/datafusion/functions/src/core/greatest.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::core::greatest_least_utils::GreatestLeastOperator; +use arrow::array::{make_comparator, Array, BooleanArray}; +use arrow::compute::kernels::cmp; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use arrow_buffer::BooleanBuffer; +use datafusion_common::{internal_err, Result}; +use datafusion_doc::Documentation; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::scalar::Scalar; +use datafusion_macros::user_doc; +use std::any::Any; + +const SORT_OPTIONS: SortOptions = SortOptions { + // We want greatest first + descending: false, + + // NULL will be less than any other value + nulls_first: true, +}; + +#[user_doc( + doc_section(label = "Conditional Functions"), + description = "Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.", + syntax_example = "greatest(expression1[, ..., expression_n])", + sql_example = r#"```sql +> select greatest(4, 7, 5); ++---------------------------+ +| greatest(4,7,5) | ++---------------------------+ +| 7 | ++---------------------------+ +```"#, + argument( + name = "expression1, expression_n", + description = "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + ) +)] +#[derive(Debug)] +pub struct GreatestFunc { + signature: Signature, +} + +impl Default for GreatestFunc { + fn default() -> Self { + GreatestFunc::new() + } +} + +impl GreatestFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl GreatestLeastOperator for GreatestFunc { + const NAME: &'static str = "greatest"; + + fn keep_scalar<'a>(lhs: &'a Scalar, rhs: &'a Scalar) -> Result<&'a Scalar> { + if !lhs.data_type().is_nested() { + return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; + } + + // If complex type we can't compare directly as we want null values to be smaller + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; + + if cmp(0, 0).is_ge() { + Ok(lhs) + } else { + Ok(rhs) + } + } + + /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array + /// Nulls are always considered smaller than any other value + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorized kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && lhs.logical_null_count() == 0 + && rhs.logical_null_count() == 0 + { + return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); + } + + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + + if lhs.len() != rhs.len() { + return internal_err!( + "All arrays should have the same length for greatest comparison" + ); + } + + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); + + // No nulls as we only want to keep the values that are larger, its either true or false + Ok(BooleanArray::new(values, None)) + } +} + +impl ScalarUDFImpl for GreatestFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "greatest" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + super::greatest_least_utils::execute_conditional::(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let coerced_type = + super::greatest_least_utils::find_coerced_type::(arg_types)?; + + Ok(vec![coerced_type; arg_types.len()]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod test { + use crate::core; + use arrow::datatypes::DataType; + use datafusion_expr::ScalarUDFImpl; + + #[test] + fn test_greatest_return_types_without_common_supertype_in_arg_type() { + let greatest = core::greatest::GreatestFunc::new(); + let return_type = greatest + .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)]) + .unwrap(); + assert_eq!( + return_type, + vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)] + ); + } +} diff --git a/datafusion/functions/src/core/greatest_least_utils.rs b/datafusion/functions/src/core/greatest_least_utils.rs new file mode 100644 index 0000000000000..1a55cdf2d4649 --- /dev/null +++ b/datafusion/functions/src/core/greatest_least_utils.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BooleanArray}; +use arrow::compute::kernels::zip::zip; +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::scalar::Scalar; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; +use std::sync::Arc; + +pub(super) trait GreatestLeastOperator { + const NAME: &'static str; + + fn keep_scalar<'a>(lhs: &'a Scalar, rhs: &'a Scalar) -> Result<&'a Scalar>; + + /// Return array with true for values that we should keep from the lhs array + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result; +} + +fn keep_array( + lhs: ArrayRef, + rhs: ArrayRef, +) -> Result { + // True for values that we should keep from the left array + let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?; + + let result = zip(&keep_lhs, &lhs, &rhs)?; + + Ok(result) +} + +pub(super) fn execute_conditional( + args: &[ColumnarValue], +) -> Result { + if args.is_empty() { + return internal_err!( + "{} was called with no arguments. It requires at least 1.", + Op::NAME + ); + } + + // Some engines (e.g. SQL Server) allow greatest/least with single arg, it's a noop + if args.len() == 1 { + return Ok(args[0].clone()); + } + + // Split to scalars and arrays for later optimization + let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { + ColumnarValue::Scalar(_) => true, + ColumnarValue::Array(_) => false, + }); + + let mut arrays_iter = arrays.iter().map(|x| match x { + ColumnarValue::Array(a) => a, + _ => unreachable!(), + }); + + let first_array = arrays_iter.next(); + + let mut result: ArrayRef; + + // Optimization: merge all scalars into one to avoid recomputing (constant folding) + if !scalars.is_empty() { + let mut scalars_iter = scalars.iter().map(|x| match x { + ColumnarValue::Scalar(s) => s, + _ => unreachable!(), + }); + + // We have at least one scalar + let mut result_scalar = scalars_iter.next().unwrap(); + + for scalar in scalars_iter { + result_scalar = Op::keep_scalar(result_scalar, scalar)?; + } + + // If we only have scalars, return the one that we should keep (largest/least) + if arrays.is_empty() { + return Ok(ColumnarValue::Scalar(result_scalar.clone())); + } + + // We have at least one array + let first_array = first_array.unwrap(); + + // Start with the result value + result = keep_array::( + Arc::clone(first_array), + result_scalar.to_array_of_size(first_array.len())?, + )?; + } else { + // If we only have arrays, start with the first array + // (We must have at least one array) + result = Arc::clone(first_array.unwrap()); + } + + for array in arrays_iter { + result = keep_array::(Arc::clone(array), result)?; + } + + Ok(ColumnarValue::Array(result)) +} + +pub(super) fn find_coerced_type( + data_types: &[DataType], +) -> Result { + if data_types.is_empty() { + plan_err!( + "{} was called without any arguments. It requires at least 1.", + Op::NAME + ) + } else if let Some(coerced_type) = type_union_resolution(data_types) { + Ok(coerced_type) + } else { + plan_err!("Cannot find a common type for arguments") + } +} diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs new file mode 100644 index 0000000000000..78694eab3077a --- /dev/null +++ b/datafusion/functions/src/core/least.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::core::greatest_least_utils::GreatestLeastOperator; +use arrow::array::{make_comparator, Array, BooleanArray}; +use arrow::compute::kernels::cmp; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use arrow_buffer::BooleanBuffer; +use datafusion_common::{internal_err, Result}; +use datafusion_doc::Documentation; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::scalar::Scalar; +use datafusion_macros::user_doc; +use std::any::Any; + +const SORT_OPTIONS: SortOptions = SortOptions { + // Having the smallest result first + descending: false, + + // NULL will be greater than any other value + nulls_first: false, +}; + +#[user_doc( + doc_section(label = "Conditional Functions"), + description = "Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_.", + syntax_example = "least(expression1[, ..., expression_n])", + sql_example = r#"```sql +> select least(4, 7, 5); ++---------------------------+ +| least(4,7,5) | ++---------------------------+ +| 4 | ++---------------------------+ +```"#, + argument( + name = "expression1, expression_n", + description = "Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + ) +)] +#[derive(Debug)] +pub struct LeastFunc { + signature: Signature, +} + +impl Default for LeastFunc { + fn default() -> Self { + LeastFunc::new() + } +} + +impl LeastFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl GreatestLeastOperator for LeastFunc { + const NAME: &'static str = "least"; + + fn keep_scalar<'a>(lhs: &'a Scalar, rhs: &'a Scalar) -> Result<&'a Scalar> { + // Manual checking for nulls as: + // 1. If we're going to use <=, in Rust None is smaller than Some(T), which we don't want + // 2. And we can't use make_comparator as it has no natural order (Arrow error) + if lhs.is_null() { + return Ok(rhs); + } + + if rhs.is_null() { + return Ok(lhs); + } + + if !lhs.data_type().is_nested() { + return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) }; + } + + // Not using <= as in Rust None is smaller than Some(T) + + // If complex type we can't compare directly as we want null values to be larger + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; + + if cmp(0, 0).is_le() { + Ok(lhs) + } else { + Ok(rhs) + } + } + + /// Return boolean array where `arr[i] = lhs[i] <= rhs[i]` for all i, where `arr` is the result array + /// Nulls are always considered larger than any other value + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorized kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::lt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && lhs.logical_null_count() == 0 + && rhs.logical_null_count() == 0 + { + return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into()); + } + + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + + if lhs.len() != rhs.len() { + return internal_err!( + "All arrays should have the same length for least comparison" + ); + } + + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le()); + + // No nulls as we only want to keep the values that are smaller, its either true or false + Ok(BooleanArray::new(values, None)) + } +} + +impl ScalarUDFImpl for LeastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "least" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + super::greatest_least_utils::execute_conditional::(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let coerced_type = + super::greatest_least_utils::find_coerced_type::(arg_types)?; + + Ok(vec![coerced_type; arg_types.len()]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod test { + use crate::core::least::LeastFunc; + use arrow::datatypes::DataType; + use datafusion_expr::ScalarUDFImpl; + + #[test] + fn test_least_return_types_without_common_supertype_in_arg_type() { + let least = LeastFunc::new(); + let return_type = least + .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)]) + .unwrap(); + assert_eq!( + return_type, + vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)] + ); + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 1c69f9c9b2f37..76fb4bbe5b474 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -25,6 +25,9 @@ pub mod arrowtypeof; pub mod coalesce; pub mod expr_ext; pub mod getfield; +pub mod greatest; +mod greatest_least_utils; +pub mod least; pub mod named_struct; pub mod nullif; pub mod nvl; @@ -34,16 +37,18 @@ pub mod r#struct; pub mod version; // create UDFs -make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); -make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); -make_udf_function!(nvl::NVLFunc, NVL, nvl); -make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); -make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); -make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); -make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); -make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); -make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(version::VersionFunc, VERSION, version); +make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); +make_udf_function!(nullif::NullIfFunc, nullif); +make_udf_function!(nvl::NVLFunc, nvl); +make_udf_function!(nvl2::NVL2Func, nvl2); +make_udf_function!(arrowtypeof::ArrowTypeOfFunc, arrow_typeof); +make_udf_function!(r#struct::StructFunc, r#struct); +make_udf_function!(named_struct::NamedStructFunc, named_struct); +make_udf_function!(getfield::GetFieldFunc, get_field); +make_udf_function!(coalesce::CoalesceFunc, coalesce); +make_udf_function!(greatest::GreatestFunc, greatest); +make_udf_function!(least::LeastFunc, least); +make_udf_function!(version::VersionFunc, version); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -80,6 +85,14 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, + ),( + greatest, + "Returns `greatest(args...)`, which evaluates to the greatest value in the list of expressions or NULL if all the expressions are NULL", + args, + ),( + least, + "Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL", + args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -102,10 +115,13 @@ pub fn functions() -> Vec> { // `get_field(my_struct_col, "field_name")`. // // However, it is also exposed directly for use cases such as - // serializing / deserializing plans with the field access desugared to - // calls to `get_field` + // serializing / deserializing plans with the field access desugared to + // calls to [`get_field`] get_field(), coalesce(), + greatest(), + least(), version(), + r#struct(), ] } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index b367f20ca1262..8e68a8e602287 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -17,16 +17,16 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use hashbrown::HashSet; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; -/// put values in a struct array. +/// Put values in a struct array. fn named_struct_expr(args: &[ColumnarValue]) -> Result { - // do not accept 0 arguments. + // Do not accept 0 arguments. if args.is_empty() { return exec_err!( "named_struct requires at least one pair of arguments, got 0 instead" @@ -86,6 +86,38 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(struct_array))) } +#[user_doc( + doc_section(label = "Struct Functions"), + description = "Returns an Arrow struct using the specified name and input expressions pairs.", + syntax_example = "named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input])", + sql_example = r#" +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ +> select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +```"#, + argument( + name = "expression_n_name", + description = "Name of the column field. Must be a constant string." + ), + argument( + name = "expression_n_input", + description = "Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators." + ) +)] #[derive(Debug)] pub struct NamedStructFunc { signature: Signature, @@ -126,7 +158,7 @@ impl ScalarUDFImpl for NamedStructFunc { fn return_type_from_exprs( &self, - args: &[datafusion_expr::Expr], + args: &[Expr], schema: &dyn datafusion_common::ExprSchema, _arg_types: &[DataType], ) -> Result { @@ -165,7 +197,15 @@ impl ScalarUDFImpl for NamedStructFunc { Ok(DataType::Struct(Fields::from(return_fields))) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { named_struct_expr(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 9f95392c37bff..67b1bab9be0c7 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -17,38 +17,47 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Documentation}; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; use datafusion_common::ScalarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; - +#[user_doc( + doc_section(label = "Conditional Functions"), + description = "Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. +This can be used to perform the inverse operation of [`coalesce`](#coalesce).", + syntax_example = "nullif(expression1, expression2)", + sql_example = r#"```sql +> select nullif('datafusion', 'data'); ++-----------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("data")) | ++-----------------------------------------+ +| datafusion | ++-----------------------------------------+ +> select nullif('datafusion', 'datafusion'); ++-----------------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("datafusion")) | ++-----------------------------------------------+ +| | ++-----------------------------------------------+ +```"#, + argument( + name = "expression1", + description = "Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "expression2", + description = "Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators." + ) +)] #[derive(Debug)] pub struct NullIfFunc { signature: Signature, } -/// Currently supported types by the nullif function. -/// The order of these types correspond to the order on which coercion applies -/// This should thus be from least informative to most informative -static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; - impl Default for NullIfFunc { fn default() -> Self { Self::new() @@ -58,11 +67,20 @@ impl Default for NullIfFunc { impl NullIfFunc { pub fn new() -> Self { Self { - signature: Signature::uniform( - 2, - SUPPORTED_NULLIF_TYPES.to_vec(), - Volatility::Immutable, - ), + // Documentation mentioned in Postgres, + // The result has the same type as the first argument — but there is a subtlety. + // What is actually returned is the first argument of the implied = operator, + // and in some cases that will have been promoted to match the second argument's type. + // For example, NULLIF(1, 2.2) yields numeric, because there is no integer = numeric operator, only numeric = numeric + // + // We don't strictly follow Postgres or DuckDB for **simplicity**. + // In this function, we will coerce arguments to the same data type for comparison need. Unlike DuckDB + // we don't return the **original** first argument type but return the final coerced type. + // + // In Postgres, nullif('2', 2) returns Null but nullif('2::varchar', 2) returns error. + // While in DuckDB both query returns Null. We follow DuckDB in this case since I think they are equivalent thing and should + // have the same result as well. + signature: Signature::comparable(2, Volatility::Immutable), } } } @@ -80,19 +98,20 @@ impl ScalarUDFImpl for NullIfFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - // NULLIF has two args and they might get coerced, get a preview of this - let coerced_types = datafusion_expr::type_coercion::functions::data_types( - arg_types, - &self.signature, - ); - coerced_types - .map(|typs| typs[0].clone()) - .map_err(|e| e.context("Failed to coerce arguments for NULLIF")) + Ok(arg_types[0].to_owned()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { nullif_func(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Implements NULLIF(expr1, expr2) @@ -186,7 +205,7 @@ mod tests { #[test] // Ensure that arrays with no nulls can also invoke NULLIF() correctly - fn nullif_int32_nonulls() -> Result<()> { + fn nullif_int32_non_nulls() -> Result<()> { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); let a = ColumnarValue::Array(Arc::new(a)); diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 09d86ddfc0d90..075f58c195e02 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -20,9 +20,39 @@ use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::sync::Arc; - +#[user_doc( + doc_section(label = "Conditional Functions"), + description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.", + syntax_example = "nvl(expression1, expression2)", + sql_example = r#"```sql +> select nvl(null, 'a'); ++---------------------+ +| nvl(NULL,Utf8("a")) | ++---------------------+ +| a | ++---------------------+\ +> select nvl('b', 'a'); ++--------------------------+ +| nvl(Utf8("b"),Utf8("a")) | ++--------------------------+ +| b | ++--------------------------+ +``` +"#, + argument( + name = "expression1", + description = "Expression to return if not null. Can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "expression2", + description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." + ) +)] #[derive(Debug)] pub struct NVLFunc { signature: Signature, @@ -44,6 +74,7 @@ static SUPPORTED_NVL_TYPES: &[DataType] = &[ DataType::Int64, DataType::Float32, DataType::Float64, + DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8, ]; @@ -84,13 +115,21 @@ impl ScalarUDFImpl for NVLFunc { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { nvl_func(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn nvl_func(args: &[ColumnarValue]) -> Result { @@ -169,7 +208,7 @@ mod tests { #[test] // Ensure that arrays with no nulls can also invoke nvl() correctly - fn nvl_int32_nonulls() -> Result<()> { + fn nvl_int32_non_nulls() -> Result<()> { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); let a = ColumnarValue::Array(Arc::new(a)); diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index f3027925b26a1..a00dc637f7e37 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -21,11 +21,44 @@ use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ - type_coercion::binary::comparison_coercion, ColumnarValue, ScalarUDFImpl, Signature, - Volatility, + type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use std::sync::Arc; +#[user_doc( + doc_section(label = "Conditional Functions"), + description = "Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.", + syntax_example = "nvl2(expression1, expression2, expression3)", + sql_example = r#"```sql +> select nvl2(null, 'a', 'b'); ++--------------------------------+ +| nvl2(NULL,Utf8("a"),Utf8("b")) | ++--------------------------------+ +| b | ++--------------------------------+ +> select nvl2('data', 'a', 'b'); ++----------------------------------------+ +| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | ++----------------------------------------+ +| a | ++----------------------------------------+ +``` +"#, + argument( + name = "expression1", + description = "Expression to test for null. Can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "expression2", + description = "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "expression3", + description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." + ) +)] #[derive(Debug)] pub struct NVL2Func { signature: Signature, @@ -62,7 +95,11 @@ impl ScalarUDFImpl for NVL2Func { Ok(arg_types[1].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { nvl2_func(args) } @@ -90,6 +127,10 @@ impl ScalarUDFImpl for NVL2Func { )?; Ok(vec![new_type; arg_types.len()]) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn nvl2_func(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 5873b4e1af41c..717a74797c0b5 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -49,7 +49,7 @@ impl ExprPlanner for CoreFunctionPlanner { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf( if is_named_struct { - crate::core::named_struct() + named_struct() } else { crate::core::r#struct() }, diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index bdddbb81beabe..f5bff2cc726b4 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -18,8 +18,9 @@ use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -54,9 +55,50 @@ fn struct_expr(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } +#[user_doc( + doc_section(label = "Struct Functions"), + description = "Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. +For example: `c0`, `c1`, `c2`, etc.", + syntax_example = "struct(expression1[, ..., expression_n])", + sql_example = r#"For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `c1`: +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +-- use default names `c0`, `c1` +> select struct(a, b) from t; ++-----------------+ +| struct(t.a,t.b) | ++-----------------+ +| {c0: 1, c1: 2} | +| {c0: 3, c1: 4} | ++-----------------+ + +-- name the first field `field_a` +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ +```"#, + argument( + name = "expression1, expression_n", + description = "Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators." + ) +)] #[derive(Debug)] pub struct StructFunc { signature: Signature, + aliases: Vec, } impl Default for StructFunc { @@ -69,6 +111,7 @@ impl StructFunc { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("row")], } } } @@ -81,6 +124,10 @@ impl ScalarUDFImpl for StructFunc { "struct" } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn signature(&self) -> &Signature { &self.signature } @@ -94,7 +141,15 @@ impl ScalarUDFImpl for StructFunc { Ok(DataType::Struct(Fields::from(return_fields))) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { struct_expr(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index bc3bc1339aa02..0a3f53cecf134 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -17,12 +17,26 @@ //! [`VersionFunc`]: Implementation of the `version` function. -use std::any::Any; - use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - +use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +#[user_doc( + doc_section(label = "Other Functions"), + description = "Returns the version of DataFusion.", + syntax_example = "version()", + sql_example = r#"```sql +> select version(); ++--------------------------------------------+ +| version() | ++--------------------------------------------+ +| Apache DataFusion 42.0.0, aarch64 on macos | ++--------------------------------------------+ +```"# +)] #[derive(Debug)] pub struct VersionFunc { signature: Signature, @@ -63,11 +77,14 @@ impl ScalarUDFImpl for VersionFunc { } } - fn invoke(&self, _: &[ColumnarValue]) -> Result { - not_impl_err!("version does not take any arguments") - } - - fn invoke_no_args(&self, _: usize) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { + if !args.is_empty() { + return internal_err!("{} function does not accept arguments", self.name()); + } // TODO it would be great to add rust version and arrow version, // but that requires a `build.rs` script and/or adding a version const to arrow-rs let version = format!( @@ -78,6 +95,10 @@ impl ScalarUDFImpl for VersionFunc { ); Ok(ColumnarValue::from(ScalarValue::Utf8(Some(version)))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[cfg(test)] @@ -88,7 +109,8 @@ mod test { #[tokio::test] async fn test_version_udf() { let version_udf = ScalarUDF::from(VersionFunc::new()); - let version = version_udf.invoke_no_args(0).unwrap(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let version = version_udf.invoke_batch(&[], 1).unwrap(); let ColumnarValue::Scalar(version) = version else { panic!("Expected scalar version") diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index 5c2390d4b9528..fa56c1bf31fc4 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -17,17 +17,18 @@ //! "crypto" DataFusion functions -use arrow::array::StringArray; use arrow::array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait}; +use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; use datafusion_common::cast::as_binary_array; +use arrow::compute::StringArrayType; use datafusion_common::plan_err; use datafusion_common::{ - cast::{as_generic_binary_array, as_generic_string_array}, - exec_err, internal_err, DataFusionError, Result, ScalarValue, + cast::as_generic_binary_array, exec_err, internal_err, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::ColumnarValue; use md5::Md5; @@ -120,11 +121,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result { ); } let digest_algorithm = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } - other => exec_err!("Unsupported data type {other:?} for function digest"), + ColumnarValue::Scalar(scalar) => match scalar.value().try_as_str() { + Some(Some(method)) => method.parse::(), + _ => exec_err!("Unsupported data type {scalar:?} for function digest"), }, ColumnarValue::Array(_) => { internal_err!("Digest using dynamically decided method is not yet supported") @@ -132,6 +131,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result { }?; digest_process(&args[0], digest_algorithm) } + impl FromStr for DigestAlgorithm { type Err = DataFusionError; fn from_str(name: &str) -> Result { @@ -166,12 +166,14 @@ impl FromStr for DigestAlgorithm { }) } } + impl fmt::Display for DigestAlgorithm { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", format!("{self:?}").to_lowercase()) } } -// /// computes md5 hash digest of the given input + +/// computes md5 hash digest of the given input pub fn md5(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return exec_err!( @@ -180,7 +182,9 @@ pub fn md5(args: &[ColumnarValue]) -> Result { DigestAlgorithm::Md5 ); } + let value = digest_process(&args[0], DigestAlgorithm::Md5)?; + // md5 requires special handling because of its unique utf8 return type Ok(match value { ColumnarValue::Array(array) => { @@ -216,7 +220,8 @@ pub fn utf8_or_binary_to_binary_type( name: &str, ) -> Result { Ok(match arg_type { - DataType::LargeUtf8 + DataType::Utf8View + | DataType::LargeUtf8 | DataType::Utf8 | DataType::Binary | DataType::LargeBinary => DataType::Binary, @@ -298,8 +303,30 @@ impl DigestAlgorithm { where T: OffsetSizeTrait, { - let input_value = as_generic_string_array::(value)?; - let array: ArrayRef = match self { + let array = match value.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + let v = value.as_string::(); + self.digest_utf8_array_impl::<&GenericStringArray>(v) + } + DataType::Utf8View => { + let v = value.as_string_view(); + self.digest_utf8_array_impl::<&StringViewArray>(v) + } + other => { + return exec_err!("unsupported type for digest_utf_array: {other:?}") + } + }; + Ok(ColumnarValue::Array(array)) + } + + pub fn digest_utf8_array_impl<'a, StringArrType>( + self, + input_value: StringArrType, + ) -> ArrayRef + where + StringArrType: StringArrayType<'a>, + { + match self { Self::Md5 => digest_to_array!(Md5, input_value), Self::Sha224 => digest_to_array!(Sha224, input_value), Self::Sha256 => digest_to_array!(Sha256, input_value), @@ -320,8 +347,7 @@ impl DigestAlgorithm { .collect(); Arc::new(binary_array) } - }; - Ok(ColumnarValue::Array(array)) + } } } pub fn digest_process( @@ -330,6 +356,7 @@ pub fn digest_process( ) -> Result { match value { ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => digest_algorithm.digest_utf8_array::(a.as_ref()), DataType::Utf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), DataType::LargeUtf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), DataType::Binary => digest_algorithm.digest_binary_array::(a.as_ref()), @@ -341,7 +368,9 @@ pub fn digest_process( ), }, ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ScalarValue::Utf8View(a) + | ScalarValue::Utf8(a) + | ScalarValue::LargeUtf8(a) => { Ok(digest_algorithm .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) } diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index c9dd3c1f56a29..cc52f32614fdf 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -20,10 +20,37 @@ use super::basic::{digest, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the binary hash of an expression using the specified algorithm.", + syntax_example = "digest(expression, algorithm)", + sql_example = r#"```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String"), + argument( + name = "algorithm", + description = "String expression specifying algorithm to use. Must be one of: + - md5 + - sha224 + - sha256 + - sha384 + - sha512 + - blake2s + - blake2b + - blake3" + ) +)] #[derive(Debug)] pub struct DigestFunc { signature: Signature, @@ -40,6 +67,7 @@ impl DigestFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Utf8View]), Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, Utf8]), Exact(vec![Binary, Utf8]), @@ -66,7 +94,15 @@ impl ScalarUDFImpl for DigestFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { digest(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index ccb6fbba80aad..636ca65735c9a 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -19,9 +19,26 @@ use crate::crypto::basic::md5; use arrow::datatypes::DataType; use datafusion_common::{plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes an MD5 128-bit checksum for a string expression.", + syntax_example = "md5(expression)", + sql_example = r#"```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] #[derive(Debug)] pub struct Md5Func { signature: Signature, @@ -38,7 +55,7 @@ impl Md5Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } @@ -60,11 +77,11 @@ impl ScalarUDFImpl for Md5Func { fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; Ok(match &arg_types[0] { - LargeUtf8 | LargeBinary => LargeUtf8, - Utf8 | Binary => Utf8, + LargeUtf8 | LargeBinary => Utf8, + Utf8View | Utf8 | Binary => Utf8, Null => Null, Dictionary(_, t) => match **t { - LargeUtf8 | LargeBinary => LargeUtf8, + LargeUtf8 | LargeBinary => Utf8, Utf8 | Binary => Utf8, Null => Null, _ => { @@ -81,7 +98,15 @@ impl ScalarUDFImpl for Md5Func { } }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { md5(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index 46177fc22b601..62ea3c2e27371 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -27,12 +27,12 @@ pub mod sha224; pub mod sha256; pub mod sha384; pub mod sha512; -make_udf_function!(digest::DigestFunc, DIGEST, digest); -make_udf_function!(md5::Md5Func, MD5, md5); -make_udf_function!(sha224::SHA224Func, SHA224, sha224); -make_udf_function!(sha256::SHA256Func, SHA256, sha256); -make_udf_function!(sha384::SHA384Func, SHA384, sha384); -make_udf_function!(sha512::SHA512Func, SHA512, sha512); +make_udf_function!(digest::DigestFunc, digest); +make_udf_function!(md5::Md5Func, md5); +make_udf_function!(sha224::SHA224Func, sha224); +make_udf_function!(sha256::SHA256Func, sha256); +make_udf_function!(sha384::SHA384Func, sha384); +make_udf_function!(sha512::SHA512Func, sha512); pub mod expr_fn { export_functions!(( diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index d603e5bcf2952..341b3495f9c65 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -19,13 +19,26 @@ use super::basic::{sha224, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-224 hash of a binary string.", + syntax_example = "sha224(expression)", + sql_example = r#"```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] #[derive(Debug)] pub struct SHA224Func { signature: Signature, @@ -43,27 +56,13 @@ impl SHA224Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_sha224_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_HASHING) - .with_description("Computes the SHA-224 hash of a binary string.") - .with_syntax_example("sha224(expression)") - .with_standard_argument("expression", "String") - .build() - .unwrap() - }) -} - impl ScalarUDFImpl for SHA224Func { fn as_any(&self) -> &dyn Any { self @@ -81,11 +80,15 @@ impl ScalarUDFImpl for SHA224Func { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { sha224(args) } fn documentation(&self) -> Option<&Documentation> { - Some(get_sha224_doc()) + self.doc() } } diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 0a3f3b26e4310..f40dd99c59fec 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -19,9 +19,26 @@ use super::basic::{sha256, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-256 hash of a binary string.", + syntax_example = "sha256(expression)", + sql_example = r#"```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] #[derive(Debug)] pub struct SHA256Func { signature: Signature, @@ -38,7 +55,7 @@ impl SHA256Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } @@ -60,7 +77,16 @@ impl ScalarUDFImpl for SHA256Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { sha256(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index c3f7845ce7bd7..e38a755826f81 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -19,9 +19,26 @@ use super::basic::{sha384, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-384 hash of a binary string.", + syntax_example = "sha384(expression)", + sql_example = r#"```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] #[derive(Debug)] pub struct SHA384Func { signature: Signature, @@ -38,7 +55,7 @@ impl SHA384Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } @@ -60,7 +77,16 @@ impl ScalarUDFImpl for SHA384Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { sha384(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index dc3bfac9d8bdb..7fe2a26ebbce1 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -19,9 +19,26 @@ use super::basic::{sha512, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-512 hash of a binary string.", + syntax_example = "sha512(expression)", + sql_example = r#"```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] #[derive(Debug)] pub struct SHA512Func { signature: Signature, @@ -38,7 +55,7 @@ impl SHA512Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } @@ -60,7 +77,16 @@ impl ScalarUDFImpl for SHA512Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { sha512(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index da1cc2f31c76e..533125e729a16 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -18,14 +18,14 @@ use std::sync::Arc; use arrow::array::{ - Array, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, + StringArrayType, StringViewArray, }; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::LocalResult::Single; use chrono::{DateTime, TimeZone, Utc}; -use itertools::Either; use datafusion_common::cast::as_generic_string_array; use datafusion_common::{ @@ -41,14 +41,15 @@ pub(crate) fn string_to_timestamp_nanos_shim(s: &str) -> Result { string_to_timestamp_nanos(s).map_err(|e| e.into()) } -/// Checks that all the arguments from the second are of type [Utf8] or [LargeUtf8] +/// Checks that all the arguments from the second are of type [Utf8], [LargeUtf8] or [Utf8View] /// /// [Utf8]: DataType::Utf8 /// [LargeUtf8]: DataType::LargeUtf8 +/// [Utf8View]: DataType::Utf8View pub(crate) fn validate_data_types(args: &[ColumnarValue], name: &str) -> Result<()> { for (idx, a) in args.iter().skip(1).enumerate() { match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } _ => { @@ -178,38 +179,53 @@ pub(crate) fn string_to_timestamp_millis_formatted(s: &str, format: &str) -> Res .timestamp_millis()) } -pub(crate) fn handle<'a, O, F, S>( - args: &'a [ColumnarValue], +pub(crate) fn handle( + args: &[ColumnarValue], op: F, name: &str, ) -> Result where O: ArrowPrimitiveType, S: ScalarType, - F: Fn(&'a str) -> Result, + F: Fn(&str) -> Result, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( - unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, + DataType::Utf8View => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&StringViewArray, O, _>( + a.as_ref().as_string_view(), + op, + )?, + ))), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&GenericStringArray, O, _>( + a.as_ref().as_string::(), + op, + )?, + ))), + DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&GenericStringArray, O, _>( + a.as_ref().as_string::(), + op, + )?, ))), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x)).transpose()?; + ColumnarValue::Scalar(scalar) => match scalar.value().try_as_str() { + Some(a) => { + let result = a.as_ref().map(|x| op(x)).transpose()?; Ok(ColumnarValue::from(S::scalar(result))) } - other => exec_err!("Unsupported data type {other:?} for function {name}"), + _ => exec_err!("Unsupported data type {scalar:?} for function {name}"), }, } } -// given an function that maps a `&str`, `&str` to an arrow native type, +// Given a function that maps a `&str`, `&str` to an arrow native type, // returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` // depending on the `args`'s variant. -pub(crate) fn handle_multiple<'a, O, F, S, M>( - args: &'a [ColumnarValue], +pub(crate) fn handle_multiple( + args: &[ColumnarValue], op: F, op2: M, name: &str, @@ -217,24 +233,24 @@ pub(crate) fn handle_multiple<'a, O, F, S, M>( where O: ArrowPrimitiveType, S: ScalarType, - F: Fn(&'a str, &'a str) -> Result, + F: Fn(&str, &str) -> Result, M: Fn(O::Native) -> O::Native, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // validate the column types for (pos, arg) in args.iter().enumerate() { match arg { ColumnarValue::Array(arg) => match arg.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), }, ColumnarValue::Scalar(arg) => { match arg.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View| DataType::LargeUtf8 | DataType::Utf8 => { // all good } other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), @@ -244,7 +260,7 @@ where } Ok(ColumnarValue::Array(Arc::new( - strings_to_primitive_function::(args, op, op2, name)?, + strings_to_primitive_function::(args, op, op2, name)?, ))) } other => { @@ -252,8 +268,8 @@ where } }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ColumnarValue::Scalar(scalar) => match scalar.value().try_as_str() { + Some(a) => { let a = a.as_ref(); // ASK: Why do we trust `a` to be non-null at this point? let a = unwrap_or_internal_err!(a); @@ -265,13 +281,15 @@ where return exec_err!("Expected scalar of data type {v:?} for function {name}, arg # {pos}"); }; - let (ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x)) = v.value() + let (ScalarValue::Utf8View(x) + | ScalarValue::Utf8(x) + | ScalarValue::LargeUtf8(x)) = v.value() else { return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); }; if let Some(s) = x { - match op(a.as_str(), s.as_str()) { + match op(a, s.as_str()) { Ok(r) => { ret = Some(Ok(ColumnarValue::from(S::scalar(Some( op2(r), @@ -301,18 +319,16 @@ where /// # Errors /// This function errors iff: /// * the number of arguments is not > 1 or -/// * the array arguments are not castable to a `GenericStringArray` or /// * the function `op` errors for all input -pub(crate) fn strings_to_primitive_function<'a, T, O, F, F2>( - args: &'a [ColumnarValue], +pub(crate) fn strings_to_primitive_function( + args: &[ColumnarValue], op: F, op2: F2, name: &str, ) -> Result> where O: ArrowPrimitiveType, - T: OffsetSizeTrait, - F: Fn(&'a str, &'a str) -> Result, + F: Fn(&str, &str) -> Result, F2: Fn(O::Native) -> O::Native, { if args.len() < 2 { @@ -323,50 +339,81 @@ where ); } - // this will throw the error if any of the array args are not castable to GenericStringArray - let data = args - .iter() - .map(|a| match a { - ColumnarValue::Array(a) => { - Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => { + let string_array = a.as_string_view(); + handle_array_op::( + &string_array, + &args[1..], + op, + op2, + ) } - ColumnarValue::Scalar(s) => match s.value() { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(Either::Right(a)), - other => exec_err!( - "Unexpected scalar type encountered '{other}' for function '{name}'" - ), - }, - }) - .collect::, &Option>>>>()?; - - let first_arg = &data.first().unwrap().left().unwrap(); + DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&a)?; + handle_array_op::, F, F2>( + &string_array, + &args[1..], + op, + op2, + ) + } + DataType::Utf8 => { + let string_array = as_generic_string_array::(&a)?; + handle_array_op::, F, F2>( + &string_array, + &args[1..], + op, + op2, + ) + } + other => exec_err!( + "Unsupported data type {other:?} for function substr,\ + expected Utf8View, Utf8 or LargeUtf8." + ), + }, + other => exec_err!( + "Received {} data type, expected only array", + other.data_type() + ), + } +} - first_arg +fn handle_array_op<'a, O, V, F, F2>( + first: &V, + args: &[ColumnarValue], + op: F, + op2: F2, +) -> Result> +where + V: StringArrayType<'a>, + O: ArrowPrimitiveType, + F: Fn(&str, &str) -> Result, + F2: Fn(O::Native) -> O::Native, +{ + first .iter() .enumerate() .map(|(pos, x)| { let mut val = None; - if let Some(x) = x { - let param_args = data.iter().skip(1); - - // go through the args and find the first successful result. Only the last - // failure will be returned if no successful result was received. - for param_arg in param_args { - // param_arg is an array, use the corresponding index into the array as the arg - // we're currently parsing - let p = *param_arg; - let r = if p.is_left() { - let p = p.left().unwrap(); - op(x, p.value(pos)) - } - // args is a scalar, use it directly - else if let Some(p) = p.right().unwrap() { - op(x, p.as_str()) - } else { - continue; - }; + for arg in args { + let v = match arg { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => Ok(a.as_string_view().value(pos)), + DataType::LargeUtf8 => Ok(a.as_string::().value(pos)), + DataType::Utf8 => Ok(a.as_string::().value(pos)), + other => exec_err!("Unexpected type encountered '{other}'"), + }, + ColumnarValue::Scalar(s) => match s.value().try_as_str() { + Some(Some(v)) => Ok(v), + Some(None) => continue, // null string + None => exec_err!("Unexpected scalar type encountered '{s}'"), + }, + }?; + let r = op(x, v); if r.is_ok() { val = Some(Ok(op2(r.unwrap()))); break; @@ -387,28 +434,16 @@ where /// # Errors /// This function errors iff: /// * the number of arguments is not 1 or -/// * the first argument is not castable to a `GenericStringArray` or /// * the function `op` errors -fn unary_string_to_primitive_function<'a, T, O, F>( - args: &[&'a dyn Array], +fn unary_string_to_primitive_function<'a, StringArrType, O, F>( + array: StringArrType, op: F, - name: &str, ) -> Result> where + StringArrType: StringArrayType<'a>, O: ArrowPrimitiveType, - T: OffsetSizeTrait, F: Fn(&'a str) -> Result, { - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } - - let array = as_generic_string_array::(args[0])?; - // first map is the iterator, second is for the `Option<_>` array.iter().map(|x| x.map(&op).transpose()).collect() } diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 10bc9dabf90fe..624b9eb9e7480 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -23,8 +23,20 @@ use chrono::{Datelike, NaiveDate}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r#" +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. +"#, + syntax_example = "current_date()" +)] #[derive(Debug)] pub struct CurrentDateFunc { signature: Signature, @@ -40,7 +52,7 @@ impl Default for CurrentDateFunc { impl CurrentDateFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(0, vec![], Volatility::Stable), + signature: Signature::nullary(Volatility::Stable), aliases: vec![String::from("today")], } } @@ -69,7 +81,11 @@ impl ScalarUDFImpl for CurrentDateFunc { Ok(Date32) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { internal_err!( "invoke should not be called on a simplified current_date() function" ) @@ -95,4 +111,8 @@ impl ScalarUDFImpl for CurrentDateFunc { ScalarValue::Date32(days), ))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 6872a13eebdb9..01a5057538794 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -15,16 +15,27 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Time64; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r#" +Returns the current UTC time. +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. +"#, + syntax_example = "current_time()" +)] #[derive(Debug)] pub struct CurrentTimeFunc { signature: Signature, @@ -39,7 +50,7 @@ impl Default for CurrentTimeFunc { impl CurrentTimeFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(0, vec![], Volatility::Stable), + signature: Signature::nullary(Volatility::Stable), } } } @@ -67,7 +78,11 @@ impl ScalarUDFImpl for CurrentTimeFunc { Ok(Time64(Nanosecond)) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { internal_err!( "invoke should not be called on a simplified current_time() function" ) @@ -84,4 +99,8 @@ impl ScalarUDFImpl for CurrentTimeFunc { ScalarValue::Time64Nanosecond(nano), ))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 05e45327896a1..3032273efc2b7 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -35,11 +35,66 @@ use datafusion_common::{exec_err, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; +use datafusion_macros::user_doc; use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r#" +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. +"#, + syntax_example = "date_bin(interval, expression, origin-timestamp)", + sql_example = r#"```sql +-- Bin the timestamp into 1 day intervals +> SELECT date_bin(interval '1 day', time) as bin +FROM VALUES ('2023-01-01T18:18:18Z'), ('2023-01-03T19:00:03Z') t(time); ++---------------------+ +| bin | ++---------------------+ +| 2023-01-01T00:00:00 | +| 2023-01-03T00:00:00 | ++---------------------+ +2 row(s) fetched. + +-- Bin the timestamp into 1 day intervals starting at 3AM on 2023-01-01 +> SELECT date_bin(interval '1 day', time, '2023-01-01T03:00:00') as bin +FROM VALUES ('2023-01-01T18:18:18Z'), ('2023-01-03T19:00:03Z') t(time); ++---------------------+ +| bin | ++---------------------+ +| 2023-01-01T03:00:00 | +| 2023-01-03T03:00:00 | ++---------------------+ +2 row(s) fetched. +```"#, + argument(name = "interval", description = "Bin interval."), + argument( + name = "expression", + description = "Time expression to operate on. Can be a constant, column, or function." + ), + argument( + name = "origin-timestamp", + description = r#"Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). The following intervals are supported: + + - nanoseconds + - microseconds + - milliseconds + - seconds + - minutes + - hours + - days + - weeks + - months + - years + - century +"# + ) +)] #[derive(Debug)] pub struct DateBinFunc { signature: Signature, @@ -132,7 +187,11 @@ impl ScalarUDFImpl for DateBinFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() == 2 { // Default to unix EPOCH let origin = ColumnarValue::from(ScalarValue::TimestampNanosecond( @@ -163,6 +222,9 @@ impl ScalarUDFImpl for DateBinFunc { Ok(SortProperties::Unordered) } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } enum Interval { @@ -201,7 +263,7 @@ fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> i64 { fn compute_distance(time_diff: i64, stride: i64) -> i64 { let time_delta = time_diff - (time_diff % stride); - if time_diff < 0 && stride > 1 { + if time_diff < 0 && stride > 1 && time_delta != time_diff { // The origin is later than the source timestamp, round down to the previous bin time_delta - stride } else { @@ -252,7 +314,7 @@ fn to_utc_date_time(nanos: i64) -> DateTime { // Supported intervals: // 1. IntervalDayTime: this means that the stride is in days, hours, minutes, seconds and milliseconds // We will assume month interval won't be converted into this type -// TODO (my next PR): without `INTERVAL` keyword, the stride was converted into ScalarValue::IntervalDayTime somwhere +// TODO (my next PR): without `INTERVAL` keyword, the stride was converted into ScalarValue::IntervalDayTime somewhere // for month interval. I need to find that and make it ScalarValue::IntervalMonthDayNano instead // 2. IntervalMonthDayNano fn date_bin_impl( @@ -447,7 +509,7 @@ mod tests { use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc}; use arrow::array::types::TimestampNanosecondType; - use arrow::array::{IntervalDayTimeArray, TimestampNanosecondArray}; + use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, TimeUnit}; @@ -458,49 +520,69 @@ mod tests { use chrono::TimeDelta; #[test] + #[allow(deprecated)] // TODO migrate UDF invoke from invoke_batch fn test_date_bin() { - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - }))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - }))), - ColumnarValue::Array(timestamps), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let batch_len = timestamps.len(); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + ))), + ColumnarValue::Array(timestamps), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + batch_len, + ); assert!(res.is_ok()); - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - }))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert!(res.is_ok()); // stride supports month-day-nano - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 1, - }, - ))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert!(res.is_ok()); // @@ -508,97 +590,129 @@ mod tests { // // invalid number of arguments - let res = DateBinFunc::new().invoke(&[ColumnarValue::from( - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - })), - )]); + let res = DateBinFunc::new().invoke_batch( + &[ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + )))], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); // stride: invalid type - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" ); // stride: invalid value - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 0, - }))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 0, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); // stride: overflow of day-time interval - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime::MAX))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: overflow of month-day-nano interval - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: month intervals - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); // origin: invalid type - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - }))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" ); - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - }))), - ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert!(res.is_ok()); // unsupported array type for stride @@ -612,11 +726,14 @@ mod tests { }) .collect::(), ); - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Array(intervals), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::Array(intervals), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" @@ -624,14 +741,20 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); - let res = DateBinFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 0, - milliseconds: 1, - }))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Array(timestamps), - ]); + let batch_len = timestamps.len(); + let res = DateBinFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + ))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Array(timestamps), + ], + batch_len, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" @@ -746,15 +869,20 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + let batch_len = input.len(); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateBinFunc::new() - .invoke(&[ - ColumnarValue::from(ScalarValue::new_interval_dt(1, 0)), - ColumnarValue::Array(Arc::new(input)), - ColumnarValue::from(ScalarValue::TimestampNanosecond( - Some(string_to_timestamp_nanos(origin).unwrap()), - tz_opt.clone(), - )), - ]) + .invoke_batch( + &[ + ColumnarValue::from(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::from(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), + tz_opt.clone(), + )), + ], + batch_len, + ) .unwrap(); if let ColumnarValue::Array(result) = result { assert_eq!( @@ -829,4 +957,32 @@ mod tests { assert_eq!(result, expected1, "{source} = {expected}"); }) } + + #[test] + fn test_date_bin_before_epoch() { + let cases = [ + ( + (TimeDelta::try_minutes(15), "1969-12-31T23:44:59.999999999"), + "1969-12-31T23:30:00", + ), + ( + (TimeDelta::try_minutes(15), "1969-12-31T23:45:00"), + "1969-12-31T23:45:00", + ), + ( + (TimeDelta::try_minutes(15), "1969-12-31T23:45:00.000000001"), + "1969-12-31T23:45:00", + ), + ]; + + cases.iter().for_each(|((stride, source), expected)| { + let stride = stride.unwrap(); + let stride1 = stride.num_nanoseconds().unwrap(); + let source1 = string_to_timestamp_nanos(source).unwrap(); + + let expected1 = string_to_timestamp_nanos(expected).unwrap(); + let result = date_bin_nanos_interval(stride1, source1, 0); + assert_eq!(result, expected1, "{source} = {expected}"); + }) + } } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index b1867e09c2404..082ec2ce283cb 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -19,29 +19,64 @@ use std::any::Any; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::array::{Array, ArrayRef, Float64Array, Int32Array}; use arrow::compute::kernels::cast_utils::IntervalUnit; -use arrow::compute::{binary, cast, date_part, DatePart}; +use arrow::compute::{binary, date_part, DatePart}; use arrow::datatypes::DataType::{ - Date32, Date64, Duration, Float64, Interval, Time32, Time64, Timestamp, Utf8, - Utf8View, + Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; -use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano, YearMonth}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, - as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, - as_timestamp_microsecond_array, as_timestamp_millisecond_array, - as_timestamp_nanosecond_array, as_timestamp_second_array, +use datafusion_common::not_impl_err; +use datafusion_common::{ + cast::{ + as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, + as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, + }, + exec_err, internal_err, + types::logical_string, + ExprSchema, Result, ScalarValue, }; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, TypeSignature, + Volatility, }; +use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Returns the specified part of the date as an integer.", + syntax_example = "date_part(part, expression)", + alternative_syntax = "extract(field FROM source)", + argument( + name = "part", + description = r#"Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) +"# + ), + argument( + name = "expression", + description = "Time expression to operate on. Can be a constant, column, or function." + ) +)] #[derive(Debug)] pub struct DatePartFunc { signature: Signature, @@ -59,72 +94,26 @@ impl DatePartFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8View, Timestamp(Nanosecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, ]), - Exact(vec![ - Utf8View, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Date, ]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8View, Timestamp(Millisecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Time, ]), - Exact(vec![ - Utf8View, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Interval, ]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8View, Timestamp(Microsecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Duration, ]), - Exact(vec![ - Utf8View, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8View, Timestamp(Second, None)]), - Exact(vec![ - Utf8, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Date64]), - Exact(vec![Utf8View, Date64]), - Exact(vec![Utf8, Date32]), - Exact(vec![Utf8View, Date32]), - Exact(vec![Utf8, Time32(Second)]), - Exact(vec![Utf8View, Time32(Second)]), - Exact(vec![Utf8, Time32(Millisecond)]), - Exact(vec![Utf8View, Time32(Millisecond)]), - Exact(vec![Utf8, Time64(Microsecond)]), - Exact(vec![Utf8View, Time64(Microsecond)]), - Exact(vec![Utf8, Time64(Nanosecond)]), - Exact(vec![Utf8View, Time64(Nanosecond)]), - Exact(vec![Utf8, Interval(YearMonth)]), - Exact(vec![Utf8View, Interval(YearMonth)]), - Exact(vec![Utf8, Interval(DayTime)]), - Exact(vec![Utf8View, Interval(DayTime)]), - Exact(vec![Utf8, Interval(MonthDayNano)]), - Exact(vec![Utf8View, Interval(MonthDayNano)]), - Exact(vec![Utf8, Duration(Second)]), - Exact(vec![Utf8View, Duration(Second)]), - Exact(vec![Utf8, Duration(Millisecond)]), - Exact(vec![Utf8View, Duration(Millisecond)]), - Exact(vec![Utf8, Duration(Microsecond)]), - Exact(vec![Utf8View, Duration(Microsecond)]), - Exact(vec![Utf8, Duration(Nanosecond)]), - Exact(vec![Utf8View, Duration(Nanosecond)]), ], Volatility::Immutable, ), @@ -147,10 +136,29 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Float64) + internal_err!("return_type_from_exprs should be called instead") + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + _arg_types: &[DataType], + ) -> Result { + match &args[0] { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(part)) if is_epoch(part) => Ok(DataType::Float64), + _ => Ok(DataType::Int32), + }, + _ => Ok(DataType::Int32), + } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 2 { return exec_err!("Expected two arguments in DATE_PART"); } @@ -174,35 +182,31 @@ impl ScalarUDFImpl for DatePartFunc { ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; - // to remove quotes at most 2 characters - let part_trim = part.trim_matches(|c| c == '\'' || c == '\"'); - if ![2, 0].contains(&(part.len() - part_trim.len())) { - return exec_err!("Date part '{part}' not supported"); - } + let part_trim = part_normalization(part); // using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds") // and synonyms ( like "ms,msec,msecond,millisecond") to Arrow let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) { match interval_unit { - IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?, - IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?, - IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?, - IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?, - IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?, - IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?, - IntervalUnit::Second => seconds(array.as_ref(), Second)?, - IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?, - IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?, - IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?, + IntervalUnit::Year => date_part(array.as_ref(), DatePart::Year)?, + IntervalUnit::Month => date_part(array.as_ref(), DatePart::Month)?, + IntervalUnit::Week => date_part(array.as_ref(), DatePart::Week)?, + IntervalUnit::Day => date_part(array.as_ref(), DatePart::Day)?, + IntervalUnit::Hour => date_part(array.as_ref(), DatePart::Hour)?, + IntervalUnit::Minute => date_part(array.as_ref(), DatePart::Minute)?, + IntervalUnit::Second => seconds_as_i32(array.as_ref(), Second)?, + IntervalUnit::Millisecond => seconds_as_i32(array.as_ref(), Millisecond)?, + IntervalUnit::Microsecond => seconds_as_i32(array.as_ref(), Microsecond)?, + IntervalUnit::Nanosecond => seconds_as_i32(array.as_ref(), Nanosecond)?, // century and decade are not supported by `DatePart`, although they are supported in postgres _ => return exec_err!("Date part '{part}' not supported"), } } else { // special cases that can be extracted (in postgres) but are not interval units match part_trim.to_lowercase().as_str() { - "qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, - "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, - "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "qtr" | "quarter" => date_part(array.as_ref(), DatePart::Quarter)?, + "doy" => date_part(array.as_ref(), DatePart::DayOfYear)?, + "dow" => date_part(array.as_ref(), DatePart::DayOfWeekSunday0)?, "epoch" => epoch(array.as_ref())?, _ => return exec_err!("Date part '{part}' not supported"), } @@ -218,16 +222,81 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn is_epoch(part: &str) -> bool { + let part = part_normalization(part); + matches!(part.to_lowercase().as_str(), "epoch") +} + +// Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error +fn part_normalization(part: &str) -> &str { + part.strip_prefix(|c| c == '\'' || c == '\"') + .and_then(|s| s.strip_suffix(|c| c == '\'' || c == '\"')) + .unwrap_or(part) } -/// Invoke [`date_part`] and cast the result to Float64 -fn date_part_f64(array: &dyn Array, part: DatePart) -> Result { - Ok(cast(date_part(array, part)?.as_ref(), &Float64)?) +/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// result to a total number of seconds, milliseconds, microseconds or +/// nanoseconds +fn seconds_as_i32(array: &dyn Array, unit: TimeUnit) -> Result { + // Nanosecond is neither supported in Postgres nor DuckDB, to avoid dealing + // with overflow and precision issue we don't support nanosecond + if unit == Nanosecond { + return not_impl_err!("Date part {unit:?} not supported"); + } + + let conversion_factor = match unit { + Second => 1_000_000_000, + Millisecond => 1_000_000, + Microsecond => 1_000, + Nanosecond => 1, + }; + + let second_factor = match unit { + Second => 1, + Millisecond => 1_000, + Microsecond => 1_000_000, + Nanosecond => 1_000_000_000, + }; + + let secs = date_part(array, DatePart::Second)?; + // This assumes array is primitive and not a dictionary + let secs = as_int32_array(secs.as_ref())?; + let subsecs = date_part(array, DatePart::Nanosecond)?; + let subsecs = as_int32_array(subsecs.as_ref())?; + + // Special case where there are no nulls. + if subsecs.null_count() == 0 { + let r: Int32Array = binary(secs, subsecs, |secs, subsecs| { + secs * second_factor + (subsecs % 1_000_000_000) / conversion_factor + })?; + Ok(Arc::new(r)) + } else { + // Nulls in secs are preserved, nulls in subsecs are treated as zero to account for the case + // where the number of nanoseconds overflows. + let r: Int32Array = secs + .iter() + .zip(subsecs) + .map(|(secs, subsecs)| { + secs.map(|secs| { + let subsecs = subsecs.unwrap_or(0); + secs * second_factor + (subsecs % 1_000_000_000) / conversion_factor + }) + }) + .collect(); + Ok(Arc::new(r)) + } } /// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the /// result to a total number of seconds, milliseconds, microseconds or /// nanoseconds +/// +/// Given epoch return f64, this is a duplicated function to optimize for f64 type fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { let sf = match unit { Second => 1_f64, diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index eb9bc5109588b..32220045b48e2 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -36,13 +36,37 @@ use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; +use datafusion_macros::user_doc; use chrono::{ DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, }; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Truncates a timestamp value to a specified precision.", + syntax_example = "date_trunc(precision, expression)", + argument( + name = "precision", + description = r#"Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND +"# + ), + argument( + name = "expression", + description = "Time expression to operate on. Can be a constant, column, or function." + ) +)] #[derive(Debug)] pub struct DateTruncFunc { signature: Signature, @@ -136,7 +160,11 @@ impl ScalarUDFImpl for DateTruncFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let (granularity, array) = (&args[0], &args[1]); let ColumnarValue::Scalar(granularity) = granularity else { @@ -244,6 +272,9 @@ impl ScalarUDFImpl for DateTruncFunc { Ok(SortProperties::Unordered) } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn _date_trunc_coarse(granularity: &str, value: Option) -> Result> @@ -452,7 +483,7 @@ mod tests { use arrow::array::cast::as_primitive_array; use arrow::array::types::TimestampNanosecondType; - use arrow::array::TimestampNanosecondArray; + use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::ScalarValue; @@ -692,11 +723,16 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + let batch_len = input.len(); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateTruncFunc::new() - .invoke(&[ - ColumnarValue::from(ScalarValue::from("day")), - ColumnarValue::Array(Arc::new(input)), - ]) + .invoke_batch( + &[ + ColumnarValue::from(ScalarValue::from("day")), + ColumnarValue::Array(Arc::new(input)), + ], + batch_len, + ) .unwrap(); if let ColumnarValue::Array(result) = result { assert_eq!( @@ -850,11 +886,16 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + let batch_len = input.len(); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateTruncFunc::new() - .invoke(&[ - ColumnarValue::from(ScalarValue::from("hour")), - ColumnarValue::Array(Arc::new(input)), - ]) + .invoke_batch( + &[ + ColumnarValue::from(ScalarValue::from("hour")), + ColumnarValue::Array(Arc::new(input)), + ], + batch_len, + ) .unwrap(); if let ColumnarValue::Array(result) = result { assert_eq!( diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d36ebe735ee70..5cc69442c1c27 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -16,14 +16,36 @@ // under the License. use std::any::Any; +use std::sync::Arc; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Int64, Timestamp}; +use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; +use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", + syntax_example = "from_unixtime(expression[, timezone])", + sql_example = r#"```sql +> select from_unixtime(1599572549, 'America/New_York'); ++-----------------------------------------------------------+ +| from_unixtime(Int64(1599572549),Utf8("America/New_York")) | ++-----------------------------------------------------------+ +| 2020-09-08T09:42:29-04:00 | ++-----------------------------------------------------------+ +```"#, + standard_argument(name = "expression",), + argument( + name = "timezone", + description = "Optional timezone to use when converting the integer to a timestamp. If not provided, the default timezone is UTC." + ) +)] #[derive(Debug)] pub struct FromUnixtimeFunc { signature: Signature, @@ -38,7 +60,10 @@ impl Default for FromUnixtimeFunc { impl FromUnixtimeFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + signature: Signature::one_of( + vec![Exact(vec![Int64, Utf8]), Exact(vec![Int64])], + Volatility::Immutable, + ), } } } @@ -56,26 +81,130 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + arg_types: &[DataType], + ) -> Result { + match arg_types.len() { + 1 => Ok(Timestamp(Second, None)), + 2 => match &args[1] { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(tz)) => Ok(Timestamp(Second, Some(Arc::from(tz.to_string())))), + _ => exec_err!( + "Second argument for `from_unixtime` must be non-null utf8, received {:?}", + arg_types[1]), + } , + _ => exec_err!( + "Second argument for `from_unixtime` must be non-null utf8, received {:?}", + arg_types[1]), + }, + _ => exec_err!( + "from_unixtime function requires 1 or 2 arguments, got {}", + arg_types.len() + ), + } + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Second, None)) + internal_err!("call return_type_from_exprs instead") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.len() != 1 { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { + let len = args.len(); + if len != 1 && len != 2 { return exec_err!( - "from_unixtime function requires 1 argument, got {}", + "from_unixtime function requires 1 or 2 argument, got {}", args.len() ); } - match args[0].data_type() { - Int64 => args[0].cast_to(&Timestamp(Second, None), None), - other => { - exec_err!( - "Unsupported data type {:?} for function from_unixtime", - other - ) + if *args[0].data_type() != Int64 { + return exec_err!( + "Unsupported data type {:?} for function from_unixtime", + args[0].data_type() + ); + } + + match len { + 1 => args[0].cast_to(&Timestamp(Second, None), None), + 2 => match &args[1] { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(tz)) => args[0].cast_to( + &Timestamp(Second, Some(Arc::from(tz.to_string()))), + None, + ), + _ => exec_err!( + "Unsupported data type {:?} for function from_unixtime", + args[1].data_type() + ), + }, + _ => { + exec_err!( + "Unsupported data type {:?} for function from_unixtime", + args[1].data_type() + ) + } + }, + _ => unreachable!(), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod test { + use crate::datetime::from_unixtime::FromUnixtimeFunc; + use datafusion_common::ScalarValue; + use datafusion_common::ScalarValue::Int64; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_without_timezone() { + let args = [ColumnarValue::from(Int64(Some(1729900800)))]; + + // TODO use invoke_with_args + let result = FromUnixtimeFunc::new().invoke_batch(&args, 1).unwrap(); + + match result { + ColumnarValue::Scalar(scalar) => { + assert_eq!( + scalar.value(), + &ScalarValue::TimestampSecond(Some(1729900800), None) + ); + } + _ => panic!("Expected scalar value"), + } + } + + #[test] + fn test_with_timezone() { + let args = [ + ColumnarValue::from(Int64(Some(1729900800))), + ColumnarValue::from(ScalarValue::Utf8(Some("America/New_York".to_string()))), + ]; + + // TODO use invoke_with_args + let result = FromUnixtimeFunc::new().invoke_batch(&args, 2).unwrap(); + + match result { + ColumnarValue::Scalar(scalar) => { + let expected = ScalarValue::TimestampSecond( + Some(1729900800), + Some(Arc::from("America/New_York")), + ); + assert_eq!(scalar.value(), &expected); } + _ => panic!("Expected scalar value"), } } } diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index ed30f5f5a26cd..47d967751013c 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -27,8 +27,45 @@ use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf use chrono::prelude::*; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Make a date from year/month/day component parts.", + syntax_example = "make_date(year, month, day)", + sql_example = r#"```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +"#, + argument( + name = "year", + description = "Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators." + ), + argument( + name = "month", + description = "Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators." + ), + argument( + name = "day", + description = "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators." + ) +)] #[derive(Debug)] pub struct MakeDateFunc { signature: Signature, @@ -69,7 +106,11 @@ impl ScalarUDFImpl for MakeDateFunc { Ok(Date32) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 3 { return exec_err!( "make_date function requires 3 arguments, got {}", @@ -86,9 +127,9 @@ impl ScalarUDFImpl for MakeDateFunc { ColumnarValue::Array(a) => Some(a.len()), }); - let years = args[0].cast_to(&DataType::Int32, None)?; - let months = args[1].cast_to(&DataType::Int32, None)?; - let days = args[2].cast_to(&DataType::Int32, None)?; + let years = args[0].cast_to(&Int32, None)?; + let months = args[1].cast_to(&Int32, None)?; + let days = args[2].cast_to(&Int32, None)?; let scalar_value_fn = |col: &ColumnarValue| -> Result { let ColumnarValue::Scalar(s) = col else { @@ -148,6 +189,9 @@ impl ScalarUDFImpl for MakeDateFunc { Ok(value) } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Converts the year/month/day fields to an `i32` representing the days from @@ -190,12 +234,16 @@ mod tests { #[test] fn test_make_date() { + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let res = MakeDateFunc::new() - .invoke(&[ - ColumnarValue::from(ScalarValue::Int32(Some(2024))), - ColumnarValue::from(ScalarValue::Int64(Some(1))), - ColumnarValue::from(ScalarValue::UInt32(Some(14))), - ]) + .invoke_batch( + &[ + ColumnarValue::from(ScalarValue::Int32(Some(2024))), + ColumnarValue::from(ScalarValue::Int64(Some(1))), + ColumnarValue::from(ScalarValue::UInt32(Some(14))), + ], + 1, + ) .expect("that make_date parsed values without error"); let ColumnarValue::Scalar(scalar) = res else { @@ -211,12 +259,16 @@ mod tests { ) } + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let res = MakeDateFunc::new() - .invoke(&[ - ColumnarValue::from(ScalarValue::Int64(Some(2024))), - ColumnarValue::from(ScalarValue::UInt64(Some(1))), - ColumnarValue::from(ScalarValue::UInt32(Some(14))), - ]) + .invoke_batch( + &[ + ColumnarValue::from(ScalarValue::Int64(Some(2024))), + ColumnarValue::from(ScalarValue::UInt64(Some(1))), + ColumnarValue::from(ScalarValue::UInt32(Some(14))), + ], + 1, + ) .expect("that make_date parsed values without error"); let ColumnarValue::Scalar(scalar) = res else { @@ -232,12 +284,16 @@ mod tests { ) } + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let res = MakeDateFunc::new() - .invoke(&[ - ColumnarValue::from(ScalarValue::Utf8(Some("2024".to_string()))), - ColumnarValue::from(ScalarValue::LargeUtf8(Some("1".to_string()))), - ColumnarValue::from(ScalarValue::Utf8(Some("14".to_string()))), - ]) + .invoke_batch( + &[ + ColumnarValue::from(ScalarValue::Utf8(Some("2024".to_string()))), + ColumnarValue::from(ScalarValue::LargeUtf8(Some("1".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("14".to_string()))), + ], + 1, + ) .expect("that make_date parsed values without error"); let ColumnarValue::Scalar(scalar) = res else { @@ -256,12 +312,17 @@ mod tests { let years = Arc::new((2021..2025).map(Some).collect::()); let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); + let batch_len = years.len(); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let res = MakeDateFunc::new() - .invoke(&[ - ColumnarValue::Array(years), - ColumnarValue::Array(months), - ColumnarValue::Array(days), - ]) + .invoke_batch( + &[ + ColumnarValue::Array(years), + ColumnarValue::Array(months), + ColumnarValue::Array(days), + ], + batch_len, + ) .expect("that make_date parsed values without error"); if let ColumnarValue::Array(array) = res { @@ -281,41 +342,54 @@ mod tests { // // invalid number of arguments + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let res = MakeDateFunc::new() - .invoke(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))]); + .invoke_batch(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))], 1); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" ); // invalid type - let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let res = MakeDateFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" ); // overflow of month - let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::Int32(Some(2023))), - ColumnarValue::from(ScalarValue::UInt64(Some(u64::MAX))), - ColumnarValue::from(ScalarValue::Int32(Some(22))), - ]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let res = MakeDateFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::Int32(Some(2023))), + ColumnarValue::from(ScalarValue::UInt64(Some(u64::MAX))), + ColumnarValue::from(ScalarValue::Int32(Some(22))), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" ); // overflow of day - let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::Int32(Some(2023))), - ColumnarValue::from(ScalarValue::Int32(Some(22))), - ColumnarValue::from(ScalarValue::UInt32(Some(u32::MAX))), - ]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let res = MakeDateFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::Int32(Some(2023))), + ColumnarValue::from(ScalarValue::Int32(Some(22))), + ColumnarValue::from(ScalarValue::UInt32(Some(u32::MAX))), + ], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index db4e365267dd2..96ca63010ee4e 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -37,43 +37,23 @@ pub mod to_timestamp; pub mod to_unixtime; // create UDFs -make_udf_function!(current_date::CurrentDateFunc, CURRENT_DATE, current_date); -make_udf_function!(current_time::CurrentTimeFunc, CURRENT_TIME, current_time); -make_udf_function!(date_bin::DateBinFunc, DATE_BIN, date_bin); -make_udf_function!(date_part::DatePartFunc, DATE_PART, date_part); -make_udf_function!(date_trunc::DateTruncFunc, DATE_TRUNC, date_trunc); -make_udf_function!(make_date::MakeDateFunc, MAKE_DATE, make_date); -make_udf_function!( - from_unixtime::FromUnixtimeFunc, - FROM_UNIXTIME, - from_unixtime -); -make_udf_function!(now::NowFunc, NOW, now); -make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); -make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); -make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); -make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); -make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); -make_udf_function!( - to_timestamp::ToTimestampSecondsFunc, - TO_TIMESTAMP_SECONDS, - to_timestamp_seconds -); -make_udf_function!( - to_timestamp::ToTimestampMillisFunc, - TO_TIMESTAMP_MILLIS, - to_timestamp_millis -); -make_udf_function!( - to_timestamp::ToTimestampMicrosFunc, - TO_TIMESTAMP_MICROS, - to_timestamp_micros -); -make_udf_function!( - to_timestamp::ToTimestampNanosFunc, - TO_TIMESTAMP_NANOS, - to_timestamp_nanos -); +make_udf_function!(current_date::CurrentDateFunc, current_date); +make_udf_function!(current_time::CurrentTimeFunc, current_time); +make_udf_function!(date_bin::DateBinFunc, date_bin); +make_udf_function!(date_part::DatePartFunc, date_part); +make_udf_function!(date_trunc::DateTruncFunc, date_trunc); +make_udf_function!(make_date::MakeDateFunc, make_date); +make_udf_function!(from_unixtime::FromUnixtimeFunc, from_unixtime); +make_udf_function!(now::NowFunc, now); +make_udf_function!(to_char::ToCharFunc, to_char); +make_udf_function!(to_date::ToDateFunc, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, to_local_time); +make_udf_function!(to_unixtime::ToUnixtimeFunc, to_unixtime); +make_udf_function!(to_timestamp::ToTimestampFunc, to_timestamp); +make_udf_function!(to_timestamp::ToTimestampSecondsFunc, to_timestamp_seconds); +make_udf_function!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); +make_udf_function!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); +make_udf_function!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); // we cannot currently use the export_functions macro since it doesn't handle // functions with varargs currently diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 80d2e4966f815..e326f0cf0f9a6 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,19 +15,31 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, ExprSchema, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r#" +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. +"#, + syntax_example = "now()" +)] #[derive(Debug)] pub struct NowFunc { signature: Signature, + aliases: Vec, } impl Default for NowFunc { @@ -39,7 +51,8 @@ impl Default for NowFunc { impl NowFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(0, vec![], Volatility::Stable), + signature: Signature::nullary(Volatility::Stable), + aliases: vec!["current_timestamp".to_string()], } } } @@ -67,7 +80,11 @@ impl ScalarUDFImpl for NowFunc { Ok(Timestamp(Nanosecond, Some("+00:00".into()))) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { internal_err!("invoke should not be called on a simplified now() function") } @@ -84,4 +101,16 @@ impl ScalarUDFImpl for NowFunc { ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), ))) } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { + false + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 706540b6d1f88..85443c198c8d5 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -31,9 +31,38 @@ use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; - +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported.", + syntax_example = "to_char(expression, format)", + sql_example = r#"```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration." + ), + argument( + name = "format", + description = "A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression." + ), + argument( + name = "day", + description = "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators." + ) +)] #[derive(Debug)] pub struct ToCharFunc { signature: Signature, @@ -53,34 +82,34 @@ impl ToCharFunc { vec![ Exact(vec![Date32, Utf8]), Exact(vec![Date64, Utf8]), + Exact(vec![Time64(Nanosecond), Utf8]), + Exact(vec![Time64(Microsecond), Utf8]), Exact(vec![Time32(Millisecond), Utf8]), Exact(vec![Time32(Second), Utf8]), - Exact(vec![Time64(Microsecond), Utf8]), - Exact(vec![Time64(Nanosecond), Utf8]), - Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![ - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Millisecond, None), Utf8]), + Exact(vec![Timestamp(Nanosecond, None), Utf8]), Exact(vec![ - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), Exact(vec![Timestamp(Microsecond, None), Utf8]), Exact(vec![ - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Nanosecond, None), Utf8]), + Exact(vec![Timestamp(Millisecond, None), Utf8]), Exact(vec![ - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Duration(Second), Utf8]), - Exact(vec![Duration(Millisecond), Utf8]), - Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![Duration(Nanosecond), Utf8]), + Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Duration(Millisecond), Utf8]), + Exact(vec![Duration(Second), Utf8]), ], Volatility::Immutable, ), @@ -106,7 +135,11 @@ impl ScalarUDFImpl for ToCharFunc { Ok(Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 2 { return exec_err!( "to_char function requires 2 arguments, got {}", @@ -137,6 +170,9 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn _build_format_options<'a>( @@ -185,10 +221,7 @@ fn _to_char_scalar( if is_scalar_expression { return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); } else { - return Ok(ColumnarValue::Array(new_null_array( - &DataType::Utf8, - array.len(), - ))); + return Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))); } } @@ -350,8 +383,12 @@ mod tests { ]; for (value, format, expected) in scalar_data { + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let result = ToCharFunc::new() - .invoke(&[ColumnarValue::from(value), ColumnarValue::from(format)]) + .invoke_batch( + &[ColumnarValue::from(value), ColumnarValue::from(format)], + 1, + ) .expect("that to_char parsed values without error"); match result { @@ -426,11 +463,16 @@ mod tests { ]; for (value, format, expected) in scalar_array_data { + let batch_len = format.len(); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let result = ToCharFunc::new() - .invoke(&[ - ColumnarValue::from(value), - ColumnarValue::Array(Arc::new(format) as ArrayRef), - ]) + .invoke_batch( + &[ + ColumnarValue::from(value), + ColumnarValue::Array(Arc::new(format) as ArrayRef), + ], + batch_len, + ) .expect("that to_char parsed values without error"); match result { @@ -553,11 +595,16 @@ mod tests { ]; for (value, format, expected) in array_scalar_data { + let batch_len = value.len(); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let result = ToCharFunc::new() - .invoke(&[ - ColumnarValue::Array(value as ArrayRef), - ColumnarValue::from(format), - ]) + .invoke_batch( + &[ + ColumnarValue::Array(value as ArrayRef), + ColumnarValue::from(format), + ], + batch_len, + ) .expect("that to_char parsed values without error"); if let ColumnarValue::Array(result) = result { @@ -569,11 +616,16 @@ mod tests { } for (value, format, expected) in array_array_data { + let batch_len = value.len(); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let result = ToCharFunc::new() - .invoke(&[ - ColumnarValue::Array(value), - ColumnarValue::Array(Arc::new(format) as ArrayRef), - ]) + .invoke_batch( + &[ + ColumnarValue::Array(value), + ColumnarValue::Array(Arc::new(format) as ArrayRef), + ], + batch_len, + ) .expect("that to_char parsed values without error"); if let ColumnarValue::Array(result) = result { @@ -589,18 +641,23 @@ mod tests { // // invalid number of arguments - let result = - ToCharFunc::new().invoke(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let result = ToCharFunc::new() + .invoke_batch(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))], 1); assert_eq!( result.err().unwrap().strip_backtrace(), "Execution error: to_char function requires 2 arguments, got 1" ); // invalid type - let result = ToCharFunc::new().invoke(&[ - ColumnarValue::from(ScalarValue::Int32(Some(1))), - ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), - ]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let result = ToCharFunc::new().invoke_batch( + &[ + ColumnarValue::from(ScalarValue::Int32(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ], + 1, + ); assert_eq!( result.err().unwrap().strip_backtrace(), "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(Nanosecond, None)" diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index fc3f062e12e0a..259e5e02aec19 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -17,18 +17,52 @@ use crate::datetime::common::*; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Date32; +use arrow::datatypes::DataType::*; use arrow::error::ArrowError::ParseError; use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; use datafusion_common::error::DataFusionError; use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r"Converts a value to a date (`YYYY-MM-DD`). +Supports strings, integer and double types as input. +Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding date. + +Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`.", + syntax_example = "to_date('2017-05-31', '%Y-%m-%d')", + sql_example = r#"```sql +> select to_date('2023-01-31'); ++-------------------------------+ +| to_date(Utf8("2023-01-31")) | ++-------------------------------+ +| 2023-01-31 | ++-------------------------------+ +> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); ++---------------------------------------------------------------------+ +| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | ++---------------------------------------------------------------------+ +| 2023-01-31 | ++---------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +"#, + standard_argument(name = "expression", prefix = "String"), + argument( + name = "format_n", + description = r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned." + ) +)] #[derive(Debug)] pub struct ToDateFunc { signature: Signature, @@ -79,50 +113,6 @@ impl ToDateFunc { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_to_date_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_DATETIME) - .with_description(r#"Converts a value to a date (`YYYY-MM-DD`). -Supports strings, integer and double types as input. -Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. -Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding date. - -Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. -"#) - .with_syntax_example("to_date('2017-05-31', '%Y-%m-%d')") - .with_sql_example(r#"```sql -> select to_date('2023-01-31'); -+-----------------------------+ -| to_date(Utf8("2023-01-31")) | -+-----------------------------+ -| 2023-01-31 | -+-----------------------------+ -> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); -+---------------------------------------------------------------+ -| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | -+---------------------------------------------------------------+ -| 2023-01-31 | -+---------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) -"#) - .with_standard_argument("expression", "String") - .with_argument( - "format_n", - "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned.", - ) - .build() - .unwrap() - }) -} - impl ScalarUDFImpl for ToDateFunc { fn as_any(&self) -> &dyn Any { self @@ -140,7 +130,11 @@ impl ScalarUDFImpl for ToDateFunc { Ok(Date32) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.is_empty() { return exec_err!("to_date function requires 1 or more arguments, got 0"); } @@ -151,13 +145,10 @@ impl ScalarUDFImpl for ToDateFunc { } match args[0].data_type() { - DataType::Int32 - | DataType::Int64 - | DataType::Null - | DataType::Float64 - | DataType::Date32 - | DataType::Date64 => args[0].cast_to(&DataType::Date32, None), - DataType::Utf8 => self.to_date(args), + Int32 | Int64 | Null | Float64 | Date32 | Date64 => { + args[0].cast_to(&Date32, None) + } + Utf8View | LargeUtf8 | Utf8 => self.to_date(args), other => { exec_err!("Unsupported data type {:?} for function to_date", other) } @@ -165,15 +156,17 @@ impl ScalarUDFImpl for ToDateFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_to_date_doc()) + self.doc() } } #[cfg(test)] mod tests { + use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; use super::ToDateFunc; @@ -204,9 +197,19 @@ mod tests { ]; for tc in &test_cases { - let date_scalar = ScalarValue::Utf8(Some(tc.date_str.to_string())); + test_scalar(ScalarValue::Utf8(Some(tc.date_str.to_string())), tc); + test_scalar(ScalarValue::LargeUtf8(Some(tc.date_str.to_string())), tc); + test_scalar(ScalarValue::Utf8View(Some(tc.date_str.to_string())), tc); + + test_array::>(tc); + test_array::>(tc); + test_array::(tc); + } + + fn test_scalar(sv: ScalarValue, tc: &TestCase) { + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::from(date_scalar)]); + ToDateFunc::new().invoke_batch(&[ColumnarValue::from(sv)], 1); match to_date_result { Ok(ColumnarValue::Scalar(scalar)) => match scalar.into_value() { @@ -224,6 +227,35 @@ mod tests { _ => unreachable!(), } } + + fn test_array(tc: &TestCase) + where + A: From> + Array + 'static, + { + let date_array = A::from(vec![tc.date_str]); + let batch_len = date_array.len(); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let to_date_result = ToDateFunc::new() + .invoke_batch(&[ColumnarValue::Array(Arc::new(date_array))], batch_len); + + match to_date_result { + Ok(ColumnarValue::Array(a)) => { + assert_eq!(a.len(), 1); + + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + let mut builder = Date32Array::builder(4); + builder.append_value(expected.unwrap()); + + assert_eq!( + &builder.finish() as &dyn Array, + a.as_ref(), + "{}: to_date created wrong value", + tc.name + ); + } + _ => panic!("Could not convert '{}' to Date", tc.date_str), + } + } } #[test] @@ -275,14 +307,29 @@ mod tests { ]; for tc in &test_cases { - let formatted_date_scalar = - ScalarValue::Utf8(Some(tc.formatted_date.to_string())); + test_scalar(ScalarValue::Utf8(Some(tc.formatted_date.to_string())), tc); + test_scalar( + ScalarValue::LargeUtf8(Some(tc.formatted_date.to_string())), + tc, + ); + test_scalar( + ScalarValue::Utf8View(Some(tc.formatted_date.to_string())), + tc, + ); + + test_array::>(tc); + test_array::>(tc); + test_array::(tc); + } + + fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); - let to_date_result = ToDateFunc::new().invoke(&[ - ColumnarValue::from(formatted_date_scalar), - ColumnarValue::from(format_scalar), - ]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let to_date_result = ToDateFunc::new().invoke_batch( + &[ColumnarValue::from(sv), ColumnarValue::from(format_scalar)], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(scalar)) => match scalar.into_value() { @@ -299,6 +346,46 @@ mod tests { _ => unreachable!(), } } + + fn test_array(tc: &TestCase) + where + A: From> + Array + 'static, + { + let date_array = A::from(vec![tc.formatted_date]); + let format_array = A::from(vec![tc.format_str]); + let batch_len = date_array.len(); + + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let to_date_result = ToDateFunc::new().invoke_batch( + &[ + ColumnarValue::Array(Arc::new(date_array)), + ColumnarValue::Array(Arc::new(format_array)), + ], + batch_len, + ); + + match to_date_result { + Ok(ColumnarValue::Array(a)) => { + assert_eq!(a.len(), 1); + + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + let mut builder = Date32Array::builder(4); + builder.append_value(expected.unwrap()); + + assert_eq!( + &builder.finish() as &dyn Array, a.as_ref(), + "{}: to_date created wrong value for date '{}' with format string '{}'", + tc.name, + tc.formatted_date, + tc.format_str + ); + } + _ => panic!( + "Could not convert '{}' with format string '{}'to Date: {:?}", + tc.formatted_date, tc.format_str, to_date_result + ), + } + } } #[test] @@ -307,11 +394,15 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); - let to_date_result = ToDateFunc::new().invoke(&[ - ColumnarValue::from(formatted_date_scalar), - ColumnarValue::from(format1_scalar), - ColumnarValue::from(format2_scalar), - ]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let to_date_result = ToDateFunc::new().invoke_batch( + &[ + ColumnarValue::from(formatted_date_scalar), + ColumnarValue::from(format1_scalar), + ColumnarValue::from(format2_scalar), + ], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(scalar)) => match scalar.into_value() { @@ -338,8 +429,9 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::from(formatted_date_scalar)]); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch + let to_date_result = ToDateFunc::new() + .invoke_batch(&[ColumnarValue::from(formatted_date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(scalar)) => match scalar.into_value() { @@ -360,8 +452,9 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::from(date_scalar)]); + ToDateFunc::new().invoke_batch(&[ColumnarValue::from(date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(scalar)) => match scalar.into_value() { @@ -384,13 +477,14 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); + #[allow(deprecated)] // TODO migrate UDF to invoke from invoke_batch let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::from(date_scalar)]); + ToDateFunc::new().invoke_batch(&[ColumnarValue::from(date_scalar)], 1); if let Ok(ColumnarValue::Scalar(scalar)) = to_date_result { if let ScalarValue::Date32(_) = scalar.value() { panic!( - "Conversion of {} succeded, but should have failed, ", + "Conversion of {} succeeded, but should have failed, ", date_str ) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index d842f1a253c5c..5aeed735a68e5 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -31,11 +31,69 @@ use arrow::datatypes::{ use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; /// A UDF function that converts a timezone-aware timestamp to local time (with no offset or /// timezone information). In other words, this function strips off the timezone from the timestamp, /// while keep the display value of the timestamp the same. +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes.", + syntax_example = "to_local_time(expression)", + sql_example = r#"```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +```"#, + argument( + name = "expression", + description = "Time expression to operate on. Can be a constant, column, or function." + ) +)] #[derive(Debug)] pub struct ToLocalTimeFunc { signature: Signature, @@ -65,7 +123,7 @@ impl ToLocalTimeFunc { let time_value = &args[0]; let arg_type = time_value.data_type(); match arg_type { - DataType::Timestamp(_, None) => { + Timestamp(_, None) => { // if no timezone specified, just return the input Ok(time_value.clone()) } @@ -75,7 +133,7 @@ impl ToLocalTimeFunc { // for more details. // // Then remove the timezone in return type, i.e. return None - DataType::Timestamp(_, Some(timezone)) => { + Timestamp(_, Some(timezone)) => { let tz: Tz = timezone.parse()?; match time_value { @@ -309,7 +367,11 @@ impl ScalarUDFImpl for ToLocalTimeFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 1 { return exec_err!( "to_local_time function requires 1 argument, got {:?}", @@ -343,6 +405,9 @@ impl ScalarUDFImpl for ToLocalTimeFunc { _ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"), } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } #[cfg(test)] @@ -354,7 +419,7 @@ mod tests { use arrow::datatypes::{DataType, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, Scalar, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, Scalar, ScalarFunctionArgs, ScalarUDFImpl}; use super::{adjust_to_local_time, ToLocalTimeFunc}; @@ -481,7 +546,11 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() - .invoke(&[ColumnarValue::from(input)]) + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::from(input)], + number_rows: 1, + return_type: &expected.data_type(), + }) .unwrap(); match res { ColumnarValue::Scalar(res) => { @@ -539,8 +608,10 @@ mod tests { .iter() .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); + let batch_size = input.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = ToLocalTimeFunc::new() - .invoke(&[ColumnarValue::Array(Arc::new(input))]) + .invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size) .unwrap(); if let ColumnarValue::Array(result) = result { assert_eq!( diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index b8f148236c588..369a07f94d6e2 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -18,38 +18,189 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ ArrowTimestampType, DataType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use datafusion_common::{exec_err, Result, ScalarType}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - use crate::datetime::common::*; - +use datafusion_common::{exec_err, Result, ScalarType}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. +"#, + syntax_example = "to_timestamp(expression[, ..., format_n])", + sql_example = r#"```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "format_n", + description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + ) +)] #[derive(Debug)] pub struct ToTimestampFunc { signature: Signature, } +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.", + syntax_example = "to_timestamp_seconds(expression[, ..., format_n])", + sql_example = r#"```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "format_n", + description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + ) +)] #[derive(Debug)] pub struct ToTimestampSecondsFunc { signature: Signature, } +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.", + syntax_example = "to_timestamp_millis(expression[, ..., format_n])", + sql_example = r#"```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "format_n", + description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + ) +)] #[derive(Debug)] pub struct ToTimestampMillisFunc { signature: Signature, } +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp.", + syntax_example = "to_timestamp_micros(expression[, ..., format_n])", + sql_example = r#"```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "format_n", + description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + ) +)] #[derive(Debug)] pub struct ToTimestampMicrosFunc { signature: Signature, } +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.", + syntax_example = "to_timestamp_nanos(expression[, ..., format_n])", + sql_example = r#"```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "format_n", + description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + ) +)] #[derive(Debug)] pub struct ToTimestampNanosFunc { signature: Signature, @@ -148,7 +299,11 @@ impl ScalarUDFImpl for ToTimestampFunc { Ok(return_type_for(&arg_types[0], Nanosecond)) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.is_empty() { return exec_err!( "to_timestamp function requires 1 or more arguments, got {}", @@ -162,16 +317,16 @@ impl ScalarUDFImpl for ToTimestampFunc { } match args[0].data_type() { - DataType::Int32 | DataType::Int64 => args[0] + Int32 | Int64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, None), None), - DataType::Null | DataType::Float64 | Timestamp(_, None) => { + Null | Float64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Nanosecond, Some(Arc::clone(tz))), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp") } other => { @@ -182,6 +337,9 @@ impl ScalarUDFImpl for ToTimestampFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } impl ScalarUDFImpl for ToTimestampSecondsFunc { @@ -201,7 +359,11 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { Ok(return_type_for(&arg_types[0], Second)) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.is_empty() { return exec_err!( "to_timestamp_seconds function requires 1 or more arguments, got {}", @@ -215,13 +377,13 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Second, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Second, Some(Arc::clone(tz))), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_seconds") } other => { @@ -232,6 +394,9 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } impl ScalarUDFImpl for ToTimestampMillisFunc { @@ -251,7 +416,11 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { Ok(return_type_for(&arg_types[0], Millisecond)) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.is_empty() { return exec_err!( "to_timestamp_millis function requires 1 or more arguments, got {}", @@ -265,13 +434,13 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Millisecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Millisecond, Some(Arc::clone(tz))), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_millis") } other => { @@ -282,6 +451,9 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } impl ScalarUDFImpl for ToTimestampMicrosFunc { @@ -301,7 +473,11 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { Ok(return_type_for(&arg_types[0], Microsecond)) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.is_empty() { return exec_err!( "to_timestamp_micros function requires 1 or more arguments, got {}", @@ -315,13 +491,13 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Microsecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Microsecond, Some(Arc::clone(tz))), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_micros") } other => { @@ -332,6 +508,9 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } impl ScalarUDFImpl for ToTimestampNanosFunc { @@ -351,7 +530,11 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { Ok(return_type_for(&arg_types[0], Nanosecond)) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.is_empty() { return exec_err!( "to_timestamp_nanos function requires 1 or more arguments, got {}", @@ -365,13 +548,13 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Nanosecond, Some(Arc::clone(tz))), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_nanos") } other => { @@ -382,6 +565,9 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Returns the return type for the to_timestamp_* function, preserving @@ -432,7 +618,6 @@ mod tests { use arrow::array::{ArrayRef, Int64Array, StringBuilder}; use arrow::datatypes::TimeUnit; use chrono::Utc; - use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -804,17 +989,17 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type().clone()]).unwrap(); - assert!(matches!(rt, DataType::Timestamp(_, Some(_)))); - + assert!(matches!(rt, Timestamp(_, Some(_)))); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let res = udf - .invoke(&[array.clone()]) + .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); let array = match res { ColumnarValue::Array(res) => res, _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, DataType::Timestamp(_, Some(_)))); + assert!(matches!(ty, Timestamp(_, Some(_)))); } } @@ -847,17 +1032,17 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type().clone()]).unwrap(); - assert!(matches!(rt, DataType::Timestamp(_, None))); - + assert!(matches!(rt, Timestamp(_, None))); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let res = udf - .invoke(&[array.clone()]) + .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); let array = match res { ColumnarValue::Array(res) => res, _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, DataType::Timestamp(_, None))); + assert!(matches!(ty, Timestamp(_, None))); } } } @@ -933,10 +1118,7 @@ mod tests { .expect("that to_timestamp with format args parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { assert_eq!(parsed_array.len(), 1); - assert!(matches!( - parsed_array.data_type(), - DataType::Timestamp(_, None) - )); + assert!(matches!(parsed_array.data_type(), Timestamp(_, None))); match time_unit { Nanosecond => { diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 396dadccb4b3e..6776981bc74ac 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -15,16 +15,45 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - -use arrow::datatypes::{DataType, TimeUnit}; - +use super::to_timestamp::ToTimestampSecondsFunc; use crate::datetime::common::*; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use super::to_timestamp::ToTimestampSecondsFunc; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided.", + syntax_example = "to_unixtime(expression[, ..., format_n])", + sql_example = r#" +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` +"#, + argument( + name = "expression", + description = "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "format_n", + description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + ) +)] #[derive(Debug)] pub struct ToUnixtimeFunc { signature: Signature, @@ -61,7 +90,11 @@ impl ScalarUDFImpl for ToUnixtimeFunc { Ok(DataType::Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + batch_size: usize, + ) -> Result { if args.is_empty() { return exec_err!("to_unixtime function requires 1 or more arguments, got 0"); } @@ -78,12 +111,16 @@ impl ScalarUDFImpl for ToUnixtimeFunc { DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0] .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)? .cast_to(&DataType::Int64, None), + #[allow(deprecated)] // TODO: migrate to invoke_with_args DataType::Utf8 => ToTimestampSecondsFunc::new() - .invoke(args)? + .invoke_batch(args, batch_size)? .cast_to(&DataType::Int64, None), other => { exec_err!("Unsupported data type {:?} for function to_unixtime", other) } } } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index a791d77697abf..5d69d0b9debc4 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -32,13 +32,27 @@ use datafusion_common::{ use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use std::{fmt, str::FromStr}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_BINARY_STRING; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; +#[user_doc( + doc_section(label = "Binary String Functions"), + description = "Encode binary data into a textual representation.", + syntax_example = "encode(expression, format)", + argument( + name = "expression", + description = "Expression containing string or binary data" + ), + argument( + name = "format", + description = "Supported formats are: `base64`, `hex`" + ), + related_udf(name = "decode") +)] #[derive(Debug)] pub struct EncodeFunc { signature: Signature, @@ -58,22 +72,6 @@ impl EncodeFunc { } } -static ENCODE_DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_encode_doc() -> &'static Documentation { - ENCODE_DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_BINARY_STRING) - .with_description("Encode binary data into a textual representation.") - .with_syntax_example("encode(expression, format)") - .with_argument("expression", "Expression containing string or binary data") - .with_argument("format", "Supported formats are: `base64`, `hex`") - .with_related_udf("decode") - .build() - .unwrap() - }) -} - impl ScalarUDFImpl for EncodeFunc { fn as_any(&self) -> &dyn Any { self @@ -87,10 +85,28 @@ impl ScalarUDFImpl for EncodeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].to_owned()) + use DataType::*; + + Ok(match arg_types[0] { + Utf8 => Utf8, + LargeUtf8 => LargeUtf8, + Utf8View => Utf8, + Binary => Utf8, + LargeBinary => LargeUtf8, + Null => Null, + _ => { + return plan_err!( + "The encode function can only accept Utf8 or Binary or Null." + ); + } + }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { encode(args) } @@ -108,12 +124,12 @@ impl ScalarUDFImpl for EncodeFunc { } match arg_types[0] { - DataType::Utf8 | DataType::Binary | DataType::Null => { + DataType::Utf8 | DataType::Utf8View | DataType::Null => { Ok(vec![DataType::Utf8; 2]) } - DataType::LargeUtf8 | DataType::LargeBinary => { - Ok(vec![DataType::LargeUtf8, DataType::Utf8]) - } + DataType::LargeUtf8 => Ok(vec![DataType::LargeUtf8, DataType::Utf8]), + DataType::Binary => Ok(vec![DataType::Binary, DataType::Utf8]), + DataType::LargeBinary => Ok(vec![DataType::LargeBinary, DataType::Utf8]), _ => plan_err!( "1st argument should be Utf8 or Binary or Null, got {:?}", arg_types[0] @@ -122,10 +138,21 @@ impl ScalarUDFImpl for EncodeFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_encode_doc()) + self.doc() } } +#[user_doc( + doc_section(label = "Binary String Functions"), + description = "Decode binary data from textual representation in string.", + syntax_example = "decode(expression, format)", + argument( + name = "expression", + description = "Expression containing encoded string data" + ), + argument(name = "format", description = "Same arguments as [encode](#encode)"), + related_udf(name = "encode") +)] #[derive(Debug)] pub struct DecodeFunc { signature: Signature, @@ -145,22 +172,6 @@ impl DecodeFunc { } } -static DECODE_DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_decode_doc() -> &'static Documentation { - DECODE_DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_BINARY_STRING) - .with_description("Decode binary data from textual representation in string.") - .with_syntax_example("decode(expression, format)") - .with_argument("expression", "Expression containing encoded string data") - .with_argument("format", "Same arguments as [encode](#encode)") - .with_related_udf("encode") - .build() - .unwrap() - }) -} - impl ScalarUDFImpl for DecodeFunc { fn as_any(&self) -> &dyn Any { self @@ -177,7 +188,11 @@ impl ScalarUDFImpl for DecodeFunc { Ok(arg_types[0].to_owned()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { decode(args) } @@ -195,7 +210,7 @@ impl ScalarUDFImpl for DecodeFunc { } match arg_types[0] { - DataType::Utf8 | DataType::Binary | DataType::Null => { + DataType::Utf8 | DataType::Utf8View | DataType::Binary | DataType::Null => { Ok(vec![DataType::Binary, DataType::Utf8]) } DataType::LargeUtf8 | DataType::LargeBinary => { @@ -209,7 +224,7 @@ impl ScalarUDFImpl for DecodeFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_decode_doc()) + self.doc() } } @@ -224,6 +239,7 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result match a.data_type() { DataType::Utf8 => encoding.encode_utf8_array::(a.as_ref()), DataType::LargeUtf8 => encoding.encode_utf8_array::(a.as_ref()), + DataType::Utf8View => encoding.encode_utf8_array::(a.as_ref()), DataType::Binary => encoding.encode_binary_array::(a.as_ref()), DataType::LargeBinary => encoding.encode_binary_array::(a.as_ref()), other => exec_err!( @@ -237,6 +253,9 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result Ok(encoding .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Utf8View(a) => { + Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) + } ScalarValue::Binary(a) => Ok( encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) ), @@ -255,6 +274,7 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result match a.data_type() { DataType::Utf8 => encoding.decode_utf8_array::(a.as_ref()), DataType::LargeUtf8 => encoding.decode_utf8_array::(a.as_ref()), + DataType::Utf8View => encoding.decode_utf8_array::(a.as_ref()), DataType::Binary => encoding.decode_binary_array::(a.as_ref()), DataType::LargeBinary => encoding.decode_binary_array::(a.as_ref()), other => exec_err!( @@ -268,6 +288,9 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result encoding .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())), + ScalarValue::Utf8View(a) => { + encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) + } ScalarValue::Binary(a) => { encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) } @@ -512,7 +535,7 @@ impl FromStr for Encoding { } } -/// Encodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Encodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn encode(args: &[ColumnarValue]) -> Result { @@ -523,12 +546,10 @@ fn encode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ColumnarValue::Scalar(scalar) => match scalar.value().try_as_str() { + Some(Some(method)) => method.parse::(), _ => not_impl_err!( - "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" + "Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}" ), }, ColumnarValue::Array(_) => not_impl_err!( @@ -538,7 +559,7 @@ fn encode(args: &[ColumnarValue]) -> Result { encode_process(&args[0], encoding) } -/// Decodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn decode(args: &[ColumnarValue]) -> Result { @@ -549,12 +570,10 @@ fn decode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ColumnarValue::Scalar(scalar) => match scalar.value().try_as_str() { + Some(Some(method))=> method.parse::(), _ => not_impl_err!( - "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" + "Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}" ), }, ColumnarValue::Array(_) => not_impl_err!( diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs index 48171370ad585..b0ddbd368a6b4 100644 --- a/datafusion/functions/src/encoding/mod.rs +++ b/datafusion/functions/src/encoding/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; pub mod inner; // create `encode` and `decode` UDFs -make_udf_function!(inner::EncodeFunc, ENCODE, encode); -make_udf_function!(inner::DecodeFunc, DECODE, decode); +make_udf_function!(inner::EncodeFunc, encode); +make_udf_function!(inner::DecodeFunc, decode); // Export the functions out of this package, both as expr_fn as well as a list of functions pub mod expr_fn { diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 81be5552666d9..7278fe3ec5366 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -135,6 +136,8 @@ make_stub_package!(unicode, "unicode_expressions"); #[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] pub mod planner; +pub mod strings; + mod utils; /// Fluent-style API for creating `Expr`s diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index e850673ef8afc..48eff4fcd423e 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -65,26 +65,23 @@ macro_rules! export_functions { }; } -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that singleton. +/// Creates a singleton `ScalarUDF` of the `$UDF` function and a function +/// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. macro_rules! make_udf_function { - ($UDF:ty, $GNAME:ident, $NAME:ident) => { - /// Singleton instance of the function - static $GNAME: std::sync::OnceLock> = - std::sync::OnceLock::new(); - - #[doc = "Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation "] - #[doc = stringify!($UDF)] + ($UDF:ty, $NAME:ident) => { + #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { - $GNAME - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() + // Singleton instance of the function + static INSTANCE: std::sync::LazyLock< + std::sync::Arc, + > = std::sync::LazyLock::new(|| { + std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + <$UDF>::new(), + )) + }); + std::sync::Arc::clone(&INSTANCE) } }; } @@ -112,68 +109,63 @@ macro_rules! make_stub_package { }; } -/// Invokes a function on each element of an array and returns the result as a new array -/// -/// $ARG: ArrayRef -/// $NAME: name of the function (for error messages) -/// $ARGS_TYPE: the type of array to cast the argument to -/// $RETURN_TYPE: the type of array to return -/// $FUNC: the function to apply to each element of $ARG -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - -/// Downcast an argument to a specific array type, returning an internal error +/// Downcast a named argument to a specific array type, returning an internal error /// if the cast fails /// /// $ARG: ArrayRef /// $NAME: name of the argument (for error messages) /// $ARRAY_TYPE: the type of array to cast the argument to -macro_rules! downcast_arg { +#[macro_export] +macro_rules! downcast_named_arg { ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", $NAME, std::any::type_name::<$ARRAY_TYPE>() - )) + ) })? }}; } +/// Downcast an argument to a specific array type, returning an internal error +/// if the cast fails +/// +/// $ARG: ArrayRef +/// $ARRAY_TYPE: the type of array to cast the argument to +#[macro_export] +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + downcast_named_arg!($ARG, "", $ARRAY_TYPE) + }}; +} + /// Macro to create a unary math UDF. /// /// A unary math function takes an argument of type Float32 or Float64, /// applies a unary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + }; #[derive(Debug)] pub struct $UDF { @@ -226,28 +218,23 @@ macro_rules! make_math_unary_udf { $EVALUATE_BOUNDS(inputs) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - Float64Array, - { f64::$UNARY_FUNC } - )) - } - DataType::Float32 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - Float32Array, - { f32::$UNARY_FUNC } - )) - } + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), + ) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -255,8 +242,13 @@ macro_rules! make_math_unary_udf { ) } }; + Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } } } }; @@ -268,24 +260,26 @@ macro_rules! make_math_unary_udf { /// applies a binary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $BINARY_FUNC: the binary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_binary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + }; #[derive(Debug)] pub struct $UDF { @@ -336,27 +330,33 @@ macro_rules! make_math_binary_udf { $OUTPUT_ORDERING(input) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::$BINARY_FUNC } - )), - - DataType::Float32 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::$BINARY_FUNC } - )), + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -364,49 +364,14 @@ macro_rules! make_math_binary_udf { ) } }; + Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } } } }; } - -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index f7a17f0caf947..1af5e0dfaf37c 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -26,17 +26,20 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_datafusion_err, not_impl_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; type MathArrayFunction = fn(&Vec) -> Result; macro_rules! make_abs_function { ($ARRAY_TYPE:ident) => {{ |args: &Vec| { - let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let array = downcast_named_arg!(&args[0], "abs arg", $ARRAY_TYPE); let res: $ARRAY_TYPE = array.unary(|x| x.abs()); Ok(Arc::new(res) as ArrayRef) } @@ -46,7 +49,7 @@ macro_rules! make_abs_function { macro_rules! make_try_abs_function { ($ARRAY_TYPE:ident) => {{ |args: &Vec| { - let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let array = downcast_named_arg!(&args[0], "abs arg", $ARRAY_TYPE); let res: $ARRAY_TYPE = array.try_unary(|x| { x.checked_abs().ok_or_else(|| { ArrowError::ComputeError(format!( @@ -64,7 +67,7 @@ macro_rules! make_try_abs_function { macro_rules! make_decimal_abs_function { ($ARRAY_TYPE:ident) => {{ |args: &Vec| { - let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let array = downcast_named_arg!(&args[0], "abs arg", $ARRAY_TYPE); let res: $ARRAY_TYPE = array .unary(|x| x.wrapping_abs()) .with_data_type(args[0].data_type().clone()); @@ -100,6 +103,12 @@ fn create_abs_function(input_data_type: &DataType) -> Result other => not_impl_err!("Unsupported data type {other:?} for function abs"), } } +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the absolute value of a number.", + syntax_example = "abs(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct AbsFunc { signature: Signature, @@ -114,7 +123,7 @@ impl Default for AbsFunc { impl AbsFunc { pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::numeric(1, Volatility::Immutable), } } } @@ -157,7 +166,11 @@ impl ScalarUDFImpl for AbsFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 1 { @@ -184,4 +197,8 @@ impl ScalarUDFImpl for AbsFunc { Ok(SortProperties::Unordered) } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 66219960d9a2f..8b4f9317fe5fd 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -18,16 +18,22 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; - -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use crate::utils::make_scalar_function; - +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the cotangent of a number.", + syntax_example = r#"cot(numeric_expression)"#, + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct CotFunc { signature: Signature, @@ -77,7 +83,15 @@ impl ScalarUDFImpl for CotFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(cot, vec![])(args) } } @@ -85,18 +99,16 @@ impl ScalarUDFImpl for CotFunc { ///cot SQL function fn cot(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float64Array, - { compute_cot64 } - )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float32Array, - { compute_cot32 } - )) as ArrayRef), + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| compute_cot64(x)), + ) as ArrayRef), + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| compute_cot32(x)), + ) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function cot"), } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 74ad2c738a93c..2e519735eae46 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -26,9 +26,20 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - +use datafusion_common::{ + arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Factorial. Returns 1 if value is less than 2.", + syntax_example = "factorial(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct FactorialFunc { signature: Signature, @@ -65,16 +76,24 @@ impl ScalarUDFImpl for FactorialFunc { Ok(Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(factorial, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Int64 => { - let arg = downcast_arg!((&args[0]), "value", Int64Array); + Int64 => { + let arg = downcast_named_arg!((&args[0]), "value", Int64Array); Ok(arg .iter() .map(|a| match a { @@ -97,7 +116,6 @@ fn factorial(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { - use datafusion_common::cast::as_int64_array; use super::*; diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 10faf9f390bb3..14503701f6616 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -25,10 +25,21 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - +use datafusion_common::{ + arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.", + syntax_example = "gcd(expression_x, expression_y)", + standard_argument(name = "expression_x", prefix = "First numeric"), + standard_argument(name = "expression_y", prefix = "Second numeric") +)] #[derive(Debug)] pub struct GcdFunc { signature: Signature, @@ -66,17 +77,25 @@ impl ScalarUDFImpl for GcdFunc { Ok(Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(gcd, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Gcd SQL function fn gcd(args: &[ArrayRef]) -> Result { match args[0].data_type() { Int64 => { - let arg1 = downcast_arg!(&args[0], "x", Int64Array); - let arg2 = downcast_arg!(&args[1], "y", Int64Array); + let arg1 = downcast_named_arg!(&args[0], "x", Int64Array); + let arg2 = downcast_named_arg!(&args[1], "y", Int64Array); Ok(arg1 .iter() diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index e6a7280533593..8e72ee2855189 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -18,17 +18,25 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::datatypes::DataType::{Boolean, Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use crate::utils::make_scalar_function; +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", + syntax_example = "iszero(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct IsZeroFunc { signature: Signature, @@ -69,28 +77,30 @@ impl ScalarUDFImpl for IsZeroFunc { Ok(Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(iszero, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Iszero SQL function pub fn iszero(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { |x: f64| { x == 0_f64 } } + Float64 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { |x: f32| { x == 0_f32 } } + Float32 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function iszero"), diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 21c201657e906..c2c72c89841db 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -23,13 +23,24 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::error::ArrowError; -use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{ + arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use super::gcd::unsigned_gcd; use crate::utils::make_scalar_function; +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", + syntax_example = "lcm(expression_x, expression_y)", + standard_argument(name = "expression_x", prefix = "First numeric"), + standard_argument(name = "expression_y", prefix = "Second numeric") +)] #[derive(Debug)] pub struct LcmFunc { signature: Signature, @@ -67,9 +78,17 @@ impl ScalarUDFImpl for LcmFunc { Ok(Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(lcm, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Lcm SQL function @@ -96,8 +115,8 @@ fn lcm(args: &[ArrayRef]) -> Result { match args[0].data_type() { Int64 => { - let arg1 = downcast_arg!(&args[0], "x", Int64Array); - let arg2 = downcast_arg!(&args[1], "y", Int64Array); + let arg1 = downcast_named_arg!(&args[0], "x", Int64Array); + let arg2 = downcast_named_arg!(&args[1], "y", Int64Array); Ok(arg1 .iter() diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index d79feb79ae352..6c4c554ea6f7b 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -18,25 +18,32 @@ //! Math function: `log()`. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use super::power::PowerFunc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ lit, ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature::*, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.", + syntax_example = r#"log(base, numeric_expression) +log(numeric_expression)"#, + standard_argument(name = "base", prefix = "Base numeric"), + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct LogFunc { signature: Signature, @@ -48,22 +55,6 @@ impl Default for LogFunc { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_log_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_MATH) - .with_description("Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.") - .with_syntax_example(r#"log(base, numeric_expression) -log(numeric_expression)"#) - .with_standard_argument("base", "Base numeric") - .with_standard_argument("numeric_expression", "Numeric") - .build() - .unwrap() - }) -} - impl LogFunc { pub fn new() -> Self { use DataType::*; @@ -126,7 +117,11 @@ impl ScalarUDFImpl for LogFunc { } // Support overloaded log(base, x) and log(x) which defaults to log(10, x) - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let mut base = ColumnarValue::from(ScalarValue::Float32(Some(10.0))); @@ -141,9 +136,9 @@ impl ScalarUDFImpl for LogFunc { DataType::Float64 => match base { ColumnarValue::Scalar(scalar) => match scalar.into_value() { ScalarValue::Float32(Some(base)) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) + Arc::new(x.as_primitive::().unary::<_, Float64Type>( + |value: f64| f64::log(value, base as f64), + )) } _ => { return exec_err!( @@ -151,33 +146,36 @@ impl ScalarUDFImpl for LogFunc { ) } }, - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + x, + base, + f64::log, + )?; + Arc::new(result) as _ + } }, DataType::Float32 => match base { ColumnarValue::Scalar(scalar) => match scalar.into_value() { - ScalarValue::Float32(Some(base)) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) - } + ScalarValue::Float32(Some(base)) => Arc::new( + x.as_primitive::() + .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + ), _ => return exec_err!("log function requires a Float32 scalar"), }, - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + x, + base, + f32::log, + )?; + Arc::new(result) as _ + } }, other => { return exec_err!("Unsupported data type {other:?} for function log") @@ -188,7 +186,7 @@ impl ScalarUDFImpl for LogFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_log_doc()) + self.doc() } /// Simplify the `log` function by the relevant rules: @@ -265,12 +263,192 @@ mod tests { use super::*; + use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; + #[test] + #[should_panic] + fn test_log_invalid_base_type() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let _ = LogFunc::new().invoke_batch(&args, 4); + } + + #[test] + fn test_log_invalid_value() { + let args = [ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new().invoke_batch(&args, 1); + result.expect_err("expected error"); + } + + #[test] + fn test_log_scalar_f32_unary() { + let args = [ + ColumnarValue::from(ScalarValue::Float32(Some(10.0))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64_unary() { + let args = [ + ColumnarValue::from(ScalarValue::Float64(Some(10.0))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f32() { + let args = [ + ColumnarValue::from(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::from(ScalarValue::Float32(Some(32.0))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 5.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64() { + let args = [ + ColumnarValue::from(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::from(ScalarValue::Float64(Some(64.0))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f64_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new() + .invoke_batch(&args, 4) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f32_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let result = LogFunc::new() + .invoke_batch(&args, 4) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + #[test] fn test_log_f64() { let args = [ @@ -279,9 +457,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { @@ -309,9 +487,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index b221fb900cfa3..4eb337a30110e 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -40,198 +40,204 @@ pub mod signum; pub mod trunc; // Create UDFs -make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(abs::AbsFunc, abs); make_math_unary_udf!( AcosFunc, - ACOS, acos, acos, super::acos_order, - super::bounds::acos_bounds + super::bounds::acos_bounds, + super::get_acos_doc ); make_math_unary_udf!( AcoshFunc, - ACOSH, acosh, acosh, super::acosh_order, - super::bounds::acosh_bounds + super::bounds::acosh_bounds, + super::get_acosh_doc ); make_math_unary_udf!( AsinFunc, - ASIN, asin, asin, super::asin_order, - super::bounds::asin_bounds + super::bounds::asin_bounds, + super::get_asin_doc ); make_math_unary_udf!( AsinhFunc, - ASINH, asinh, asinh, super::asinh_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_asinh_doc ); make_math_unary_udf!( AtanFunc, - ATAN, atan, atan, super::atan_order, - super::bounds::atan_bounds + super::bounds::atan_bounds, + super::get_atan_doc ); make_math_unary_udf!( AtanhFunc, - ATANH, atanh, atanh, super::atanh_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_atanh_doc +); +make_math_binary_udf!( + Atan2, + atan2, + atan2, + super::atan2_order, + super::get_atan2_doc ); -make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, super::atan2_order); make_math_unary_udf!( CbrtFunc, - CBRT, cbrt, cbrt, super::cbrt_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_cbrt_doc ); make_math_unary_udf!( CeilFunc, - CEIL, ceil, ceil, super::ceil_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_ceil_doc ); make_math_unary_udf!( CosFunc, - COS, cos, cos, super::cos_order, - super::bounds::cos_bounds + super::bounds::cos_bounds, + super::get_cos_doc ); make_math_unary_udf!( CoshFunc, - COSH, cosh, cosh, super::cosh_order, - super::bounds::cosh_bounds + super::bounds::cosh_bounds, + super::get_cosh_doc ); -make_udf_function!(cot::CotFunc, COT, cot); +make_udf_function!(cot::CotFunc, cot); make_math_unary_udf!( DegreesFunc, - DEGREES, degrees, to_degrees, super::degrees_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_degrees_doc ); make_math_unary_udf!( ExpFunc, - EXP, exp, exp, super::exp_order, - super::bounds::exp_bounds + super::bounds::exp_bounds, + super::get_exp_doc ); -make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); +make_udf_function!(factorial::FactorialFunc, factorial); make_math_unary_udf!( FloorFunc, - FLOOR, floor, floor, super::floor_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_floor_doc ); -make_udf_function!(log::LogFunc, LOG, log); -make_udf_function!(gcd::GcdFunc, GCD, gcd); -make_udf_function!(nans::IsNanFunc, ISNAN, isnan); -make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); -make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_udf_function!(log::LogFunc, log); +make_udf_function!(gcd::GcdFunc, gcd); +make_udf_function!(nans::IsNanFunc, isnan); +make_udf_function!(iszero::IsZeroFunc, iszero); +make_udf_function!(lcm::LcmFunc, lcm); make_math_unary_udf!( LnFunc, - LN, ln, ln, super::ln_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_ln_doc ); make_math_unary_udf!( Log2Func, - LOG2, log2, log2, super::log2_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_log2_doc ); make_math_unary_udf!( Log10Func, - LOG10, log10, log10, super::log10_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_log10_doc ); -make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); -make_udf_function!(pi::PiFunc, PI, pi); -make_udf_function!(power::PowerFunc, POWER, power); +make_udf_function!(nanvl::NanvlFunc, nanvl); +make_udf_function!(pi::PiFunc, pi); +make_udf_function!(power::PowerFunc, power); make_math_unary_udf!( RadiansFunc, - RADIANS, radians, to_radians, super::radians_order, - super::bounds::radians_bounds + super::bounds::radians_bounds, + super::get_radians_doc ); -make_udf_function!(random::RandomFunc, RANDOM, random); -make_udf_function!(round::RoundFunc, ROUND, round); -make_udf_function!(signum::SignumFunc, SIGNUM, signum); +make_udf_function!(random::RandomFunc, random); +make_udf_function!(round::RoundFunc, round); +make_udf_function!(signum::SignumFunc, signum); make_math_unary_udf!( SinFunc, - SIN, sin, sin, super::sin_order, - super::bounds::sin_bounds + super::bounds::sin_bounds, + super::get_sin_doc ); make_math_unary_udf!( SinhFunc, - SINH, sinh, sinh, super::sinh_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_sinh_doc ); make_math_unary_udf!( SqrtFunc, - SQRT, sqrt, sqrt, super::sqrt_order, - super::bounds::sqrt_bounds + super::bounds::sqrt_bounds, + super::get_sqrt_doc ); make_math_unary_udf!( TanFunc, - TAN, tan, tan, super::tan_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_tan_doc ); make_math_unary_udf!( TanhFunc, - TANH, tanh, tanh, super::tanh_order, - super::bounds::tanh_bounds + super::bounds::tanh_bounds, + super::get_tanh_doc ); -make_udf_function!(trunc::TruncFunc, TRUNC, trunc); +make_udf_function!(trunc::TruncFunc, trunc); pub mod expr_fn { export_functions!( diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 52f2ec5171982..46c670b8e651c 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -15,9 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::sync::OnceLock; + use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::Documentation; /// Non-increasing on the interval \[−1, 1\], undefined otherwise. pub fn acos_order(input: &[ExprProperties]) -> Result { @@ -34,6 +38,20 @@ pub fn acos_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ACOS: OnceLock = OnceLock::new(); + +pub fn get_acos_doc() -> &'static Documentation { + DOCUMENTATION_ACOS.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the arc cosine or inverse cosine of a number.", + "acos(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for x ≥ 1, undefined otherwise. pub fn acosh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -51,6 +69,20 @@ pub fn acosh_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ACOSH: OnceLock = OnceLock::new(); + +pub fn get_acosh_doc() -> &'static Documentation { + DOCUMENTATION_ACOSH.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number.", + + "acosh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing on the interval \[−1, 1\], undefined otherwise. pub fn asin_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -66,16 +98,58 @@ pub fn asin_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ASIN: OnceLock = OnceLock::new(); + +pub fn get_asin_doc() -> &'static Documentation { + DOCUMENTATION_ASIN.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the arc sine or inverse sine of a number.", + "asin(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn asinh_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_ASINH: OnceLock = OnceLock::new(); + +pub fn get_asinh_doc() -> &'static Documentation { + DOCUMENTATION_ASINH.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the area hyperbolic sine or inverse hyperbolic sine of a number.", + "asinh(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn atan_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_ATAN: OnceLock = OnceLock::new(); + +pub fn get_atan_doc() -> &'static Documentation { + DOCUMENTATION_ATAN.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the arc tangent or inverse tangent of a number.", + "atan(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing on the interval \[−1, 1\], undefined otherwise. pub fn atanh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -91,22 +165,81 @@ pub fn atanh_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ATANH: OnceLock = OnceLock::new(); + +pub fn get_atanh_doc() -> &'static Documentation { + DOCUMENTATION_ATANH.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number.", + + "atanh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Order depends on the quadrant. // TODO: Implement ordering rule of the ATAN2 function. pub fn atan2_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_ATANH2: OnceLock = OnceLock::new(); + +pub fn get_atan2_doc() -> &'static Documentation { + DOCUMENTATION_ATANH2.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the arc tangent or inverse tangent of `expression_y / expression_x`.", + + "atan2(expression_y, expression_x)") + .with_argument("expression_y", r#"First numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators."#) + .with_argument("expression_x", r#"Second numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators."#) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn cbrt_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_CBRT: OnceLock = OnceLock::new(); + +pub fn get_cbrt_doc() -> &'static Documentation { + DOCUMENTATION_CBRT.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the cube root of a number.", + "cbrt(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn ceil_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_CEIL: OnceLock = OnceLock::new(); + +pub fn get_ceil_doc() -> &'static Documentation { + DOCUMENTATION_CEIL.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the nearest integer greater than or equal to a number.", + "ceil(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the ATAN2 function. @@ -114,6 +247,20 @@ pub fn cos_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_COS: OnceLock = OnceLock::new(); + +pub fn get_cos_doc() -> &'static Documentation { + DOCUMENTATION_COS.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the cosine of a number.", + "cos(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0. pub fn cosh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -130,21 +277,77 @@ pub fn cosh_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_COSH: OnceLock = OnceLock::new(); + +pub fn get_cosh_doc() -> &'static Documentation { + DOCUMENTATION_COSH.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the hyperbolic cosine of a number.", + "cosh(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing function that converts radians to degrees. pub fn degrees_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_DEGREES: OnceLock = OnceLock::new(); + +pub fn get_degrees_doc() -> &'static Documentation { + DOCUMENTATION_DEGREES.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Converts radians to degrees.", + "degrees(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn exp_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_EXP: OnceLock = OnceLock::new(); + +pub fn get_exp_doc() -> &'static Documentation { + DOCUMENTATION_EXP.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the base-e exponential of a number.", + "exp(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn floor_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_FLOOR: OnceLock = OnceLock::new(); + +pub fn get_floor_doc() -> &'static Documentation { + DOCUMENTATION_FLOOR.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the nearest integer less than or equal to a number.", + "floor(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn ln_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -159,6 +362,20 @@ pub fn ln_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_LN: OnceLock = OnceLock::new(); + +pub fn get_ln_doc() -> &'static Documentation { + DOCUMENTATION_LN.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the natural logarithm of a number.", + "ln(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn log2_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -173,6 +390,20 @@ pub fn log2_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_LOG2: OnceLock = OnceLock::new(); + +pub fn get_log2_doc() -> &'static Documentation { + DOCUMENTATION_LOG2.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the base-2 logarithm of a number.", + "log2(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn log10_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -187,11 +418,39 @@ pub fn log10_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_LOG10: OnceLock = OnceLock::new(); + +pub fn get_log10_doc() -> &'static Documentation { + DOCUMENTATION_LOG10.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the base-10 logarithm of a number.", + "log10(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers x. pub fn radians_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_RADIONS: OnceLock = OnceLock::new(); + +pub fn get_radians_doc() -> &'static Documentation { + DOCUMENTATION_RADIONS.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Converts degrees to radians.", + "radians(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the SIN function. @@ -199,11 +458,39 @@ pub fn sin_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_SIN: OnceLock = OnceLock::new(); + +pub fn get_sin_doc() -> &'static Documentation { + DOCUMENTATION_SIN.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the sine of a number.", + "sin(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn sinh_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_SINH: OnceLock = OnceLock::new(); + +pub fn get_sinh_doc() -> &'static Documentation { + DOCUMENTATION_SINH.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the hyperbolic sine of a number.", + "sinh(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn sqrt_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -218,6 +505,20 @@ pub fn sqrt_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_SQRT: OnceLock = OnceLock::new(); + +pub fn get_sqrt_doc() -> &'static Documentation { + DOCUMENTATION_SQRT.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the square root of a number.", + "sqrt(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing between vertical asymptotes at x = k * π ± π / 2 for any /// integer k. // TODO: Implement ordering rule of the TAN function. @@ -225,7 +526,35 @@ pub fn tan_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_TAN: OnceLock = OnceLock::new(); + +pub fn get_tan_doc() -> &'static Documentation { + DOCUMENTATION_TAN.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the tangent of a number.", + "tan(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} + /// Non-decreasing for all real numbers. pub fn tanh_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } + +static DOCUMENTATION_TANH: OnceLock = OnceLock::new(); + +pub fn get_tanh_doc() -> &'static Documentation { + DOCUMENTATION_TANH.get_or_init(|| { + Documentation::builder( + DOC_SECTION_MATH, + "Returns the hyperbolic tangent of a number.", + "tanh(numeric_expression)", + ) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + }) +} diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index b02839b40bd95..30c920c29a21c 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,15 +17,22 @@ //! Math function: `isnan()`. -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, DataFusionError, Result}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, TypeSignature}; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns true if a given number is +NaN or -NaN otherwise returns false.", + syntax_example = "isnan(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct IsNanFunc { signature: Signature, @@ -68,24 +75,23 @@ impl ScalarUDFImpl for IsNanFunc { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - BooleanArray, - { f64::is_nan } - )), - DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - BooleanArray, - { f32::is_nan } - )), + DataType::Float64 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + f64::is_nan, + )) as ArrayRef, + + DataType::Float32 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + f32::is_nan, + )) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -95,4 +101,8 @@ impl ScalarUDFImpl for IsNanFunc { }; Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index d81a690843b63..33823acce7518 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,17 +18,32 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::make_scalar_function; - +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = r#"Returns the first argument if it's not _NaN_. +Returns the second argument otherwise."#, + syntax_example = "nanvl(expression_x, expression_y)", + argument( + name = "expression_x", + description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators." + ), + argument( + name = "expression_y", + description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators." + ) +)] #[derive(Debug)] pub struct NanvlFunc { signature: Signature, @@ -72,9 +87,17 @@ impl ScalarUDFImpl for NanvlFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(nanvl, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Nanvl SQL function @@ -89,14 +112,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float64Array; + let y = args[1].as_primitive() as &Float64Array; + arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -107,14 +127,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float32Array; + let y = args[1].as_primitive() as &Float32Array; + arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -122,10 +139,12 @@ fn nanvl(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_nanvl_f64() { diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index b4adfc190e697..10f61b829dfe5 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -19,10 +19,18 @@ use std::any::Any; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns an approximate value of π.", + syntax_example = "pi()" +)] #[derive(Debug)] pub struct PiFunc { signature: Signature, @@ -37,7 +45,7 @@ impl Default for PiFunc { impl PiFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), } } } @@ -59,11 +67,14 @@ impl ScalarUDFImpl for PiFunc { Ok(Float64) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - not_impl_err!("{} function does not accept arguments", self.name()) - } - - fn invoke_no_args(&self, _number_rows: usize) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { + if !args.is_empty() { + return internal_err!("{} function does not accept arguments", self.name()); + } Ok(ColumnarValue::from(ScalarValue::Float64(Some( std::f64::consts::PI, )))) @@ -73,4 +84,8 @@ impl ScalarUDFImpl for PiFunc { // This function returns a constant value. Ok(SortProperties::Singleton) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 44d6a20d96b10..b5a9bfa6da658 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,24 +16,30 @@ // under the License. //! Math function: `power()`. +use std::any::Any; +use std::sync::Arc; -use arrow::datatypes::{ArrowNativeTypeOp, DataType}; +use super::log::LogFunc; +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ - arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, - DataFusionError, Result, ScalarValue, + arrow_datafusion_err, exec_datafusion_err, exec_err, internal_datafusion_err, + plan_datafusion_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDF, TypeSignature}; - -use arrow::array::{ArrayRef, Float64Array, Int64Array}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use super::log::LogFunc; - +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns a base expression raised to the power of an exponent.", + syntax_example = "power(base, exponent)", + standard_argument(name = "base", prefix = "Numeric"), + standard_argument(name = "exponent", prefix = "Exponent numeric") +)] #[derive(Debug)] pub struct PowerFunc { signature: Signature, @@ -85,22 +91,27 @@ impl ScalarUDFImpl for PowerFunc { &self.aliases } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )), - + DataType::Float64 => { + let bases = args[0].as_primitive::(); + let exponents = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + bases, + exponents, + f64::powf, + )?; + Arc::new(result) as _ + } DataType::Int64 => { - let bases = downcast_arg!(&args[0], "base", Int64Array); - let exponents = downcast_arg!(&args[1], "exponent", Int64Array); + let bases = downcast_named_arg!(&args[0], "base", Int64Array); + let exponents = downcast_named_arg!(&args[1], "exponent", Int64Array); bases .iter() .zip(exponents.iter()) @@ -115,7 +126,7 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(None), }) .collect::>() - .map(Arc::new)? as ArrayRef + .map(Arc::new)? as _ } other => { @@ -168,6 +179,10 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])), } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Return true if this function call is a call to `Log` @@ -177,6 +192,7 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { + use arrow::array::Float64Array; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; @@ -187,9 +203,9 @@ mod tests { ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = PowerFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function power"); match result { @@ -214,9 +230,9 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = PowerFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function power"); match result { diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 20591a02a930d..197d065ea408f 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -23,10 +23,17 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; use rand::{thread_rng, Rng}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "Math Functions"), + description = r#"Returns a random float value in the range [0, 1). +The random seed is unique to each row."#, + syntax_example = "random()" +)] #[derive(Debug)] pub struct RandomFunc { signature: Signature, @@ -41,7 +48,7 @@ impl Default for RandomFunc { impl RandomFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } } @@ -63,11 +70,14 @@ impl ScalarUDFImpl for RandomFunc { Ok(Float64) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - not_impl_err!("{} function does not accept arguments", self.name()) - } - - fn invoke_no_args(&self, num_rows: usize) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + num_rows: usize, + ) -> Result { + if !args.is_empty() { + return internal_err!("{} function does not accept arguments", self.name()); + } let mut rng = thread_rng(); let mut values = vec![0.0; num_rows]; // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient @@ -76,4 +86,8 @@ impl ScalarUDFImpl for RandomFunc { Ok(ColumnarValue::Array(Arc::new(array))) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 9067a1748e197..d9e623445d2d1 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,17 +20,28 @@ use std::sync::Arc; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Rounds a number to the nearest integer.", + syntax_example = "round(numeric_expression[, decimal_places])", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + argument( + name = "decimal_places", + description = "Optional. The number of decimal places to round to. Defaults to 0." + ) +)] #[derive(Debug)] pub struct RoundFunc { signature: Signature, @@ -79,7 +90,11 @@ impl ScalarUDFImpl for RoundFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(round, vec![])(args) } @@ -97,6 +112,10 @@ impl ScalarUDFImpl for RoundFunc { Ok(SortProperties::Unordered) } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Round SQL function @@ -115,7 +134,7 @@ pub fn round(args: &[ArrayRef]) -> Result { } match args[0].data_type() { - DataType::Float64 => match decimal_places { + Float64 => match decimal_places { ColumnarValue::Scalar(scalar) => match scalar.into_value() { ScalarValue::Int64(Some(decimal_places)) => { let decimal_places: i32 = decimal_places.try_into().map_err(|e| { @@ -124,17 +143,13 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a Int64 scalar for decimal_places") @@ -149,24 +164,22 @@ pub fn round(args: &[ArrayRef]) -> Result { .map_err(|e| { exec_datafusion_err!("Invalid values for decimal places: {e}") })?; - Ok(Arc::new(make_function_inputs2!( - &args[0], + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + values, decimal_places, - "value", - "decimal_places", - Float64Array, - Int32Array, - { - |value: f64, decimal_places: i32| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } }, - DataType::Float32 => match decimal_places { + Float32 => match decimal_places { ColumnarValue::Scalar(scalar) => match scalar.into_value() { ScalarValue::Int64(Some(decimal_places)) => { let decimal_places: i32 = decimal_places.try_into().map_err(|e| { @@ -175,17 +188,13 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a Int64 scalar for decimal_places") @@ -200,20 +209,17 @@ pub fn round(args: &[ArrayRef]) -> Result { panic!("Unexpected result of ColumnarValue::Array.cast") }; - Ok(Arc::new(make_function_inputs2!( - &args[0], + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, decimal_places, - "value", - "decimal_places", - Float32Array, - Int32Array, - { - |value: f32, decimal_places: i32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } }, diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index d2a806a46e136..f68834db375e5 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -18,17 +18,27 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; use crate::utils::make_scalar_function; +#[user_doc( + doc_section(label = "Math Functions"), + description = r#"Returns the sign of a number. +Negative numbers return `-1`. +Zero and positive numbers return `1`."#, + syntax_example = "signum(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric") +)] #[derive(Debug)] pub struct SignumFunc { signature: Signature, @@ -78,45 +88,49 @@ impl ScalarUDFImpl for SignumFunc { Ok(input[0].sort_properties) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(signum, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// signum SQL function pub fn signum(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "signum", - Float64Array, - Float64Array, - { - |x: f64| { - if x == 0_f64 { - 0_f64 - } else { - x.signum() - } - } - } - )) as ArrayRef), - - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "signum", - Float32Array, - Float32Array, - { - |x: f32| { - if x == 0_f32 { - 0_f32 - } else { - x.signum() - } - } - } - )) as ArrayRef), + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>( + |x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), + + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>( + |x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function signum"), } @@ -135,7 +149,7 @@ mod test { #[test] fn test_signum_f32() { - let args = [ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + let array = Arc::new(Float32Array::from(vec![ -1.0, -0.0, 0.0, @@ -145,10 +159,11 @@ mod test { f32::NAN, f32::INFINITY, f32::NEG_INFINITY, - ])))]; - + ])); + let batch_size = array.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = SignumFunc::new() - .invoke(&args) + .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); match result { @@ -175,7 +190,7 @@ mod test { #[test] fn test_signum_f64() { - let args = [ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + let array = Arc::new(Float64Array::from(vec![ -1.0, -0.0, 0.0, @@ -185,10 +200,11 @@ mod test { f64::NAN, f64::INFINITY, f64::NEG_INFINITY, - ])))]; - + ])); + let batch_size = array.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = SignumFunc::new() - .invoke(&args) + .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); match result { diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 2a2b33c1e279a..c38053097af2d 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -20,15 +20,32 @@ use std::sync::Arc; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "Math Functions"), + description = "Truncates a number to a whole number or truncated to the specified decimal places.", + syntax_example = "trunc(numeric_expression[, decimal_places])", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + argument( + name = "decimal_places", + description = r#"Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`."# + ) +)] #[derive(Debug)] pub struct TruncFunc { signature: Signature, @@ -82,7 +99,11 @@ impl ScalarUDFImpl for TruncFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(trunc, vec![])(args) } @@ -100,6 +121,10 @@ impl ScalarUDFImpl for TruncFunc { Ok(SortProperties::Unordered) } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } /// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function @@ -111,8 +136,8 @@ fn trunc(args: &[ArrayRef]) -> Result { ); } - //if only one arg then invoke toolchain trunc(num) and precision = 0 by default - //or then invoke the compute_truncate method to process precision + // If only one arg then invoke toolchain trunc(num) and precision = 0 by default + // or then invoke the compute_truncate method to process precision let num = &args[0]; let precision = if args.len() == 1 { ColumnarValue::from(Int64(Some(0))) @@ -120,46 +145,58 @@ fn trunc(args: &[ArrayRef]) -> Result { ColumnarValue::from(Arc::clone(&args[1])) }; - match args[0].data_type() { + match num.data_type() { Float64 => match precision { ColumnarValue::Scalar(scalar) => match scalar.value() { - Int64(Some(0)) => Ok(Arc::new(make_function_scalar_inputs!( - num, - "num", - Float64Array, - { f64::trunc } - )) as ArrayRef), + Int64(Some(0)) => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.trunc() + } + }), + ) as ArrayRef), _ => exec_err!("trunc function requires a Int64 scalar for precision"), }, - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float64Array, - Int64Array, - { compute_truncate64 } - )) as ArrayRef), + ColumnarValue::Array(precision) => { + let num_array = num.as_primitive::(); + let precision_array = precision.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(num_array, precision_array, |x, y| { + compute_truncate64(x, y) + })?; + + Ok(Arc::new(result) as ArrayRef) + } }, Float32 => match precision { ColumnarValue::Scalar(scalar) => match scalar.value() { - Int64(Some(0)) => Ok(Arc::new(make_function_scalar_inputs!( - num, - "num", - Float32Array, - { f32::trunc } - )) as ArrayRef), + Int64(Some(0)) => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.trunc() + } + }), + ) as ArrayRef), _ => exec_err!("trunc function requires a Int64 scalar for precision"), }, - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float32Array, - Int64Array, - { compute_truncate32 } - )) as ArrayRef), + ColumnarValue::Array(precision) => { + let num_array = num.as_primitive::(); + let precision_array = precision.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(num_array, precision_array, |x, y| { + compute_truncate32(x, y) + })?; + + Ok(Arc::new(result) as ArrayRef) + } }, other => exec_err!("Unsupported data type {other:?} for function trunc"), } diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4afbe6cbbb89c..13fbc049af582 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,22 +17,40 @@ //! "regex" DataFusion functions +use std::sync::Arc; + +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs -make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); -make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); -make_udf_function!( - regexpreplace::RegexpReplaceFunc, - REGEXP_REPLACE, - regexp_replace -); +make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); +make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); +make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); pub mod expr_fn { use datafusion_expr::Expr; + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + /// Returns a list of regular expression matches in a string. pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -67,6 +85,11 @@ pub mod expr_fn { } /// Returns all DataFusion functions defined in this package -pub fn functions() -> Vec> { - vec![regexp_match(), regexp_like(), regexp_replace()] +pub fn functions() -> Vec> { + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] } diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 0000000000000..4e38ebcead9eb --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,1058 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use datafusion_macros::user_doc; +use itertools::izip; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.", + syntax_example = "regexp_count(str, regexp[, start, flags])", + sql_example = r#"```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + standard_argument(name = "regexp", prefix = "Regular"), + argument( + name = "start", + description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function." + ), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"# + ) +)] +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.to_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::from) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +pub fn regexp_count_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_count" + ); + } + } + + regexp_count( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_count` function. +/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_count( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (Utf8View, Utf8View, None) => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +pub fn regexp_count_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { + (Some(regex_array.value(0)), true) + } else { + (None, false) + }; + + let (start_array, start_scalar, is_start_scalar) = + if let Some(start_array) = start_array { + if is_start_scalar || start_array.len() == 1 { + (None, Some(start_array.value(0)), true) + } else { + (Some(start_array), None, false) + } + } else { + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = + if let Some(flags_array) = flags_array { + if is_flags_scalar || flags_array.len() == 1 { + (None, Some(flags_array.value(0)), true) + } else { + (Some(flags_array), None, false) + } + } else { + (None, None, true) + }; + + let mut regex_cache = HashMap::new(); + + match (is_regex_scalar, is_start_scalar, is_flags_scalar) { + (true, true, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + Ok(Arc::new( + values + .iter() + .map(|value| count_matches(value, &pattern, start_scalar)) + .collect::>()?, + )) + } + (true, true, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new( + values + .iter() + .zip(flags_array.iter()) + .map(|(value, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, pattern, start_scalar) + }) + .collect::>()?, + )) + } + (true, false, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + let start_array = start_array.unwrap(); + + Ok(Arc::new( + values + .iter() + .zip(start_array.iter()) + .map(|(value, start)| count_matches(value, &pattern, start)) + .collect::>()?, + )) + } + (true, false, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new( + izip!( + values.iter(), + start_array.unwrap().iter(), + flags_array.iter() + ) + .map(|(value, start, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, pattern, start) + }) + .collect::>()?, + )) + } + (false, true, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + Ok(Arc::new( + values + .iter() + .zip(regex_array.iter()) + .map(|(value, regex)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, pattern, start_scalar) + }) + .collect::>()?, + )) + } + (false, true, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new( + izip!(values.iter(), regex_array.iter(), flags_array.iter()) + .map(|(value, regex, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, pattern, start_scalar) + }) + .collect::>()?, + )) + } + (false, false, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + Ok(Arc::new( + izip!(values.iter(), regex_array.iter(), start_array.iter()) + .map(|(value, regex, start)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, pattern, start) + }) + .collect::>()?, + )) + } + (false, false, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new( + izip!( + values.iter(), + regex_array.iter(), + start_array.iter(), + flags_array.iter() + ) + .map(|(value, regex, start, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, pattern, start) + }) + .collect::>()?, + )) + } + } +} + +fn compile_and_cache_regex<'strings, 'cache>( + regex: &'strings str, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result<&'cache Regex, ArrowError> +where + 'strings: 'cache, +{ + let result = match regex_cache.entry((regex, flags)) { + Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), + Entry::Vacant(vacant_entry) => { + let compiled = compile_regex(regex, flags)?; + vacant_entry.insert(compiled) + } + }; + Ok(result) +} + +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +fn count_matches( + value: Option<&str>, + pattern: &Regex, + start: Option, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; + + if let Some(start) = start { + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let count = pattern.find_iter(find_slice.as_str()).count(); + Ok(count as i64) + } else { + let count = pattern.find_iter(value).count(); + Ok(count as i64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{GenericStringArray, StringViewArray}; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar(); + test_case_sensitive_regexp_count_scalar_start(); + test_case_insensitive_regexp_count_scalar_flags(); + test_case_sensitive_regexp_count_start_scalar_complex(); + + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::(); + + test_case_regexp_count_cache_check::>(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let expected: Vec = vec![0, 1, 2, 1, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let expected = ScalarValue::Int64(expected.get(pos).cloned()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::from(v_sv), ColumnarValue::from(regex_sv)], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::from(v_sv), ColumnarValue::from(regex_sv)], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::from(v_sv), ColumnarValue::from(regex_sv)], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 2; + let expected: Vec = vec![0, 1, 1, 0, 2]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let expected = ScalarValue::Int64(expected.get(pos).cloned()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 1; + let flags = "i"; + let expected: Vec = vec![0, 1, 2, 2, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); + let expected = ScalarValue::Int64(expected.get(pos).cloned()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv.clone()), + ColumnarValue::from(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv.clone()), + ColumnarValue::from(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv), + ColumnarValue::from(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = ["", "abc", "a", "bc", "ab"]; + let start = 5; + let flags = ["", "i", "", "", "i"]; + let expected: Vec = vec![0, 0, 0, 1, 1]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); + let expected = ScalarValue::Int64(expected.get(pos).cloned()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv.clone()), + ColumnarValue::from(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv.clone()), + ColumnarValue::from(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); + #[allow(deprecated)] // TODO: migrate to invoke_with_args + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::from(v_sv), + ColumnarValue::from(regex_sv), + ColumnarValue::from(start_sv), + ColumnarValue::from(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(scalar)) => { + assert_eq!( + scalar.value(), + &expected, + "regexp_count scalar test failed" + ); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_regexp_count_cache_check() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["aaa", "Aaa", "aaa"]); + let regex = A::from(vec!["aaa", "aaa", "aaa"]); + let start = Int64Array::from(vec![1, 1, 1]); + let flags = A::from(vec!["", "i", ""]); + + let expected = Int64Array::from(vec![1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 61ef35bb6e5a8..56469283666e7 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -15,42 +15,28 @@ // specific language governing permissions and limitations // under the License. -//! Regx expressions -use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +//! Regex expressions + +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; -use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result, -}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::{Arc, OnceLock}; - -#[derive(Debug)] -pub struct RegexpLikeFunc { - signature: Signature, -} +use datafusion_macros::user_doc; -impl Default for RegexpLikeFunc { - fn default() -> Self { - Self::new() - } -} - -static DOCUMENTATION: OnceLock = OnceLock::new(); +use std::any::Any; +use std::sync::Arc; -fn get_regexp_like_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_REGEX) - .with_description("Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise.") - .with_syntax_example("regexp_like(str, regexp[, flags])") - .with_sql_example(r#"```sql +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise.", + syntax_example = "regexp_like(str, regexp[, flags])", + sql_example = r#"```sql select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); +--------------------------------------------------------+ | regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | @@ -65,32 +51,35 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); +--------------------------------------------------+ ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) -"#) - .with_standard_argument("str", "String") - .with_standard_argument("regexp","Regular") - .with_argument("flags", - r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +"#, + standard_argument(name = "str", prefix = "String"), + standard_argument(name = "regexp", prefix = "Regular"), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*?"#) - .build() - .unwrap() - }) + - **U**: swap the meaning of x* and x*?"# + ) +)] +#[derive(Debug)] +pub struct RegexpLikeFunc { + signature: Signature, +} + +impl Default for RegexpLikeFunc { + fn default() -> Self { + Self::new() + } } impl RegexpLikeFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Utf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), } @@ -120,7 +109,12 @@ impl ScalarUDFImpl for RegexpLikeFunc { _ => Boolean, }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let len = args .iter() .fold(Option::::None, |acc, arg| match arg { @@ -132,10 +126,10 @@ impl ScalarUDFImpl for RegexpLikeFunc { let inferred_length = len.unwrap_or(1); let args = args .iter() - .map(|arg| arg.clone().into_array(inferred_length)) + .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; - let result = regexp_like_func(&args); + let result = regexp_like(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); @@ -146,18 +140,10 @@ impl ScalarUDFImpl for RegexpLikeFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_regexp_like_doc()) - } -} -fn regexp_like_func(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Utf8 => regexp_like::(args), - DataType::LargeUtf8 => regexp_like::(args), - other => { - internal_err!("Unsupported data type {other:?} for function regexp_like") - } + self.doc() } } + /// Tests a string using a regular expression returning true if at /// least one match, false otherwise. /// @@ -200,47 +186,141 @@ fn regexp_like_func(args: &[ArrayRef]) -> Result { /// # Ok(()) /// # } /// ``` -pub fn regexp_like(args: &[ArrayRef]) -> Result { +pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { - 2 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags: Option<&GenericStringArray> = None; - let array = regexp::regexp_is_match(values, regex, flags) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags = as_generic_string_array::(&args[2])?; + let flags = match args[2].data_type() { + Utf8 => args[2].as_string::(), + LargeUtf8 => { + let large_string_array = args[2].as_string::(); + let string_vec: Vec> = (0..large_string_array.len()).map(|i| { + if large_string_array.is_null(i) { + None + } else { + Some(large_string_array.value(i)) + } + }) + .collect(); + + &GenericStringArray::::from(string_vec) + }, + _ => { + let string_view_array = args[2].as_string_view(); + let string_vec: Vec> = (0..string_view_array.len()).map(|i| { + if string_view_array.is_null(i) { + None + } else { + Some(string_view_array.value(i).to_string()) + } + }) + .collect(); + &GenericStringArray::::from(string_vec) + }, + }; if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); } - let array = regexp::regexp_is_match(values, regex, Some(flags)) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + handle_regexp_like(&args[0], &args[1], Some(flags)) + }, other => exec_err!( - "regexp_like was called with {other} arguments. It requires at least 2 and at most 3." + "`regexp_like` was called with {other} arguments. It requires at least 2 and at most 3." ), } } + +fn handle_regexp_like( + values: &ArrayRef, + patterns: &ArrayRef, + flags: Option<&GenericStringArray>, +) -> Result { + let array = match (values.data_type(), patterns.data_type()) { + (Utf8View, Utf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, Utf8View) => { + let value = values.as_string_view(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, LargeUtf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function `regexp_like`" + ) + } + }; + + Ok(Arc::new(array) as ArrayRef) +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::BooleanBuilder; use arrow::array::StringArray; + use arrow::array::{BooleanBuilder, StringViewArray}; use crate::regex::regexplike::regexp_like; #[test] - fn test_case_sensitive_regexp_like() { + fn test_case_sensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = @@ -254,13 +334,33 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = regexp_like::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_sensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); assert_eq!(re.as_ref(), &expected); } #[test] - fn test_case_insensitive_regexp_like() { + fn test_case_insensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); @@ -274,9 +374,29 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + let patterns = + StringViewArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -288,7 +408,7 @@ mod tests { let flags = StringArray::from(vec!["g"]); let re_err = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); assert_eq!( diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 714e2c6a5339f..2b6d6a3d3eaac 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -26,11 +26,48 @@ use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::{ColumnarValue, TypeSignature}; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.", + syntax_example = "regexp_match(str, regexp[, flags])", + sql_example = r#"```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "regexp", + description = "Regular expression to match against. + Can be a constant, column, or function." + ), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"# + ) +)] #[derive(Debug)] pub struct RegexpMatchFunc { signature: Signature, @@ -79,10 +116,14 @@ impl ScalarUDFImpl for RegexpMatchFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match &arg_types[0] { DataType::Null => DataType::Null, - other => DataType::List(Arc::new(Field::new("item", other.clone(), true))), + other => DataType::List(Arc::new(Field::new_list_field(other.clone(), true))), }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let len = args .iter() .fold(Option::::None, |acc, arg| match arg { @@ -94,7 +135,7 @@ impl ScalarUDFImpl for RegexpMatchFunc { let inferred_length = len.unwrap_or(1); let args = args .iter() - .map(|arg| arg.clone().into_array(inferred_length)) + .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; let result = regexp_match_func(&args); @@ -106,7 +147,12 @@ impl ScalarUDFImpl for RegexpMatchFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } + fn regexp_match_func(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => regexp_match::(args), diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 27800af347ea9..3d0cec0910999 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Regx expressions +//! Regex expressions use arrow::array::ArrayDataBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericStringArray; @@ -34,12 +34,54 @@ use datafusion_common::{ use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use regex::Regex; use std::any::Any; use std::collections::HashMap; -use std::sync::Arc; -use std::sync::OnceLock; +use std::sync::{Arc, LazyLock}; + +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax).", + syntax_example = "regexp_replace(str, regexp, replacement[, flags])", + sql_example = r#"```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "regexp", + description = "Regular expression to match against. + Can be a constant, column, or function." + ), + argument( + name = "replacement", + description = "Replacement string expression to operate on. Can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*?"# + ) +)] #[derive(Debug)] pub struct RegexpReplaceFunc { signature: Signature, @@ -105,7 +147,11 @@ impl ScalarUDFImpl for RegexpReplaceFunc { } }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let len = args .iter() .fold(Option::::None, |acc, arg| match arg { @@ -123,6 +169,10 @@ impl ScalarUDFImpl for RegexpReplaceFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } } fn regexp_replace_func(args: &[ColumnarValue]) -> Result { @@ -139,11 +189,9 @@ fn regexp_replace_func(args: &[ColumnarValue]) -> Result { /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { - fn capture_groups_re() -> &'static Regex { - static CAPTURE_GROUPS_RE_LOCK: OnceLock = OnceLock::new(); - CAPTURE_GROUPS_RE_LOCK.get_or_init(|| Regex::new(r"(\\)(\d*)").unwrap()) - } - capture_groups_re() + static CAPTURE_GROUPS_RE_LOCK: LazyLock = + LazyLock::new(|| Regex::new(r"(\\)(\d*)").unwrap()); + CAPTURE_GROUPS_RE_LOCK .replace_all(replacement, "$${$2}") .into_owned() } @@ -526,7 +574,7 @@ pub fn specialize_regexp_replace( Hint::AcceptsSingular => 1, Hint::Pad => inferred_length, }; - arg.clone().into_array(expansion_len) + arg.to_array(expansion_len) }) .collect::>>()?; _regexp_replace_static_pattern_replace::(&args) @@ -537,7 +585,7 @@ pub fn specialize_regexp_replace( (_, _, _, _) => { let args = args .iter() - .map(|arg| arg.clone().into_array(inferred_length)) + .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; match args[0].data_type() { diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 610366b6e6d90..1dd245de664fb 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -20,12 +20,33 @@ use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::{internal_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the Unicode character code of the first character in a string.", + syntax_example = "ascii(str)", + sql_example = r#"```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "chr") +)] #[derive(Debug)] pub struct AsciiFunc { signature: Signature, @@ -64,48 +85,19 @@ impl ScalarUDFImpl for AsciiFunc { Ok(Int32) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(ascii, vec![])(args) } fn documentation(&self) -> Option<&Documentation> { - Some(get_ascii_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_ascii_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description( - "Returns the Unicode character code of the first character in a string.", - ) - .with_syntax_example("ascii(str)") - .with_sql_example( - r#"```sql -> select ascii('abc'); -+--------------------+ -| ascii(Utf8("abc")) | -+--------------------+ -| 97 | -+--------------------+ -> select ascii('🚀'); -+-------------------+ -| ascii(Utf8("🚀")) | -+-------------------+ -| 128640 | -+-------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_related_udf("chr") - .build() - .unwrap() - }) -} - fn calculate_ascii<'a, V>(array: V) -> Result where V: ArrayAccessor, @@ -155,7 +147,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -164,7 +156,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::from(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::from(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i32, Int32, @@ -173,7 +165,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::from(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 93d4fa25ae81b..b78f9cfec2cd5 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -18,14 +18,29 @@ use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; use std::any::Any; -use std::sync::OnceLock; use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the bit length of a string.", + syntax_example = "bit_length(str)", + sql_example = r#"```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "length"), + related_udf(name = "octet_length") +)] #[derive(Debug)] pub struct BitLengthFunc { signature: Signature, @@ -62,7 +77,11 @@ impl ScalarUDFImpl for BitLengthFunc { utf8_to_int_type(&arg_types[0], "bit_length") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 1 { return exec_err!( "bit_length function requires 1 argument, got {}", @@ -79,38 +98,15 @@ impl ScalarUDFImpl for BitLengthFunc { ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::from(ScalarValue::Int64( v.as_ref().map(|x| (x.len() * 8) as i64), ))), - _ => unreachable!(), + ScalarValue::Utf8View(v) => Ok(ColumnarValue::from(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + _ => unreachable!("bit length"), }, } } fn documentation(&self) -> Option<&Documentation> { - Some(get_bit_length_doc()) + self.doc() } } - -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_bit_length_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the bit length of a string.") - .with_syntax_example("bit_length(str)") - .with_sql_example( - r#"```sql -> select bit_length('datafusion'); -+--------------------------------+ -| bit_length(Utf8("datafusion")) | -+--------------------------------+ -| 80 | -+--------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_related_udf("length") - .with_related_udf("octet_length") - .build() - .unwrap() - }) -} diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 55687c2d1be77..1e3c27f4b496f 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -21,12 +21,11 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' @@ -35,6 +34,28 @@ fn btrim(args: &[ArrayRef]) -> Result { general_trim::(args, TrimType::Both, use_string_view) } +#[user_doc( + doc_section(label = "String Functions"), + description = "Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.", + syntax_example = "btrim(str[, trim_str])", + sql_example = r#"```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "trim_str", + description = r"String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._" + ), + alternative_syntax = "trim(BOTH trim_str FROM str)", + alternative_syntax = "trim(trim_str FROM str)", + related_udf(name = "ltrim"), + related_udf(name = "rtrim") +)] #[derive(Debug)] pub struct BTrimFunc { signature: Signature, @@ -80,7 +101,11 @@ impl ScalarUDFImpl for BTrimFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( btrim::, @@ -102,35 +127,10 @@ impl ScalarUDFImpl for BTrimFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_btrim_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_btrim_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.") - .with_syntax_example("btrim(str[, trim_str])") - .with_sql_example(r#"```sql -> select btrim('__datafusion____', '_'); -+-------------------------------------------+ -| btrim(Utf8("__datafusion____"),Utf8("_")) | -+-------------------------------------------+ -| datafusion | -+-------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("trim_str", "String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._") - .with_related_udf("ltrim") - .with_related_udf("rtrim") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use arrow::array::{Array, StringArray, StringViewArray}; @@ -147,9 +147,9 @@ mod tests { // String view cases for checking normal logic test_function!( BTrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from("alphabet ") - ))),], + )))], Ok(Some("alphabet")), &str, Utf8View, @@ -157,7 +157,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet")), @@ -167,7 +167,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -180,7 +180,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -193,7 +193,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -207,7 +207,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabetxxx" )))), @@ -221,7 +221,7 @@ mod tests { // String cases test_function!( BTrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "alphabet " )))),], Ok(Some("alphabet")), @@ -231,7 +231,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "alphabet " )))),], Ok(Some("alphabet")), @@ -241,7 +241,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -252,7 +252,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -263,7 +263,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index ae0900af37d3d..e5f06d6213a2b 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::ArrayRef; use arrow::array::StringArray; @@ -27,9 +27,9 @@ use arrow::datatypes::DataType::Utf8; use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' @@ -60,6 +60,21 @@ pub fn chr(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the character with the specified ASCII or Unicode code value.", + syntax_example = "chr(expression)", + sql_example = r#"```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +```"#, + standard_argument(name = "expression", prefix = "String"), + related_udf(name = "ascii") +)] #[derive(Debug)] pub struct ChrFunc { signature: Signature, @@ -96,38 +111,15 @@ impl ScalarUDFImpl for ChrFunc { Ok(Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(chr, vec![])(args) } fn documentation(&self) -> Option<&Documentation> { - Some(get_chr_doc()) + self.doc() } } - -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_chr_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description( - "Returns the character with the specified ASCII or Unicode code value.", - ) - .with_syntax_example("chr(expression)") - .with_sql_example( - r#"```sql -> select chr(128640); -+--------------------+ -| chr(Int64(128640)) | -+--------------------+ -| 🚀 | -+--------------------+ -```"#, - ) - .with_standard_argument("expression", "String") - .with_related_udf("ascii") - .build() - .unwrap() - }) -} diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 1243c090f2479..ebbc018d7f3aa 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -20,12 +20,12 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; +use crate::strings::make_and_append_view; use arrow::array::{ - make_view, new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, - ArrayRef, ByteView, GenericStringArray, GenericStringBuilder, LargeStringArray, - OffsetSizeTrait, StringArray, StringBuilder, StringViewArray, StringViewBuilder, + new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder, + OffsetSizeTrait, StringBuilder, StringViewArray, }; -use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; +use arrow::buffer::Buffer; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; @@ -33,42 +33,6 @@ use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::ColumnarValue; -/// Append a new view to the views buffer with the given substr -/// -/// # Safety -/// -/// original_view must be a valid view (the format described on -/// [`GenericByteViewArray`](arrow::array::GenericByteViewArray). -/// -/// # Arguments -/// - views_buffer: The buffer to append the new view to -/// - null_builder: The buffer to append the null value to -/// - original_view: The original view value -/// - substr: The substring to append. Must be a valid substring of the original view -/// - start_offset: The start offset of the substring in the view -pub(crate) fn make_and_append_view( - views_buffer: &mut Vec, - null_builder: &mut NullBufferBuilder, - original_view: &u128, - substr: &str, - start_offset: u32, -) { - let substr_len = substr.len(); - let sub_view = if substr_len > 12 { - let view = ByteView::from(*original_view); - make_view( - substr.as_bytes(), - view.buffer_index, - view.offset + start_offset, - ) - } else { - // inline value does not need block id or offset - make_view(substr.as_bytes(), 0, 0) - }; - views_buffer.push(sub_view); - null_builder.append_non_null(); -} - pub(crate) enum TrimType { Left, Right, @@ -97,7 +61,7 @@ pub(crate) fn general_trim( str::trim_start_matches::<&[char]>(input, pattern.as_ref()); // `ltrimmed_str` is actually `input`[start_offset..], // so `start_offset` = len(`input`) - len(`ltrimmed_str`) - let start_offset = input.as_bytes().len() - ltrimmed_str.as_bytes().len(); + let start_offset = input.len() - ltrimmed_str.len(); (ltrimmed_str, start_offset as u32) }, @@ -114,7 +78,7 @@ pub(crate) fn general_trim( str::trim_start_matches::<&[char]>(input, pattern.as_ref()); // `btrimmed_str` can be got by rtrim(ltrim(`input`)), // so its `start_offset` should be same as ltrim situation above - let start_offset = input.as_bytes().len() - ltrimmed_str.as_bytes().len(); + let start_offset = input.len() - ltrimmed_str.len(); let btrimmed_str = str::trim_end_matches::<&[char]>(ltrimmed_str, pattern.as_ref()); @@ -399,370 +363,6 @@ where } } -#[derive(Debug)] -pub(crate) enum ColumnarValueRef<'a> { - Scalar(&'a [u8]), - NullableArray(&'a StringArray), - NonNullableArray(&'a StringArray), - NullableLargeStringArray(&'a LargeStringArray), - NonNullableLargeStringArray(&'a LargeStringArray), - NullableStringViewArray(&'a StringViewArray), - NonNullableStringViewArray(&'a StringViewArray), -} - -impl<'a> ColumnarValueRef<'a> { - #[inline] - pub fn is_valid(&self, i: usize) -> bool { - match &self { - Self::Scalar(_) - | Self::NonNullableArray(_) - | Self::NonNullableLargeStringArray(_) - | Self::NonNullableStringViewArray(_) => true, - Self::NullableArray(array) => array.is_valid(i), - Self::NullableStringViewArray(array) => array.is_valid(i), - Self::NullableLargeStringArray(array) => array.is_valid(i), - } - } - - #[inline] - pub fn nulls(&self) -> Option { - match &self { - Self::Scalar(_) - | Self::NonNullableArray(_) - | Self::NonNullableStringViewArray(_) - | Self::NonNullableLargeStringArray(_) => None, - Self::NullableArray(array) => array.nulls().cloned(), - Self::NullableStringViewArray(array) => array.nulls().cloned(), - Self::NullableLargeStringArray(array) => array.nulls().cloned(), - } - } -} - -/// Abstracts iteration over different types of string arrays. -/// -/// The [`StringArrayType`] trait helps write generic code for string functions that can work with -/// different types of string arrays. -/// -/// Currently three types are supported: -/// - [`StringArray`] -/// - [`LargeStringArray`] -/// - [`StringViewArray`] -/// -/// It is inspired / copied from [arrow-rs]. -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/bf0ea9129e617e4a3cf915a900b747cc5485315f/arrow-string/src/like.rs#L151-L157 -/// -/// # Examples -/// Generic function that works for [`StringArray`], [`LargeStringArray`] -/// and [`StringViewArray`]: -/// ``` -/// # use arrow::array::{StringArray, LargeStringArray, StringViewArray}; -/// # use datafusion_functions::string::common::StringArrayType; -/// -/// /// Combines string values for any StringArrayType type. It can be invoked on -/// /// and combination of `StringArray`, `LargeStringArray` or `StringViewArray` -/// fn combine_values<'a, S1, S2>(array1: S1, array2: S2) -> Vec -/// where S1: StringArrayType<'a>, S2: StringArrayType<'a> -/// { -/// // iterate over the elements of the 2 arrays in parallel -/// array1 -/// .iter() -/// .zip(array2.iter()) -/// .map(|(s1, s2)| { -/// // if both values are non null, combine them -/// if let (Some(s1), Some(s2)) = (s1, s2) { -/// format!("{s1}{s2}") -/// } else { -/// "None".to_string() -/// } -/// }) -/// .collect() -/// } -/// -/// let string_array = StringArray::from(vec!["foo", "bar"]); -/// let large_string_array = LargeStringArray::from(vec!["foo2", "bar2"]); -/// let string_view_array = StringViewArray::from(vec!["foo3", "bar3"]); -/// -/// // can invoke this function a string array and large string array -/// assert_eq!( -/// combine_values(&string_array, &large_string_array), -/// vec![String::from("foofoo2"), String::from("barbar2")] -/// ); -/// -/// // Can call the same function with string array and string view array -/// assert_eq!( -/// combine_values(&string_array, &string_view_array), -/// vec![String::from("foofoo3"), String::from("barbar3")] -/// ); -/// ``` -/// -/// [`LargeStringArray`]: arrow::array::LargeStringArray -pub trait StringArrayType<'a>: ArrayAccessor + Sized { - /// Return an [`ArrayIter`] over the values of the array. - /// - /// This iterator iterates returns `Option<&str>` for each item in the array. - fn iter(&self) -> ArrayIter; - - /// Check if the array is ASCII only. - fn is_ascii(&self) -> bool; -} - -impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { - fn iter(&self) -> ArrayIter { - GenericStringArray::::iter(self) - } - - fn is_ascii(&self) -> bool { - GenericStringArray::::is_ascii(self) - } -} - -impl<'a> StringArrayType<'a> for &'a StringViewArray { - fn iter(&self) -> ArrayIter { - StringViewArray::iter(self) - } - - fn is_ascii(&self) -> bool { - StringViewArray::is_ascii(self) - } -} - -/// Optimized version of the StringBuilder in Arrow that: -/// 1. Precalculating the expected length of the result, avoiding reallocations. -/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` -pub(crate) struct StringArrayBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, -} - -impl StringArrayBuilder { - pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i32) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - } - - pub fn append_offset(&mut self) { - let next_offset: i32 = self - .value_buffer - .len() - .try_into() - .expect("byte array offset overflow"); - unsafe { self.offsets_buffer.push_unchecked(next_offset) }; - } - - pub fn finish(self, null_buffer: Option) -> StringArray { - let array_builder = ArrayDataBuilder::new(DataType::Utf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - // SAFETY: all data that was appended was valid UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - StringArray::from(array_data) - } -} - -pub(crate) struct StringViewArrayBuilder { - builder: StringViewBuilder, - block: String, -} - -impl StringViewArrayBuilder { - pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { - let builder = StringViewBuilder::with_capacity(data_capacity); - Self { - builder, - block: String::new(), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.block.push_str(std::str::from_utf8(s).unwrap()); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); - } - } - } - - pub fn append_offset(&mut self) { - self.builder.append_value(&self.block); - self.block = String::new(); - } - - pub fn finish(mut self) -> StringViewArray { - self.builder.finish() - } -} - -pub(crate) struct LargeStringArrayBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, -} - -impl LargeStringArrayBuilder { - pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i64) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - } - - pub fn append_offset(&mut self) { - let next_offset: i64 = self - .value_buffer - .len() - .try_into() - .expect("byte array offset overflow"); - unsafe { self.offsets_buffer.push_unchecked(next_offset) }; - } - - pub fn finish(self, null_buffer: Option) -> LargeStringArray { - let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - // SAFETY: all data that was appended was valid Large UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - LargeStringArray::from(array_data) - } -} - fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 7a2426606b7c1..4490d02c24e05 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -17,19 +17,41 @@ use arrow::array::{as_largestring_array, Array}; use arrow::datatypes::DataType; +use datafusion_expr::sort_properties::ExprProperties; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; -use crate::string::common::*; use crate::string::concat; +use crate::strings::{ + ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder, +}; use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Concatenates multiple strings together.", + syntax_example = "concat(str[, ..., str_n])", + sql_example = r#"```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "str_n", + description = "Subsequent string expressions to concatenate." + ), + related_udf(name = "concat_ws") +)] #[derive(Debug)] pub struct ConcatFunc { signature: Signature, @@ -46,7 +68,7 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8, Utf8View, LargeUtf8], + vec![Utf8View, Utf8, LargeUtf8], Volatility::Immutable, ), } @@ -83,13 +105,17 @@ impl ScalarUDFImpl for ConcatFunc { /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let mut return_datatype = DataType::Utf8; args.iter().for_each(|col| { - if col.data_type() == &DataType::Utf8View { + if *col.data_type() == DataType::Utf8View { return_datatype = col.data_type().clone(); } - if col.data_type() == &DataType::LargeUtf8 + if *col.data_type() == DataType::LargeUtf8 && return_datatype != DataType::Utf8View { return_datatype = col.data_type().clone(); @@ -108,10 +134,17 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(scalar) = arg { - if let ScalarValue::Utf8(Some(v)) = scalar.value() { - result.push_str(v); - } + let ColumnarValue::Scalar(scalar) = arg else { + return internal_err!("concat expected scalar value, got {arg:?}"); + }; + + match scalar.value().try_as_str() { + Some(Some(v)) => result.push_str(v), + Some(None) => {} // null literal + None => plan_err!( + "Concat function does not support scalar type {:?}", + scalar + )?, } } @@ -161,7 +194,7 @@ impl ScalarUDFImpl for ConcatFunc { ColumnarValueRef::NonNullableArray(string_array) }; columns.push(column); - }, + } DataType::LargeUtf8 => { let string_array = as_largestring_array(array); @@ -172,7 +205,7 @@ impl ScalarUDFImpl for ConcatFunc { ColumnarValueRef::NonNullableLargeStringArray(string_array) }; columns.push(column); - }, + } DataType::Utf8View => { let string_array = as_string_view_array(array)?; @@ -183,7 +216,7 @@ impl ScalarUDFImpl for ConcatFunc { ColumnarValueRef::NonNullableStringViewArray(string_array) }; columns.push(column); - }, + } other => { return plan_err!("Input was {other} which is not a supported datatype for concat function") } @@ -250,62 +283,66 @@ impl ScalarUDFImpl for ConcatFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_concat_doc()) + self.doc() } -} - -static DOCUMENTATION: OnceLock = OnceLock::new(); -fn get_concat_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Concatenates multiple strings together.") - .with_syntax_example("concat(str[, ..., str_n])") - .with_sql_example( - r#"```sql -> select concat('data', 'f', 'us', 'ion'); -+-------------------------------------------------------+ -| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | -+-------------------------------------------------------+ -| datafusion | -+-------------------------------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_argument("str_n", "Subsequent string expressions to concatenate.") - .with_related_udf("concat_ws") - .build() - .unwrap() - }) + fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { + Ok(true) + } } pub fn simplify_concat(args: Vec) -> Result { let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); + let return_type = { + let data_types: Vec<_> = args + .iter() + .filter_map(|expr| match expr { + Expr::Literal(l) => Some(l.data_type().clone()), + _ => None, + }) + .collect(); + ConcatFunc::new().return_type(&data_types) + }?; + for arg in args.clone() { match arg { Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Utf8(None) => {} + ScalarValue::LargeUtf8(None) => {} + ScalarValue::Utf8View(None) => {} + + // filter out `null` args + // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. + // Concatenate it with the `contiguous_scalar`. + ScalarValue::Utf8(Some(v)) => { + contiguous_scalar += v; + } + ScalarValue::LargeUtf8(Some(v)) => { + contiguous_scalar += v; + } + ScalarValue::Utf8View(Some(v)) => { + contiguous_scalar += v; + } - // filter out `null` args - ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None) => {} - // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. - // Concatenate it with the `contiguous_scalar`. - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)) - => contiguous_scalar += v, - x => { - return internal_err!( - "The scalar {x} should be casted to string type during the type coercion." - ) - } + x => { + return internal_err!( + "The scalar {x} should be casted to string type during the type coercion." + ) + } } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => unreachable!(), + } contiguous_scalar = "".to_string(); } new_args.push(arg); @@ -314,7 +351,16 @@ pub fn simplify_concat(args: Vec) -> Result { } if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => { + new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) + } + DataType::Utf8View => { + new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) + } + _ => unreachable!(), + } } if !args.eq(&new_args) { @@ -341,7 +387,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::from("bb")), ColumnarValue::from(ScalarValue::from("cc")), @@ -353,7 +399,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from("cc")), @@ -365,7 +411,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(None))], + vec![ColumnarValue::from(ScalarValue::Utf8(None))], Ok(Some("")), &str, Utf8, @@ -373,7 +419,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::Utf8View(None)), ColumnarValue::from(ScalarValue::LargeUtf8(None)), @@ -386,7 +432,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::LargeUtf8(None)), ColumnarValue::from(ScalarValue::from("cc")), @@ -396,6 +442,17 @@ mod tests { LargeUtf8, LargeStringArray ); + test_function!( + ConcatFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some("aa".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } @@ -410,11 +467,19 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; + let c3 = ColumnarValue::from(ScalarValue::Utf8View(Some(",".to_string()))); + let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("a"), + None, + Some("b"), + ]))); + let args = &[c0, c1, c2, c3, c4]; - let result = ConcatFunc::new().invoke(args)?; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let result = ConcatFunc::new().invoke_batch(args, 3)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) + as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array); diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 26e77ff593afc..32b3705deea3c 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -17,21 +17,47 @@ use arrow::array::{as_largestring_array, Array, StringArray}; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::datatypes::DataType; -use crate::string::common::*; use crate::string::concat::simplify_concat; use crate::string::concat_ws; +use crate::strings::{ColumnarValueRef, StringArrayBuilder}; use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Concatenates multiple strings together with a specified separator.", + syntax_example = "concat_ws(separator, str[, ..., str_n])", + sql_example = r#"```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +```"#, + argument( + name = "separator", + description = "Separator to insert between concatenated strings." + ), + argument( + name = "str", + description = "String expression to operate on. Can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "str_n", + description = "Subsequent string expressions to concatenate." + ), + related_udf(name = "concat") +)] #[derive(Debug)] pub struct ConcatWsFunc { signature: Signature, @@ -75,7 +101,11 @@ impl ScalarUDFImpl for ConcatWsFunc { /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { // do not accept 0 arguments. if args.len() < 2 { return exec_err!( @@ -94,57 +124,54 @@ impl ScalarUDFImpl for ConcatWsFunc { // Scalar if array_len.is_none() { - let sep = match &args[0] { - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(s)) - | ScalarValue::Utf8View(Some(s)) - | ScalarValue::LargeUtf8(Some(s)) => s, - ScalarValue::Utf8(None) - | ScalarValue::Utf8View(None) - | ScalarValue::LargeUtf8(None) => { - return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); - } - _ => unreachable!(), - }, - _ => unreachable!(), + let ColumnarValue::Scalar(scalar) = &args[0] else { + // loop above checks for all args being scalar + unreachable!() + }; + let sep = match scalar.value().try_as_str() { + Some(Some(s)) => s, + Some(None) => { + // null literal string + return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); + } + None => return internal_err!("Expected string literal, got {scalar:?}"), }; let mut result = String::new(); - let iter = &mut args[1..].iter(); - - for arg in iter.by_ref() { - match arg { - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(s)) - | ScalarValue::Utf8View(Some(s)) - | ScalarValue::LargeUtf8(Some(s)) => { - result.push_str(s); - break; - } - ScalarValue::Utf8(None) - | ScalarValue::Utf8View(None) - | ScalarValue::LargeUtf8(None) => {} - _ => unreachable!(), - }, - _ => unreachable!(), + // iterator over Option + let iter = &mut args[1..].iter().map(|arg| { + let ColumnarValue::Scalar(scalar) = arg else { + // loop above checks for all args being scalar + unreachable!() + }; + scalar.value().try_as_str() + }); + + // append first non null arg + for scalar in iter.by_ref() { + match scalar { + Some(Some(s)) => { + result.push_str(s); + break; + } + Some(None) => {} // null literal string + None => { + return internal_err!("Expected string literal, got {scalar:?}") + } } } - for arg in iter.by_ref() { - match arg { - ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(s)) - | ScalarValue::Utf8View(Some(s)) - | ScalarValue::LargeUtf8(Some(s)) => { - result.push_str(sep); - result.push_str(s); - } - ScalarValue::Utf8(None) - | ScalarValue::Utf8View(None) - | ScalarValue::LargeUtf8(None) => {} - _ => unreachable!(), - }, - _ => unreachable!(), + // handle subsequent non null args + for scalar in iter.by_ref() { + match scalar { + Some(Some(s)) => { + result.push_str(sep); + result.push_str(s); + } + Some(None) => {} // null literal string + None => { + return internal_err!("Expected string literal, got {scalar:?}") + } } } @@ -282,45 +309,10 @@ impl ScalarUDFImpl for ConcatWsFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_concat_ws_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_concat_ws_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description( - "Concatenates multiple strings together with a specified separator.", - ) - .with_syntax_example("concat_ws(separator, str[, ..., str_n])") - .with_sql_example( - r#"```sql -> select concat_ws('_', 'data', 'fusion'); -+--------------------------------------------------+ -| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | -+--------------------------------------------------+ -| data_fusion | -+--------------------------------------------------+ -```"#, - ) - .with_argument( - "separator", - "Separator to insert between concatenated strings.", - ) - .with_standard_argument("str", "String") - .with_standard_argument( - "str_n", - "Subsequent string expressions to concatenate.", - ) - .with_related_udf("concat") - .build() - .unwrap() - }) -} - fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { Expr::Literal(scalar) => match scalar.value() { @@ -423,7 +415,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("|")), ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::from("bb")), @@ -436,7 +428,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("|")), ColumnarValue::from(ScalarValue::Utf8(None)), ], @@ -447,7 +439,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::from("bb")), @@ -460,7 +452,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("|")), ColumnarValue::from(ScalarValue::from("aa")), ColumnarValue::from(ScalarValue::Utf8(None)), @@ -488,7 +480,8 @@ mod tests { ]))); let args = &[c0, c1, c2]; - let result = ConcatWsFunc::new().invoke(args)?; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let result = ConcatWsFunc::new().invoke_batch(args, 3)?; let expected = Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; match &result { @@ -513,7 +506,8 @@ mod tests { ]))); let args = &[c0, c1, c2]; - let result = ConcatWsFunc::new().invoke(args)?; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let result = ConcatWsFunc::new().invoke_batch(args, 3)?; let expected = Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) as ArrayRef; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 0f75731aa1c3f..27d9be9157b4f 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -16,20 +16,35 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; -use arrow::compute::regexp_is_match; +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Return true if search_str is found within string (case-sensitive).", + syntax_example = "contains(str, search_str)", + sql_example = r#"```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "search_str", description = "The string to search for in str.") +)] #[derive(Debug)] pub struct ContainsFunc { signature: Signature, @@ -66,76 +81,38 @@ impl ScalarUDFImpl for ContainsFunc { Ok(Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(contains, vec![])(args) } fn documentation(&self) -> Option<&Documentation> { - Some(get_contains_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_contains_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description( - "Return true if search_str is found within string (case-sensitive).", - ) - .with_syntax_example("contains(str, search_str)") - .with_sql_example( - r#"```sql -> select contains('the quick brown fox', 'row'); -+---------------------------------------------------+ -| contains(Utf8("the quick brown fox"),Utf8("row")) | -+---------------------------------------------------+ -| true | -+---------------------------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_argument("search_str", "The string to search for in str.") - .build() - .unwrap() - }) -} - -/// use regexp_is_match_utf8_scalar to do the calculation for contains +/// use `arrow::compute::contains` to do the calculation for contains pub fn contains(args: &[ArrayRef]) -> Result { match (args[0].data_type(), args[1].data_type()) { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< - StringViewArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (Utf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (LargeUtf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } other => { @@ -143,3 +120,32 @@ pub fn contains(args: &[ArrayRef]) -> Result { } } } + +#[cfg(test)] +mod test { + use super::ContainsFunc; + use arrow::array::{BooleanArray, StringArray}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_contains_udf() { + let udf = ContainsFunc::new(); + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("xxx?()"), + Some("yyy?()"), + ]))); + let scalar = ColumnarValue::from(ScalarValue::Utf8(Some("x?(".to_string()))); + #[allow(deprecated)] // TODO migrate UDF to invoke + let actual = udf.invoke_batch(&[array, scalar], 2).unwrap(); + let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + ]))); + assert_eq!( + *actual.into_array(2).unwrap(), + *expect.into_array(2).unwrap() + ); + } +} diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 42d673cf7b39c..ddd21beae4d18 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -16,17 +16,38 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::ArrayRef; use arrow::datatypes::DataType; use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Tests if a string ends with a substring.", + syntax_example = "ends_with(str, substr)", + sql_example = r#"```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "substr", description = "Substring to test for.") +)] #[derive(Debug)] pub struct EndsWithFunc { signature: Signature, @@ -63,7 +84,11 @@ impl ScalarUDFImpl for EndsWithFunc { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { make_scalar_function(ends_with, vec![])(args) @@ -75,41 +100,10 @@ impl ScalarUDFImpl for EndsWithFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_ends_with_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_ends_with_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Tests if a string ends with a substring.") - .with_syntax_example("ends_with(str, substr)") - .with_sql_example( - r#"```sql -> select ends_with('datafusion', 'soin'); -+--------------------------------------------+ -| ends_with(Utf8("datafusion"),Utf8("soin")) | -+--------------------------------------------+ -| false | -+--------------------------------------------+ -> select ends_with('datafusion', 'sion'); -+--------------------------------------------+ -| ends_with(Utf8("datafusion"),Utf8("sion")) | -+--------------------------------------------+ -| true | -+--------------------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_argument("substr", "Substring to test for.") - .build() - .unwrap() - }) -} - /// Returns true if string ends with suffix. /// ends_with('alphabet', 'abet') = 't' pub fn ends_with(args: &[ArrayRef]) -> Result { @@ -134,7 +128,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from("alph")), ], @@ -145,7 +139,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from("bet")), ], @@ -156,7 +150,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from("alph")), ], @@ -167,7 +161,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 558e71239f84e..57392c114d79e 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -25,10 +25,31 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.", + syntax_example = "levenshtein(str1, str2)", + sql_example = r#"```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +```"#, + argument( + name = "str1", + description = "String expression to compute Levenshtein distance with str2." + ), + argument( + name = "str2", + description = "String expression to compute Levenshtein distance with str1." + ) +)] #[derive(Debug)] pub struct LevenshteinFunc { signature: Signature, @@ -65,7 +86,11 @@ impl ScalarUDFImpl for LevenshteinFunc { utf8_to_int_type(&arg_types[0], "levenshtein") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8View | DataType::Utf8 => { make_scalar_function(levenshtein::, vec![])(args) @@ -78,33 +103,10 @@ impl ScalarUDFImpl for LevenshteinFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_levenshtein_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_levenshtein_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.") - .with_syntax_example("levenshtein(str1, str2)") - .with_sql_example(r#"```sql -> select levenshtein('kitten', 'sitting'); -+---------------------------------------------+ -| levenshtein(Utf8("kitten"),Utf8("sitting")) | -+---------------------------------------------+ -| 3 | -+---------------------------------------------+ -```"#) - .with_argument("str1", "String expression to compute Levenshtein distance with str2.") - .with_argument("str2", "String expression to compute Levenshtein distance with str1.") - .build() - .unwrap() - }) -} - ///Returns the Levenshtein distance between the two given strings. /// LEVENSHTEIN('kitten', 'sitting') = 3 pub fn levenshtein(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index f82b11ca90512..e90c3804b1eea 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -17,15 +17,30 @@ use arrow::datatypes::DataType; use std::any::Any; -use std::sync::OnceLock; use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Converts a string to lower-case.", + syntax_example = "lower(str)", + sql_example = r#"```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "initcap"), + related_udf(name = "upper") +)] #[derive(Debug)] pub struct LowerFunc { signature: Signature, @@ -62,52 +77,33 @@ impl ScalarUDFImpl for LowerFunc { utf8_to_str_type(&arg_types[0], "lower") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { to_lower(args, "lower") } fn documentation(&self) -> Option<&Documentation> { - Some(get_lower_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_lower_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Converts a string to lower-case.") - .with_syntax_example("lower(str)") - .with_sql_example( - r#"```sql -> select lower('Ångström'); -+-------------------------+ -| lower(Utf8("Ångström")) | -+-------------------------+ -| ångström | -+-------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_related_udf("initcap") - .with_related_udf("upper") - .build() - .unwrap() - }) -} #[cfg(test)] mod tests { use super::*; - use arrow::array::{ArrayRef, StringArray}; + use arrow::array::{Array, ArrayRef, StringArray}; use std::sync::Arc; fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); + let batch_len = input.len(); let args = vec![ColumnarValue::Array(input)]; - let result = match func.invoke(&args)? { + #[allow(deprecated)] // TODO migrate UDF to invoke + let result = match func.invoke_batch(&args, batch_len)? { ColumnarValue::Array(result) => result, - _ => unreachable!(), + _ => unreachable!("lower"), }; assert_eq!(&expected, &result); Ok(()) diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 5bf6abaa8cbe3..7c7d58598cbec 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -18,15 +18,14 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use std::any::Any; -use std::sync::OnceLock; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' @@ -35,6 +34,33 @@ fn ltrim(args: &[ArrayRef]) -> Result { general_trim::(args, TrimType::Left, use_string_view) } +#[user_doc( + doc_section(label = "String Functions"), + description = "Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.", + syntax_example = "ltrim(str[, trim_str])", + sql_example = r#"```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "trim_str", + description = r"String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._" + ), + alternative_syntax = "trim(LEADING trim_str FROM str)", + related_udf(name = "btrim"), + related_udf(name = "rtrim") +)] #[derive(Debug)] pub struct LtrimFunc { signature: Signature, @@ -78,7 +104,11 @@ impl ScalarUDFImpl for LtrimFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( ltrim::, @@ -96,41 +126,10 @@ impl ScalarUDFImpl for LtrimFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_ltrim_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_ltrim_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.") - .with_syntax_example("ltrim(str[, trim_str])") - .with_sql_example(r#"```sql -> select ltrim(' datafusion '); -+-------------------------------+ -| ltrim(Utf8(" datafusion ")) | -+-------------------------------+ -| datafusion | -+-------------------------------+ -> select ltrim('___datafusion___', '_'); -+-------------------------------------------+ -| ltrim(Utf8("___datafusion___"),Utf8("_")) | -+-------------------------------------------+ -| datafusion___ | -+-------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("trim_str", "String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") - .with_related_udf("btrim") - .with_related_udf("rtrim") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use arrow::array::{Array, StringArray, StringViewArray}; @@ -147,7 +146,7 @@ mod tests { // String view cases for checking normal logic test_function!( LtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -157,7 +156,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet ")), @@ -167,7 +166,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -180,7 +179,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -193,7 +192,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -207,7 +206,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabet" )))), @@ -221,7 +220,7 @@ mod tests { // String cases test_function!( LtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "alphabet " )))),], Ok(Some("alphabet ")), @@ -231,7 +230,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "alphabet " )))),], Ok(Some("alphabet ")), @@ -241,7 +240,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -252,7 +251,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -263,7 +262,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 622802f0142bc..442c055ac37d6 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -30,7 +30,6 @@ pub mod concat; pub mod concat_ws; pub mod contains; pub mod ends_with; -pub mod initcap; pub mod levenshtein; pub mod lower; pub mod ltrim; @@ -45,28 +44,27 @@ pub mod to_hex; pub mod upper; pub mod uuid; // create UDFs -make_udf_function!(ascii::AsciiFunc, ASCII, ascii); -make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); -make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); -make_udf_function!(chr::ChrFunc, CHR, chr); -make_udf_function!(concat::ConcatFunc, CONCAT, concat); -make_udf_function!(concat_ws::ConcatWsFunc, CONCAT_WS, concat_ws); -make_udf_function!(ends_with::EndsWithFunc, ENDS_WITH, ends_with); -make_udf_function!(initcap::InitcapFunc, INITCAP, initcap); -make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); -make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); -make_udf_function!(lower::LowerFunc, LOWER, lower); -make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); -make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); -make_udf_function!(repeat::RepeatFunc, REPEAT, repeat); -make_udf_function!(replace::ReplaceFunc, REPLACE, replace); -make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); -make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); -make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); -make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); -make_udf_function!(upper::UpperFunc, UPPER, upper); -make_udf_function!(uuid::UuidFunc, UUID, uuid); -make_udf_function!(contains::ContainsFunc, CONTAINS, contains); +make_udf_function!(ascii::AsciiFunc, ascii); +make_udf_function!(bit_length::BitLengthFunc, bit_length); +make_udf_function!(btrim::BTrimFunc, btrim); +make_udf_function!(chr::ChrFunc, chr); +make_udf_function!(concat::ConcatFunc, concat); +make_udf_function!(concat_ws::ConcatWsFunc, concat_ws); +make_udf_function!(ends_with::EndsWithFunc, ends_with); +make_udf_function!(levenshtein::LevenshteinFunc, levenshtein); +make_udf_function!(ltrim::LtrimFunc, ltrim); +make_udf_function!(lower::LowerFunc, lower); +make_udf_function!(octet_length::OctetLengthFunc, octet_length); +make_udf_function!(overlay::OverlayFunc, overlay); +make_udf_function!(repeat::RepeatFunc, repeat); +make_udf_function!(replace::ReplaceFunc, replace); +make_udf_function!(rtrim::RtrimFunc, rtrim); +make_udf_function!(starts_with::StartsWithFunc, starts_with); +make_udf_function!(split_part::SplitPartFunc, split_part); +make_udf_function!(to_hex::ToHexFunc, to_hex); +make_udf_function!(upper::UpperFunc, upper); +make_udf_function!(uuid::UuidFunc, uuid); +make_udf_function!(contains::ContainsFunc, contains); pub mod expr_fn { use datafusion_expr::Expr; @@ -94,10 +92,6 @@ pub mod expr_fn { ends_with, "Returns true if the `string` ends with the `suffix`, false otherwise.", string suffix - ),( - initcap, - "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase", - string ),( levenshtein, "Returns the Levenshtein distance between the two given strings", @@ -151,7 +145,7 @@ pub mod expr_fn { "returns uuid v4 as a string value", ), ( contains, - "Return true if search_string is found within string. treated it like a reglike", + "Return true if search_string is found within string.", )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] @@ -177,7 +171,6 @@ pub fn functions() -> Vec> { concat(), concat_ws(), ends_with(), - initcap(), levenshtein(), lower(), ltrim(), diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 26ce70a416d03..40f7fd3ed9881 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -18,14 +18,29 @@ use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; use std::any::Any; -use std::sync::OnceLock; use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the length of a string in bytes.", + syntax_example = "octet_length(str)", + sql_example = r#"```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "bit_length"), + related_udf(name = "length") +)] #[derive(Debug)] pub struct OctetLengthFunc { signature: Signature, @@ -62,7 +77,11 @@ impl ScalarUDFImpl for OctetLengthFunc { utf8_to_int_type(&arg_types[0], "octet_length") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { if args.len() != 1 { return exec_err!( "octet_length function requires 1 argument, got {}", @@ -82,42 +101,16 @@ impl ScalarUDFImpl for OctetLengthFunc { ScalarValue::Utf8View(v) => Ok(ColumnarValue::from(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), ))), - _ => unreachable!(), + _ => unreachable!("OctetLengthFunc"), }, } } fn documentation(&self) -> Option<&Documentation> { - Some(get_octet_length_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_octet_length_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the length of a string in bytes.") - .with_syntax_example("octet_length(str)") - .with_sql_example( - r#"```sql -> select octet_length('Ångström'); -+--------------------------------+ -| octet_length(Utf8("Ångström")) | -+--------------------------------+ -| 10 | -+--------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_related_udf("bit_length") - .with_related_udf("length") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -136,7 +129,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Int32(Some(12)))], + vec![ColumnarValue::from(ScalarValue::Int32(Some(12)))], exec_err!( "The OCTET_LENGTH function can only accept strings, but got Int32." ), @@ -146,7 +139,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Array(Arc::new(StringArray::from(vec![ + vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![ String::from("chars"), String::from("chars2"), ])))], @@ -157,7 +150,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("chars")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("chars")))) ], @@ -168,7 +161,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "chars" ))))], Ok(Some(5)), @@ -178,7 +171,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "josé" ))))], Ok(Some(5)), @@ -188,7 +181,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "" ))))], Ok(Some(0)), @@ -198,7 +191,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(None))], + vec![ColumnarValue::from(ScalarValue::Utf8(None))], Ok(None), i32, Int32, @@ -206,7 +199,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from("joséjoséjoséjosé") )))], Ok(Some(20)), @@ -216,7 +209,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from("josé") )))], Ok(Some(5)), @@ -226,7 +219,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from("") )))], Ok(Some(0)), diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 3b31bc360851a..3389da0968f7c 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -26,10 +26,33 @@ use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_string_view_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the string which is replaced by another string from the specified position and specified count length.", + syntax_example = "overlay(str PLACING substr FROM pos [FOR count])", + sql_example = r#"```sql +> select overlay('Txxxxas' placing 'hom' from 2 for 4); ++--------------------------------------------------------+ +| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) | ++--------------------------------------------------------+ +| Thomas | ++--------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "substr", description = "Substring to replace in str."), + argument( + name = "pos", + description = "The start position to start the replace in str." + ), + argument( + name = "count", + description = "The count of characters to be replaced from start position of str. If not specified, will use substr length instead." + ) +)] #[derive(Debug)] pub struct OverlayFunc { signature: Signature, @@ -77,7 +100,11 @@ impl ScalarUDFImpl for OverlayFunc { utf8_to_str_type(&arg_types[0], "overlay") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8View | DataType::Utf8 => { make_scalar_function(overlay::, vec![])(args) @@ -88,35 +115,10 @@ impl ScalarUDFImpl for OverlayFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_overlay_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_overlay_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the string which is replaced by another string from the specified position and specified count length.") - .with_syntax_example("overlay(str PLACING substr FROM pos [FOR count])") - .with_sql_example(r#"```sql -> select overlay('Txxxxas' placing 'hom' from 2 for 4); -+--------------------------------------------------------+ -| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) | -+--------------------------------------------------------+ -| Thomas | -+--------------------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("substr", "Substring to replace in str.") - .with_argument("pos", "The start position to start the replace in str.") - .with_argument("count", "The count of characters to be replaced from start position of str. If not specified, will use substr length instead.") - .build() - .unwrap() - }) -} - macro_rules! process_overlay { // For the three-argument case ($string_array:expr, $characters_array:expr, $pos_num:expr) => {{ diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 50ca641c12d7e..1f5abde495ca8 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -16,22 +16,41 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; -use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringViewArray, + OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Int64, LargeUtf8, Utf8, Utf8View}; +use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{logical_int64, logical_string}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; -use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; - +use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns a string with an input string repeated a specified number.", + syntax_example = "repeat(str, n)", + sql_example = r#"```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "n", + description = "Number of times to repeat the input string." + ) +)] #[derive(Debug)] pub struct RepeatFunc { signature: Signature, @@ -46,14 +65,10 @@ impl Default for RepeatFunc { impl RepeatFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( + signature: Signature::coercible( vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. - // If that fails, it proceeds to `(Utf8, Int64)`. - TypeSignature::Exact(vec![Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Int64]), + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Native(logical_int64()), ], Volatility::Immutable, ), @@ -78,42 +93,19 @@ impl ScalarUDFImpl for RepeatFunc { utf8_to_str_type(&arg_types[0], "repeat") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(repeat, vec![])(args) } fn documentation(&self) -> Option<&Documentation> { - Some(get_repeat_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_repeat_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description( - "Returns a string with an input string repeated a specified number.", - ) - .with_syntax_example("repeat(str, n)") - .with_sql_example( - r#"```sql -> select repeat('data', 3); -+-------------------------------+ -| repeat(Utf8("data"),Int64(3)) | -+-------------------------------+ -| datadatadata | -+-------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_argument("n", "Number of times to repeat the input string.") - .build() - .unwrap() - }) -} - /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' fn repeat(args: &[ArrayRef]) -> Result { @@ -175,7 +167,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::from(ScalarValue::Int64(Some(4))), ], @@ -186,7 +178,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::Int64(Some(4))), ], @@ -197,7 +189,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::from(ScalarValue::Int64(None)), ], @@ -209,7 +201,7 @@ mod tests { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::from(ScalarValue::Int64(Some(4))), ], @@ -220,7 +212,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(None)), ColumnarValue::from(ScalarValue::Int64(Some(4))), ], @@ -231,7 +223,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::from(ScalarValue::Int64(None)), ], diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index e2ece6cb384f6..cce5aff5c8a79 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; @@ -24,10 +24,28 @@ use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; - +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Replaces all occurrences of a specified substring in a string with a new substring.", + syntax_example = "replace(str, substr, replacement)", + sql_example = r#"```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + standard_argument( + name = "substr", + prefix = "Substring expression to replace in the input string. Substring" + ), + standard_argument(name = "replacement", prefix = "Replacement substring") +)] #[derive(Debug)] pub struct ReplaceFunc { signature: Signature, @@ -64,7 +82,11 @@ impl ScalarUDFImpl for ReplaceFunc { utf8_to_str_type(&arg_types[0], "replace") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8 => make_scalar_function(replace::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), @@ -76,34 +98,10 @@ impl ScalarUDFImpl for ReplaceFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_replace_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_replace_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Replaces all occurrences of a specified substring in a string with a new substring.") - .with_syntax_example("replace(str, substr, replacement)") - .with_sql_example(r#"```sql -> select replace('ABabbaBA', 'ab', 'cd'); -+-------------------------------------------------+ -| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | -+-------------------------------------------------+ -| ABcdbaBA | -+-------------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_standard_argument("substr", "Substring expression to replace in the input string. Substring expression") - .with_standard_argument("replacement", "Replacement substring") - .build() - .unwrap() - }) -} - fn replace_view(args: &[ArrayRef]) -> Result { let string_array = as_string_view_array(&args[0])?; let from_array = as_string_view_array(&args[1])?; @@ -154,7 +152,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("bb")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("ccc")))), @@ -167,7 +165,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::LargeUtf8(Some(String::from("aabbb")))), ColumnarValue::from(ScalarValue::LargeUtf8(Some(String::from("bbb")))), ColumnarValue::from(ScalarValue::LargeUtf8(Some(String::from("cc")))), @@ -180,7 +178,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("aabbbcw")))), ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("bb")))), ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("cc")))), diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 1ba20bedf6421..1799b45c8c875 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -18,15 +18,14 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use std::any::Any; -use std::sync::OnceLock; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' @@ -35,6 +34,33 @@ fn rtrim(args: &[ArrayRef]) -> Result { general_trim::(args, TrimType::Right, use_string_view) } +#[user_doc( + doc_section(label = "String Functions"), + description = "Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.", + syntax_example = "rtrim(str[, trim_str])", + alternative_syntax = "trim(TRAILING trim_str FROM str)", + sql_example = r#"```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "trim_str", + description = "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._" + ), + related_udf(name = "btrim"), + related_udf(name = "ltrim") +)] #[derive(Debug)] pub struct RtrimFunc { signature: Signature, @@ -78,7 +104,11 @@ impl ScalarUDFImpl for RtrimFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( rtrim::, @@ -96,41 +126,10 @@ impl ScalarUDFImpl for RtrimFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_rtrim_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_rtrim_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.") - .with_syntax_example("rtrim(str[, trim_str])") - .with_sql_example(r#"```sql -> select rtrim(' datafusion '); -+-------------------------------+ -| rtrim(Utf8(" datafusion ")) | -+-------------------------------+ -| datafusion | -+-------------------------------+ -> select rtrim('___datafusion___', '_'); -+-------------------------------------------+ -| rtrim(Utf8("___datafusion___"),Utf8("_")) | -+-------------------------------------------+ -| ___datafusion | -+-------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("trim_str", "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") - .with_related_udf("btrim") - .with_related_udf("ltrim") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use arrow::array::{Array, StringArray, StringViewArray}; @@ -147,7 +146,7 @@ mod tests { // String view cases for checking normal logic test_function!( RtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -157,7 +156,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View(Some( + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -167,7 +166,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -180,7 +179,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -193,7 +192,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -207,7 +206,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabetalphabetxxx" )))), @@ -221,7 +220,7 @@ mod tests { // String cases test_function!( RtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "alphabet " )))),], Ok(Some("alphabet")), @@ -231,7 +230,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + vec![ColumnarValue::from(ScalarValue::Utf8(Some(String::from( " alphabet " )))),], Ok(Some(" alphabet")), @@ -241,7 +240,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("t ")))), ], @@ -252,7 +251,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -263,7 +262,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::from(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 5fedc5b172375..088c5cccdecd0 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -17,21 +17,36 @@ use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringViewArray, + ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType, + StringViewArray, }; use arrow::array::{AsArray, GenericStringBuilder}; use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::{Arc, OnceLock}; - -use super::common::StringArrayType; +use std::sync::Arc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Splits a string based on a specified delimiter and returns the substring in the specified position.", + syntax_example = "split_part(str, delimiter, pos)", + sql_example = r#"```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "delimiter", description = "String or character to split on."), + argument(name = "pos", description = "Position of the part to return.") +)] #[derive(Debug)] pub struct SplitPartFunc { signature: Signature, @@ -82,7 +97,11 @@ impl ScalarUDFImpl for SplitPartFunc { utf8_to_str_type(&arg_types[0], "split_part") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { // First, determine if any of the arguments is an Array let len = args.iter().find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), @@ -179,34 +198,10 @@ impl ScalarUDFImpl for SplitPartFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_split_part_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_split_part_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Splits a string based on a specified delimiter and returns the substring in the specified position.") - .with_syntax_example("split_part(str, delimiter, pos)") - .with_sql_example(r#"```sql -> select split_part('1.2.3.4.5', '.', 3); -+--------------------------------------------------+ -| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | -+--------------------------------------------------+ -| 3 | -+--------------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("delimiter", "String or character to split on.") - .with_argument("pos", "Position of the part to return.") - .build() - .unwrap() - }) -} - /// impl pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>( string_array: StringArrType, @@ -268,7 +263,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -282,7 +277,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -296,7 +291,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -310,7 +305,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index a7917036eb283..c4ad132d76916 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -16,16 +16,16 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::ArrayRef; use arrow::datatypes::DataType; use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' @@ -34,6 +34,21 @@ pub fn starts_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +#[user_doc( + doc_section(label = "String Functions"), + description = "Tests if a string starts with a substring.", + syntax_example = "starts_with(str, substr)", + sql_example = r#"```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "substr", description = "Substring to test for.") +)] #[derive(Debug)] pub struct StartsWithFunc { signature: Signature, @@ -70,7 +85,11 @@ impl ScalarUDFImpl for StartsWithFunc { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { make_scalar_function(starts_with, vec![])(args) @@ -80,35 +99,10 @@ impl ScalarUDFImpl for StartsWithFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_starts_with_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_starts_with_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Tests if a string starts with a substring.") - .with_syntax_example("starts_with(str, substr)") - .with_sql_example( - r#"```sql -> select starts_with('datafusion','data'); -+----------------------------------------------+ -| starts_with(Utf8("datafusion"),Utf8("data")) | -+----------------------------------------------+ -| true | -+----------------------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_argument("substr", "Substring to test for.") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use crate::utils::test::test_function; @@ -155,7 +149,7 @@ mod tests { for (args, expected) in test_cases { test_function!( StartsWithFunc::new(), - &args, + args, Ok(expected), bool, Boolean, diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 72cd4fbffa332..64654ef6ef106 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ @@ -27,9 +27,10 @@ use crate::utils::make_scalar_function; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; + use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' @@ -59,6 +60,20 @@ where Ok(Arc::new(result) as ArrayRef) } +#[user_doc( + doc_section(label = "String Functions"), + description = "Converts an integer to a hexadecimal string.", + syntax_example = "to_hex(int)", + sql_example = r#"```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +```"#, + standard_argument(name = "int", prefix = "Integer") +)] #[derive(Debug)] pub struct ToHexFunc { signature: Signature, @@ -103,7 +118,11 @@ impl ScalarUDFImpl for ToHexFunc { }) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Int32 => make_scalar_function(to_hex::, vec![])(args), DataType::Int64 => make_scalar_function(to_hex::, vec![])(args), @@ -112,34 +131,10 @@ impl ScalarUDFImpl for ToHexFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_to_hex_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_to_hex_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Converts an integer to a hexadecimal string.") - .with_syntax_example("to_hex(int)") - .with_sql_example( - r#"```sql -> select to_hex(12345689); -+-------------------------+ -| to_hex(Int64(12345689)) | -+-------------------------+ -| bc6159 | -+-------------------------+ -```"#, - ) - .with_standard_argument("int", "Integer") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use arrow::array::{Int32Array, StringArray}; diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index bfcb2a86994d0..7bab33e68a4d6 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -19,12 +19,27 @@ use crate::string::common::to_upper; use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; +#[user_doc( + doc_section(label = "String Functions"), + description = "Converts a string to upper-case.", + syntax_example = "upper(str)", + sql_example = r#"```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "initcap"), + related_udf(name = "lower") +)] #[derive(Debug)] pub struct UpperFunc { signature: Signature, @@ -61,53 +76,33 @@ impl ScalarUDFImpl for UpperFunc { utf8_to_str_type(&arg_types[0], "upper") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { to_upper(args, "upper") } fn documentation(&self) -> Option<&Documentation> { - Some(get_upper_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_upper_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Converts a string to upper-case.") - .with_syntax_example("upper(str)") - .with_sql_example( - r#"```sql -> select upper('dataFusion'); -+---------------------------+ -| upper(Utf8("dataFusion")) | -+---------------------------+ -| DATAFUSION | -+---------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_related_udf("initcap") - .with_related_udf("lower") - .build() - .unwrap() - }) -} - #[cfg(test)] mod tests { use super::*; - use arrow::array::{ArrayRef, StringArray}; + use arrow::array::{Array, ArrayRef, StringArray}; use std::sync::Arc; fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); + let batch_len = input.len(); let args = vec![ColumnarValue::Array(input)]; - let result = match func.invoke(&args)? { + #[allow(deprecated)] // TODO migrate UDF to invoke + let result = match func.invoke_batch(&args, batch_len)? { ColumnarValue::Array(result) => result, - _ => unreachable!(), + _ => unreachable!("upper"), }; assert_eq!(&expected, &result); Ok(()) diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 0fbdce16ccd13..f6d6a941068d6 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -16,18 +16,31 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::GenericStringArray; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Utf8; use uuid::Uuid; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_(random)) string value which is unique per row.", + syntax_example = "uuid()", + sql_example = r#"```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +```"# +)] #[derive(Debug)] pub struct UuidFunc { signature: Signature, @@ -64,40 +77,22 @@ impl ScalarUDFImpl for UuidFunc { Ok(Utf8) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - not_impl_err!("{} function does not accept arguments", self.name()) - } - /// Prints random (v4) uuid values per row /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' - fn invoke_no_args(&self, num_rows: usize) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + num_rows: usize, + ) -> Result { + if !args.is_empty() { + return internal_err!("{} function does not accept arguments", self.name()); + } let values = std::iter::repeat_with(|| Uuid::new_v4().to_string()).take(num_rows); let array = GenericStringArray::::from_iter_values(values); Ok(ColumnarValue::Array(Arc::new(array))) } fn documentation(&self) -> Option<&Documentation> { - Some(get_uuid_doc()) + self.doc() } } - -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_uuid_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_(random)) string value which is unique per row.") - .with_syntax_example("uuid()") - .with_sql_example(r#"```sql -> select uuid(); -+--------------------------------------+ -| uuid() | -+--------------------------------------+ -| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | -+--------------------------------------+ -```"#) - .build() - .unwrap() - }) -} diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs new file mode 100644 index 0000000000000..bb991c28fe4d2 --- /dev/null +++ b/datafusion/functions/src/strings.rs @@ -0,0 +1,424 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; + +use arrow::array::{ + make_view, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ByteView, + GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, + StringViewBuilder, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{MutableBuffer, NullBuffer, NullBufferBuilder}; + +/// Abstracts iteration over different types of string arrays. +#[deprecated(since = "45.0.0", note = "Use arrow::array::StringArrayType instead")] +pub trait StringArrayType<'a>: ArrayAccessor + Sized { + /// Return an [`ArrayIter`] over the values of the array. + /// + /// This iterator iterates returns `Option<&str>` for each item in the array. + fn iter(&self) -> ArrayIter; + + /// Check if the array is ASCII only. + fn is_ascii(&self) -> bool; +} + +#[allow(deprecated)] +impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { + fn iter(&self) -> ArrayIter { + GenericStringArray::::iter(self) + } + + fn is_ascii(&self) -> bool { + GenericStringArray::::is_ascii(self) + } +} + +#[allow(deprecated)] +impl<'a> StringArrayType<'a> for &'a StringViewArray { + fn iter(&self) -> ArrayIter { + StringViewArray::iter(self) + } + + fn is_ascii(&self) -> bool { + StringViewArray::is_ascii(self) + } +} + +/// Optimized version of the StringBuilder in Arrow that: +/// 1. Precalculating the expected length of the result, avoiding reallocations. +/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` +pub struct StringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl StringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let capacity = item_capacity + .checked_add(1) + .map(|i| i.saturating_mul(size_of::())) + .expect("capacity integer overflow"); + + let mut offsets_buffer = MutableBuffer::with_capacity(capacity); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i32) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i32 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + self.offsets_buffer.push(next_offset); + } + + /// Finalize the builder into a concrete [`StringArray`]. + /// + /// # Panics + /// + /// This method can panic when: + /// + /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. + pub fn finish(self, null_buffer: Option) -> StringArray { + let row_count = self.offsets_buffer.len() / size_of::() - 1; + if let Some(ref null_buffer) = null_buffer { + assert_eq!( + null_buffer.len(), + row_count, + "Null buffer and offsets buffer must be the same length" + ); + } + let array_builder = ArrayDataBuilder::new(DataType::Utf8) + .len(row_count) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + StringArray::from(array_data) + } +} + +pub struct StringViewArrayBuilder { + builder: StringViewBuilder, + block: String, +} + +impl StringViewArrayBuilder { + pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { + let builder = StringViewBuilder::with_capacity(data_capacity); + Self { + builder, + block: String::new(), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.block.push_str(std::str::from_utf8(s).unwrap()); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + } + } + + pub fn append_offset(&mut self) { + self.builder.append_value(&self.block); + self.block = String::new(); + } + + pub fn finish(mut self) -> StringViewArray { + self.builder.finish() + } +} + +pub struct LargeStringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl LargeStringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let capacity = item_capacity + .checked_add(1) + .map(|i| i.saturating_mul(size_of::())) + .expect("capacity integer overflow"); + + let mut offsets_buffer = MutableBuffer::with_capacity(capacity); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i64) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i64 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + self.offsets_buffer.push(next_offset); + } + + /// Finalize the builder into a concrete [`LargeStringArray`]. + /// + /// # Panics + /// + /// This method can panic when: + /// + /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. + pub fn finish(self, null_buffer: Option) -> LargeStringArray { + let row_count = self.offsets_buffer.len() / size_of::() - 1; + if let Some(ref null_buffer) = null_buffer { + assert_eq!( + null_buffer.len(), + row_count, + "Null buffer and offsets buffer must be the same length" + ); + } + let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) + .len(row_count) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid Large UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + LargeStringArray::from(array_data) + } +} + +/// Append a new view to the views buffer with the given substr +/// +/// # Safety +/// +/// original_view must be a valid view (the format described on +/// [`GenericByteViewArray`](arrow::array::GenericByteViewArray). +/// +/// # Arguments +/// - views_buffer: The buffer to append the new view to +/// - null_builder: The buffer to append the null value to +/// - original_view: The original view value +/// - substr: The substring to append. Must be a valid substring of the original view +/// - start_offset: The start offset of the substring in the view +pub fn make_and_append_view( + views_buffer: &mut Vec, + null_builder: &mut NullBufferBuilder, + original_view: &u128, + substr: &str, + start_offset: u32, +) { + let substr_len = substr.len(); + let sub_view = if substr_len > 12 { + let view = ByteView::from(*original_view); + make_view( + substr.as_bytes(), + view.buffer_index, + view.offset + start_offset, + ) + } else { + // inline value does not need block id or offset + make_view(substr.as_bytes(), 0, 0) + }; + views_buffer.push(sub_view); + null_builder.append_non_null(); +} + +#[derive(Debug)] +pub enum ColumnarValueRef<'a> { + Scalar(&'a [u8]), + NullableArray(&'a StringArray), + NonNullableArray(&'a StringArray), + NullableLargeStringArray(&'a LargeStringArray), + NonNullableLargeStringArray(&'a LargeStringArray), + NullableStringViewArray(&'a StringViewArray), + NonNullableStringViewArray(&'a StringViewArray), +} + +impl ColumnarValueRef<'_> { + #[inline] + pub fn is_valid(&self, i: usize) -> bool { + match &self { + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableLargeStringArray(_) + | Self::NonNullableStringViewArray(_) => true, + Self::NullableArray(array) => array.is_valid(i), + Self::NullableStringViewArray(array) => array.is_valid(i), + Self::NullableLargeStringArray(array) => array.is_valid(i), + } + } + + #[inline] + pub fn nulls(&self) -> Option { + match &self { + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableStringViewArray(_) + | Self::NonNullableLargeStringArray(_) => None, + Self::NullableArray(array) => array.nulls().cloned(), + Self::NullableStringViewArray(array) => array.nulls().cloned(), + Self::NullableLargeStringArray(array) => array.nulls().cloned(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic(expected = "capacity integer overflow")] + fn test_overflow_string_array_builder() { + let _builder = StringArrayBuilder::with_capacity(usize::MAX, usize::MAX); + } + + #[test] + #[should_panic(expected = "capacity integer overflow")] + fn test_overflow_large_string_array_builder() { + let _builder = LargeStringArrayBuilder::with_capacity(usize::MAX, usize::MAX); + } +} diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 2e108f127a73f..d909d85f45996 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -15,20 +15,36 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, + StringArrayType, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the number of characters in a string.", + syntax_example = "character_length(str)", + sql_example = r#"```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "bit_length"), + related_udf(name = "octet_length") +)] #[derive(Debug)] pub struct CharacterLengthFunc { signature: Signature, @@ -72,7 +88,11 @@ impl ScalarUDFImpl for CharacterLengthFunc { utf8_to_int_type(&arg_types[0], "character_length") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(character_length, vec![])(args) } @@ -81,36 +101,10 @@ impl ScalarUDFImpl for CharacterLengthFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_character_length_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_character_length_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the number of characters in a string.") - .with_syntax_example("character_length(str)") - .with_sql_example( - r#"```sql -> select character_length('Ångström'); -+------------------------------------+ -| character_length(Utf8("Ångström")) | -+------------------------------------+ -| 8 | -+------------------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .with_related_udf("bit_length") - .with_related_udf("octet_length") - .build() - .unwrap() - }) -} - /// Returns number of characters in the string. /// character_length('josé') = 4 /// The implementation counts UTF-8 code points to count the number of characters @@ -128,35 +122,56 @@ fn character_length(args: &[ArrayRef]) -> Result { let string_array = args[0].as_string_view(); character_length_general::(string_array) } - _ => unreachable!(), + _ => unreachable!("CharacterLengthFunc"), } } -fn character_length_general<'a, T: ArrowPrimitiveType, V: StringArrayType<'a>>( - array: V, -) -> Result +fn character_length_general<'a, T, V>(array: V) -> Result where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: StringArrayType<'a>, { + let mut builder = PrimitiveBuilder::::with_capacity(array.len()); + // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - let iter = array.iter(); - let result = iter - .map(|string| { - string.map(|string: &str| { - if is_array_ascii_only { - T::Native::usize_as(string.len()) - } else { - T::Native::usize_as(string.chars().count()) - } - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + if array.null_count() == 0 { + if is_array_ascii_only { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } else { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } else if is_array_ascii_only { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] @@ -172,7 +187,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -181,7 +196,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::from(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i64, Int64, @@ -190,7 +205,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::from(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index cad860e41088f..9a5b81d5f6bfb 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -16,22 +16,42 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, + new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, + OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use crate::utils::{make_scalar_function, utf8_to_int_type}; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use crate::utils::utf8_to_int_type; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; +use datafusion_expr_common::scalar::Scalar; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.", + syntax_example = "find_in_set(str, strlist)", + sql_example = r#"```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +```"#, + argument(name = "str", description = "String expression to find in strlist."), + argument( + name = "strlist", + description = "A string list is a string composed of substrings separated by , characters." + ) +)] #[derive(Debug)] pub struct FindInSetFunc { signature: Signature, @@ -76,61 +96,196 @@ impl ScalarUDFImpl for FindInSetFunc { utf8_to_int_type(&arg_types[0], "find_in_set") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_scalar_function(find_in_set, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { mut args, .. } = args; + + if args.len() != 2 { + return exec_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_list = args.pop().unwrap(); + let string = args.pop().unwrap(); + + match (string, str_list) { + // both inputs are scalars + (ColumnarValue::Scalar(string), ColumnarValue::Scalar(str_list)) => { + invoke_with_args_scalar_scalar(string, str_list) + } + + // `string` is an array, `str_list` is scalar + (ColumnarValue::Array(str_array), ColumnarValue::Scalar(str_list)) => { + invoke_with_args_array_scalar(str_array, str_list) + } + + // `string` is scalar, `str_list` is an array + (ColumnarValue::Scalar(scalar), ColumnarValue::Array(str_list_array)) => { + invoke_with_args_scalar_array(scalar, str_list_array) + } + + // both inputs are arrays + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res = find_in_set(base_array, exp_array)?; + + Ok(ColumnarValue::Array(res)) + } + } } fn documentation(&self) -> Option<&Documentation> { - Some(get_find_in_set_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); +/// Handles the `find_in_set` case for two scalar arguments. +fn invoke_with_args_scalar_scalar( + string: Scalar, + str_list: Scalar, +) -> Result { + let string = match string.into_value() { + ScalarValue::Utf8View(string) + | ScalarValue::Utf8(string) + | ScalarValue::LargeUtf8(string) => string, + _ => { + return internal_err!( + "Invalid argument type for 'str' in `find_in_set` function" + ) + } + }; -fn get_find_in_set_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.") - .with_syntax_example("find_in_set(str, strlist)") - .with_sql_example(r#"```sql -> select find_in_set('b', 'a,b,c,d'); -+----------------------------------------+ -| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | -+----------------------------------------+ -| 2 | -+----------------------------------------+ -```"#) - .with_argument("str", "String expression to find in strlist.") - .with_argument("strlist", "A string list is a string composed of substrings separated by , characters.") - .build() - .unwrap() - }) + let str_list = match str_list.into_value() { + ScalarValue::Utf8View(str_list) + | ScalarValue::Utf8(str_list) + | ScalarValue::LargeUtf8(str_list) => str_list, + _ => { + return internal_err!( + "Invalid argument type for 'str_list' in `find_in_set` function" + ) + } + }; + + let res = match (string, str_list) { + (Some(string), Some(str_list)) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + + Some(position as i32) + } + _ => None, + }; + + Ok(ColumnarValue::from(ScalarValue::from(res))) } -///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings -///A string list is a string composed of substrings separated by , characters. -fn find_in_set(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } - match args[0].data_type() { +/// Handles the `find_in_set` case for an array and a scalar argument. +fn invoke_with_args_array_scalar( + str_array: ArrayRef, + str_list: Scalar, +) -> Result { + let str_list = match str_list.into_value() { + ScalarValue::Utf8View(str_list) + | ScalarValue::Utf8(str_list) + | ScalarValue::LargeUtf8(str_list) => str_list, + _ => { + return internal_err!( + "Invalid argument type for 'str_list' in `find_in_set` function" + ) + } + }; + + let result_array = match str_list { + // find_in_set(column_a, null) = null + None => new_null_array(str_array.data_type(), str_array.len()), + Some(str_list_literal) => { + let str_list = str_list_literal.split(',').collect::>(); + let result = match str_array.data_type() { + DataType::Utf8 => { + let string_array = str_array.as_string::(); + find_in_set_right_literal::(string_array, str_list) + } + DataType::LargeUtf8 => { + let string_array = str_array.as_string::(); + find_in_set_right_literal::(string_array, str_list) + } + DataType::Utf8View => { + let string_array = str_array.as_string_view(); + find_in_set_right_literal::(string_array, str_list) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + }; + Arc::new(result?) + } + }; + Ok(ColumnarValue::Array(result_array)) +} + +/// Handles the `find_in_set` case for a scalar and an array argument. +fn invoke_with_args_scalar_array( + string_literal: Scalar, + str_list_array: ArrayRef, +) -> Result { + let string_literal = match string_literal.into_value() { + ScalarValue::Utf8View(string_literal) + | ScalarValue::Utf8(string_literal) + | ScalarValue::LargeUtf8(string_literal) => string_literal, + _ => { + return internal_err!( + "Invalid argument type for 'str' in `find_in_set` function" + ) + } + }; + + let res = match string_literal { + // find_in_set(null, column_b) = null + None => new_null_array(str_list_array.data_type(), str_list_array.len()), + Some(string) => { + let result = match str_list_array.data_type() { + DataType::Utf8 => { + let str_list = str_list_array.as_string::(); + find_in_set_left_literal::(string, str_list) + } + DataType::LargeUtf8 => { + let str_list = str_list_array.as_string::(); + find_in_set_left_literal::(string, str_list) + } + DataType::Utf8View => { + let str_list = str_list_array.as_string_view(); + find_in_set_left_literal::(string, str_list) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + }; + Arc::new(result?) + } + }; + Ok(ColumnarValue::Array(res)) +} + +/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist` +/// consisting of N substrings. A string list is a string composed of substrings separated by `,` +/// characters. +fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result { + match str.data_type() { DataType::Utf8 => { - let string_array = args[0].as_string::(); - let str_list_array = args[1].as_string::(); + let string_array = str.as_string::(); + let str_list_array = str_list.as_string::(); find_in_set_general::(string_array, str_list_array) } DataType::LargeUtf8 => { - let string_array = args[0].as_string::(); - let str_list_array = args[1].as_string::(); + let string_array = str.as_string::(); + let str_list_array = str_list.as_string::(); find_in_set_general::(string_array, str_list_array) } DataType::Utf8View => { - let string_array = args[0].as_string_view(); - let str_list_array = args[1].as_string_view(); + let string_array = str.as_string_view(); + let str_list_array = str_list.as_string_view(); find_in_set_general::(string_array, str_list_array) } other => { @@ -139,31 +294,280 @@ fn find_in_set(args: &[ArrayRef]) -> Result { } } -pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor>( +pub fn find_in_set_general<'a, T, V>( string_array: V, str_list_array: V, ) -> Result where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: ArrayAccessor, { let string_iter = ArrayIter::new(string_array); let str_list_iter = ArrayIter::new(str_list_array); - let result = string_iter + + let mut builder = PrimitiveArray::::builder(string_iter.len()); + + string_iter .zip(str_list_iter) - .map(|(string, str_list)| match (string, str_list) { - (Some(string), Some(str_list)) => { - let mut res = 0; - let str_set: Vec<&str> = str_list.split(',').collect(); - for (idx, str) in str_set.iter().enumerate() { - if str == &string { - res = idx + 1; - break; - } + .for_each( + |(string_opt, str_list_opt)| match (string_opt, str_list_opt) { + (Some(string), Some(str_list)) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); } - T::Native::from_usize(res) + _ => builder.append_null(), + }, + ); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn find_in_set_left_literal<'a, T, V>( + string: String, + str_list_array: V, +) -> Result +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, + V: ArrayAccessor, +{ + let mut builder = PrimitiveArray::::builder(str_list_array.len()); + + let str_list_iter = ArrayIter::new(str_list_array); + + str_list_iter.for_each(|str_list_opt| match str_list_opt { + Some(str_list) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn find_in_set_right_literal<'a, T, V>( + string_array: V, + str_list: Vec<&str>, +) -> Result +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, + V: ArrayAccessor, +{ + let mut builder = PrimitiveArray::::builder(string_array.len()); + + let string_iter = ArrayIter::new(string_array); + + string_iter.for_each(|string_opt| match string_opt { + Some(string) => { + let position = str_list + .iter() + .position(|s| *s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::find_in_set::FindInSetFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array, StringArray}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("a")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("a,b,c")))), + ], + Ok(Some(1)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("🔥")))), + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("a,Д,🔥")))), + ], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("d")))), + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("a,b,c")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( + "Apache Software Foundation" + )))), + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( + "Github,Apache Software Foundation,DataFusion" + )))), + ], + Ok(Some(2)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("a,b,c")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("a")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("a")))), + ColumnarValue::from(ScalarValue::Utf8View(None)), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::from(ScalarValue::Utf8View(None)), + ColumnarValue::from(ScalarValue::Utf8View(Some(String::from("a,b,c")))), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + + Ok(()) + } + + macro_rules! test_find_in_set { + ($test_name:ident, $args:expr, $expected:expr) => { + #[test] + fn $test_name() -> Result<()> { + let fis = crate::unicode::find_in_set(); + + let args = $args; + let expected = $expected; + + let type_array = args + .iter() + .map(|a| a.data_type().clone()) + .collect::>(); + let cardinality = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); + let return_type = fis.return_type(&type_array)?; + let result = fis.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: cardinality, + return_type: &return_type, + }); + assert!(result.is_ok()); + + let result = result? + .to_array(cardinality) + .expect("Failed to convert to array"); + let result = result + .as_any() + .downcast_ref::() + .expect("Failed to convert to type"); + assert_eq!(*result, expected); + + Ok(()) } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + }; + } + + test_find_in_set!( + test_find_in_set_with_scalar_args, + vec![ + ColumnarValue::Array(Arc::new(StringArray::from(vec![ + "", "a", "b", "c", "d" + ]))), + ColumnarValue::from(ScalarValue::Utf8(Some("b,c,d".to_string()))), + ], + Int32Array::from(vec![0, 0, 1, 2, 3]) + ); + test_find_in_set!( + test_find_in_set_with_scalar_args_2, + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some( + "ApacheSoftware".to_string() + ))), + ColumnarValue::Array(Arc::new(StringArray::from(vec![ + "a,b,c", + "ApacheSoftware,Github,DataFusion", + "" + ]))), + ], + Int32Array::from(vec![0, 1, 0]) + ); + test_find_in_set!( + test_find_in_set_with_scalar_args_3, + vec![ + ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))), + ColumnarValue::from(ScalarValue::Utf8View(Some("a,b,c".to_string()))), + ], + Int32Array::from(vec![None::; 3]) + ); + test_find_in_set!( + test_find_in_set_with_scalar_args_4, + vec![ + ColumnarValue::from(ScalarValue::Utf8View(Some("a".to_string()))), + ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))), + ], + Int32Array::from(vec![None::; 3]) + ); } diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs new file mode 100644 index 0000000000000..18f7d1d57d0a2 --- /dev/null +++ b/datafusion/functions/src/unicode/initcap.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder, +}; +use arrow::datatypes::DataType; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "String Functions"), + description = "Capitalizes the first character in each word in the input string. \ + Words are delimited by non-alphanumeric characters.", + syntax_example = "initcap(str)", + sql_example = r#"```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + related_udf(name = "lower"), + related_udf(name = "upper") +)] +#[derive(Debug)] +pub struct InitcapFunc { + signature: Signature, +} + +impl Default for InitcapFunc { + fn default() -> Self { + InitcapFunc::new() + } +} + +impl InitcapFunc { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for InitcapFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "initcap" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if let DataType::Utf8View = arg_types[0] { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "initcap") + } + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(initcap::, vec![])(args), + DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function `initcap`") + } + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Converts the first letter of each word to upper case and the rest to lower +/// case. Words are sequences of alphanumeric characters separated by +/// non-alphanumeric characters. +/// +/// Example: +/// ```sql +/// initcap('hi THOMAS') = 'Hi Thomas' +/// ``` +fn initcap(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let mut builder = GenericStringBuilder::::with_capacity( + string_array.len(), + string_array.value_data().len(), + ); + + string_array.iter().for_each(|str| match str { + Some(s) => { + let initcap_str = initcap_string(s); + builder.append_value(initcap_str); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn initcap_utf8view(args: &[ArrayRef]) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + + let mut builder = StringViewBuilder::with_capacity(string_view_array.len()); + + string_view_array.iter().for_each(|str| match str { + Some(s) => { + let initcap_str = initcap_string(s); + builder.append_value(initcap_str); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn initcap_string(input: &str) -> String { + let mut result = String::with_capacity(input.len()); + let mut prev_is_alphanumeric = false; + + if input.is_ascii() { + for c in input.chars() { + if prev_is_alphanumeric { + result.push(c.to_ascii_lowercase()); + } else { + result.push(c.to_ascii_uppercase()); + }; + prev_is_alphanumeric = c.is_ascii_alphanumeric(); + } + } else { + for c in input.chars() { + if prev_is_alphanumeric { + result.extend(c.to_lowercase()); + } else { + result.extend(c.to_uppercase()); + } + prev_is_alphanumeric = c.is_alphanumeric(); + } + } + + result +} + +#[cfg(test)] +mod tests { + use crate::unicode::initcap::InitcapFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::from("hi THOMAS"))], + Ok(Some("Hi Thomas")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8(Some( + "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ" + .to_string() + )))], + Ok(Some( + "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική" + )), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::from(""))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::from(""))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( + "hi THOMAS".to_string() + )))], + Ok(Some("Hi Thomas")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( + "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() + )))], + Ok(Some("Hi Thomas With M0re Than 12 Chars")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( + "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ" + .to_string() + )))], + Ok(Some( + "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική" + )), + &str, + Utf8View, + StringViewArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8View(Some( + "".to_string() + )))], + Ok(Some("")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::from(ScalarValue::Utf8View(None))], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index f8a507d51526f..17b64caf4bda5 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, @@ -31,12 +31,28 @@ use datafusion_common::cast::{ }; use datafusion_common::exec_err; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns a specified number of characters from the left side of a string.", + syntax_example = "left(str, n)", + sql_example = r#"```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "n", description = "Number of characters to return."), + related_udf(name = "right") +)] #[derive(Debug)] pub struct LeftFunc { signature: Signature, @@ -81,7 +97,11 @@ impl ScalarUDFImpl for LeftFunc { utf8_to_str_type(&arg_types[0], "left") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8 | DataType::Utf8View => { make_scalar_function(left::, vec![])(args) @@ -95,34 +115,10 @@ impl ScalarUDFImpl for LeftFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_left_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_left_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns a specified number of characters from the left side of a string.") - .with_syntax_example("left(str, n)") - .with_sql_example(r#"```sql -> select left('datafusion', 4); -+-----------------------------------+ -| left(Utf8("datafusion"),Int64(4)) | -+-----------------------------------+ -| data | -+-----------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("n", "Number of characters to return.") - .with_related_udf("right") - .build() - .unwrap() - }) -} - /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' /// The implementation uses UTF-8 code points as characters @@ -182,7 +178,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(2i64)), ], @@ -193,7 +189,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(200i64)), ], @@ -204,7 +200,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(-2i64)), ], @@ -215,7 +211,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(-200i64)), ], @@ -226,7 +222,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(0i64)), ], @@ -237,7 +233,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from(2i64)), ], @@ -248,7 +244,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::Int64(None)), ], @@ -259,7 +255,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(5i64)), ], @@ -270,7 +266,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index f1c3646fd168d..4a6ade61801b7 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -17,26 +17,45 @@ use std::any::Any; use std::fmt::Write; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringViewArray, + OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; use unicode_segmentation::UnicodeSegmentation; use DataType::{LargeUtf8, Utf8, Utf8View}; -use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Pads the left side of a string with another string to a specified string length.", + syntax_example = "lpad(str, n[, padding_str])", + sql_example = r#"```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "n", description = "String length to pad to."), + argument( + name = "padding_str", + description = "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._" + ), + related_udf(name = "rpad") +)] #[derive(Debug)] pub struct LPadFunc { signature: Signature, @@ -90,7 +109,11 @@ impl ScalarUDFImpl for LPadFunc { utf8_to_str_type(&arg_types[0], "lpad") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { Utf8 | Utf8View => make_scalar_function(lpad::, vec![])(args), LargeUtf8 => make_scalar_function(lpad::, vec![])(args), @@ -99,35 +122,10 @@ impl ScalarUDFImpl for LPadFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_lpad_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_lpad_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Pads the left side of a string with another string to a specified string length.") - .with_syntax_example("lpad(str, n[, padding_str])") - .with_sql_example(r#"```sql -> select lpad('Dolly', 10, 'hello'); -+---------------------------------------------+ -| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | -+---------------------------------------------+ -| helloDolly | -+---------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("n", "String length to pad to.") - .with_argument("padding_str", "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") - .with_related_udf("rpad") - .build() - .unwrap() - }) -} - /// Extends the string to length 'length' by prepending the characters fill (a space by default). /// If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' @@ -162,7 +160,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { length_array, &args[2], ), - (_, _) => unreachable!(), + (_, _) => unreachable!("lpad"), } } @@ -295,7 +293,7 @@ mod tests { ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8($INPUT)), ColumnarValue::from($LENGTH) ], @@ -307,7 +305,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::from($LENGTH) ], @@ -319,7 +317,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View($INPUT)), ColumnarValue::from($LENGTH) ], @@ -334,7 +332,7 @@ mod tests { // utf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::Utf8($REPLACE)) @@ -347,7 +345,7 @@ mod tests { // utf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::LargeUtf8($REPLACE)) @@ -360,7 +358,7 @@ mod tests { // utf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::Utf8View($REPLACE)) @@ -374,7 +372,7 @@ mod tests { // largeutf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::Utf8($REPLACE)) @@ -387,7 +385,7 @@ mod tests { // largeutf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::LargeUtf8($REPLACE)) @@ -400,7 +398,7 @@ mod tests { // largeutf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::Utf8View($REPLACE)) @@ -414,7 +412,7 @@ mod tests { // utf8view, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::Utf8($REPLACE)) @@ -427,7 +425,7 @@ mod tests { // utf8view, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::LargeUtf8($REPLACE)) @@ -440,7 +438,7 @@ mod tests { // utf8view, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View($INPUT)), ColumnarValue::from($LENGTH), ColumnarValue::from(ScalarValue::Utf8View($REPLACE)) diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 40915bc9efde8..3c5cde3789ea2 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -23,6 +23,7 @@ use datafusion_expr::ScalarUDF; pub mod character_length; pub mod find_in_set; +pub mod initcap; pub mod left; pub mod lpad; pub mod reverse; @@ -34,22 +35,19 @@ pub mod substrindex; pub mod translate; // create UDFs -make_udf_function!( - character_length::CharacterLengthFunc, - CHARACTER_LENGTH, - character_length -); -make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); -make_udf_function!(left::LeftFunc, LEFT, left); -make_udf_function!(lpad::LPadFunc, LPAD, lpad); -make_udf_function!(right::RightFunc, RIGHT, right); -make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); -make_udf_function!(rpad::RPadFunc, RPAD, rpad); -make_udf_function!(strpos::StrposFunc, STRPOS, strpos); -make_udf_function!(substr::SubstrFunc, SUBSTR, substr); -make_udf_function!(substr::SubstrFunc, SUBSTRING, substring); -make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); -make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); +make_udf_function!(character_length::CharacterLengthFunc, character_length); +make_udf_function!(find_in_set::FindInSetFunc, find_in_set); +make_udf_function!(initcap::InitcapFunc, initcap); +make_udf_function!(left::LeftFunc, left); +make_udf_function!(lpad::LPadFunc, lpad); +make_udf_function!(right::RightFunc, right); +make_udf_function!(reverse::ReverseFunc, reverse); +make_udf_function!(rpad::RPadFunc, rpad); +make_udf_function!(strpos::StrposFunc, strpos); +make_udf_function!(substr::SubstrFunc, substr); +make_udf_function!(substr::SubstrFunc, substring); +make_udf_function!(substrindex::SubstrIndexFunc, substr_index); +make_udf_function!(translate::TranslateFunc, translate); pub mod expr_fn { use datafusion_expr::Expr; @@ -98,9 +96,13 @@ pub mod expr_fn { left, "returns the first `n` characters in the `string`", string n + ),( + initcap, + "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase", + string ),( find_in_set, - "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings", + "Returns a value in the range of 1 to N if the string `str` is in the string list `strlist` consisting of N substrings", string strlist )); @@ -130,6 +132,7 @@ pub fn functions() -> Vec> { vec![ character_length(), find_in_set(), + initcap(), left(), lpad(), reverse(), diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 0190705966fb8..0502b01c3578e 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -16,21 +16,34 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, - OffsetSizeTrait, + Array, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, StringArrayType, }; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use DataType::{LargeUtf8, Utf8, Utf8View}; +#[user_doc( + doc_section(label = "String Functions"), + description = "Reverses the character order of a string.", + syntax_example = "reverse(str)", + sql_example = r#"```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +```"#, + standard_argument(name = "str", prefix = "String") +)] #[derive(Debug)] pub struct ReverseFunc { signature: Signature, @@ -72,7 +85,11 @@ impl ScalarUDFImpl for ReverseFunc { utf8_to_str_type(&arg_types[0], "reverse") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { Utf8 | Utf8View => make_scalar_function(reverse::, vec![])(args), LargeUtf8 => make_scalar_function(reverse::, vec![])(args), @@ -83,36 +100,11 @@ impl ScalarUDFImpl for ReverseFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_reverse_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_reverse_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Reverses the character order of a string.") - .with_syntax_example("reverse(str)") - .with_sql_example( - r#"```sql -> select reverse('datafusion'); -+-----------------------------+ -| reverse(Utf8("datafusion")) | -+-----------------------------+ -| noisufatad | -+-----------------------------+ -```"#, - ) - .with_standard_argument("str", "String") - .build() - .unwrap() - }) -} - -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' +/// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`. /// The implementation uses UTF-8 code points as characters pub fn reverse(args: &[ArrayRef]) -> Result { if args[0].data_type() == &Utf8View { @@ -122,14 +114,23 @@ pub fn reverse(args: &[ArrayRef]) -> Result { } } -fn reverse_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( +fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>( string_array: V, ) -> Result { - let result = ArrayIter::new(string_array) - .map(|string| string.map(|string: &str| string.chars().rev().collect::())) - .collect::>(); + let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), 1024); + + let mut reversed = String::new(); + for string in string_array.iter() { + if let Some(s) = string { + reversed.extend(s.chars().rev()); + builder.append_value(&reversed); + reversed.clear(); + } else { + builder.append_null(); + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] @@ -147,7 +148,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( ReverseFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, &str, Utf8, @@ -156,7 +157,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::from(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::from(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, &str, LargeUtf8, @@ -165,7 +166,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::from(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::from(ScalarValue::Utf8View($INPUT))], $EXPECTED, &str, Utf8, diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 1ec08cb87eee9..b1da3cd512f94 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::cmp::{max, Ordering}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, @@ -31,12 +31,28 @@ use datafusion_common::cast::{ }; use datafusion_common::exec_err; use datafusion_common::Result; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns a specified number of characters from the right side of a string.", + syntax_example = "right(str, n)", + sql_example = r#"```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "n", description = "Number of characters to return."), + related_udf(name = "left") +)] #[derive(Debug)] pub struct RightFunc { signature: Signature, @@ -81,7 +97,11 @@ impl ScalarUDFImpl for RightFunc { utf8_to_str_type(&arg_types[0], "right") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match args[0].data_type() { DataType::Utf8 | DataType::Utf8View => { make_scalar_function(right::, vec![])(args) @@ -95,34 +115,10 @@ impl ScalarUDFImpl for RightFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_right_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_right_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns a specified number of characters from the right side of a string.") - .with_syntax_example("right(str, n)") - .with_sql_example(r#"```sql -> select right('datafusion', 6); -+------------------------------------+ -| right(Utf8("datafusion"),Int64(6)) | -+------------------------------------+ -| fusion | -+------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("n", "Number of characters to return") - .with_related_udf("left") - .build() - .unwrap() - }) -} - /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. /// right('abcde', 2) = 'de' /// The implementation uses UTF-8 code points as characters @@ -186,7 +182,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(2i64)), ], @@ -197,7 +193,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(200i64)), ], @@ -208,7 +204,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(-2i64)), ], @@ -219,7 +215,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(-200i64)), ], @@ -230,7 +226,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::from(0i64)), ], @@ -241,7 +237,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from(2i64)), ], @@ -252,7 +248,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abcde")), ColumnarValue::from(ScalarValue::Int64(None)), ], @@ -263,7 +259,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(5i64)), ], @@ -274,7 +270,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 9d1d6c989eab4..b167ab9307161 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -15,27 +15,46 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringViewArray, + OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; use std::any::Any; use std::fmt::Write; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use unicode_segmentation::UnicodeSegmentation; use DataType::{LargeUtf8, Utf8, Utf8View}; +#[user_doc( + doc_section(label = "String Functions"), + description = "Pads the right side of a string with another string to a specified string length.", + syntax_example = "rpad(str, n[, padding_str])", + sql_example = r#"```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "n", description = "String length to pad to."), + argument( + name = "padding_str", + description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._" + ), + related_udf(name = "lpad") +)] #[derive(Debug)] pub struct RPadFunc { signature: Signature, @@ -89,7 +108,11 @@ impl ScalarUDFImpl for RPadFunc { utf8_to_str_type(&arg_types[0], "rpad") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { match ( args.len(), args[0].data_type(), @@ -118,39 +141,10 @@ impl ScalarUDFImpl for RPadFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_rpad_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_rpad_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Pads the right side of a string with another string to a specified string length.") - .with_syntax_example("rpad(str, n[, padding_str])") - .with_sql_example(r#"```sql -> select rpad('datafusion', 20, '_-'); -+-----------------------------------------------+ -| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | -+-----------------------------------------------+ -| datafusion_-_-_-_-_- | -+-----------------------------------------------+ -```"#) - .with_standard_argument( - "str", - "String", - ) - .with_argument("n", "String length to pad to.") - .with_argument("padding_str", - "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") - .with_related_udf("lpad") - .build() - .unwrap() - }) -} - pub fn rpad( args: &[ArrayRef], ) -> Result { @@ -316,7 +310,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("josé")), ColumnarValue::from(ScalarValue::from(5i64)), ], @@ -327,7 +321,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(5i64)), ], @@ -338,7 +332,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(0i64)), ], @@ -349,7 +343,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::Int64(None)), ], @@ -360,7 +354,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from(5i64)), ], @@ -371,7 +365,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(5i64)), ColumnarValue::from(ScalarValue::from("xy")), @@ -383,7 +377,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(21i64)), ColumnarValue::from(ScalarValue::from("abcdef")), @@ -395,7 +389,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(5i64)), ColumnarValue::from(ScalarValue::from(" ")), @@ -407,7 +401,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(5i64)), ColumnarValue::from(ScalarValue::from("")), @@ -419,7 +413,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from(5i64)), ColumnarValue::from(ScalarValue::from("xy")), @@ -431,7 +425,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::Int64(None)), ColumnarValue::from(ScalarValue::from("xy")), @@ -443,7 +437,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("hi")), ColumnarValue::from(ScalarValue::from(5i64)), ColumnarValue::from(ScalarValue::Utf8(None)), @@ -455,7 +449,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("josé")), ColumnarValue::from(ScalarValue::from(10i64)), ColumnarValue::from(ScalarValue::from("xy")), @@ -467,7 +461,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("josé")), ColumnarValue::from(ScalarValue::from(10i64)), ColumnarValue::from(ScalarValue::from("éñ")), diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 5c90d1923996a..64f6eae5111c9 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -16,18 +16,35 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; -use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; -use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, +}; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.", + syntax_example = "strpos(str, substr)", + alternative_syntax = "position(substr in origstr)", + sql_example = r#"```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "substr", description = "Substring expression to search for.") +)] #[derive(Debug)] pub struct StrposFunc { signature: Signature, @@ -66,7 +83,11 @@ impl ScalarUDFImpl for StrposFunc { utf8_to_int_type(&arg_types[0], "strpos/instr/position") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(strpos, vec![])(args) } @@ -75,33 +96,10 @@ impl ScalarUDFImpl for StrposFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_strpos_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_strpos_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.") - .with_syntax_example("strpos(str, substr)") - .with_sql_example(r#"```sql -> select strpos('datafusion', 'fus'); -+----------------------------------------+ -| strpos(Utf8("datafusion"),Utf8("fus")) | -+----------------------------------------+ -| 5 | -+----------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("substr", "Substring expression to search for.") - .build() - .unwrap() - }) -} - fn strpos(args: &[ArrayRef]) -> Result { match (args[0].data_type(), args[1].data_type()) { (DataType::Utf8, DataType::Utf8) => { @@ -169,13 +167,13 @@ where // the sub vector in the main vector. This is faster than string.find() method. if ascii_only { // If the substring is empty, the result is 1. - if substring.as_bytes().is_empty() { + if substring.is_empty() { T::Native::from_usize(1) } else { T::Native::from_usize( string .as_bytes() - .windows(substring.as_bytes().len()) + .windows(substring.len()) .position(|w| w == substring.as_bytes()) .map(|x| x + 1) .unwrap_or(0), @@ -214,7 +212,7 @@ mod tests { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { test_function!( StrposFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::$t1(Some($lhs.to_owned()))), ColumnarValue::from(ScalarValue::$t2(Some($rhs.to_owned()))), ], diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index d4e03edebb096..33e4ea37c80a3 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -16,23 +16,46 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; -use crate::string::common::{make_and_append_view, StringArrayType}; +use crate::strings::make_and_append_view; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - Array, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, - StringViewArray, + Array, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; - +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "String Functions"), + description = "Extracts a substring of a specified number of characters from a specific starting position in a string.", + syntax_example = "substr(str, start_pos[, length])", + alternative_syntax = "substring(str from start_pos for length)", + sql_example = r#"```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "start_pos", + description = "Character position to start the substring at. The first character in the string has a position of 1." + ), + argument( + name = "length", + description = "Number of characters to extract. If not specified, returns the rest of the string after the start position." + ) +)] #[derive(Debug)] pub struct SubstrFunc { signature: Signature, @@ -75,7 +98,11 @@ impl ScalarUDFImpl for SubstrFunc { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(substr, vec![])(args) } @@ -84,6 +111,13 @@ impl ScalarUDFImpl for SubstrFunc { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 || arg_types.len() > 3 { + return plan_err!( + "The {} function requires 2 or 3 arguments, but got {}.", + self.name(), + arg_types.len() + ); + } let first_data_type = match &arg_types[0] { DataType::Null => Ok(DataType::Utf8), DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()), @@ -143,34 +177,10 @@ impl ScalarUDFImpl for SubstrFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_substr_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_substr_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Extracts a substring of a specified number of characters from a specific starting position in a string.") - .with_syntax_example("substr(str, start_pos[, length])") - .with_sql_example(r#"```sql -> select substr('datafusion', 5, 3); -+----------------------------------------------+ -| substr(Utf8("datafusion"),Int64(5),Int64(3)) | -+----------------------------------------------+ -| fus | -+----------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("start_pos", "Character position to start the substring at. The first character in the string has a position of 1.") - .with_argument("length", "Number of characters to extract. If not specified, returns the rest of the string after the start position.") - .build() - .unwrap() - }) -} - /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) /// substr('alphabet', 3) = 'phabet' /// substr('alphabet', 3, 2) = 'ph' @@ -437,10 +447,9 @@ where match args.len() { 1 => { let iter = ArrayIter::new(string_array); - - let result = iter - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { + let mut result_builder = GenericStringBuilder::::new(); + for (string, start) in iter.zip(start_array.iter()) { + match (string, start) { (Some(string), Some(start)) => { let (start, end) = get_true_start_end( string, @@ -449,47 +458,51 @@ where enable_ascii_fast_path, ); // start, end is byte-based let substr = &string[start..end]; - Some(substr.to_string()) + result_builder.append_value(substr); } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } 2 => { let iter = ArrayIter::new(string_array); let count_array = count_array_opt.unwrap(); + let mut result_builder = GenericStringBuilder::::new(); - let result = iter - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| { - match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + for ((string, start), count) in + iter.zip(start_array.iter()).zip(count_array.iter()) + { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + return exec_err!( "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - if start == i64::MIN { - return exec_err!("negative overflow when calculating skip value"); - } - let (start, end) = get_true_start_end( - string, - start, - Some(count as u64), - enable_ascii_fast_path, - ); // start, end is byte-based - let substr = &string[start..end]; - Ok(Some(substr.to_string())) + ); + } else { + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + result_builder.append_value(substr); } - _ => Ok(None), } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } other => { exec_err!("substr was called with {other} arguments. It requires 2 or 3.") @@ -512,7 +525,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(None)), ColumnarValue::from(ScalarValue::from(1i64)), ], @@ -523,7 +536,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -536,7 +549,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "this és longer than 12B" )))), @@ -550,7 +563,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "this is longer than 12B" )))), @@ -563,7 +576,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "joséésoj" )))), @@ -576,7 +589,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -590,7 +603,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -604,7 +617,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(0i64)), ], @@ -615,7 +628,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(5i64)), ], @@ -626,7 +639,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(-5i64)), ], @@ -637,7 +650,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(1i64)), ], @@ -648,7 +661,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(2i64)), ], @@ -659,7 +672,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(3i64)), ], @@ -670,7 +683,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(-3i64)), ], @@ -681,7 +694,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(30i64)), ], @@ -692,7 +705,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::Int64(None)), ], @@ -703,7 +716,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(3i64)), ColumnarValue::from(ScalarValue::from(2i64)), @@ -715,7 +728,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(3i64)), ColumnarValue::from(ScalarValue::from(20i64)), @@ -727,7 +740,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(0i64)), ColumnarValue::from(ScalarValue::from(5i64)), @@ -740,7 +753,7 @@ mod tests { // starting from 5 (10 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(-5i64)), ColumnarValue::from(ScalarValue::from(10i64)), @@ -753,7 +766,7 @@ mod tests { // starting from -1 (4 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(-5i64)), ColumnarValue::from(ScalarValue::from(4i64)), @@ -766,7 +779,7 @@ mod tests { // starting from 0 (5 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(-5i64)), ColumnarValue::from(ScalarValue::from(5i64)), @@ -778,7 +791,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::Int64(None)), ColumnarValue::from(ScalarValue::from(20i64)), @@ -790,7 +803,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(3i64)), ColumnarValue::from(ScalarValue::Int64(None)), @@ -802,7 +815,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("alphabet")), ColumnarValue::from(ScalarValue::from(1i64)), ColumnarValue::from(ScalarValue::from(-1i64)), @@ -814,7 +827,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("joséésoj")), ColumnarValue::from(ScalarValue::from(5i64)), ColumnarValue::from(ScalarValue::from(2i64)), @@ -840,7 +853,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("abc")), ColumnarValue::from(ScalarValue::from(-9223372036854775808i64)), ], @@ -851,7 +864,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("overflow")), ColumnarValue::from(ScalarValue::from(-9223372036854775808i64)), ColumnarValue::from(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index c628367cf3551..217a300c1c913 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, @@ -26,12 +26,42 @@ use arrow::datatypes::{DataType, Int32Type, Int64Type}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = r#"Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#, + syntax_example = "substr_index(str, delim, count)", + sql_example = r#"```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument( + name = "delim", + description = "The string to find in str to split str." + ), + argument( + name = "count", + description = "The number of times to search for the delimiter. Can be either a positive or negative number." + ) +)] #[derive(Debug)] pub struct SubstrIndexFunc { signature: Signature, @@ -78,7 +108,11 @@ impl ScalarUDFImpl for SubstrIndexFunc { utf8_to_str_type(&arg_types[0], "substr_index") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(substr_index, vec![])(args) } @@ -87,42 +121,10 @@ impl ScalarUDFImpl for SubstrIndexFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_substr_index_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_substr_index_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description(r#"Returns the substring from str before count occurrences of the delimiter delim. -If count is positive, everything to the left of the final delimiter (counting from the left) is returned. -If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#) - .with_syntax_example("substr_index(str, delim, count)") - .with_sql_example(r#"```sql -> select substr_index('www.apache.org', '.', 1); -+---------------------------------------------------------+ -| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | -+---------------------------------------------------------+ -| www | -+---------------------------------------------------------+ -> select substr_index('www.apache.org', '.', -1); -+----------------------------------------------------------+ -| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | -+----------------------------------------------------------+ -| org | -+----------------------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("delim", "The string to find in str to split str.") - .with_argument("count", "The number of times to search for the delimiter. Can be either a positive or negative number.") - .build() - .unwrap() - }) -} - /// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. /// SUBSTRING_INDEX('www.apache.org', '.', 1) = www /// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache @@ -250,7 +252,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("www.apache.org")), ColumnarValue::from(ScalarValue::from(".")), ColumnarValue::from(ScalarValue::from(1i64)), @@ -262,7 +264,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("www.apache.org")), ColumnarValue::from(ScalarValue::from(".")), ColumnarValue::from(ScalarValue::from(2i64)), @@ -274,7 +276,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("www.apache.org")), ColumnarValue::from(ScalarValue::from(".")), ColumnarValue::from(ScalarValue::from(-2i64)), @@ -286,7 +288,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("www.apache.org")), ColumnarValue::from(ScalarValue::from(".")), ColumnarValue::from(ScalarValue::from(-1i64)), @@ -298,7 +300,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("www.apache.org")), ColumnarValue::from(ScalarValue::from(".")), ColumnarValue::from(ScalarValue::from(0i64)), @@ -310,7 +312,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("")), ColumnarValue::from(ScalarValue::from(".")), ColumnarValue::from(ScalarValue::from(1i64)), @@ -322,7 +324,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("www.apache.org")), ColumnarValue::from(ScalarValue::from("")), ColumnarValue::from(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index ac5a6a2117b8d..b4267e1441d19 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -16,23 +16,42 @@ // under the License. use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, }; use arrow::datatypes::DataType; -use hashbrown::HashMap; +use datafusion_common::HashMap; use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_macros::user_doc; +#[user_doc( + doc_section(label = "String Functions"), + description = "Translates characters in a string to specified translation characters.", + syntax_example = "translate(str, chars, translation)", + sql_example = r#"```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + argument(name = "chars", description = "Characters to translate."), + argument( + name = "translation", + description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string." + ) +)] #[derive(Debug)] pub struct TranslateFunc { signature: Signature, @@ -76,39 +95,19 @@ impl ScalarUDFImpl for TranslateFunc { utf8_to_str_type(&arg_types[0], "translate") } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { make_scalar_function(invoke_translate, vec![])(args) } fn documentation(&self) -> Option<&Documentation> { - Some(get_translate_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_translate_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Translates characters in a string to specified translation characters.") - .with_syntax_example("translate(str, chars, translation)") - .with_sql_example(r#"```sql -> select translate('twice', 'wic', 'her'); -+--------------------------------------------------+ -| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | -+--------------------------------------------------+ -| there | -+--------------------------------------------------+ -```"#) - .with_standard_argument("str", "String") - .with_argument("chars", "Characters to translate.") - .with_argument("translation", "Translation characters. Translation characters replace only characters at the same position in the **chars** string.") - .build() - .unwrap() - }) -} - fn invoke_translate(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8View => { @@ -201,7 +200,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("12345")), ColumnarValue::from(ScalarValue::from("143")), ColumnarValue::from(ScalarValue::from("ax")) @@ -213,7 +212,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from("143")), ColumnarValue::from(ScalarValue::from("ax")) @@ -225,7 +224,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("12345")), ColumnarValue::from(ScalarValue::Utf8(None)), ColumnarValue::from(ScalarValue::from("ax")) @@ -237,7 +236,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("12345")), ColumnarValue::from(ScalarValue::from("143")), ColumnarValue::from(ScalarValue::Utf8(None)) @@ -249,7 +248,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("é2íñ5")), ColumnarValue::from(ScalarValue::from("éñí")), ColumnarValue::from(ScalarValue::from("óü")), @@ -262,7 +261,7 @@ mod tests { #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::from(ScalarValue::from("12345")), ColumnarValue::from(ScalarValue::from("143")), ColumnarValue::from(ScalarValue::from("ax")), diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index b1f5580ebe568..f913bdb8ddb7b 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -105,9 +105,9 @@ where Hint::AcceptsSingular => 1, Hint::Pad => inferred_length, }; - arg.clone().into_array(expansion_len) + arg.to_array(expansion_len) }) - .collect::>>()?; + .collect::>>()?; let result = (inner)(&args); if is_scalar { @@ -134,25 +134,27 @@ pub mod test { let func = $FUNC; let type_array = $ARGS.iter().map(|arg| arg.data_type().clone()).collect::>(); + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); let return_type = func.return_type(&type_array); match expected { Ok(expected) => { assert_eq!(return_type.is_ok(), true); - assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + let return_type = return_type.unwrap(); + assert_eq!(return_type, $EXPECTED_DATA_TYPE); - let result = func.invoke($ARGS); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); - let len = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - let inferred_length = len.unwrap_or(1); - let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); // value is correct match expected { @@ -169,7 +171,7 @@ pub mod test { } else { // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke($ARGS) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type.unwrap()}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml new file mode 100644 index 0000000000000..8946640f20bb8 --- /dev/null +++ b/datafusion/macros/Cargo.toml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-macros" +description = "Procedural macros for DataFusion query engine" +keywords = ["datafusion", "query", "sql"] +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_macros" +# lib.rs to be re-added in the future +path = "src/user_doc.rs" +proc-macro = true + +[dependencies] +datafusion-expr = { workspace = true } +quote = "1.0.37" +syn = { version = "2.0.79", features = ["full"] } diff --git a/datafusion/macros/LICENSE.txt b/datafusion/macros/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/macros/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/macros/NOTICE.txt b/datafusion/macros/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/macros/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs new file mode 100644 index 0000000000000..6ca90ed376c37 --- /dev/null +++ b/datafusion/macros/src/user_doc.rs @@ -0,0 +1,278 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate proc_macro; +use datafusion_expr::scalar_doc_sections::doc_sections_const; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput, LitStr}; + +/// This procedural macro is intended to parse a rust custom attribute and create user documentation +/// from it by constructing a `DocumentBuilder()` automatically. The `Documentation` can be +/// retrieved from the `documentation()` method +/// declared on `AggregateUDF`, `WindowUDFImpl`, `ScalarUDFImpl` traits. +/// For `doc_section`, this macro will try to find corresponding predefined `DocSection` by label field +/// Predefined `DocSection` can be found in datafusion/expr/src/udf.rs +/// Example: +/// ```ignore +/// #[user_doc( +/// doc_section(label = "Time and Date Functions"), +/// description = r"Converts a value to a date (`YYYY-MM-DD`).", +/// syntax_example = "to_date('2017-05-31', '%Y-%m-%d')", +/// sql_example = r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, +/// standard_argument(name = "expression", prefix = "String"), +/// argument( +/// name = "format_n", +/// description = r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order +/// they appear with the first successful one being returned. If none of the formats successfully parse the expression +/// an error will be returned." +/// ) +/// )] +/// #[derive(Debug)] +/// pub struct ToDateFunc { +/// signature: Signature, +/// } +/// ``` +/// will generate the following code +/// +/// ```ignore +/// pub struct ToDateFunc { +/// signature: Signature, +/// } +/// impl ToDateFunc { +/// fn doc(&self) -> Option<&datafusion_doc::Documentation> { +/// static DOCUMENTATION: std::sync::LazyLock< +/// datafusion_doc::Documentation, +/// > = std::sync::LazyLock::new(|| { +/// datafusion_doc::Documentation::builder( +/// datafusion_doc::DocSection { +/// include: true, +/// label: "Time and Date Functions", +/// description: None, +/// }, +/// r"Converts a value to a date (`YYYY-MM-DD`).".to_string(), +/// "to_date('2017-05-31', '%Y-%m-%d')".to_string(), +/// ) +/// .with_sql_example( +/// r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, +/// ) +/// .with_standard_argument("expression", "String".into()) +/// .with_argument( +/// "format_n", +/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order +/// they appear with the first successful one being returned. If none of the formats successfully parse the expression +/// an error will be returned.", +/// ) +/// .build() +/// }); +/// Some(&DOCUMENTATION) +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { + let mut doc_section_lbl: Option = None; + + let mut description: Option = None; + let mut syntax_example: Option = None; + let mut alt_syntax_example: Vec> = vec![]; + let mut sql_example: Option = None; + let mut standard_args: Vec<(Option, Option)> = vec![]; + let mut udf_args: Vec<(Option, Option)> = vec![]; + let mut related_udfs: Vec> = vec![]; + + let parser = syn::meta::parser(|meta| { + if meta.path.is_ident("doc_section") { + meta.parse_nested_meta(|meta| { + if meta.path.is_ident("label") { + doc_section_lbl = meta.value()?.parse()?; + return Ok(()); + } + Ok(()) + }) + } else if meta.path.is_ident("description") { + description = Some(meta.value()?.parse()?); + Ok(()) + } else if meta.path.is_ident("syntax_example") { + syntax_example = Some(meta.value()?.parse()?); + Ok(()) + } else if meta.path.is_ident("alternative_syntax") { + alt_syntax_example.push(Some(meta.value()?.parse()?)); + Ok(()) + } else if meta.path.is_ident("sql_example") { + sql_example = Some(meta.value()?.parse()?); + Ok(()) + } else if meta.path.is_ident("standard_argument") { + let mut standard_arg: (Option, Option) = (None, None); + let m = meta.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + standard_arg.0 = meta.value()?.parse()?; + return Ok(()); + } else if meta.path.is_ident("prefix") { + standard_arg.1 = meta.value()?.parse()?; + return Ok(()); + } + Ok(()) + }); + + standard_args.push(standard_arg.clone()); + + m + } else if meta.path.is_ident("argument") { + let mut arg: (Option, Option) = (None, None); + let m = meta.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + arg.0 = meta.value()?.parse()?; + return Ok(()); + } else if meta.path.is_ident("description") { + arg.1 = meta.value()?.parse()?; + return Ok(()); + } + Ok(()) + }); + + udf_args.push(arg.clone()); + + m + } else if meta.path.is_ident("related_udf") { + let mut arg: Option = None; + let m = meta.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + arg = meta.value()?.parse()?; + return Ok(()); + } + Ok(()) + }); + + related_udfs.push(arg.clone()); + + m + } else { + Err(meta.error(format!("Unsupported property: {:?}", meta.path.get_ident()))) + } + }); + + parse_macro_input!(args with parser); + + // Parse the input struct + let input = parse_macro_input!(input as DeriveInput); + let name = input.clone().ident; + + if doc_section_lbl.is_none() { + eprintln!("label for doc_section should exist"); + } + let label = doc_section_lbl.as_ref().unwrap().value(); + // Try to find a predefined const by label first. + // If there is no match but label exists, default value will be used for include and description + let doc_section_option = doc_sections_const().iter().find(|ds| ds.label == label); + let (doc_section_include, doc_section_label, doc_section_desc) = + match doc_section_option { + Some(section) => (section.include, section.label, section.description), + None => (true, label.as_str(), None), + }; + let doc_section_description = doc_section_desc + .map(|desc| quote! { Some(#desc)}) + .unwrap_or(quote! { None }); + + let sql_example = sql_example.map(|ex| { + quote! { + .with_sql_example(#ex) + } + }); + + let udf_args = udf_args + .iter() + .map(|(name, desc)| { + quote! { + .with_argument(#name, #desc) + } + }) + .collect::>(); + + let standard_args = standard_args + .iter() + .map(|(name, desc)| { + let desc = if let Some(d) = desc { + quote! { #d.into() } + } else { + quote! { None } + }; + + quote! { + .with_standard_argument(#name, #desc) + } + }) + .collect::>(); + + let related_udfs = related_udfs + .iter() + .map(|name| { + quote! { + .with_related_udf(#name) + } + }) + .collect::>(); + + let alt_syntax_example = alt_syntax_example.iter().map(|syn| { + quote! { + .with_alternative_syntax(#syn) + } + }); + + let generated = quote! { + #input + + impl #name { + fn doc(&self) -> Option<&datafusion_doc::Documentation> { + static DOCUMENTATION: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + datafusion_doc::Documentation::builder(datafusion_doc::DocSection { include: #doc_section_include, label: #doc_section_label, description: #doc_section_description }, + #description.to_string(), #syntax_example.to_string()) + #sql_example + #(#alt_syntax_example)* + #(#standard_args)* + #(#udf_args)* + #(#related_udfs)* + .build() + }); + Some(&DOCUMENTATION) + } + } + }; + + // Debug the generated code if needed + // if name == "ArrayAgg" { + // eprintln!("Generated code: {}", generated); + // } + + // Return the generated code + TokenStream::from(generated) +} diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 337a24ffae206..3f5ec9b0da030 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -36,27 +36,26 @@ name = "datafusion_optimizer" path = "src/lib.rs" [features] -default = ["regex_expressions"] -regex_expressions = ["datafusion-physical-expr/regex_expressions"] +recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } -async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } -hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } -paste = "1.0.14" +recursive = { workspace = true, optional = true } +regex = { workspace = true } regex-syntax = "0.8.0" [dev-dependencies] -arrow-buffer = { workspace = true } +async-trait = { workspace = true } ctor = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-window = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/LICENSE.txt b/datafusion/optimizer/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/optimizer/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/optimizer/NOTICE.txt b/datafusion/optimizer/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/optimizer/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b3b24724552a3..454afa24b628c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,7 +101,7 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + col, exists, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; @@ -219,7 +219,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(expr::WindowFunction::new( + .window(vec![Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index a26ec4be5c851..9fbe54e1ccb92 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -26,7 +26,9 @@ use datafusion_expr::expr::PlannedReplaceSelectItem; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, find_base_plan, }; -use datafusion_expr::{Expr, LogicalPlan, Projection, SubqueryAlias}; +use datafusion_expr::{ + Distinct, DistinctOn, Expr, LogicalPlan, Projection, SubqueryAlias, +}; #[derive(Default, Debug)] pub struct ExpandWildcardRule {} @@ -59,12 +61,25 @@ fn expand_internal(plan: LogicalPlan) -> Result> { .map(LogicalPlan::Projection)?, )) } - // Teh schema of the plan should also be updated if the child plan is transformed. + // The schema of the plan should also be updated if the child plan is transformed. LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { Ok(Transformed::yes( SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias)?, )) } + LogicalPlan::Distinct(Distinct::On(distinct_on)) => { + let projected_expr = + expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; + validate_unique_names("Distinct", projected_expr.iter())?; + Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new( + distinct_on.on_expr, + projected_expr, + distinct_on.sort_expr, + distinct_on.input, + )?, + )))) + } _ => Ok(Transformed::no(plan)), } } @@ -240,6 +255,18 @@ mod tests { assert_plan_eq(plan, expected) } + #[test] + fn test_expand_wildcard_in_distinct_on() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on(vec![col("a")], vec![wildcard()], None)? + .build()?; + let expected = "\ + DistinctOn: on_expr=[[test.a]], select_expr=[[test.a, test.b, test.c]], sort_expr=[[]] [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + #[test] fn test_subquery_schema() -> Result<()> { let analyzer = Analyzer::with_rules(vec![Arc::new(ExpandWildcardRule::new())]); diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 342d85a915b4d..95781b395f3c3 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -23,8 +23,7 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; -use datafusion_expr::expr::WildcardOptions; -use datafusion_expr::{logical_plan::LogicalPlan, Expr, LogicalPlanBuilder}; +use datafusion_expr::{logical_plan::LogicalPlan, wildcard, Expr, LogicalPlanBuilder}; /// Analyzed rule that inlines TableScan that provide a [`LogicalPlan`] /// (DataFrame / ViewTable) @@ -93,10 +92,7 @@ fn generate_projection_expr( ))); } } else { - exprs.push(Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }); + exprs.push(wildcard()); } Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 4cd891664e7f5..9d0ac6b54cf45 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -24,17 +24,14 @@ use log::debug; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::expr::Exists; -use datafusion_expr::expr::InSubquery; +use datafusion_common::Result; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_expr::{InvariantLevel, LogicalPlan}; use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; -use crate::analyzer::subquery::check_subquery_expr; +use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; @@ -44,17 +41,25 @@ pub mod count_wildcard_rule; pub mod expand_wildcard_rule; pub mod function_rewrite; pub mod inline_table_scan; -pub mod subquery; +pub mod resolve_grouping_function; pub mod type_coercion; +pub mod subquery { + #[deprecated( + since = "44.0.0", + note = "please use `datafusion_expr::check_subquery_expr` instead" + )] + pub use datafusion_expr::check_subquery_expr; +} + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// -/// This is different than an [`OptimizerRule`](crate::OptimizerRule) +/// `AnalyzerRule`s are different than an [`OptimizerRule`](crate::OptimizerRule)s /// which must preserve the semantics of the `LogicalPlan`, while computing /// results in a more optimal way. /// -/// For example, an `AnalyzerRule` may resolve [`Expr`]s into more specific +/// For example, an `AnalyzerRule` may resolve [`Expr`](datafusion_expr::Expr)s into more specific /// forms such as a subquery reference, or do type coercion to ensure the types /// of operands are correct. /// @@ -70,10 +75,13 @@ pub trait AnalyzerRule: Debug { fn name(&self) -> &str; } -/// A rule-based Analyzer. +/// Rule-based Analyzer. /// -/// An `Analyzer` transforms a `LogicalPlan` -/// prior to the rest of the DataFusion optimization process. +/// Applies [`FunctionRewrite`]s and [`AnalyzerRule`]s to transform a +/// [`LogicalPlan`] in preparation for execution. +/// +/// For example, the `Analyzer` applies type coercion to ensure the types of +/// operands match the types required by functions. #[derive(Clone, Debug)] pub struct Analyzer { /// Expr --> Function writes to apply prior to analysis passes @@ -96,6 +104,7 @@ impl Analyzer { // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. Arc::new(ExpandWildcardRule::new()), // [Expr::Wildcard] should be expanded before [TypeCoercion] + Arc::new(ResolveGroupingFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; @@ -134,6 +143,10 @@ impl Analyzer { where F: FnMut(&LogicalPlan, &dyn AnalyzerRule), { + // verify the logical plan required invariants at the start, before analyzer + plan.check_invariants(InvariantLevel::Always) + .map_err(|e| e.context("Invalid input plan passed to Analyzer"))?; + let start_time = Instant::now(); let mut new_plan = plan; @@ -155,39 +168,20 @@ impl Analyzer { // TODO add common rule executor for Analyzer and Optimizer for rule in rules { - new_plan = rule.analyze(new_plan, config).map_err(|e| { - DataFusionError::Context(rule.name().to_string(), Box::new(e)) - })?; + new_plan = rule + .analyze(new_plan, config) + .map_err(|e| e.context(rule.name()))?; log_plan(rule.name(), &new_plan); observer(&new_plan, rule.as_ref()); } - // for easier display in explain output - check_plan(&new_plan).map_err(|e| { - DataFusionError::Context("check_analyzed_plan".to_string(), Box::new(e)) - })?; + + // verify at the end, after the last LP analyzer pass, that the plan is executable. + new_plan + .check_invariants(InvariantLevel::Executable) + .map_err(|e| e.context("Invalid (non-executable) plan after Analyzer"))?; + log_plan("Final analyzed plan", &new_plan); debug!("Analyzer took {} ms", start_time.elapsed().as_millis()); Ok(new_plan) } } - -/// Do necessary check and fail the invalid plan -fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply_with_subqueries(|plan: &LogicalPlan| { - plan.apply_expressions(|expr| { - // recursively look for subqueries - expr.apply(|expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - check_subquery_expr(plan, &subquery.subquery, expr)?; - } - _ => {} - }; - Ok(TreeNodeRecursion::Continue) - }) - }) - }) - .map(|_| ()) -} diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs new file mode 100644 index 0000000000000..001df8e2ca0d9 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzed rule to replace TableScan references +//! such as DataFrames and Views and inlines the LogicalPlan. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::analyzer::AnalyzerRule; + +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, +}; +use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::grouping_set_to_exprlist; +use datafusion_expr::{ + bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, + Expr, Projection, +}; +use itertools::Itertools; + +/// Replaces grouping aggregation function with value derived from internal grouping id +#[derive(Default, Debug)] +pub struct ResolveGroupingFunction; + +impl ResolveGroupingFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for ResolveGroupingFunction { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + plan.transform_up(analyze_internal).data() + } + + fn name(&self) -> &str { + "resolve_grouping_function" + } +} + +/// Create a map from grouping expr to index in the internal grouping id. +/// +/// For more details on how the grouping id bitmap works the documentation for +/// [[Aggregate::INTERNAL_GROUPING_ID]] +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { + Ok(grouping_set_to_exprlist(group_expr)? + .into_iter() + .rev() + .enumerate() + .map(|(idx, v)| (v, idx)) + .collect::>()) +} + +fn replace_grouping_exprs( + input: Arc, + schema: DFSchemaRef, + group_expr: Vec, + aggr_expr: Vec, +) -> Result { + // Create HashMap from Expr to index in the grouping_id bitmap + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; + let columns = schema.columns(); + let mut new_agg_expr = Vec::new(); + let mut projection_exprs = Vec::new(); + let grouping_id_len = if is_grouping_set { 1 } else { 0 }; + let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; + projection_exprs.extend( + columns + .iter() + .take(group_expr_len) + .map(|column| Expr::Column(column.clone())), + ); + for (expr, column) in aggr_expr + .into_iter() + .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) + { + match expr { + Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + is_grouping_set, + )?; + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + column.relation, + column.name, + ))); + } + _ => { + projection_exprs.push(Expr::Column(column)); + new_agg_expr.push(expr); + } + } + } + // Recreate aggregate without grouping functions + let new_aggregate = + LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + // Create projection with grouping functions calculations + let projection = LogicalPlan::Projection(Projection::try_new( + projection_exprs, + new_aggregate.into(), + )?); + Ok(projection) +} + +fn analyze_internal(plan: LogicalPlan) -> Result> { + // rewrite any subqueries in the plan first + let transformed_plan = + plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; + + let transformed_plan = transformed_plan.transform_data(|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, + )), + _ => Ok(Transformed::no(plan)), + })?; + + Ok(transformed_plan) +} + +fn is_grouping_function(expr: &Expr) -> bool { + // TODO: Do something better than name here should grouping be a built + // in expression? + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") +} + +fn contains_grouping_function(exprs: &[Expr]) -> bool { + exprs.iter().any(is_grouping_function) +} + +/// Validate that the arguments to the grouping function are in the group by clause. +fn validate_args( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, +) -> Result<()> { + let expr_not_in_group_by = function + .args + .iter() + .find(|expr| !group_by_expr.contains_key(expr)); + if let Some(expr) = expr_not_in_group_by { + plan_err!( + "Argument {} to grouping function is not in grouping columns {}", + expr, + group_by_expr.keys().map(|e| e.to_string()).join(", ") + ) + } else { + Ok(()) + } +} + +fn grouping_function_on_id( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, + is_grouping_set: bool, +) -> Result { + validate_args(function, group_by_expr)?; + let args = &function.args; + + // Postgres allows grouping function for group by without grouping sets, the result is then + // always 0 + if !is_grouping_set { + return Ok(Expr::from(ScalarValue::from(0i32))); + } + + let group_by_expr_count = group_by_expr.len(); + let literal = |value: usize| { + if group_by_expr_count < 8 { + Expr::from(ScalarValue::from(value as u8)) + } else if group_by_expr_count < 16 { + Expr::from(ScalarValue::from(value as u16)) + } else if group_by_expr_count < 32 { + Expr::from(ScalarValue::from(value as u32)) + } else { + Expr::from(ScalarValue::from(value as u64)) + } + }; + + let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); + // The grouping call is exactly our internal grouping id + if args.len() == group_by_expr_count + && args + .iter() + .rev() + .enumerate() + .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) + { + return Ok(cast(grouping_id_column, DataType::Int32)); + } + + args.iter() + .rev() + .enumerate() + .map(|(arg_idx, expr)| { + group_by_expr.get(expr).map(|group_by_idx| { + let group_by_bit = + bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); + match group_by_idx.cmp(&arg_idx) { + Ordering::Less => { + bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) + } + Ordering::Greater => { + bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) + } + Ordering::Equal => group_by_bit, + } + }) + }) + .collect::>>() + .and_then(|bit_exprs| { + bit_exprs + .into_iter() + .reduce(bitwise_or) + .map(|expr| cast(expr, DataType::Int32)) + }) + .ok_or_else(|| { + internal_datafusion_err!("Grouping sets should contains at least one element") + }) +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 1d9aea25274fa..1f359da2d416d 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator, - Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -157,7 +158,7 @@ pub struct TypeCoercionRewriter<'a> { impl<'a> TypeCoercionRewriter<'a> { /// Create a new [`TypeCoercionRewriter`] with a provided schema /// representing both the inputs and output of the [`LogicalPlan`] node. - fn new(schema: &'a DFSchema) -> Self { + pub fn new(schema: &'a DFSchema) -> Self { Self { schema } } @@ -169,6 +170,7 @@ impl<'a> TypeCoercionRewriter<'a> { match plan { LogicalPlan::Join(join) => self.coerce_join(join), LogicalPlan::Union(union) => Self::coerce_union(union), + LogicalPlan::Limit(limit) => Self::coerce_limit(limit), _ => Ok(plan), } } @@ -230,6 +232,37 @@ impl<'a> TypeCoercionRewriter<'a> { })) } + /// Coerce the fetch and skip expression to Int64 type. + fn coerce_limit(limit: Limit) -> Result { + fn coerce_limit_expr( + expr: Expr, + schema: &DFSchema, + expr_name: &str, + ) -> Result { + let dt = expr.get_type(schema)?; + if dt.is_integer() || dt.is_null() { + expr.cast_to(&DataType::Int64, schema) + } else { + plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}") + } + } + + let empty_schema = DFSchema::empty(); + let new_fetch = limit + .fetch + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT")) + .transpose()?; + let new_skip = limit + .skip + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) + .transpose()?; + Ok(LogicalPlan::Limit(Limit { + input: limit.input, + fetch: new_fetch.map(Box::new), + skip: new_skip.map(Box::new), + })) + } + fn coerce_join_filter(&self, expr: Expr) -> Result { let expr_type = expr.get_type(self.schema)?; match expr_type { @@ -257,7 +290,7 @@ impl<'a> TypeCoercionRewriter<'a> { } } -impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { +impl TreeNodeRewriter for TypeCoercionRewriter<'_> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { @@ -393,7 +426,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { )) })?; let high_type = high.get_type(self.schema)?; - let high_coerced_type = comparison_coercion(&expr_type, &low_type) + let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" @@ -655,6 +688,22 @@ fn coerce_frame_bound( } } +fn extract_window_frame_target_type(col_type: &DataType) -> Result { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + || matches!(col_type, DataType::Boolean) + { + Ok(col_type.clone()) + } else if is_datetime(col_type) { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } else if let DataType::Dictionary(_, value_type) = col_type { + extract_window_frame_target_type(value_type) + } else { + return internal_err!("Cannot run range queries on datatype: {col_type:?}"); + } +} + // Coerces the given `window_frame` to use appropriate natural types. // For example, ROWS and GROUPS frames use `UInt64` during calculations. fn coerce_window_frame( @@ -663,33 +712,23 @@ fn coerce_window_frame( expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; - let current_types = expressions - .iter() - .map(|s| s.expr.get_type(schema)) - .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { - if let Some(col_type) = current_types.first() { - if col_type.is_numeric() - || is_utf8_or_large_utf8(col_type) - || matches!(col_type, DataType::Null) - { - col_type - } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) - } else { - return internal_err!( - "Cannot run range queries on datatype: {col_type:?}" - ); - } + let current_types = expressions + .first() + .map(|s| s.expr.get_type(schema)) + .transpose()?; + if let Some(col_type) = current_types { + extract_window_frame_target_type(&col_type)? } else { return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64, }; - window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; - window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; + window_frame.start_bound = + coerce_frame_bound(&target_type, window_frame.start_bound)?; + window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?; Ok(window_frame) } @@ -904,7 +943,7 @@ pub fn coerce_union_schema(inputs: &[Arc]) -> Result { ); } - // coerce data type and nullablity for each field + // coerce data type and nullability for each field for (union_datatype, union_nullable, union_field_map, plan_field) in izip!( union_datatypes.iter_mut(), union_nullabilities.iter_mut(), @@ -956,16 +995,19 @@ fn project_with_column_index( .enumerate() .map(|(i, e)| match e { Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { - e.unalias().alias(schema.field(i).name()) + Ok(e.unalias().alias(schema.field(i).name())) } Expr::Column(Column { relation: _, ref name, - }) if name != schema.field(i).name() => e.alias(schema.field(i).name()), - Expr::Alias { .. } | Expr::Column { .. } => e, - _ => e.alias(schema.field(i).name()), + }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())), + Expr::Alias { .. } | Expr::Column { .. } => Ok(e), + Expr::Wildcard { .. } => { + plan_err!("Wildcard should be expanded before type coercion") + } + _ => Ok(e.alias(schema.field(i).name())), }) - .collect::>(); + .collect::>>()?; Projection::try_new_with_schema(alias_expr, input, schema) .map(LogicalPlan::Projection) @@ -979,6 +1021,10 @@ mod test { use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, TimeUnit}; + use crate::analyzer::type_coercion::{ + coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; @@ -993,11 +1039,6 @@ mod test { }; use datafusion_functions_aggregate::average::AvgAccumulator; - use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; - fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -1209,10 +1250,14 @@ mod test { } fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(Utf8) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { Ok(ColumnarValue::from(ScalarValue::from("a"))) } } @@ -1314,7 +1359,7 @@ mod test { let err = Projection::try_new(vec![udaf], empty).err().unwrap(); assert!( - err.strip_backtrace().starts_with("Error during planning: Error during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed") + err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to MY_AVG function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed") ); Ok(()) } @@ -1364,7 +1409,7 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert!(err.starts_with("Error during planning: Error during planning: Coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed.")); + assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to avg function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed.")); Ok(()) } @@ -1412,10 +1457,10 @@ mod test { cast(lit("2002-05-08"), DataType::Date32) + lit(ScalarValue::new_interval_ym(0, 1)), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); let expected = - "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ + "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) AND CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\")\ \n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1428,7 +1473,7 @@ mod test { + lit(ScalarValue::new_interval_ym(0, 1)), lit("2002-12-08"), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = @@ -1437,6 +1482,17 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } + #[test] + fn between_null() -> Result<()> { + let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64)); + let empty = empty(); + let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + let expected = + "Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)\ + \n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + } + #[test] fn is_bool_for_type_coercion() -> Result<()> { // is true @@ -1483,7 +1539,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1491,7 +1547,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1511,7 +1567,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1519,7 +1575,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1547,7 +1603,7 @@ mod test { let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); @@ -1565,7 +1621,7 @@ mod test { #[test] fn concat_for_type_coercion() -> Result<()> { - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature @@ -1700,7 +1756,7 @@ mod test { true, ), Field::new("binary", DataType::Binary, true), - Field::new("string", DataType::Utf8, true), + Field::new("string", Utf8, true), Field::new("decimal", DataType::Decimal128(10, 10), true), ] .into(), @@ -1717,7 +1773,7 @@ mod test { else_expr: None, }; let case_when_common_type = DataType::Boolean; - let then_else_common_type = DataType::Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), &case_when_common_type, @@ -1736,8 +1792,8 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = DataType::Utf8; - let then_else_common_type = DataType::Utf8; + let case_when_common_type = Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), &case_when_common_type, @@ -1804,7 +1860,7 @@ mod test { #[test] fn tes_case_when_list() -> Result<()> { - let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); + let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true)); let schema = Arc::new(DFSchema::from_unqualified_fields( vec![ Field::new( @@ -1826,48 +1882,48 @@ mod test { test_case_expression!( Some("list"), vec![(Box::new(col("large_list")), Box::new(lit("1")))], - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), + Utf8, schema ); test_case_expression!( Some("large_list"), vec![(Box::new(col("list")), Box::new(lit("1")))], - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), + Utf8, schema ); test_case_expression!( Some("list"), vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), + Utf8, schema ); test_case_expression!( Some("fixed_list"), vec![(Box::new(col("list")), Box::new(lit("1")))], - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), + Utf8, schema ); test_case_expression!( Some("fixed_list"), vec![(Box::new(col("large_list")), Box::new(lit("1")))], - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), + Utf8, schema ); test_case_expression!( Some("large_list"), vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), + Utf8, schema ); Ok(()) @@ -1875,7 +1931,7 @@ mod test { #[test] fn test_then_else_list() -> Result<()> { - let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); + let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true)); let schema = Arc::new(DFSchema::from_unqualified_fields( vec![ Field::new("boolean", DataType::Boolean, true), @@ -1903,7 +1959,7 @@ mod test { (Box::new(col("boolean")), Box::new(col("list"))) ], DataType::Boolean, - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), schema ); @@ -1914,7 +1970,7 @@ mod test { (Box::new(col("boolean")), Box::new(col("large_list"))) ], DataType::Boolean, - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), schema ); @@ -1926,7 +1982,7 @@ mod test { (Box::new(col("boolean")), Box::new(col("list"))) ], DataType::Boolean, - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), schema ); @@ -1937,7 +1993,7 @@ mod test { (Box::new(col("boolean")), Box::new(col("fixed_list"))) ], DataType::Boolean, - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))), schema ); @@ -1949,7 +2005,7 @@ mod test { (Box::new(col("boolean")), Box::new(col("large_list"))) ], DataType::Boolean, - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), schema ); @@ -1960,7 +2016,7 @@ mod test { (Box::new(col("boolean")), Box::new(col("fixed_list"))) ], DataType::Boolean, - DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))), schema ); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c13cb3a8e9734..4b9a83fd3e4c0 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,8 +17,8 @@ //! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions -use std::collections::{BTreeSet, HashMap}; -use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::collections::BTreeSet; +use std::fmt::Debug; use std::sync::Arc; use crate::{OptimizerConfig, OptimizerRule}; @@ -26,93 +26,18 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; -use datafusion_common::hash_utils::combine_hashes; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; + +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::tree_node::replace_sort_expressions; -use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator}; -use indexmap::IndexMap; +use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr}; const CSE_PREFIX: &str = "__common_expr"; -/// Identifier that represents a subexpression tree. -/// -/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and -/// "have no collision (as low as possible)" -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -struct Identifier<'n> { - // Hash of `expr` built up incrementally during the first, visiting traversal, but its - // value is not necessarily equal to `expr.hash()`. - hash: u64, - expr: &'n Expr, -} - -impl<'n> Identifier<'n> { - fn new(expr: &'n Expr, random_state: &RandomState) -> Self { - let mut hasher = random_state.build_hasher(); - expr.hash_node(&mut hasher); - let hash = hasher.finish(); - Self { hash, expr } - } - - fn combine(mut self, other: Option) -> Self { - other.map_or(self, |other_id| { - self.hash = combine_hashes(self.hash, other_id.hash); - self - }) - } -} - -impl Hash for Identifier<'_> { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - -/// A cache that contains the postorder index and the identifier of expression tree nodes -/// by the preorder index of the nodes. -/// -/// This cache is filled by `ExprIdentifierVisitor` during the first traversal and is used -/// by `CommonSubexprRewriter` during the second traversal. -/// -/// The purpose of this cache is to quickly find the identifier of a node during the -/// second traversal. -/// -/// Elements in this array are added during `f_down` so the indexes represent the preorder -/// index of expression nodes and thus element 0 belongs to the root of the expression -/// tree. -/// The elements of the array are tuples that contain: -/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start -/// from 0. -/// - Identifier of the expression. If empty (`""`), expr should not be considered for -/// CSE. -/// -/// # Example -/// An expression like `(a + b)` would have the following `IdArray`: -/// ```text -/// [ -/// (2, "a + b"), -/// (1, "a"), -/// (0, "b") -/// ] -/// ``` -type IdArray<'n> = Vec<(usize, Option>)>; - -/// A map that contains the number of normal and conditional occurrences of expressions by -/// their identifiers. -type ExprStats<'n> = HashMap, (usize, usize)>; - -/// A map that contains the common expressions and their alias extracted during the -/// second, rewriting traversal. -type CommonExprs<'n> = IndexMap, (Expr, String)>; - /// Performs Common Sub-expression Elimination optimization. /// /// This optimization improves query performance by computing expressions that @@ -140,168 +65,11 @@ type CommonExprs<'n> = IndexMap, (Expr, String)>; /// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once /// ``` #[derive(Debug)] -pub struct CommonSubexprEliminate { - random_state: RandomState, -} - -/// The result of potentially rewriting a list of expressions to eliminate common -/// subexpressions. -#[derive(Debug)] -enum FoundCommonExprs { - /// No common expressions were found - No { original_exprs_list: Vec> }, - /// Common expressions were found - Yes { - /// extracted common expressions - common_exprs: Vec<(Expr, String)>, - /// new expressions with common subexpressions replaced - new_exprs_list: Vec>, - /// original expressions - original_exprs_list: Vec>, - }, -} +pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { pub fn new() -> Self { - Self { - random_state: RandomState::new(), - } - } - - /// Returns the identifier list for each element in `exprs` and a flag to indicate if - /// rewrite phase of CSE make sense. - /// - /// Returns and array with 1 element for each input expr in `exprs` - /// - /// Each element is itself the result of [`CommonSubexprEliminate::expr_to_identifier`] for that expr - /// (e.g. the identifiers for each node in the tree) - fn to_arrays<'n>( - &self, - exprs: &'n [Expr], - expr_stats: &mut ExprStats<'n>, - expr_mask: ExprMask, - ) -> Result<(bool, Vec>)> { - let mut found_common = false; - exprs - .iter() - .map(|e| { - let mut id_array = vec![]; - self.expr_to_identifier(e, expr_stats, &mut id_array, expr_mask) - .map(|fc| { - found_common |= fc; - - id_array - }) - }) - .collect::>>() - .map(|id_arrays| (found_common, id_arrays)) - } - - /// Add an identifier to `id_array` for every subexpression in this tree. - fn expr_to_identifier<'n>( - &self, - expr: &'n Expr, - expr_stats: &mut ExprStats<'n>, - id_array: &mut IdArray<'n>, - expr_mask: ExprMask, - ) -> Result { - let mut visitor = ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - random_state: &self.random_state, - found_common: false, - conditional: false, - }; - expr.visit(&mut visitor)?; - - Ok(visitor.found_common) - } - - /// Rewrites `exprs_list` with common sub-expressions replaced with a new - /// column. - /// - /// `common_exprs` is updated with any sub expressions that were replaced. - /// - /// Returns the rewritten expressions - fn rewrite_exprs_list<'n>( - &self, - exprs_list: Vec>, - arrays_list: &[Vec>], - expr_stats: &ExprStats<'n>, - common_exprs: &mut CommonExprs<'n>, - alias_generator: &AliasGenerator, - ) -> Result>> { - exprs_list - .into_iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { - exprs - .into_iter() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr( - expr, - id_array, - expr_stats, - common_exprs, - alias_generator, - ) - }) - .collect::>>() - }) - .collect::>>() - } - - /// Extracts common sub-expressions and rewrites `exprs_list`. - /// - /// Returns `FoundCommonExprs` recording the result of the extraction - fn find_common_exprs( - &self, - exprs_list: Vec>, - config: &dyn OptimizerConfig, - expr_mask: ExprMask, - ) -> Result> { - let mut found_common = false; - let mut expr_stats = ExprStats::new(); - let id_arrays_list = exprs_list - .iter() - .map(|exprs| { - self.to_arrays(exprs, &mut expr_stats, expr_mask).map( - |(fc, id_arrays)| { - found_common |= fc; - - id_arrays - }, - ) - }) - .collect::>>()?; - if found_common { - let mut common_exprs = CommonExprs::new(); - let new_exprs_list = self.rewrite_exprs_list( - // Must clone as Identifiers use references to original expressions so we have - // to keep the original expressions intact. - exprs_list.clone(), - &id_arrays_list, - &expr_stats, - &mut common_exprs, - config.alias_generator().as_ref(), - )?; - assert!(!common_exprs.is_empty()); - - Ok(Transformed::yes(FoundCommonExprs::Yes { - common_exprs: common_exprs.into_values().collect(), - new_exprs_list, - original_exprs_list: exprs_list, - })) - } else { - Ok(Transformed::no(FoundCommonExprs::No { - original_exprs_list: exprs_list, - })) - } + Self {} } fn try_optimize_proj( @@ -322,6 +90,7 @@ impl CommonSubexprEliminate { .map(LogicalPlan::Projection) }) } + fn try_optimize_sort( &self, sort: Sort, @@ -329,12 +98,23 @@ impl CommonSubexprEliminate { ) -> Result> { let Sort { expr, input, fetch } = sort; let input = Arc::unwrap_or_clone(input); - let sort_expressions = expr.iter().map(|sort| sort.expr.clone()).collect(); + let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr + .into_iter() + .map(|sort| (sort.expr, (sort.asc, sort.nulls_first))) + .unzip(); let new_sort = self .try_unary_plan(sort_expressions, input, config)? .update_data(|(new_expr, new_input)| { LogicalPlan::Sort(Sort { - expr: replace_sort_expressions(expr, new_expr), + expr: new_expr + .into_iter() + .zip(sort_params) + .map(|(expr, (asc, nulls_first))| SortExpr { + expr, + asc, + nulls_first, + }) + .collect(), input: Arc::new(new_input), fetch, }) @@ -372,80 +152,83 @@ impl CommonSubexprEliminate { get_consecutive_window_exprs(window); // Extract common sub-expressions from the list. - self.find_common_exprs(window_expr_list, config, ExprMask::Normal)? - .map_data(|common| match common { - // If there are common sub-expressions, then the insert a projection node - // with the common expressions between the new window nodes and the - // original input. - FoundCommonExprs::Yes { - common_exprs, - new_exprs_list, - original_exprs_list, - } => { - build_common_expr_project_plan(input, common_exprs).map(|new_input| { - (new_exprs_list, new_input, Some(original_exprs_list)) + + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(window_expr_list)? + { + // If there are common sub-expressions, then the insert a projection node + // with the common expressions between the new window nodes and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: new_exprs_list, + original_nodes_list: original_exprs_list, + } => build_common_expr_project_plan(input, common_exprs).map(|new_input| { + Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list))) + }), + FoundCommonNodes::No { + original_nodes_list: original_exprs_list, + } => Ok(Transformed::no((original_exprs_list, input, None))), + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok((new_window_expr_list, new_input, window_expr_list)) + }) + })? + // Rebuild the consecutive window nodes. + .map_data(|(new_window_expr_list, new_input, window_expr_list)| { + // If there were common expressions extracted, then we need to make sure + // we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around extracted + // common expressions this doesn't mean that the original column names + // (schema) are preserved due to the inserted aliases are not always at + // the top of the expression. + // Let's consider improving `find_common_exprs()` to always keep column + // names and get rid of additional name preserving logic here. + if let Some(window_expr_list) = window_expr_list { + let name_preserver = NamePreserver::new_for_projection(); + let saved_names = window_expr_list + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>() }) - } - FoundCommonExprs::No { - original_exprs_list, - } => Ok((original_exprs_list, input, None)), - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { - self.rewrite(new_input, config)?.map_data(|new_input| { - Ok((new_window_expr_list, new_input, window_expr_list)) - }) - })? - // Rebuild the consecutive window nodes. - .map_data(|(new_window_expr_list, new_input, window_expr_list)| { - // If there were common expressions extracted, then we need to make sure - // we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around extracted - // common expressions this doesn't mean that the original column names - // (schema) are preserved due to the inserted aliases are not always at - // the top of the expression. - // Let's consider improving `find_common_exprs()` to always keep column - // names and get rid of additional name preserving logic here. - if let Some(window_expr_list) = window_expr_list { - let name_preserver = NamePreserver::new_for_projection(); - let saved_names = window_expr_list - .iter() - .map(|exprs| { - exprs - .iter() - .map(|expr| name_preserver.save(expr)) - .collect::>() - }) - .collect::>(); - new_window_expr_list.into_iter().zip(saved_names).try_rfold( - new_input, - |plan, (new_window_expr, saved_names)| { - let new_window_expr = new_window_expr - .into_iter() - .zip(saved_names) - .map(|(new_window_expr, saved_name)| { - saved_name.restore(new_window_expr) - }) - .collect::>(); - Window::try_new(new_window_expr, Arc::new(plan)) - .map(LogicalPlan::Window) - }, - ) - } else { - new_window_expr_list - .into_iter() - .zip(window_schemas) - .try_rfold(new_input, |plan, (new_window_expr, schema)| { - Window::try_new_with_schema( - new_window_expr, - Arc::new(plan), - schema, - ) + .collect::>(); + new_window_expr_list.into_iter().zip(saved_names).try_rfold( + new_input, + |plan, (new_window_expr, saved_names)| { + let new_window_expr = new_window_expr + .into_iter() + .zip(saved_names) + .map(|(new_window_expr, saved_name)| { + saved_name.restore(new_window_expr) + }) + .collect::>(); + Window::try_new(new_window_expr, Arc::new(plan)) .map(LogicalPlan::Window) - }) - } - }) + }, + ) + } else { + new_window_expr_list + .into_iter() + .zip(window_schemas) + .try_rfold(new_input, |plan, (new_window_expr, schema)| { + Window::try_new_with_schema( + new_window_expr, + Arc::new(plan), + schema, + ) + .map(LogicalPlan::Window) + }) + } + }) } fn try_optimize_aggregate( @@ -462,174 +245,175 @@ impl CommonSubexprEliminate { } = aggregate; let input = Arc::unwrap_or_clone(input); // Extract common sub-expressions from the aggregate and grouping expressions. - self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)? - .map_data(|common| { - match common { - // If there are common sub-expressions, then insert a projection node - // with the common expressions between the new aggregate node and the - // original input. - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - mut original_exprs_list, - } => { - let new_aggr_expr = new_exprs_list.pop().unwrap(); - let new_group_expr = new_exprs_list.pop().unwrap(); - - build_common_expr_project_plan(input, common_exprs).map( - |new_input| { - let aggr_expr = original_exprs_list.pop().unwrap(); - ( - new_aggr_expr, - new_group_expr, - new_input, - Some(aggr_expr), - ) - }, - ) - } - - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let new_aggr_expr = original_exprs_list.pop().unwrap(); - let new_group_expr = original_exprs_list.pop().unwrap(); - - Ok((new_aggr_expr, new_group_expr, input, None)) - } - } - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { - self.rewrite(new_input, config)?.map_data(|new_input| { - Ok(( + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![group_expr, aggr_expr])? + { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); + + build_common_expr_project_plan(input, common_exprs).map(|new_input| { + let aggr_expr = original_exprs_list.pop().unwrap(); + Transformed::yes(( new_aggr_expr, new_group_expr, - aggr_expr, - Arc::new(new_input), + new_input, + Some(aggr_expr), )) }) - })? - // Try extracting common aggregate expressions and rebuild the aggregate node. - .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { + } + + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); + + Ok(Transformed::no(( + new_aggr_expr, + new_group_expr, + input, + None, + ))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok(( + new_aggr_expr, + new_group_expr, + aggr_expr, + Arc::new(new_input), + )) + }) + })? + // Try extracting common aggregate expressions and rebuild the aggregate node. + .transform_data( + |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { // Extract common aggregate sub-expressions from the aggregate expressions. - self.find_common_exprs( - vec![new_aggr_expr], - config, + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), ExprMask::NormalAndAggregates, - )? - .map_data(|common| { - match common { - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - mut original_exprs_list, - } => { - let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); - let new_aggr_expr = original_exprs_list.pop().unwrap(); - - let mut agg_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); + )) + .extract_common_nodes(vec![new_aggr_expr])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &mut proj_exprs) - } - for (expr_rewritten, expr_orig) in - rewritten_aggr_expr.into_iter().zip(new_aggr_expr) - { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = - expr_rewritten - { - agg_exprs.push(expr.alias(&name)); - proj_exprs - .push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = - config.alias_generator().next(CSE_PREFIX); - let (qualifier, field_name) = - expr_rewritten.qualified_name(); - let out_name = qualified_name( - qualifier.as_ref(), - &field_name, - ); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)) - .alias(out_name), - ); - } + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &mut proj_exprs) + } + for (expr_rewritten, expr_orig) in + rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = + expr_rewritten + { + agg_exprs.push(expr.alias(&name)); + proj_exprs + .push(Expr::Column(Column::from_name(name))); } else { - proj_exprs.push(expr_rewritten); + let expr_alias = + config.alias_generator().next(CSE_PREFIX); + let (qualifier, field_name) = + expr_rewritten.qualified_name(); + let out_name = + qualified_name(qualifier.as_ref(), &field_name); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)) + .alias(out_name), + ); } + } else { + proj_exprs.push(expr_rewritten); } - - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - new_input, - new_group_expr, - agg_exprs, - )?); - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) } - // If there aren't any common aggregate sub-expressions, then just - // rebuild the aggregate node. - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); - - // If there were common expressions extracted, then we need to - // make sure we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around - // extracted common expressions this doesn't mean that the - // original column names (schema) are preserved due to the - // inserted aliases are not always at the top of the - // expression. - // Let's consider improving `find_common_exprs()` to always - // keep column names and get rid of additional name - // preserving logic here. - if let Some(aggr_expr) = aggr_expr { - let name_perserver = NamePreserver::new_for_projection(); - let saved_names = aggr_expr - .iter() - .map(|expr| name_perserver.save(expr)) - .collect::>(); - let new_aggr_expr = rewritten_aggr_expr - .into_iter() - .zip(saved_names) - .map(|(new_expr, saved_name)| { - saved_name.restore(new_expr) - }) - .collect::>(); - - // Since `group_expr` may have changed, schema may also. - // Use `try_new()` method. - Aggregate::try_new( - new_input, - new_group_expr, - new_aggr_expr, - ) - .map(LogicalPlan::Aggregate) - } else { - Aggregate::try_new_with_schema( - new_input, - new_group_expr, - rewritten_aggr_expr, - schema, - ) + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(|p| Transformed::yes(LogicalPlan::Projection(p))) + } + + // If there aren't any common aggregate sub-expressions, then just + // rebuild the aggregate node. + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); + + // If there were common expressions extracted, then we need to + // make sure we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around + // extracted common expressions this doesn't mean that the + // original column names (schema) are preserved due to the + // inserted aliases are not always at the top of the + // expression. + // Let's consider improving `find_common_exprs()` to always + // keep column names and get rid of additional name + // preserving logic here. + if let Some(aggr_expr) = aggr_expr { + let name_preserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>(); + let new_aggr_expr = rewritten_aggr_expr + .into_iter() + .zip(saved_names) + .map(|(new_expr, saved_name)| { + saved_name.restore(new_expr) + }) + .collect::>(); + + // Since `group_expr` may have changed, schema may also. + // Use `try_new()` method. + Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) - } + .map(Transformed::no) + } else { + Aggregate::try_new_with_schema( + new_input, + new_group_expr, + rewritten_aggr_expr, + schema, + ) + .map(LogicalPlan::Aggregate) + .map(Transformed::no) } } - }) - }) + } + }, + ) } /// Rewrites the expr list and input to remove common subexpressions @@ -653,30 +437,34 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result, LogicalPlan)>> { // Extract common sub-expressions from the expressions. - self.find_common_exprs(vec![exprs], config, ExprMask::Normal)? - .map_data(|common| match common { - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - original_exprs_list: _, - } => { - let new_exprs = new_exprs_list.pop().unwrap(); - build_common_expr_project_plan(input, common_exprs) - .map(|new_input| (new_exprs, new_input)) - } - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let new_exprs = original_exprs_list.pop().unwrap(); - Ok((new_exprs, input)) - } - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_exprs, new_input)| { - self.rewrite(new_input, config)? - .map_data(|new_input| Ok((new_exprs, new_input))) - }) + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![exprs])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let new_exprs = new_exprs_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs) + .map(|new_input| Transformed::yes((new_exprs, new_input))) + } + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_exprs = original_exprs_list.pop().unwrap(); + Ok(Transformed::no((new_exprs, input))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_exprs, new_input)| { + self.rewrite(new_input, config)? + .map_data(|new_input| Ok((new_exprs, new_input))) + }) } } @@ -743,6 +531,7 @@ impl OptimizerRule for CommonSubexprEliminate { None } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, @@ -757,7 +546,6 @@ impl OptimizerRule for CommonSubexprEliminate { LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) @@ -776,8 +564,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Prepare(_) => { + | LogicalPlan::RecursiveQuery(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? @@ -800,71 +587,6 @@ impl OptimizerRule for CommonSubexprEliminate { } } -impl Default for CommonSubexprEliminate { - fn default() -> Self { - Self::new() - } -} - -/// Build the "intermediate" projection plan that evaluates the extracted common -/// expressions. -/// -/// # Arguments -/// input: the input plan -/// -/// common_exprs: which common subexpressions were used (and thus are added to -/// intermediate projection) -/// -/// expr_stats: the set of common subexpressions -fn build_common_expr_project_plan( - input: LogicalPlan, - common_exprs: Vec<(Expr, String)>, -) -> Result { - let mut fields_set = BTreeSet::new(); - let mut project_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| { - fields_set.insert(expr_alias.clone()); - Ok(expr.alias(expr_alias)) - }) - .collect::>>()?; - - for (qualifier, field) in input.schema().iter() { - if fields_set.insert(qualified_name(qualifier, field.name())) { - project_exprs.push(Expr::from((qualifier, field))); - } - } - - Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) -} - -/// Build the projection plan to eliminate unnecessary columns produced by -/// the "intermediate" projection plan built in [build_common_expr_project_plan]. -/// -/// This is required to keep the schema the same for plans that pass the input -/// on to the output, such as `Filter` or `Sort`. -fn build_recover_project_plan( - schema: &DFSchema, - input: LogicalPlan, -) -> Result { - let col_exprs = schema.iter().map(Expr::from).collect(); - Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) -} - -fn extract_expressions(expr: &Expr, result: &mut Vec) { - if let Expr::GroupingSet(groupings) = expr { - for e in groupings.distinct_expr() { - let (qualifier, field_name) = e.qualified_name(); - let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)) - } - } else { - let (qualifier, field_name) = expr.qualified_name(); - let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)); - } -} - /// Which type of [expressions](Expr) should be considered for rewriting? #[derive(Debug, Clone, Copy)] enum ExprMask { @@ -882,156 +604,36 @@ enum ExprMask { NormalAndAggregates, } -impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { - let is_normal_minus_aggregates = matches!( - expr, - Expr::Literal(..) - | Expr::Column(..) - | Expr::ScalarVariable(..) - | Expr::Alias(..) - | Expr::Wildcard { .. } - ); - - let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } - } -} - -/// Go through an expression tree and generate identifiers for each subexpression. -/// -/// An identifier contains information of the expression itself and its sub-expression. -/// This visitor implementation use a stack `visit_stack` to track traversal, which -/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called -/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` -/// before the first `EnterMark` is considered to be sub-tree of the leaving node. -/// -/// This visitor also records identifier in `id_array`. Makes the following traverse -/// pass can get the identifier of a node without recalculate it. We assign each node -/// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`f_up()`) a node. Has the property -/// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to -/// get the index of `id_array` for each node. -/// -/// `Expr` without sub-expr (column, literal etc.) will not have identifier -/// because they should not be recognized as common sub-expr. -struct ExprIdentifierVisitor<'a, 'n> { - // statistics of expressions - expr_stats: &'a mut ExprStats<'n>, - // cache to speed up second traversal - id_array: &'a mut IdArray<'n>, - // inner states - visit_stack: Vec>, - // preorder index, start from 0. - down_index: usize, - // postorder index, start from 0. - up_index: usize, - // which expression should be skipped? - expr_mask: ExprMask, - // a `RandomState` to generate hashes during the first traversal - random_state: &'a RandomState, - // a flag to indicate that common expression found - found_common: bool, - // if we are in a conditional branch. A conditional branch means that the expression - // might not be executed depending on the runtime values of other expressions, and - // thus can not be extracted as a common expression. - conditional: bool, -} +struct ExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: ExprMask, -/// Record item that used when traversing an expression tree. -enum VisitRecord<'n> { - /// Marks the beginning of expression. It contains: - /// - The post-order index assigned during the first, visiting traversal. - EnterMark(usize), - - /// Marks an accumulated subexpression tree. It contains: - /// - The accumulated identifier of a subexpression. - /// - A boolean flag if the expression is valid for subexpression elimination. - /// The flag is propagated up from children to parent. (E.g. volatile expressions - /// are not valid and can't be extracted, but non-volatile children of volatile - /// expressions can be extracted.) - ExprItem(Identifier<'n>, bool), + // how many aliases have we seen so far + alias_counter: usize, } -impl<'n> ExprIdentifierVisitor<'_, 'n> { - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before - /// it. Returns a tuple that contains: - /// - The pre-order index of the expression we marked. - /// - The accumulated identifier of the children of the marked expression. - /// - An accumulated boolean flag from the children of the marked expression if all - /// children are valid for subexpression elimination (i.e. it is safe to extract the - /// expression as a common expression from its children POV). - /// (E.g. if any of the children of the marked expression is not valid (e.g. is - /// volatile) then the expression is also not valid, so we can propagate this - /// information up from children to parents via `visit_stack` during the first, - /// visiting traversal and no need to test the expression's validity beforehand with - /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { - let mut expr_id = None; - let mut is_valid = true; - - while let Some(item) = self.visit_stack.pop() { - match item { - VisitRecord::EnterMark(down_index) => { - return (down_index, expr_id, is_valid); - } - VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { - expr_id = Some(sub_expr_id.combine(expr_id)); - is_valid &= sub_expr_is_valid; - } - } +impl<'a> ExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self { + Self { + alias_generator, + mask, + alias_counter: 0, } - unreachable!("Enter mark should paired with node number"); - } - - /// Save the current `conditional` status and run `f` with `conditional` set to true. - fn conditionally Result<()>>( - &mut self, - mut f: F, - ) -> Result<()> { - let conditional = self.conditional; - self.conditional = true; - f(self)?; - self.conditional = conditional; - - Ok(()) } } -impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { +impl CSEController for ExprCSEController<'_> { type Node = Expr; - fn f_down(&mut self, expr: &'n Expr) -> Result { - self.id_array.push((0, None)); - self.visit_stack - .push(VisitRecord::EnterMark(self.down_index)); - self.down_index += 1; - - // If an expression can short-circuit then some of its children might not be - // executed so count the occurrence of subexpressions as conditional in all - // children. - Ok(match expr { - // If we are already in a conditionally evaluated subtree then continue - // traversal. - _ if self.conditional => TreeNodeRecursion::Continue, - + fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { + match node { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) if func.short_circuits() => { - self.conditionally(|visitor| { - args.iter().try_for_each(|e| e.visit(visitor).map(|_| ())) - })?; - - TreeNodeRecursion::Jump + Some((vec![], args.iter().collect())) } // In case of `And` and `Or` the first child is surely executed, but we @@ -1040,12 +642,7 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { left, op: Operator::And | Operator::Or, right, - }) => { - left.visit(self)?; - self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?; - - TreeNodeRecursion::Jump - } + }) => Some((vec![left.as_ref()], vec![right.as_ref()])), // In case of `Case` the optional base expression and the first when // expressions are surely executed, but we account subexpressions as @@ -1054,165 +651,150 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { expr, when_then_expr, else_expr, - }) => { - expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?; - when_then_expr.iter().take(1).try_for_each(|(when, then)| { - when.visit(self)?; - self.conditionally(|visitor| then.visit(visitor).map(|_| ())) - })?; - self.conditionally(|visitor| { - when_then_expr.iter().skip(1).try_for_each(|(when, then)| { - when.visit(visitor)?; - then.visit(visitor).map(|_| ()) - })?; - else_expr - .iter() - .try_for_each(|e| e.visit(visitor).map(|_| ())) - })?; - - TreeNodeRecursion::Jump - } + }) => Some(( + expr.iter() + .map(|e| e.as_ref()) + .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref())) + .collect(), + when_then_expr + .iter() + .take(1) + .map(|(_, then)| then.as_ref()) + .chain( + when_then_expr + .iter() + .skip(1) + .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]), + ) + .chain(else_expr.iter().map(|e| e.as_ref())) + .collect(), + )), + _ => None, + } + } - // In case of non-short-circuit expressions continue the traversal. - _ => TreeNodeRecursion::Continue, - }) + fn is_valid(node: &Expr) -> bool { + !node.is_volatile_node() } - fn f_up(&mut self, expr: &'n Expr) -> Result { - let (down_index, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); + fn is_ignored(&self, node: &Expr) -> bool { + let is_normal_minus_aggregates = matches!( + node, + Expr::Literal(..) + | Expr::Column(..) + | Expr::ScalarVariable(..) + | Expr::Alias(..) + | Expr::Wildcard { .. } + ); - let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); - let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; + let is_aggr = matches!(node, Expr::AggregateFunction(..)); - self.id_array[down_index].0 = self.up_index; - if is_valid && !self.expr_mask.ignores(expr) { - self.id_array[down_index].1 = Some(expr_id); - let (count, conditional_count) = - self.expr_stats.entry(expr_id).or_insert((0, 0)); - if self.conditional { - *conditional_count += 1; - } else { - *count += 1; - } - if *count > 1 || (*count == 1 && *conditional_count > 0) { - self.found_common = true; - } + match self.mask { + ExprMask::Normal => is_normal_minus_aggregates || is_aggr, + ExprMask::NormalAndAggregates => is_normal_minus_aggregates, } - self.visit_stack - .push(VisitRecord::ExprItem(expr_id, is_valid)); - self.up_index += 1; - - Ok(TreeNodeRecursion::Continue) } -} -/// Rewrite expression by replacing detected common sub-expression with -/// the corresponding temporary column name. That column contains the -/// evaluate result of replaced expression. -struct CommonSubexprRewriter<'a, 'n> { - // statistics of expressions - expr_stats: &'a ExprStats<'n>, - // cache to speed up second traversal - id_array: &'a IdArray<'n>, - // common expression, that are replaced during the second traversal, are collected to - // this map - common_exprs: &'a mut CommonExprs<'n>, - // preorder index, starts from 0. - down_index: usize, - // how many aliases have we seen so far - alias_counter: usize, - // alias generator for extracted common expressions - alias_generator: &'a AliasGenerator, -} + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } -impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { - type Node = Expr; + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + // alias the expressions without an `Alias` ancestor node + if self.alias_counter > 0 { + col(alias) + } else { + self.alias_counter += 1; + col(alias).alias(node.schema_name().to_string()) + } + } - fn f_down(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { + fn rewrite_f_down(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { self.alias_counter += 1; } + } + fn rewrite_f_up(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { + self.alias_counter -= 1 + } + } +} - let (up_index, expr_id) = self.id_array[self.down_index]; - self.down_index += 1; +impl Default for CommonSubexprEliminate { + fn default() -> Self { + Self::new() + } +} - // Handle `Expr`s with identifiers only - if let Some(expr_id) = expr_id { - let (count, conditional_count) = self.expr_stats.get(&expr_id).unwrap(); - if *count > 1 || *count == 1 && *conditional_count > 0 { - // step index to skip all sub-node (which has smaller series number). - while self.down_index < self.id_array.len() - && self.id_array[self.down_index].0 < up_index - { - self.down_index += 1; - } +/// Build the "intermediate" projection plan that evaluates the extracted common +/// expressions. +/// +/// # Arguments +/// input: the input plan +/// +/// common_exprs: which common subexpressions were used (and thus are added to +/// intermediate projection) +/// +/// expr_stats: the set of common subexpressions +fn build_common_expr_project_plan( + input: LogicalPlan, + common_exprs: Vec<(Expr, String)>, +) -> Result { + let mut fields_set = BTreeSet::new(); + let mut project_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| { + fields_set.insert(expr_alias.clone()); + Ok(expr.alias(expr_alias)) + }) + .collect::>>()?; - let expr_name = expr.schema_name().to_string(); - let (_, expr_alias) = - self.common_exprs.entry(expr_id).or_insert_with(|| { - let expr_alias = self.alias_generator.next(CSE_PREFIX); - (expr, expr_alias) - }); - - // alias the expressions without an `Alias` ancestor node - let rewritten = if self.alias_counter > 0 { - col(expr_alias.clone()) - } else { - self.alias_counter += 1; - col(expr_alias.clone()).alias(expr_name) - }; - - return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); - } + for (qualifier, field) in input.schema().iter() { + if fields_set.insert(qualified_name(qualifier, field.name())) { + project_exprs.push(Expr::from((qualifier, field))); } - - Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { - self.alias_counter -= 1 - } + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) +} - Ok(Transformed::no(expr)) - } +/// Build the projection plan to eliminate unnecessary columns produced by +/// the "intermediate" projection plan built in [build_common_expr_project_plan]. +/// +/// This is required to keep the schema the same for plans that pass the input +/// on to the output, such as `Filter` or `Sort`. +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { + let col_exprs = schema.iter().map(Expr::from).collect(); + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) } -/// Replace common sub-expression in `expr` with the corresponding temporary -/// column name, updating `common_exprs` with any replaced expressions -fn replace_common_expr<'n>( - expr: Expr, - id_array: &IdArray<'n>, - expr_stats: &ExprStats<'n>, - common_exprs: &mut CommonExprs<'n>, - alias_generator: &AliasGenerator, -) -> Result { - if id_array.is_empty() { - Ok(Transformed::no(expr)) +fn extract_expressions(expr: &Expr, result: &mut Vec) { + if let Expr::GroupingSet(groupings) = expr { + for e in groupings.distinct_expr() { + let (qualifier, field_name) = e.qualified_name(); + let col = Column::new(qualifier, field_name); + result.push(Expr::Column(col)) + } } else { - expr.rewrite(&mut CommonSubexprRewriter { - expr_stats, - id_array, - common_exprs, - down_index: 0, - alias_counter: 0, - alias_generator, - }) + let (qualifier, field_name) = expr.qualified_name(); + let col = Column::new(qualifier, field_name); + result.push(Expr::Column(col)); } - .data() } #[cfg(test)] mod test { use std::any::Any; - use std::collections::HashSet; use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, + grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, }; @@ -1238,154 +820,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let optimizer = CommonSubexprEliminate::new(); - - let a_plus_1 = col("a") + lit(1); - let avg_c = avg(col("c")); - let sum_a_plus_1 = sum(a_plus_1); - let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c; - let expr = sum_a_plus_1_minus_avg_c * lit(2); - - let Expr::BinaryExpr(BinaryExpr { - left: sum_a_plus_1_minus_avg_c, - .. - }) = &expr - else { - panic!("Cannot extract subexpression reference") - }; - let Expr::BinaryExpr(BinaryExpr { - left: sum_a_plus_1, - right: avg_c, - .. - }) = sum_a_plus_1_minus_avg_c.as_ref() - else { - panic!("Cannot extract subexpression reference") - }; - let Expr::AggregateFunction(AggregateFunction { - args: a_plus_1_vec, .. - }) = sum_a_plus_1.as_ref() - else { - panic!("Cannot extract subexpression reference") - }; - let a_plus_1 = &a_plus_1_vec.as_slice()[0]; - - // skip aggregates - let mut id_array = vec![]; - optimizer.expr_to_identifier( - &expr, - &mut ExprStats::new(), - &mut id_array, - ExprMask::Normal, - )?; - - // Collect distinct hashes and set them to 0 in `id_array` - fn collect_hashes(id_array: &mut IdArray) -> HashSet { - id_array - .iter_mut() - .flat_map(|(_, expr_id_option)| { - expr_id_option.as_mut().map(|expr_id| { - let hash = expr_id.hash; - expr_id.hash = 0; - hash - }) - }) - .collect::>() - } - - let hashes = collect_hashes(&mut id_array); - assert_eq!(hashes.len(), 3); - - let expected = vec![ - ( - 8, - Some(Identifier { - hash: 0, - expr: &expr, - }), - ), - ( - 6, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1_minus_avg_c, - }), - ), - (3, None), - ( - 2, - Some(Identifier { - hash: 0, - expr: a_plus_1, - }), - ), - (0, None), - (1, None), - (5, None), - (4, None), - (7, None), - ]; - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - optimizer.expr_to_identifier( - &expr, - &mut ExprStats::new(), - &mut id_array, - ExprMask::NormalAndAggregates, - )?; - - let hashes = collect_hashes(&mut id_array); - assert_eq!(hashes.len(), 5); - - let expected = vec![ - ( - 8, - Some(Identifier { - hash: 0, - expr: &expr, - }), - ), - ( - 6, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1_minus_avg_c, - }), - ), - ( - 3, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1, - }), - ), - ( - 2, - Some(Identifier { - hash: 0, - expr: a_plus_1, - }), - ), - (0, None), - (1, None), - ( - 5, - Some(Identifier { - hash: 0, - expr: avg_c, - }), - ), - (4, None), - (7, None), - ]; - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: @@ -1517,7 +951,7 @@ mod test { )? .build()?; - let expected ="Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; @@ -1620,8 +1054,9 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ - \n TableScan: test"; + let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ + \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; assert_optimized_plan_eq(expected, plan, None); @@ -1978,6 +1413,259 @@ mod test { Ok(()) } + #[test] + fn test_normalize_add_expression() -> Result<()> { + // a + b <=> b + a + let table_scan = test_table_scan()?; + let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_multi_expression() -> Result<()> { + // a * b <=> b * a + let table_scan = test_table_scan()?; + let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_and_expression() -> Result<()> { + // a & b <=> b & a + let table_scan = test_table_scan()?; + let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_or_expression() -> Result<()> { + // a | b <=> b | a + let table_scan = test_table_scan()?; + let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_xor_expression() -> Result<()> { + // a # b <=> b # a + let table_scan = test_table_scan()?; + let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_eq_expression() -> Result<()> { + // a = b <=> b = a + let table_scan = test_table_scan()?; + let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 AND __common_expr_1\ + \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_ne_expression() -> Result<()> { + // a != b <=> b != a + let table_scan = test_table_scan()?; + let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 AND __common_expr_1\ + \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_complex_expression() -> Result<()> { + // case1: a + b * c <=> b * c + a + let table_scan = test_table_scan()?; + let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a"))) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ + \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) + let table_scan = test_table_scan()?; + let expr = (((col("a") + col("b") / col("c")) * col("c")) + / (col("c") * (col("b") / col("c") + col("a"))) + + col("a")) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ + \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // c2 / (c1 + c3) <=> c2 / (c3 + c1) + let table_scan = test_table_scan()?; + let expr = ((col("b") / (col("a") + col("c"))) + * (col("b") / (col("c") + col("a")))) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ + \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[derive(Debug)] + pub struct TestUdf { + signature: Signature, + } + + impl TestUdf { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for TestUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "my_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _: &[ColumnarValue]) -> Result { + panic!("not implemented") + } + } + + #[test] + fn test_normalize_inner_binary_expression() -> Result<()> { + // Not(a == b) <=> Not(b == a) + let table_scan = test_table_scan()?; + let expr1 = not(col("a").eq(col("b"))); + let expr2 = not(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ + \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // is_null(a == b) <=> is_null(b == a) + let table_scan = test_table_scan()?; + let expr1 = is_null(col("a").eq(col("b"))); + let expr2 = is_null(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ + \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // a + b between 0 and 10 <=> b + a between 0 and 10 + let table_scan = test_table_scan()?; + let expr1 = (col("a") + col("b")).between(lit(0), lit(10)); + let expr2 = (col("b") + col("a")).between(lit(0), lit(10)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ + \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // c between a + b and 10 <=> c between b + a and 10 + let table_scan = test_table_scan()?; + let expr1 = col("c").between(col("a") + col("b"), lit(10)); + let expr2 = col("c").between(col("b") + col("a"), lit(10)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ + \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // function call with argument <=> function call with argument + let udf = ScalarUDF::from(TestUdf::new()); + let table_scan = test_table_scan()?; + let expr1 = udf.call(vec![col("a") + col("b")]); + let expr2 = udf.call(vec![col("b") + col("a")]); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ + \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + Ok(()) + } + /// returns a "random" function that is marked volatile (aka each invocation /// returns a different value) /// @@ -2016,7 +1704,11 @@ mod test { Ok(DataType::Float64) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { unimplemented!() } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 154b3faf4711d..9de905f42cd18 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -17,22 +17,24 @@ //! [`PullUpCorrelatedExpr`] converts correlated subqueries to `Joins` -use std::collections::{BTreeSet, HashMap}; +use std::collections::BTreeSet; use std::ops::Deref; use std::sync::Arc; use crate::simplify_expressions::ExprSimplifier; -use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; +use datafusion_common::{plan_err, Column, DFSchemaRef, HashMap, Result, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; +use datafusion_expr::utils::{ + collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, +}; use datafusion_expr::{ - expr, lit, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, Scalar, + expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, + LogicalPlanBuilder, Operator, Scalar, }; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -51,6 +53,9 @@ pub struct PullUpCorrelatedExpr { pub exists_sub_query: bool, /// Can the correlated expressions be pulled up. Defaults to **TRUE** pub can_pull_up: bool, + /// Indicates if we encounter any correlated expression that can not be pulled up + /// above a aggregation without changing the meaning of the query. + can_pull_over_aggregation: bool, /// Do we need to handle [the Count bug] during the pull up process /// /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 @@ -75,6 +80,7 @@ impl PullUpCorrelatedExpr { in_predicate_opt: None, exists_sub_query: false, can_pull_up: true, + can_pull_over_aggregation: true, need_handle_count_bug: false, collected_count_expr_map: HashMap::new(), pull_up_having_expr: None, @@ -154,6 +160,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { match &plan { LogicalPlan::Filter(plan_filter) => { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + self.can_pull_over_aggregation = self.can_pull_over_aggregation + && subquery_filter_exprs + .iter() + .filter(|e| e.contains_outer()) + .all(|&e| can_pullup_over_aggregation(e)); let (mut join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?; if let Some(in_predicate) = &self.in_predicate_opt { @@ -259,6 +270,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { + // If the aggregation is from a distinct it will not change the result for + // exists/in subqueries so we can still pull up all predicates. + let is_distinct = aggregate.aggr_expr.is_empty(); + if !is_distinct { + self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation; + } let mut local_correlated_cols = BTreeSet::new(); collect_local_correlated_cols( &plan, @@ -329,16 +346,15 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => Transformed::yes( - if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { + (true, false) => Transformed::yes(match limit.get_fetch_type()? { + FetchType::Literal(Some(0)) => { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::clone(limit.input.schema()), }) - } else { - LogicalPlanBuilder::from((*limit.input).clone()).build()? - }, - ), + } + _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, + }), _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { @@ -386,6 +402,33 @@ impl PullUpCorrelatedExpr { } } +fn can_pullup_over_aggregation(expr: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + { + match (left.deref(), right.deref()) { + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }), right) + if matches!(expr.deref(), Expr::Column(_)) => + { + !right.any_column_refs() + } + (left, Expr::Cast(Cast { expr, .. })) + if matches!(expr.deref(), Expr::Column(_)) => + { + !left.any_column_refs() + } + (_, _) => false, + } + } else { + false + } +} + fn collect_local_correlated_cols( plan: &LogicalPlan, all_cols_map: &HashMap>, diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index d1ac80003ba71..a87688c1a3179 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -27,11 +27,11 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; +use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, @@ -48,79 +48,6 @@ impl DecorrelatePredicateSubquery { pub fn new() -> Self { Self::default() } - - fn rewrite_subquery( - &self, - mut subquery: Subquery, - config: &dyn OptimizerConfig, - ) -> Result { - subquery.subquery = Arc::new( - self.rewrite(Arc::unwrap_or_clone(subquery.subquery), config)? - .data, - ); - Ok(subquery) - } - - /// Finds expressions that have the predicate subqueries (and recurses when found) - /// - /// # Arguments - /// - /// * `predicate` - A conjunction to split and search - /// * `optimizer_config` - For generating unique subquery aliases - /// - /// Returns a tuple (subqueries, non-subquery expressions) - fn extract_subquery_exprs( - &self, - predicate: Expr, - config: &dyn OptimizerConfig, - ) -> Result<(Vec, Vec)> { - let filters = split_conjunction_owned(predicate); // TODO: add ExistenceJoin to support disjunctions - - let mut subqueries = vec![]; - let mut others = vec![]; - for it in filters.into_iter() { - match it { - Expr::Not(not_expr) => match *not_expr { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - !negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, !negated)); - } - expr => others.push(not(expr)), - }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, negated)); - } - expr => others.push(expr), - } - } - - Ok((subqueries, others)) - } } impl OptimizerRule for DecorrelatePredicateSubquery { @@ -133,69 +60,51 @@ impl OptimizerRule for DecorrelatePredicateSubquery { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + let plan = plan + .map_subqueries(|subquery| { + subquery.transform_down(|p| self.rewrite(p, config)) + })? + .data; + let LogicalPlan::Filter(filter) = plan else { return Ok(Transformed::no(plan)); }; - // if there are no subqueries in the predicate, return the original plan - let has_subqueries = - split_conjunction(&filter.predicate) - .iter() - .any(|expr| match expr { - Expr::Not(not_expr) => { - matches!(not_expr.as_ref(), Expr::InSubquery(_) | Expr::Exists(_)) - } - Expr::InSubquery(_) | Expr::Exists(_) => true, - _ => false, - }); - - if !has_subqueries { + if !has_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - let Filter { - predicate, input, .. - } = filter; - let (subqueries, mut other_exprs) = - self.extract_subquery_exprs(predicate, config)?; - if subqueries.is_empty() { + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + if with_subqueries.is_empty() { return internal_err!( "can not find expected subqueries in DecorrelatePredicateSubquery" ); } // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(input); - for subquery in subqueries { - if let Some(plan) = - build_join(&subquery, &cur_input, config.alias_generator())? - { - cur_input = plan; - } else { - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - let sub_query_expr = match subquery { - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: false, - } => in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: true, - } => not_in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: false, - } => exists(query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: true, - } => not_exists(query.subquery), - }; - other_exprs.push(sub_query_expr); + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top(&subquery, &cur_input, config.alias_generator())? + { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } } @@ -216,6 +125,101 @@ impl OptimizerRule for DecorrelatePredicateSubquery { } } +fn rewrite_inner_subqueries( + outer: LogicalPlan, + expr: Expr, + config: &dyn OptimizerConfig, +) -> Result<(LogicalPlan, Expr)> { + let mut cur_input = outer; + let alias = config.alias_generator(); + let expr_without_subqueries = expr.transform(|e| match e { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + negated, + }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + }, + Expr::InSubquery(InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }) => { + let in_predicate = subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |output_expr| { + Ok(Expr::eq(*expr.clone(), output_expr)) + })?; + match mark_join( + &cur_input, + Arc::clone(&subquery), + Some(in_predicate), + negated, + alias, + )? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))), + None => Ok(Transformed::no(in_subquery(*expr, subquery))), + } + } + _ => Ok(Transformed::no(e)), + })?; + Ok((cur_input, expr_without_subqueries.data)) +} + +enum SubqueryPredicate { + // The subquery expression is at the top level of the filter and can be fully replaced by a + // semi/anti join + Top(SubqueryInfo), + // The subquery expression is embedded within another expression and is replaced using an + // existence join + Embedded(Expr), +} + +fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { + match expr { + Expr::Not(not_expr) => match *not_expr { + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, !negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) + } + expr => SubqueryPredicate::Embedded(not(expr)), + }, + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) + } + expr => SubqueryPredicate::Embedded(expr), + } +} + +fn has_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + /// Optimize the subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// @@ -246,7 +250,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { /// Projection: t2.id /// TableScan: t2 /// ``` -fn build_join( +fn build_join_top( query_info: &SubqueryInfo, left: &LogicalPlan, alias: &Arc, @@ -265,9 +269,55 @@ fn build_join( }) .map_or(Ok(None), |v| v.map(Some))?; + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); + build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) +} + +/// This is used to handle the case when the subquery is embedded in a more complex boolean +/// expression like and OR. For example +/// +/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)` +/// +/// The optimized plan will be: +/// +/// ```text +/// Projection: t1.id +/// Filter: t1.id < 0 OR __correlated_sq_1.mark +/// LeftMark Join: Filter: t1.id = __correlated_sq_1.id +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.id +/// TableScan: t2 +fn mark_join( + left: &LogicalPlan, + subquery: Arc, + in_predicate_opt: Option, + negated: bool, + alias_generator: &Arc, +) -> Result> { + let alias = alias_generator.next("__correlated_sq"); + + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark")); + let exists_expr = if negated { !exists_col } else { exists_col }; + + Ok( + build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)? + .map(|plan| (plan, exists_expr)), + ) +} +fn build_join( + left: &LogicalPlan, + subquery: &LogicalPlan, + in_predicate_opt: Option, + join_type: JoinType, + alias: String, +) -> Result> { let mut pull_up = PullUpCorrelatedExpr::new() .with_in_predicate_opt(in_predicate_opt.clone()) .with_exists_sub_query(in_predicate_opt.is_none()); @@ -278,7 +328,7 @@ fn build_join( } let sub_query_alias = LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.to_string())? + .alias(alias.to_string())? .build()?; let mut all_correlated_cols = BTreeSet::new(); pull_up @@ -287,10 +337,9 @@ fn build_join( .for_each(|cols| all_correlated_cols.extend(cols.clone())); // alias the join filter - let join_filter_opt = - conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, &subquery_alias) - .map(Option::Some) + let join_filter_opt = conjunction(pull_up.join_filters) + .map_or(Ok(None), |filter| { + replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some) })?; if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { @@ -302,7 +351,7 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate.and(join_filter)) } @@ -315,17 +364,13 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate) } _ => None, } { // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; @@ -361,6 +406,19 @@ impl SubqueryInfo { negated, } } + + pub fn expr(self) -> Expr { + match self.where_in_expr { + Some(expr) => match self.negated { + true => not_in_subquery(expr, self.query.subquery), + false => in_subquery(expr, self.query.subquery), + }, + None => match self.negated { + true => not_exists(self.query.subquery), + false => exists(self.query.subquery), + }, + } + } } #[cfg(test)] @@ -371,7 +429,7 @@ mod tests { use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::{and, binary_expr, col, lit, not, or, out_ref_col, table_scan}; + use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -442,60 +500,6 @@ mod tests { assert_optimized_plan_equal(plan, expected) } - /// Test for IN subquery with additional OR filter - /// filter expression not modified - #[test] - fn in_subquery_with_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(or( - and( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - ), - in_subquery(col("c"), test_subquery_with_name("sq")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) OR test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - - #[test] - fn in_subquery_with_and_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and( - or( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - in_subquery(col("b"), test_subquery_with_name("sq1")?), - ), - in_subquery(col("c"), test_subquery_with_name("sq2")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) OR test.b IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq1.c [c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq2.c [c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test for nested IN subqueries #[test] fn in_subquery_nested() -> Result<()> { @@ -512,51 +516,19 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } - /// Test for filter input modification in case filter not supported - /// Outer filter expression not modified while inner converted to join - #[test] - fn in_subquery_input_modified() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c"), test_subquery_with_name("sq_inner")?))? - .project(vec![col("b"), col("c")])? - .alias("wrapped")? - .filter(or( - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - in_subquery(col("c"), test_subquery_with_name("sq_outer")?), - ))? - .project(vec![col("b")])? - .build()?; - - let expected = "Projection: wrapped.b [b:UInt32]\ - \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq_outer.c [c:UInt32]\ - \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ - \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_inner.c [c:UInt32]\ - \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] @@ -630,13 +602,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; @@ -863,7 +835,7 @@ mod tests { .build()?; // Maybe okay if the table only has a single column? - let expected = "check_analyzed_plan\ + let expected = "Invalid (non-executable) plan after Analyzer\ \ncaused by\ \nError during planning: InSubquery should only return one column, but found 4"; assert_analyzer_check_err(vec![], plan, expected); @@ -958,7 +930,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ + let expected = "Invalid (non-executable) plan after Analyzer\ \ncaused by\ \nError during planning: InSubquery should only return one column"; assert_analyzer_check_err(vec![], plan, expected); @@ -1003,44 +975,6 @@ mod tests { Ok(()) } - /// Test for correlated IN subquery filter with disjustions - #[test] - fn in_subquery_disjunction() -> Result<()> { - let sq = Arc::new( - LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(col("orders.o_custkey")), - )? - .project(vec![col("orders.o_custkey")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter( - in_subquery(col("customer.c_custkey"), sq) - .or(col("customer.c_custkey").eq(lit(1))), - )? - .project(vec![col("customer.c_custkey")])? - .build()?; - - // TODO: support disjunction - for now expect unaltered plan - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey IN () OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - Subquery: [o_custkey:Int64] - Projection: orders.o_custkey [o_custkey:Int64] - Filter: outer_ref(customer.c_custkey) = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) - } - /// Test for correlated IN subquery filter #[test] fn in_subquery_correlated() -> Result<()> { @@ -1407,13 +1341,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_optimized_plan_equal(plan, expected) @@ -1659,7 +1593,7 @@ mod tests { assert_optimized_plan_equal(plan, expected) } - /// Test for correlated exists subquery filter with disjustions + /// Test for correlated exists subquery filter with disjunctions #[test] fn exists_subquery_disjunction() -> Result<()> { let sq = Arc::new( diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 550728ddd3f98..d35572e6d34a3 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,19 +16,18 @@ // under the License. //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. -use std::sync::Arc; - use crate::{OptimizerConfig, OptimizerRule}; +use std::sync::Arc; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ - CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, + Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; #[derive(Default, Debug)] pub struct EliminateCrossJoin; @@ -51,7 +50,7 @@ impl EliminateCrossJoin { /// Looks like this: /// ```text /// Filter(a.x = b.y AND b.xx = 100) -/// CrossJoin +/// Cross Join /// TableScan a /// TableScan b /// ``` @@ -80,6 +79,7 @@ impl OptimizerRule for EliminateCrossJoin { true } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, @@ -88,19 +88,20 @@ impl OptimizerRule for EliminateCrossJoin { let plan_schema = Arc::clone(plan.schema()); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; + let mut all_filters: Vec = vec![]; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten // avoid unwrapping the input - let rewriteable = matches!( + let rewritable = matches!( filter.input.as_ref(), LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) | LogicalPlan::CrossJoin(_) + }) ); - if !rewriteable { + if !rewritable { // recursively try to rewrite children return rewrite_children(self, LogicalPlan::Filter(filter), config); } @@ -116,6 +117,7 @@ impl OptimizerRule for EliminateCrossJoin { Arc::unwrap_or_clone(input), &mut possible_join_keys, &mut all_inputs, + &mut all_filters, )?; extract_possible_join_keys(&predicate, &mut possible_join_keys); @@ -130,7 +132,12 @@ impl OptimizerRule for EliminateCrossJoin { if !can_flatten_join_inputs(&plan) { return Ok(Transformed::no(plan)); } - flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; None } else { // recursively try to rewrite children @@ -158,6 +165,13 @@ impl OptimizerRule for EliminateCrossJoin { )); } + if !all_filters.is_empty() { + // Add any filters on top - PushDownFilter can push filters down to applicable join + let first = all_filters.swap_remove(0); + let predicate = all_filters.into_iter().fold(first, and); + left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + } + let Some(predicate) = parent_predicate else { return Ok(Transformed::yes(left)); }; @@ -206,37 +220,25 @@ fn flatten_join_inputs( plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, + all_filters: &mut Vec, ) -> Result<()> { match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // checked in can_flatten_join_inputs - if join.filter.is_some() { - return internal_err!( - "should not have filter in inner join in flatten_join_inputs" - ); + if let Some(filter) = join.filter { + all_filters.push(filter); } possible_join_keys.insert_all_owned(join.on); flatten_join_inputs( Arc::unwrap_or_clone(join.left), possible_join_keys, all_inputs, + all_filters, )?; flatten_join_inputs( Arc::unwrap_or_clone(join.right), possible_join_keys, all_inputs, - )?; - } - LogicalPlan::CrossJoin(join) => { - flatten_join_inputs( - Arc::unwrap_or_clone(join.left), - possible_join_keys, - all_inputs, - )?; - flatten_join_inputs( - Arc::unwrap_or_clone(join.right), - possible_join_keys, - all_inputs, + all_filters, )?; } _ => { @@ -253,30 +255,19 @@ fn flatten_join_inputs( fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - if join.filter.is_some() { - return false; - } - } - LogicalPlan::CrossJoin(_) => {} + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} _ => return false, }; for child in plan.inputs() { - match child { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !can_flatten_join_inputs(child) { - return false; - } + if let LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) = child + { + if !can_flatten_join_inputs(child) { + return false; } - // the child is not a join/cross join - _ => (), } } true @@ -351,10 +342,15 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::CrossJoin(CrossJoin { + Ok(LogicalPlan::Join(Join { left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, })) } @@ -462,12 +458,6 @@ mod tests { assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: LogicalPlan) { - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(!transformed_plan.transformed) - } - #[test] fn eliminate_cross_with_simple_and() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -513,7 +503,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -601,7 +591,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -627,7 +617,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -637,8 +627,7 @@ mod tests { } #[test] - /// See https://github.com/apache/datafusion/issues/7530 - fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> { + fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; let t3 = test_table_scan_with_name("t3")?; @@ -655,7 +644,17 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(plan); + let expected = vec![ + "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" + ]; + + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -843,7 +842,7 @@ mod tests { let expected = vec![ "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", @@ -924,7 +923,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -999,7 +998,7 @@ mod tests { "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -1237,10 +1236,10 @@ mod tests { .build()?; let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(plan, expected); @@ -1293,10 +1292,10 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(plan, expected); diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 554985667fdf9..4669500920956 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -23,8 +23,10 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; -use indexmap::IndexSet; use std::hash::{Hash, Hasher}; + +use indexmap::IndexSet; + /// Optimization rule that eliminate duplicated expr. #[derive(Default, Debug)] pub struct EliminateDuplicatedExpr; diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 13d03d647fe20..1213c8ffb3685 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -94,7 +94,7 @@ impl OptimizerRule for EliminateGroupByConstant { /// Checks if expression is constant, and can be eliminated from group by. /// /// Intended to be used only within this rule, helper function, which heavily -/// reiles on `SimplifyExpressions` result. +/// relies on `SimplifyExpressions` result. fn is_constant_expression(expr: &Expr) -> bool { match expr { Expr::Alias(e) => is_constant_expression(&e.expr), @@ -155,7 +155,11 @@ mod tests { fn return_type(&self, _args: &[DataType]) -> Result { Ok(DataType::Int32) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { unimplemented!() } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 9f4fe7bb110f7..cbdb568a7a83f 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -23,7 +23,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, - CrossJoin, Expr, + Expr, }; /// Eliminates joins when join condition is false. @@ -55,13 +55,6 @@ impl OptimizerRule for EliminateJoin { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match &join.filter { Some(Expr::Literal(scalar)) => match scalar.value() { - ScalarValue::Boolean(Some(true)) => { - Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { - left: join.left, - right: join.right, - schema: join.schema, - }))) - } ScalarValue::Boolean(Some(false)) => Ok(Transformed::yes( LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -108,21 +101,4 @@ mod tests { let expected = "EmptyRelation"; assert_optimized_plan_equal(plan, expected) } - - #[test] - fn join_on_true() -> Result<()> { - let plan = LogicalPlanBuilder::empty(false) - .join_on( - LogicalPlanBuilder::empty(false).build()?, - Inner, - Some(lit(true)), - )? - .build()?; - - let expected = "\ - CrossJoin:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) - } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 25304d4ccafaa..267615c3e0d93 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; +use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -57,14 +57,16 @@ impl OptimizerRule for EliminateLimit { &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result< - datafusion_common::tree_node::Transformed, - datafusion_common::DataFusionError, - > { + ) -> Result, datafusion_common::DataFusionError> { match plan { LogicalPlan::Limit(limit) => { - if let Some(fetch) = limit.fetch { - if fetch == 0 { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + if let Some(v) = fetch { + if v == 0 { return Ok(Transformed::yes(LogicalPlan::EmptyRelation( EmptyRelation { produce_one_row: false, @@ -72,11 +74,10 @@ impl OptimizerRule for EliminateLimit { }, ))); } - } else if limit.skip == 0 { - // input also can be Limit, so we should apply again. - return Ok(self - .rewrite(Arc::unwrap_or_clone(limit.input), _config) - .unwrap()); + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } Ok(Transformed::no(LogicalPlan::Limit(limit))) } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 66c7463c3d5d9..2e7a751ca4c57 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -264,6 +264,34 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + #[test] + fn one_side_unqualified() -> Result<()> { + let (t1, t2) = test_tables()?; + let plan_from_exprs = LogicalPlanBuilder::from(t1.clone()) + .join_with_expr_keys( + t2.clone(), + JoinType::Inner, + (vec![col("optional_id")], vec![col("t2.optional_id")]), + None, + )? + .build()?; + let plan_from_cols = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Inner, + (vec!["optional_id"], vec!["t2.optional_id"]), + None, + )? + .build()?; + let expected = "Inner Join: t1.optional_id = t2.optional_id\ + \n Filter: t1.optional_id IS NOT NULL\ + \n TableScan: t1\ + \n Filter: t2.optional_id IS NOT NULL\ + \n TableScan: t2"; + assert_optimized_plan_equal(plan_from_cols, expected)?; + assert_optimized_plan_equal(plan_from_exprs, expected) + } + fn build_plan( left_table: LogicalPlan, right_table: LogicalPlan, diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs index c0eec78b183db..0a97173b30966 100644 --- a/datafusion/optimizer/src/join_key_set.rs +++ b/datafusion/optimizer/src/join_key_set.rs @@ -148,7 +148,7 @@ impl<'a> ExprPair<'a> { } } -impl<'a> Equivalent<(Expr, Expr)> for ExprPair<'a> { +impl Equivalent<(Expr, Expr)> for ExprPair<'_> { fn equivalent(&self, other: &(Expr, Expr)) -> bool { self.0 == &other.0 && self.1 == &other.1 } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 3b1df3510d2a4..263770b81fcdc 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -51,7 +52,6 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; -pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b5d581f3919f2..b7dd391586a18 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -19,15 +19,14 @@ mod required_indices; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use std::collections::HashSet; +use std::sync::Arc; use datafusion_common::{ get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column, - JoinType, Result, + HashMap, JoinType, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::Unnest; @@ -36,10 +35,10 @@ use datafusion_expr::{ TableScan, Window, }; -use crate::optimize_projections::required_indices::RequiredIndicies; +use crate::optimize_projections::required_indices::RequiredIndices; use crate::utils::NamePreserver; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; /// Optimizer rule to prune unnecessary columns from intermediate schemas @@ -86,7 +85,7 @@ impl OptimizerRule for OptimizeProjections { config: &dyn OptimizerConfig, ) -> Result> { // All output fields are necessary: - let indices = RequiredIndicies::new_for_all_exprs(&plan); + let indices = RequiredIndices::new_for_all_exprs(&plan); optimize_projections(plan, config, indices) } } @@ -110,10 +109,11 @@ impl OptimizerRule for OptimizeProjections { /// columns. /// - `Ok(None)`: Signal that the given logical plan did not require any change. /// - `Err(error)`: An error occurred during the optimization process. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn optimize_projections( plan: LogicalPlan, config: &dyn OptimizerConfig, - indices: RequiredIndicies, + indices: RequiredIndices, ) -> Result> { // Recursively rewrite any nodes that may be able to avoid computation given // their parents' required indices. @@ -176,7 +176,7 @@ fn optimize_projections( let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); let schema = aggregate.input.schema(); let necessary_indices = - RequiredIndicies::new().with_exprs(schema, all_exprs_iter); + RequiredIndices::new().with_exprs(schema, all_exprs_iter); let necessary_exprs = necessary_indices.get_required_exprs(schema); return optimize_projections( @@ -274,7 +274,7 @@ fn optimize_projections( // For other plan node types, calculate indices for columns they use and // try to rewrite their children - let mut child_required_indices: Vec = match &plan { + let mut child_required_indices: Vec = match &plan { LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Repartition(_) @@ -295,7 +295,7 @@ fn optimize_projections( }) .collect::>()? } - LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + LogicalPlan::Limit(_) => { // Pass index requirements from the parent as well as column indices // that appear in this plan's expressions to its child. These operators // do not benefit from "small" inputs, so the projection_beneficial @@ -311,6 +311,7 @@ fn optimize_projections( | LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Subquery(_) + | LogicalPlan::Statement(_) | LogicalPlan::Distinct(Distinct::All(_)) => { // These plans require all their fields, and their children should // be treated as final plans -- otherwise, we may have schema a @@ -319,7 +320,7 @@ fn optimize_projections( // EXISTS expression), we may not need to require all indices. plan.inputs() .into_iter() - .map(RequiredIndicies::new_for_all_exprs) + .map(RequiredIndices::new_for_all_exprs) .collect() } LogicalPlan::Extension(extension) => { @@ -339,14 +340,13 @@ fn optimize_projections( .into_iter() .zip(necessary_children_indices) .map(|(child, necessary_indices)| { - RequiredIndicies::new_from_indices(necessary_indices) + RequiredIndices::new_from_indices(necessary_indices) .with_plan_exprs(&plan, child.schema()) }) .collect::>>()? } LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. @@ -367,17 +367,6 @@ fn optimize_projections( right_indices.with_projection_beneficial(), ] } - LogicalPlan::CrossJoin(cross_join) => { - let left_len = cross_join.left.schema().fields().len(); - let (left_indices, right_indices) = - split_join_requirements(left_len, indices, &JoinType::Inner); - // Joins benefit from "small" input tables (lower memory usage). - // Therefore, each child benefits from projection: - vec![ - left_indices.with_projection_beneficial(), - right_indices.with_projection_beneficial(), - ] - } // these nodes are explicitly rewritten in the match statement above LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) @@ -390,7 +379,7 @@ fn optimize_projections( LogicalPlan::Unnest(Unnest { dependency_indices, .. }) => { - vec![RequiredIndicies::new_from_indices( + vec![RequiredIndices::new_from_indices( dependency_indices.clone(), )] } @@ -454,7 +443,7 @@ fn optimize_projections( /// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the /// merged projection. /// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). -/// - `Err(error)`: An error occured during the function call. +/// - `Err(error)`: An error occurred during the function call. fn merge_consecutive_projections(proj: Projection) -> Result> { let Projection { expr, @@ -494,7 +483,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result( /// adjusted based on the join type. fn split_join_requirements( left_len: usize, - indices: RequiredIndicies, + indices: RequiredIndices, join_type: &JoinType, -) -> (RequiredIndicies, RequiredIndicies) { +) -> (RequiredIndices, RequiredIndices) { match join_type { // In these cases requirements are split between left/right children: - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) } // All requirements can be re-routed to left child directly. - JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndicies::new()), + JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()), // All requirements can be re-routed to right side directly. // No need to change index, join schema is right child schema. - JoinType::RightSemi | JoinType::RightAnti => (RequiredIndicies::new(), indices), + JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices), } } @@ -748,18 +741,18 @@ fn add_projection_on_top_if_helpful( /// /// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection /// - `Ok(None)`: No rewrite necessary. -/// - `Err(error)`: An error occured during the function call. +/// - `Err(error)`: An error occurred during the function call. fn rewrite_projection_given_requirements( proj: Projection, config: &dyn OptimizerConfig, - indices: &RequiredIndicies, + indices: &RequiredIndices, ) -> Result> { let Projection { expr, input, .. } = proj; let exprs_used = indices.get_at_indices(&expr); let required_indices = - RequiredIndicies::new().with_exprs(input.schema(), exprs_used.iter()); + RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter()); // rewrite the children projection, and if they are changed rewrite the // projection down diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index 60d8ef1a8e6ce..c1e0885c9b5f2 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`RequiredIndicies`] helper for OptimizeProjection +//! [`RequiredIndices`] helper for OptimizeProjection use crate::optimize_projections::outer_columns; use datafusion_common::tree_node::TreeNodeRecursion; @@ -33,9 +33,9 @@ use datafusion_expr::{Expr, LogicalPlan}; /// /// Indices are always in order and without duplicates. For example, if these /// indices were added `[3, 2, 4, 3, 6, 1]`, the instance would be represented -/// by `[1, 2, 3, 6]`. +/// by `[1, 2, 3, 4, 6]`. #[derive(Debug, Clone, Default)] -pub(super) struct RequiredIndicies { +pub(super) struct RequiredIndices { /// The indices of the required columns in the indices: Vec, /// If putting a projection above children is beneficial for the parent. @@ -43,7 +43,7 @@ pub(super) struct RequiredIndicies { projection_beneficial: bool, } -impl RequiredIndicies { +impl RequiredIndices { /// Create a new, empty instance pub fn new() -> Self { Self::default() diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 08dcefa22f08a..49bce3c1ce82c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,19 +17,19 @@ //! [`Optimizer`] and [`OptimizerRule`] -use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; use chrono::{DateTime, Utc}; use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::{assert_expected_schema, InvariantLevel}; use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::{internal_err, DFSchema, DataFusionError, Result}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result}; use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; @@ -51,7 +51,6 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; -use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -70,7 +69,6 @@ use crate::utils::log_plan; /// /// [`AnalyzerRule`]: crate::analyzer::AnalyzerRule /// [`SessionState::add_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_optimizer_rule - pub trait OptimizerRule: Debug { /// Try and rewrite `plan` to an optimized form, returning None if the plan /// cannot be optimized by this rule. @@ -251,11 +249,6 @@ impl Optimizer { Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), - // simplify expressions does not simplify expressions in subqueries, so we - // run it again after running the optimizations that potentially converted - // subqueries to joins - Arc::new(SimplifyExpressions::new()), - Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), @@ -309,7 +302,7 @@ impl<'a> Rewriter<'a> { } } -impl<'a> TreeNodeRewriter for Rewriter<'a> { +impl TreeNodeRewriter for Rewriter<'_> { type Node = LogicalPlan; fn f_down(&mut self, node: LogicalPlan) -> Result> { @@ -363,6 +356,10 @@ impl Optimizer { where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { + // verify LP is valid, before the first LP optimizer pass. + plan.check_invariants(InvariantLevel::Executable) + .map_err(|e| e.context("Invalid input plan before LP Optimizers"))?; + let start_time = Instant::now(); let options = config.options(); let mut new_plan = plan; @@ -370,6 +367,8 @@ impl Optimizer { let mut previous_plans = HashSet::with_capacity(16); previous_plans.insert(LogicalPlanSignature::new(&new_plan)); + let starting_schema = Arc::clone(new_plan.schema()); + let mut i = 0; while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); @@ -386,17 +385,22 @@ impl Optimizer { let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( - apply_order, - rule.as_ref(), - config, - )), + Some(apply_order) => new_plan.rewrite_with_subqueries( + &mut Rewriter::new(apply_order, rule.as_ref(), config), + ), // rule handles recursion itself None => optimize_plan_node(new_plan, rule.as_ref(), config), } - // verify the rule didn't change the schema .and_then(|tnr| { - assert_schema_is_the_same(rule.name(), &starting_schema, &tnr.data)?; + // run checks optimizer invariant checks, per optimizer rule applied + assert_valid_optimization(&tnr.data, &starting_schema) + .map_err(|e| e.context(format!("Check optimizer-specific invariants after optimizer rule: {}", rule.name())))?; + + // run LP invariant checks only in debug mode for performance reasons + #[cfg(debug_assertions)] + tnr.data.check_invariants(InvariantLevel::Executable) + .map_err(|e| e.context(format!("Invalid (non-executable) plan after Optimizer rule: {}", rule.name())))?; + Ok(tnr) }); @@ -455,35 +459,38 @@ impl Optimizer { } i += 1; } + + // verify that the optimizer passes only mutated what was permitted. + assert_valid_optimization(&new_plan, &starting_schema).map_err(|e| { + e.context("Check optimizer-specific invariants after all passes") + })?; + + // verify LP is valid, after the last optimizer pass. + new_plan + .check_invariants(InvariantLevel::Executable) + .map_err(|e| { + e.context("Invalid (non-executable) plan after LP Optimizers") + })?; + log_plan("Final optimized plan", &new_plan); debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); Ok(new_plan) } } -/// Returns an error if `new_plan`'s schema is different than `prev_schema` +/// These are invariants which should hold true before and after [`LogicalPlan`] optimization. /// -/// It ignores metadata and nullability. -pub(crate) fn assert_schema_is_the_same( - rule_name: &str, - prev_schema: &DFSchema, - new_plan: &LogicalPlan, +/// This differs from [`LogicalPlan::check_invariants`], which addresses if a singular +/// LogicalPlan is valid. Instead this address if the optimization was valid based upon permitted changes. +fn assert_valid_optimization( + plan: &LogicalPlan, + prev_schema: &Arc, ) -> Result<()> { - let equivalent = new_plan.schema().equivalent_names_and_types(prev_schema); + // verify invariant: optimizer passes should not change the schema + // Refer to + assert_expected_schema(prev_schema, plan)?; - if !equivalent { - let e = DataFusionError::Internal(format!( - "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_schema, - new_plan.schema() - )); - Err(DataFusionError::Context( - String::from(rule_name), - Box::new(e), - )) - } else { - Ok(()) - } + Ok(()) } #[cfg(test)] @@ -537,9 +544,11 @@ mod tests { schema: Arc::new(DFSchema::empty()), }); let err = opt.optimize(plan, &config, &observe).unwrap_err(); - assert_eq!( + assert!(err.strip_backtrace().starts_with( "Optimizer rule 'get table_scan rule' failed\n\ - caused by\nget table_scan rule\ncaused by\n\ + caused by\n\ + Check optimizer-specific invariants after optimizer rule: get table_scan rule\n\ + caused by\n\ Internal error: Failed due to a difference in schemas, \ original schema: DFSchema { inner: Schema { \ fields: [], \ @@ -555,10 +564,8 @@ mod tests { ], \ metadata: {} }, \ field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ - functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ - This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", - err.strip_backtrace() - ); + functional_dependencies: FunctionalDependencies { deps: [] } }", + )); } #[test] diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index b5e1077ee5bea..d26df073dc6fd 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -72,19 +72,6 @@ impl OptimizerRule for PropagateEmptyRelation { } Ok(Transformed::no(plan)) } - LogicalPlan::CrossJoin(ref join) => { - let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; - if left_empty || right_empty { - return Ok(Transformed::yes(LogicalPlan::EmptyRelation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(plan.schema()), - }, - ))); - } - Ok(Transformed::no(LogicalPlan::CrossJoin(join.clone()))) - } - LogicalPlan::Join(ref join) => { // TODO: For Join, more join type need to be careful: // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index cdca86505dfaa..731fb27d28004 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1,3 +1,6 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance @@ -14,33 +17,30 @@ //! [`PushDownFilter`] applies filters as early as possible -use indexmap::IndexSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use indexmap::IndexSet; use itertools::Itertools; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef, - JoinConstraint, Result, + internal_err, plan_err, qualified_name, Column, DFSchema, Result, }; +use datafusion_expr::expr::WindowFunction; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{ - CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, -}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; use datafusion_expr::{ - and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - Projection, TableProviderFilterPushDown, + and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, }; use crate::optimizer::ApplyOrder; -use crate::utils::has_all_column_refs; +use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; /// Optimizer rule for pushing (moving) filter expressions down in a plan so @@ -165,7 +165,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::Full => (false, false), // No columns from the right side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => (true, false), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), // No columns from the left side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. JoinType::RightSemi | JoinType::RightAnti => (false, true), @@ -190,6 +190,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftSemi | JoinType::RightSemi => (true, true), JoinType::LeftAnti => (false, true), JoinType::RightAnti => (true, false), + JoinType::LeftMark => (false, true), } } @@ -562,10 +563,6 @@ fn infer_join_predicates( predicates: &[Expr], on_filters: &[Expr], ) -> Result> { - if join.join_type != JoinType::Inner { - return Ok(vec![]); - } - // Only allow both side key is column. let join_col_keys = join .on @@ -577,55 +574,178 @@ fn infer_join_predicates( }) .collect::>(); - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - - let columns = predicate.column_refs(); - - for &col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == *l { - join_cols_to_replace.insert(col, *r); - break; - } else if col == *r { - join_cols_to_replace.insert(col, *l); - break; - } - } - } + let join_type = join.join_type; - if join_cols_to_replace.is_empty() { - return None; - } + let mut inferred_predicates = InferredPredicates::new(join_type); - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; + infer_join_predicates_from_predicates( + &join_col_keys, + predicates, + &mut inferred_predicates, + )?; - Some(Ok(join_side_predicate)) - }) - .collect::>>() + infer_join_predicates_from_on_filters( + &join_col_keys, + join_type, + on_filters, + &mut inferred_predicates, + )?; + + Ok(inferred_predicates.predicates) +} + +/// Inferred predicates collector. +/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly +/// filter out NULL, otherwise ignore it. e.g. +/// ```text +/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL; +/// ``` +/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to +/// the left side, resulting in the wrong result. +struct InferredPredicates { + predicates: Vec, + is_inner_join: bool, +} + +impl InferredPredicates { + fn new(join_type: JoinType) -> Self { + Self { + predicates: vec![], + is_inner_join: matches!(join_type, JoinType::Inner), + } + } + + fn try_build_predicate( + &mut self, + predicate: Expr, + replace_map: &HashMap<&Column, &Column>, + ) -> Result<()> { + if self.is_inner_join + || matches!( + is_restrict_null_predicate( + predicate.clone(), + replace_map.keys().cloned() + ), + Ok(true) + ) + { + self.predicates.push(replace_col(predicate, replace_map)?); + } + + Ok(()) + } +} + +/// Infer predicates from the pushed down predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `predicates` the pushed down predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_predicates( + join_col_keys: &[(&Column, &Column)], + predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + infer_join_predicates_impl::( + join_col_keys, + predicates, + inferred_predicates, + ) +} + +/// Infer predicates from the join filter. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `join_type` the JoinType of Join +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_on_filters( + join_col_keys: &[(&Column, &Column)], + join_type: JoinType, + on_filters: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + match join_type { + JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()), + JoinType::Inner => infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ), + JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } + JoinType::Right | JoinType::RightSemi => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } + } +} + +/// Infer predicates from the given predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `input_predicates` the given predicates. It can be the pushed down predicates, +/// or it can be the filters of the Join +/// +/// * `inferred_predicates` the inferred results +/// +/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can +/// be inferred from the left table related predicate +/// +/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can +/// be inferred from the right table related predicate +/// +fn infer_join_predicates_impl< + const ENABLE_LEFT_TO_RIGHT: bool, + const ENABLE_RIGHT_TO_LEFT: bool, +>( + join_col_keys: &[(&Column, &Column)], + input_predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + for predicate in input_predicates { + let mut join_cols_to_replace = HashMap::new(); + + for &col in &predicate.column_refs() { + for (l, r) in join_col_keys.iter() { + if ENABLE_LEFT_TO_RIGHT && col == *l { + join_cols_to_replace.insert(col, *r); + break; + } + if ENABLE_RIGHT_TO_LEFT && col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } + if join_cols_to_replace.is_empty() { + continue; + } + + inferred_predicates + .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; + } + Ok(()) } impl OptimizerRule for PushDownFilter { @@ -866,31 +986,127 @@ impl OptimizerRule for PushDownFilter { } }) } - LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::CrossJoin(cross_join) => { + // Tries to push filters based on the partition key(s) of the window function(s) used. + // Example: + // Before: + // Filter: (a > 1) and (b > 1) and (c > 1) + // Window: func() PARTITION BY [a] ... + // --- + // After: + // Filter: (b > 1) and (c > 1) + // Window: func() PARTITION BY [a] ... + // Filter: (a > 1) + LogicalPlan::Window(window) => { + // Retrieve the set of potential partition keys where we can push filters by. + // Unlike aggregations, where there is only one statement per SELECT, there can be + // multiple window functions, each with potentially different partition keys. + // Therefore, we need to ensure that any potential partition key returned is used in + // ALL window functions. Otherwise, filters cannot be pushed by through that column. + let extract_partition_keys = |func: &WindowFunction| { + func.partition_by + .iter() + .map(|c| Column::from_qualified_name(c.schema_name().to_string())) + .collect::>() + }; + let potential_partition_keys = window + .window_expr + .iter() + .map(|e| { + match e { + Expr::WindowFunction(window_func) => { + extract_partition_keys(window_func) + } + Expr::Alias(alias) => { + if let Expr::WindowFunction(window_func) = + alias.expr.as_ref() + { + extract_partition_keys(window_func) + } else { + // window functions expressions are only Expr::WindowFunction + unreachable!() + } + } + _ => { + // window functions expressions are only Expr::WindowFunction + unreachable!() + } + } + }) + // performs the set intersection of the partition keys of all window functions, + // returning only the common ones + .reduce(|a, b| &a & &b) + .unwrap_or_default(); + let predicates = split_conjunction_owned(filter.predicate); - let join = convert_cross_join_to_inner_join(cross_join)?; - let plan = push_down_all_join(predicates, vec![], join, vec![])?; - convert_to_cross_join_if_beneficial(plan.data) + let mut keep_predicates = vec![]; + let mut push_predicates = vec![]; + for expr in predicates { + let cols = expr.column_refs(); + if cols.iter().all(|c| potential_partition_keys.contains(c)) { + push_predicates.push(expr); + } else { + keep_predicates.push(expr); + } + } + + // Unlike with aggregations, there are no cases where we have to replace, e.g., + // `a+b` with Column(a)+Column(b). This is because partition expressions are not + // available as standalone columns to the user. For example, while an aggregation on + // `a+b` becomes Column(a + b), in a window partition it becomes + // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in + // place, so we can use `push_predicates` directly. This is consistent with other + // optimizers, such as the one used by Postgres. + + let window_input = Arc::clone(&window.input); + Transformed::yes(LogicalPlan::Window(window)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the window + if let Some(predicate) = conjunction(push_predicates) { + let new_filter = make_filter(predicate, window_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) } + LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); - let results = scan + + let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + filter_predicates + .into_iter() + .partition(|pred| pred.is_volatile()); + + // Check which non-volatile filters are supported by source + let supported_filters = scan .source - .supports_filters_pushdown(filter_predicates.as_slice())?; - if filter_predicates.len() != results.len() { + .supports_filters_pushdown(non_volatile_filters.as_slice())?; + if non_volatile_filters.len() != supported_filters.len() { return internal_err!( "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}", - results.len(), - filter_predicates.len()); + supported_filters.len(), + non_volatile_filters.len()); } - let zip = filter_predicates.into_iter().zip(results); + // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type + let zip = non_volatile_filters.into_iter().zip(supported_filters); let new_scan_filters = zip .clone() .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported) .map(|(pred, _)| pred); + + // Add new scan filters let new_scan_filters: Vec = scan .filters .iter() @@ -898,9 +1114,13 @@ impl OptimizerRule for PushDownFilter { .unique() .cloned() .collect(); + + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters let new_predicate: Vec = zip .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) - .map(|(pred, _)| pred.clone()) + .map(|(pred, _)| pred) + .chain(volatile_filters) + .cloned() .collect(); let new_scan = LogicalPlan::TableScan(TableScan { @@ -1035,7 +1255,7 @@ fn rewrite_projection( (qualified_name(qualifier, field.name()), expr) }) - .partition(|(_, value)| value.is_volatile().unwrap_or(true)); + .partition(|(_, value)| value.is_volatile()); let mut push_predicates = vec![]; let mut keep_predicates = vec![]; @@ -1114,48 +1334,6 @@ impl PushDownFilter { } } -/// Converts the given cross join to an inner join with an empty equality -/// predicate and an empty filter condition. -fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { - let CrossJoin { left, right, .. } = cross_join; - let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - Ok(Join { - left, - right, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - on: vec![], - filter: None, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }) -} - -/// Converts the given inner join with an empty equality predicate and an -/// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial( - plan: LogicalPlan, -) -> Result> { - match plan { - // Can be converted back to cross join - LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => { - LogicalPlanBuilder::from(Arc::unwrap_or_clone(join.left)) - .cross_join(Arc::unwrap_or_clone(join.right))? - .build() - .map(Transformed::yes) - } - LogicalPlan::Filter(filter) => { - convert_to_cross_join_if_beneficial(Arc::unwrap_or_clone(filter.input))? - .transform_data(|child_plan| { - Filter::try_new(filter.predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) - .map(Transformed::yes) - }) - } - plan => Ok(Transformed::no(plan)), - } -} - /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -1203,17 +1381,17 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::ScalarValue; - use datafusion_expr::expr::ScalarFunction; + use datafusion_common::{DFSchemaRef, ScalarValue}; + use datafusion_expr::expr::{ScalarFunction, WindowFunction}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, ColumnarValue, Extension, ScalarUDF, - ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore, - Volatility, + col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension, + LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, + UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, }; use crate::optimizer::Optimizer; - use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; + use crate::simplify_expressions::SimplifyExpressions; use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; @@ -1235,7 +1413,7 @@ mod tests { expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ - Arc::new(RewriteDisjunctivePredicate::new()), + Arc::new(SimplifyExpressions::new()), Arc::new(PushDownFilter::new()), ]); let optimized_plan = @@ -1357,6 +1535,227 @@ mod tests { assert_optimized_plan_eq(plan, expected) } + /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed + #[test] + fn filter_move_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a"), col("b")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(col("b").gt(lit(10i64)))? + .build()?; + + let expected = "\ + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.b > Int64(10)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and + /// 'b' are pushed + #[test] + fn filter_move_complex_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a"), col("b")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? + .build()?; + + let expected = "\ + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed + #[test] + fn filter_move_partial_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? + .build()?; + + let expected = "\ + Filter: test.b = Int64(1)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.a > Int64(10)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that filters on partition expressions are not pushed, as the single expression + /// column is not available to the user, unlike with aggregations + #[test] + fn filter_expression_keep_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + // unlike with aggregations, single partition column "test.a + test.b" is not available + // to the plan, so we use multiple columns when filtering + .filter(add(col("a"), col("b")).gt(lit(10i64)))? + .build()?; + + let expected = "\ + Filter: test.a + test.b > Int64(10)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that filters are not pushed on order by columns (that are not used in partitioning) + #[test] + fn filter_order_keep_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let window = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(col("c").gt(lit(10i64)))? + .build()?; + + let expected = "\ + Filter: test.c > Int64(10)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when we use multiple window functions with a common partition key, the filter + /// on that key is pushed + #[test] + fn filter_multiple_windows_common_partitions() -> Result<()> { + let table_scan = test_table_scan()?; + + let window1 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let window2 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("b"), col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window1, window2])? + .filter(col("a").gt(lit(10i64)))? // a appears in both window functions + .build()?; + + let expected = "\ + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test, full_filters=[test.a > Int64(10)]"; + assert_optimized_plan_eq(plan, expected) + } + + /// verifies that when we use multiple window functions with different partitions keys, the + /// filter cannot be pushed + #[test] + fn filter_multiple_windows_disjoint_partitions() -> Result<()> { + let table_scan = test_table_scan()?; + + let window1 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let window2 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("b"), col("a")]) + .order_by(vec![col("c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window1, window2])? + .filter(col("b").gt(lit(10i64)))? // b only appears in one window function + .build()?; + + let expected = "\ + Filter: test.b > Int64(10)\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { @@ -1727,7 +2126,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.d\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.d, test1.e, test1.f\ @@ -1754,7 +2153,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.a\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ @@ -2044,7 +2443,7 @@ mod tests { let expected = "\ Filter: test2.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; assert_optimized_plan_eq(plan, expected) @@ -2084,7 +2483,7 @@ mod tests { \n Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(plan, expected) } @@ -2439,28 +2838,36 @@ mod tests { .collect()) } - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } } - fn table_scan_with_pushdown_provider( + fn table_scan_with_pushdown_provider_builder( filter_support: TableProviderFilterPushDown, - ) -> Result { + filters: Vec, + projection: Option>, + ) -> Result { let test_provider = PushDownProvider { filter_support }; let table_scan = LogicalPlan::TableScan(TableScan { table_name: "test".into(), - filters: vec![], + filters, projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), )?), - projection: None, + projection, source: Arc::new(test_provider), fetch: None, }); - LogicalPlanBuilder::from(table_scan) + Ok(LogicalPlanBuilder::from(table_scan)) + } + + fn table_scan_with_pushdown_provider( + filter_support: TableProviderFilterPushDown, + ) -> Result { + table_scan_with_pushdown_provider_builder(filter_support, vec![], None)? .filter(col("a").eq(lit(1i64)))? .build() } @@ -2517,25 +2924,14 @@ mod tests { #[test] fn multi_combined_filter() -> Result<()> { - let test_provider = PushDownProvider { - filter_support: TableProviderFilterPushDown::Inexact, - }; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".into(), - filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], - projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), - )?), - projection: Some(vec![0]), - source: Arc::new(test_provider), - fetch: None, - }); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? - .project(vec![col("a"), col("b")])? - .build()?; + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Inexact, + vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], + Some(vec![0]), + )? + .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? + .project(vec![col("a"), col("b")])? + .build()?; let expected = "Projection: a, b\ \n Filter: a = Int64(10) AND b > Int64(11)\ @@ -2546,25 +2942,14 @@ mod tests { #[test] fn multi_combined_filter_exact() -> Result<()> { - let test_provider = PushDownProvider { - filter_support: TableProviderFilterPushDown::Exact, - }; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".into(), - filters: vec![], - projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), - )?), - projection: Some(vec![0]), - source: Arc::new(test_provider), - fetch: None, - }); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? - .project(vec![col("a"), col("b")])? - .build()?; + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Exact, + vec![], + Some(vec![0]), + )? + .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? + .project(vec![col("a"), col("b")])? + .build()?; let expected = r#" Projection: a, b @@ -2866,6 +3251,46 @@ Projection: a, b assert_optimized_plan_eq(optimized_plan, expected) } + #[test] + fn left_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2" + ); + + // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. + let expected = "\ + Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2907,6 +3332,46 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. + let expected = "\ + Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2948,6 +3413,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn left_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For left anti, filter of the right side filter can be pushed down. + let expected = "\ + Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; @@ -2994,6 +3504,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For right anti, filter of the left side can be pushed down. + let expected = "\ + Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; @@ -3060,7 +3615,11 @@ Projection: a, b Ok(DataType::Int32) } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + _args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { Ok(ColumnarValue::from(ScalarValue::from(1))) } } @@ -3144,4 +3703,87 @@ Projection: a, b \n TableScan: test2"; assert_optimized_plan_eq(plan, expected) } + + #[test] + fn test_push_down_volatile_table_scan() -> Result<()> { + // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1; + let table_scan = test_table_scan()?; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .filter(expr.gt(lit(0.1)))? + .build()?; + + let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\ + \n Projection: test.a, test.b\ + \n TableScan: test"; + assert_eq!(format!("{plan}"), expected_before); + + let expected_after = "Projection: test.a, test.b\ + \n Filter: TestScalarUDF() > Float64(0.1)\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected_after) + } + + #[test] + fn test_push_down_volatile_mixed_table_scan() -> Result<()> { + // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10; + let table_scan = test_table_scan()?; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .filter( + expr.gt(lit(0.1)) + .and(col("t.a").gt(lit(5))) + .and(col("t.b").gt(lit(10))), + )? + .build()?; + + let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ + \n Projection: test.a, test.b\ + \n TableScan: test"; + assert_eq!(format!("{plan}"), expected_before); + + let expected_after = "Projection: test.a, test.b\ + \n Filter: TestScalarUDF() > Float64(0.1)\ + \n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]"; + assert_optimized_plan_eq(plan, expected_after) + } + + #[test] + fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> { + // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Unsupported, + vec![], + None, + )? + .project(vec![col("a"), col("b")])? + .filter( + expr.gt(lit(0.1)) + .and(col("t.a").gt(lit(5))) + .and(col("t.b").gt(lit(10))), + )? + .build()?; + + let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ + \n Projection: a, b\ + \n TableScan: test"; + assert_eq!(format!("{plan}"), expected_before); + + let expected_after = "Projection: a, b\ + \n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected_after) + } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 8b5e483001b32..8a3aa4bb84599 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -27,6 +27,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; use datafusion_common::Result; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{lit, FetchType, SkipType}; /// Optimization rule that tries to push down `LIMIT`. /// @@ -56,16 +57,27 @@ impl OptimizerRule for PushDownLimit { return Ok(Transformed::no(plan)); }; - let Limit { skip, fetch, input } = limit; + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = input.as_ref() { - let (skip, fetch) = - combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); - + if let LogicalPlan::Limit(child) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); let plan = LogicalPlan::Limit(Limit { - skip, - fetch, + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), input: Arc::clone(&child.input), }); @@ -75,14 +87,10 @@ impl OptimizerRule for PushDownLimit { // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch, - input, - }))); + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - match Arc::unwrap_or_clone(input) { + match Arc::unwrap_or_clone(limit.input) { LogicalPlan::TableScan(mut scan) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan @@ -110,13 +118,6 @@ impl OptimizerRule for PushDownLimit { transformed_limit(skip, fetch, LogicalPlan::Union(union)) } - LogicalPlan::CrossJoin(mut cross_join) => { - // push limit to both inputs - cross_join.left = make_arc_limit(0, fetch + skip, cross_join.left); - cross_join.right = make_arc_limit(0, fetch + skip, cross_join.right); - transformed_limit(skip, fetch, LogicalPlan::CrossJoin(cross_join)) - } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) .update_data(|join| { make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) @@ -162,8 +163,8 @@ impl OptimizerRule for PushDownLimit { .into_iter() .map(|child| { LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), input: Arc::new(child.clone()), }) }) @@ -203,8 +204,8 @@ impl OptimizerRule for PushDownLimit { /// ``` fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), + skip: Some(Box::new(lit(skip as i64))), + fetch: Some(Box::new(lit(fetch as i64))), input, }) } @@ -224,11 +225,7 @@ fn original_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::no(make_limit(skip, fetch, Arc::new(input)))) } /// Returns the a transformed limit @@ -237,11 +234,7 @@ fn transformed_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::yes(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) } /// Adds a limit to the inputs of a join, if possible @@ -254,15 +247,15 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { - Left | Right | Full => (Some(limit), Some(limit)), - LeftAnti | LeftSemi => (Some(limit), None), + Left | Right | Full | Inner => (Some(limit), Some(limit)), + LeftAnti | LeftSemi | LeftMark => (Some(limit), None), RightAnti | RightSemi => (None, Some(limit)), - Inner => (None, None), } } else { match join.join_type { Left => (Some(limit), None), Right => (None, Some(limit)), + Full => (Some(limit), Some(limit)), _ => (None, None), } }; @@ -1115,7 +1108,7 @@ mod test { .build()?; let expected = "Limit: skip=0, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000\ \n Limit: skip=0, fetch=1000\ @@ -1135,7 +1128,7 @@ mod test { .build()?; let expected = "Limit: skip=1000, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=2000\ \n TableScan: test, fetch=2000\ \n Limit: skip=0, fetch=2000\ diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index c026130c426f4..48b2828faf452 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -16,8 +16,10 @@ // under the License. //! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` + use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; +use std::sync::Arc; use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; @@ -52,8 +54,6 @@ use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// ) /// ORDER BY a DESC /// ``` - -/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] #[derive(Default, Debug)] pub struct ReplaceDistinctWithAggregate {} @@ -110,7 +110,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf: std::sync::Arc = + let first_value_udaf: Arc = config.function_registry().unwrap().udaf("first_value")?; let aggr_expr = select_expr.into_iter().map(|e| { if let Some(order_by) = &sort_expr { diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs deleted file mode 100644 index a6b633fdb8fe6..0000000000000 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`RewriteDisjunctivePredicate`] rewrites predicates to reduce redundancy - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; -use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::logical_plan::Filter; -use datafusion_expr::{Expr, LogicalPlan, Operator}; - -/// Optimizer pass that rewrites predicates of the form -/// -/// ```text -/// (A = B AND ) OR (A = B AND ) OR ... (A = B AND ) -/// ``` -/// -/// Into -/// ```text -/// (A = B) AND ( OR OR ... ) -/// ``` -/// -/// Predicates connected by `OR` typically not able to be broken down -/// and distributed as well as those connected by `AND`. -/// -/// The idea is to rewrite predicates into `good_predicate1 AND -/// good_predicate2 AND ...` where `good_predicate` means the -/// predicate has special support in the execution engine. -/// -/// Equality join predicates (e.g. `col1 = col2`), or single column -/// expressions (e.g. `col = 5`) are examples of predicates with -/// special support. -/// -/// # TPCH Q19 -/// -/// This optimization is admittedly somewhat of a niche usecase. It's -/// main use is that it appears in TPCH Q19 and is required to avoid a -/// CROSS JOIN. -/// -/// Specifically, Q19 has a WHERE clause that looks like -/// -/// ```sql -/// where -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// or -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// or -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// ) -/// ``` -/// -/// Naively planning this query will result in a CROSS join with that -/// single large OR filter. However, rewriting it using the rewrite in -/// this pass results in a proper join predicate, `p_partkey = l_partkey`: -/// -/// ```sql -/// where -/// p_partkey = l_partkey -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// and ( -/// ( -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// ) -/// ) -/// ``` -/// -#[derive(Default, Debug)] -pub struct RewriteDisjunctivePredicate; - -impl RewriteDisjunctivePredicate { - pub fn new() -> Self { - Self - } -} - -impl OptimizerRule for RewriteDisjunctivePredicate { - fn name(&self) -> &str { - "rewrite_disjunctive_predicate" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = predicate(filter.predicate)?; - let rewritten_predicate = rewrite_predicate(predicate); - let rewritten_expr = normalize_predicate(rewritten_predicate); - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - rewritten_expr, - filter.input, - )?))) - } - _ => Ok(Transformed::no(plan)), - } - } -} - -#[derive(Clone, PartialEq, Debug)] -enum Predicate { - And { args: Vec }, - Or { args: Vec }, - Other { expr: Box }, -} - -fn predicate(expr: Expr) -> Result { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::And => { - let args = vec![predicate(*left)?, predicate(*right)?]; - Ok(Predicate::And { args }) - } - Operator::Or => { - let args = vec![predicate(*left)?, predicate(*right)?]; - Ok(Predicate::Or { args }) - } - _ => Ok(Predicate::Other { - expr: Box::new(Expr::BinaryExpr(BinaryExpr::new(left, op, right))), - }), - }, - _ => Ok(Predicate::Other { - expr: Box::new(expr), - }), - } -} - -fn normalize_predicate(predicate: Predicate) -> Expr { - match predicate { - Predicate::And { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::and) - .expect("had more than one arg") - } - Predicate::Or { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::or) - .expect("had more than one arg") - } - Predicate::Other { expr } => *expr, - } -} - -fn rewrite_predicate(predicate: Predicate) -> Predicate { - match predicate { - Predicate::And { args } => { - let mut rewritten_args = Vec::with_capacity(args.len()); - for arg in args.into_iter() { - rewritten_args.push(rewrite_predicate(arg)); - } - rewritten_args = flatten_and_predicates(rewritten_args); - Predicate::And { - args: rewritten_args, - } - } - Predicate::Or { args } => { - let mut rewritten_args = vec![]; - for arg in args.into_iter() { - rewritten_args.push(rewrite_predicate(arg)); - } - rewritten_args = flatten_or_predicates(rewritten_args); - delete_duplicate_predicates(rewritten_args) - } - Predicate::Other { expr } => Predicate::Other { expr }, - } -} - -fn flatten_and_predicates( - and_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in and_predicates { - match predicate { - Predicate::And { args } => { - flattened_predicates.append(&mut flatten_and_predicates(args)); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn flatten_or_predicates( - or_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in or_predicates { - match predicate { - Predicate::Or { args } => { - flattened_predicates.append(&mut flatten_or_predicates(args)); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn delete_duplicate_predicates(or_predicates: Vec) -> Predicate { - let mut shortest_exprs: Vec = vec![]; - let mut shortest_exprs_len = 0; - // choose the shortest AND predicate - for or_predicate in or_predicates.iter() { - match or_predicate { - Predicate::And { args } => { - let args_num = args.len(); - if shortest_exprs.is_empty() || args_num < shortest_exprs_len { - shortest_exprs.clone_from(args); - shortest_exprs_len = args_num; - } - } - _ => { - // if there is no AND predicate, it must be the shortest expression. - shortest_exprs = vec![or_predicate.clone()]; - break; - } - } - } - - // dedup shortest_exprs - shortest_exprs.dedup(); - - // Check each element in shortest_exprs to see if it's in all the OR arguments. - let mut exist_exprs: Vec = vec![]; - for expr in shortest_exprs.iter() { - let found = or_predicates.iter().all(|or_predicate| match or_predicate { - Predicate::And { args } => args.contains(expr), - _ => or_predicate == expr, - }); - if found { - exist_exprs.push((*expr).clone()); - } - } - if exist_exprs.is_empty() { - return Predicate::Or { - args: or_predicates, - }; - } - - // Rebuild the OR predicate. - // (A AND B) OR A will be optimized to A. - let mut new_or_predicates = vec![]; - for or_predicate in or_predicates.into_iter() { - match or_predicate { - Predicate::And { mut args } => { - args.retain(|expr| !exist_exprs.contains(expr)); - if !args.is_empty() { - if args.len() == 1 { - new_or_predicates.push(args.remove(0)); - } else { - new_or_predicates.push(Predicate::And { args }); - } - } else { - new_or_predicates.clear(); - break; - } - } - _ => { - if exist_exprs.contains(&or_predicate) { - new_or_predicates.clear(); - break; - } - } - } - } - if !new_or_predicates.is_empty() { - if new_or_predicates.len() == 1 { - exist_exprs.push(new_or_predicates.remove(0)); - } else { - exist_exprs.push(Predicate::Or { - args: flatten_or_predicates(new_or_predicates), - }); - } - } - - if exist_exprs.len() == 1 { - exist_exprs.remove(0) - } else { - Predicate::And { - args: flatten_and_predicates(exist_exprs), - } - } -} - -#[cfg(test)] -mod tests { - use crate::rewrite_disjunctive_predicate::{ - normalize_predicate, predicate, rewrite_predicate, Predicate, - }; - - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{and, col, lit, or}; - - #[test] - fn test_rewrite_predicate() -> Result<()> { - let equi_expr = col("t1.a").eq(col("t2.b")); - let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1)))); - let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2)))); - let expr = or( - and(equi_expr.clone(), gt_expr.clone()), - and(equi_expr.clone(), lt_expr.clone()), - ); - let predicate = predicate(expr)?; - assert_eq!( - predicate, - Predicate::Or { - args: vec![ - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - ] - }, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_predicate = rewrite_predicate(predicate); - assert_eq!( - rewritten_predicate, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Or { - args: vec![ - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_expr = normalize_predicate(rewritten_predicate); - assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr))); - Ok(()) - } -} diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index d3a5cfc467402..93f9c2d892976 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -134,7 +134,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } - let mut all_subqueryies = vec![]; + let mut all_subqueries = vec![]; let mut expr_to_rewrite_expr_map = HashMap::new(); let mut subquery_to_expr_map = HashMap::new(); for expr in projection.expr.iter() { @@ -143,15 +143,15 @@ impl OptimizerRule for ScalarSubqueryToJoin { for (subquery, _) in &subqueries { subquery_to_expr_map.insert(subquery.clone(), expr.clone()); } - all_subqueryies.extend(subqueries); + all_subqueries.extend(subqueries); expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } - if all_subqueryies.is_empty() { + if all_subqueries.is_empty() { return internal_err!("Expected subqueries not found in projection"); } // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); - for (subquery, alias) in all_subqueryies { + for (subquery, alias) in all_subqueries { if let Some((optimized_subquery, expr_check_map)) = build_join(&subquery, &cur_input, &alias)? { @@ -318,8 +318,7 @@ fn build_join( // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, subquery_alias) - .map(Option::Some) + replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; // join our sub query into the main plan @@ -625,11 +624,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) != orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -652,11 +661,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) < orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -680,11 +699,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1)"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -702,7 +731,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ + let expected = "Invalid (non-executable) plan after Analyzer\ \ncaused by\ \nError during planning: Scalar subquery should only return one column"; assert_analyzer_check_err(vec![], plan, expected); @@ -764,7 +793,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ + let expected = "Invalid (non-executable) plan after Analyzer\ \ncaused by\ \nError during planning: Scalar subquery should only return one column"; assert_analyzer_check_err(vec![], plan, expected); @@ -850,7 +879,7 @@ mod tests { Ok(()) } - /// Test for correlated scalar subquery filter with disjustions + /// Test for correlated scalar subquery filter with disjunctions #[test] fn scalar_subquery_disjunction() -> Result<()> { let sq = Arc::new( diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index ef9c8480dc86e..35d867a3a9249 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -34,13 +34,13 @@ use datafusion_common::{ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, - WindowFunctionDefinition, + and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Scalar, + Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ expr::{InList, InSubquery, WindowFunction}, - Scalar, + utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -48,6 +48,8 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; +use indexmap::IndexSet; +use regex::Regex; use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; @@ -487,7 +489,7 @@ enum ConstSimplifyResult { SimplifyRuntimeError(DataFusionError, Expr), } -impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { +impl TreeNodeRewriter for ConstEvaluator<'_> { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { @@ -709,7 +711,7 @@ impl<'a, S> Simplifier<'a, S> { } } -impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { +impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; /// rewrite the expression simplifying any constant expressions @@ -854,6 +856,27 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + // Eliminate common factors in conjunctions e.g + // (A AND B) OR (A AND C) -> A AND (B OR C) + Expr::BinaryExpr(BinaryExpr { + left, + op: Or, + right, + }) if has_common_conjunction(&left, &right) => { + let lhs: IndexSet = iter_conjunction_owned(*left).collect(); + let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right) + .partition(|e| lhs.contains(e) && !e.is_volatile()); + + let new_rhs = rhs.into_iter().reduce(and); + let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and); + let common_conjunction = common.into_iter().reduce(and).unwrap(); + + let new_expr = match (new_lhs, new_rhs) { + (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)), + (_, _) => common_conjunction, + }; + Transformed::yes(new_expr) + } // // Rules for AND @@ -1024,7 +1047,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(lit(0)) + Transformed::yes(Expr::from(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // @@ -1447,19 +1472,70 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like - Expr::Like(Like { - expr, - pattern, - negated, - escape_char: _, - case_insensitive: _, - }) if !is_null(&expr) - && matches!( - pattern.as_ref(), - Expr::Literal(Scalar{ value: ScalarValue::Utf8(Some(pattern_str)), ..}) if pattern_str == "%" - ) => - { - Transformed::yes(lit(!negated)) + Expr::Like(like) => { + // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291 + let escape_char = like.escape_char.unwrap_or('\\'); + match as_string_scalar(&like.pattern) { + Some((data_type, pattern_str)) => { + match pattern_str { + None => return Ok(Transformed::yes(lit_bool_null())), + Some(pattern_str) if pattern_str == "%" => { + // exp LIKE '%' is + // - when exp is not NULL, it's true + // - when exp is NULL, it's NULL + // exp NOT LIKE '%' is + // - when exp is not NULL, it's false + // - when exp is NULL, it's NULL + let result_for_non_null = lit(!like.negated); + Transformed::yes(if !info.nullable(&like.expr)? { + result_for_non_null + } else { + Expr::Case(Case { + expr: Some(Box::new(Expr::IsNotNull(like.expr))), + when_then_expr: vec![( + Box::new(lit(true)), + Box::new(result_for_non_null), + )], + else_expr: None, + }) + }) + } + Some(pattern_str) + if pattern_str.contains("%%") + && !pattern_str.contains(escape_char) => + { + // Repeated occurrences of wildcard are redundant so remove them + // exp LIKE '%%' --> exp LIKE '%' + let simplified_pattern = Regex::new("%%+") + .unwrap() + .replace_all(pattern_str, "%") + .to_string(); + Transformed::yes(Expr::Like(Like { + pattern: Box::new(to_string_scalar( + data_type, + Some(simplified_pattern), + )), + ..like + })) + } + Some(pattern_str) + if !pattern_str + .contains(['%', '_', escape_char].as_ref()) => + { + // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression + // TODO: handle escape characters + Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: like.expr.clone(), + op: if like.negated { NotEq } else { Eq }, + right: like.pattern.clone(), + })) + } + + Some(_pattern_str) => Transformed::no(Expr::Like(like)), + } + } + None => Transformed::no(Expr::Like(like)), + } } // a is not null/unknown --> true (if a is not nullable) @@ -1512,7 +1588,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); @@ -1552,7 +1628,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1572,7 +1648,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1592,7 +1668,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1612,7 +1688,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1632,7 +1708,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1656,6 +1732,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { + match expr { + Expr::Literal(scalar) => match scalar.value() { + ScalarValue::Utf8(s) => Some((DataType::Utf8, s)), + ScalarValue::LargeUtf8(s) => Some((DataType::LargeUtf8, s)), + ScalarValue::Utf8View(s) => Some((DataType::Utf8View, s)), + _ => None, + }, + _ => None, + } +} + +fn to_string_scalar(data_type: DataType, value: Option) -> Expr { + match data_type { + DataType::Utf8 => Expr::from(ScalarValue::Utf8(value)), + DataType::LargeUtf8 => Expr::from(ScalarValue::LargeUtf8(value)), + DataType::Utf8View => Expr::from(ScalarValue::Utf8View(value)), + _ => unreachable!(), + } +} + +fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { + let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect(); + iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile()) +} + // TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq_and_match_neg( left: &Expr, @@ -1783,6 +1885,8 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { #[cfg(test)] mod tests { + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ function::{ @@ -1793,15 +1897,13 @@ mod tests { *, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - use super::*; // ------------------------------ @@ -2165,11 +2267,11 @@ mod tests { #[test] fn test_simplify_modulo_by_one_non_null() { - let expr = col("c2_non_null") % lit(1); - let expected = lit(0); + let expr = col("c3_non_null") % lit(1); + let expected = lit(0_i64); assert_eq!(simplify(expr), expected); let expr = - col("c2_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)); + col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)); assert_eq!(simplify(expr), expected); } @@ -2747,16 +2849,31 @@ mod tests { assert_no_change(regex_match(col("c1"), lit("f_o"))); // empty cases - assert_change(regex_match(col("c1"), lit("")), lit(true)); - assert_change(regex_not_match(col("c1"), lit("")), lit(false)); - assert_change(regex_imatch(col("c1"), lit("")), lit(true)); - assert_change(regex_not_imatch(col("c1"), lit("")), lit(false)); + assert_change( + regex_match(col("c1"), lit("")), + if_not_null(col("c1"), true), + ); + assert_change( + regex_not_match(col("c1"), lit("")), + if_not_null(col("c1"), false), + ); + assert_change( + regex_imatch(col("c1"), lit("")), + if_not_null(col("c1"), true), + ); + assert_change( + regex_not_imatch(col("c1"), lit("")), + if_not_null(col("c1"), false), + ); // single character - assert_change(regex_match(col("c1"), lit("x")), like(col("c1"), "%x%")); + assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%"))); // single word - assert_change(regex_match(col("c1"), lit("foo")), like(col("c1"), "%foo%")); + assert_change( + regex_match(col("c1"), lit("foo")), + col("c1").like(lit("%foo%")), + ); // regular expressions that match an exact literal assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit(""))); @@ -2843,44 +2960,55 @@ mod tests { assert_no_change(regex_match(col("c1"), lit("$foo^"))); // regular expressions that match a partial literal - assert_change(regex_match(col("c1"), lit("^foo")), like(col("c1"), "foo%")); - assert_change(regex_match(col("c1"), lit("foo$")), like(col("c1"), "%foo")); + assert_change( + regex_match(col("c1"), lit("^foo")), + col("c1").like(lit("foo%")), + ); + assert_change( + regex_match(col("c1"), lit("foo$")), + col("c1").like(lit("%foo")), + ); assert_change( regex_match(col("c1"), lit("^foo|bar$")), - like(col("c1"), "foo%").or(like(col("c1"), "%bar")), + col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))), ); // OR-chain assert_change( regex_match(col("c1"), lit("foo|bar|baz")), - like(col("c1"), "%foo%") - .or(like(col("c1"), "%bar%")) - .or(like(col("c1"), "%baz%")), + col("c1") + .like(lit("%foo%")) + .or(col("c1").like(lit("%bar%"))) + .or(col("c1").like(lit("%baz%"))), ); assert_change( regex_match(col("c1"), lit("foo|x|baz")), - like(col("c1"), "%foo%") - .or(like(col("c1"), "%x%")) - .or(like(col("c1"), "%baz%")), + col("c1") + .like(lit("%foo%")) + .or(col("c1").like(lit("%x%"))) + .or(col("c1").like(lit("%baz%"))), ); assert_change( regex_not_match(col("c1"), lit("foo|bar|baz")), - not_like(col("c1"), "%foo%") - .and(not_like(col("c1"), "%bar%")) - .and(not_like(col("c1"), "%baz%")), + col("c1") + .not_like(lit("%foo%")) + .and(col("c1").not_like(lit("%bar%"))) + .and(col("c1").not_like(lit("%baz%"))), ); // both anchored expressions (translated to equality) and unanchored assert_change( regex_match(col("c1"), lit("foo|^x$|baz")), - like(col("c1"), "%foo%") + col("c1") + .like(lit("%foo%")) .or(col("c1").eq(lit("x"))) - .or(like(col("c1"), "%baz%")), + .or(col("c1").like(lit("%baz%"))), ); assert_change( regex_not_match(col("c1"), lit("foo|^bar$|baz")), - not_like(col("c1"), "%foo%") + col("c1") + .not_like(lit("%foo%")) .and(col("c1").not_eq(lit("bar"))) - .and(not_like(col("c1"), "%baz%")), + .and(col("c1").not_like(lit("%baz%"))), ); // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION) assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc"))); @@ -2930,46 +3058,6 @@ mod tests { }) } - fn like(expr: Expr, pattern: &str) -> Expr { - Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(lit(pattern)), - escape_char: None, - case_insensitive: false, - }) - } - - fn not_like(expr: Expr, pattern: &str) -> Expr { - Expr::Like(Like { - negated: true, - expr: Box::new(expr), - pattern: Box::new(lit(pattern)), - escape_char: None, - case_insensitive: false, - }) - } - - fn ilike(expr: Expr, pattern: &str) -> Expr { - Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(lit(pattern)), - escape_char: None, - case_insensitive: true, - }) - } - - fn not_ilike(expr: Expr, pattern: &str) -> Expr { - Expr::Like(Like { - negated: true, - expr: Box::new(expr), - pattern: Box::new(lit(pattern)), - escape_char: None, - case_insensitive: true, - }) - } - // ------------------------------ // ----- Simplifier tests ------- // ------------------------------ @@ -3575,33 +3663,122 @@ mod tests { } #[test] - fn test_like_and_ilke() { - // test non-null values - let expr = like(col("c1"), "%"); + fn test_like_and_ilike() { + let null = lit(ScalarValue::Utf8(None)); + + // expr [NOT] [I]LIKE NULL + let expr = col("c1").like(null.clone()); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = col("c1").not_like(null.clone()); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = col("c1").ilike(null.clone()); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = col("c1").not_ilike(null.clone()); + assert_eq!(simplify(expr), lit_bool_null()); + + // expr [NOT] [I]LIKE '%' + let expr = col("c1").like(lit("%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), true)); + + let expr = col("c1").not_like(lit("%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), false)); + + let expr = col("c1").ilike(lit("%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), true)); + + let expr = col("c1").not_ilike(lit("%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), false)); + + // expr [NOT] [I]LIKE '%%' + let expr = col("c1").like(lit("%%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), true)); + + let expr = col("c1").not_like(lit("%%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), false)); + + let expr = col("c1").ilike(lit("%%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), true)); + + let expr = col("c1").not_ilike(lit("%%")); + assert_eq!(simplify(expr), if_not_null(col("c1"), false)); + + // not_null_expr [NOT] [I]LIKE '%' + let expr = col("c1_non_null").like(lit("%")); assert_eq!(simplify(expr), lit(true)); - let expr = not_like(col("c1"), "%"); + let expr = col("c1_non_null").not_like(lit("%")); assert_eq!(simplify(expr), lit(false)); - let expr = ilike(col("c1"), "%"); + let expr = col("c1_non_null").ilike(lit("%")); assert_eq!(simplify(expr), lit(true)); - let expr = not_ilike(col("c1"), "%"); + let expr = col("c1_non_null").not_ilike(lit("%")); assert_eq!(simplify(expr), lit(false)); - // test null values - let null = lit(ScalarValue::Utf8(None)); - let expr = like(null.clone(), "%"); + // not_null_expr [NOT] [I]LIKE '%%' + let expr = col("c1_non_null").like(lit("%%")); + assert_eq!(simplify(expr), lit(true)); + + let expr = col("c1_non_null").not_like(lit("%%")); + assert_eq!(simplify(expr), lit(false)); + + let expr = col("c1_non_null").ilike(lit("%%")); + assert_eq!(simplify(expr), lit(true)); + + let expr = col("c1_non_null").not_ilike(lit("%%")); + assert_eq!(simplify(expr), lit(false)); + + // null_constant [NOT] [I]LIKE '%' + let expr = null.clone().like(lit("%")); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().not_like(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_like(null.clone(), "%"); + let expr = null.clone().ilike(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = ilike(null.clone(), "%"); + let expr = null.clone().not_ilike(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_ilike(null, "%"); + // null_constant [NOT] [I]LIKE '%%' + let expr = null.clone().like(lit("%%")); assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().not_like(lit("%%")); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().ilike(lit("%%")); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().not_ilike(lit("%%")); + assert_eq!(simplify(expr), lit_bool_null()); + + // null_constant [NOT] [I]LIKE 'a%' + let expr = null.clone().like(lit("a%")); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().not_like(lit("a%")); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().ilike(lit("a%")); + assert_eq!(simplify(expr), lit_bool_null()); + + let expr = null.clone().not_ilike(lit("a%")); + assert_eq!(simplify(expr), lit_bool_null()); + + // expr [NOT] [I]LIKE with pattern without wildcards + let expr = col("c1").like(lit("a")); + assert_eq!(simplify(expr), col("c1").eq(lit("a"))); + let expr = col("c1").not_like(lit("a")); + assert_eq!(simplify(expr), col("c1").not_eq(lit("a"))); + let expr = col("c1").like(lit("a_")); + assert_eq!(simplify(expr), col("c1").like(lit("a_"))); + let expr = col("c1").not_like(lit("a_")); + assert_eq!(simplify(expr), col("c1").not_like(lit("a_"))); } #[test] @@ -3743,11 +3920,52 @@ mod tests { assert_eq!(expr, expected); assert_eq!(num_iter, 2); } + + fn boolean_test_schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("A", DataType::Boolean, false), + Field::new("B", DataType::Boolean, false), + Field::new("C", DataType::Boolean, false), + Field::new("D", DataType::Boolean, false), + ]) + .to_dfschema_ref() + .unwrap() + } + + #[test] + fn simplify_common_factor_conjunction_in_disjunction() { + let props = ExecutionProps::new(); + let schema = boolean_test_schema(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + + let a = || col("A"); + let b = || col("B"); + let c = || col("C"); + let d = || col("D"); + + // (A AND B) OR (A AND C) -> A AND (B OR C) + let expr = a().and(b()).or(a().and(c())); + let expected = a().and(b().or(c())); + + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D) + let expr = a().and(b()).or(a().and(c())).or(a().and(d())); + let expected = a().and(b().or(c()).or(d())); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // A OR (B AND C AND A) -> A + let expr = a().or(b().and(c().and(a()))); + let expected = a(); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + } + #[test] fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3761,7 +3979,7 @@ mod tests { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3811,7 +4029,7 @@ mod tests { fn accumulator( &self, - _acc_args: function::AccumulatorArgs, + _acc_args: AccumulatorArgs, ) -> Result> { unimplemented!("not needed for tests") } @@ -3841,9 +4059,8 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = Expr::WindowFunction( - datafusion_expr::expr::WindowFunction::new(udwf, vec![]), - ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3851,9 +4068,8 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = Expr::WindowFunction( - datafusion_expr::expr::WindowFunction::new(udwf, vec![]), - ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -3898,7 +4114,10 @@ mod tests { } } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!("not needed for tests") } @@ -3906,4 +4125,78 @@ mod tests { unimplemented!("not needed for tests") } } + #[derive(Debug)] + struct VolatileUdf { + signature: Signature, + } + + impl VolatileUdf { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for VolatileUdf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "VolatileUdf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int16) + } + } + + #[test] + fn test_optimize_volatile_conditions() { + let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new())); + let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + { + let expr = rand + .clone() + .eq(lit(0)) + .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))); + + assert_eq!(simplify(expr.clone()), expr); + } + + { + let expr = col("column1") + .eq(lit(2)) + .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))); + + assert_eq!(simplify(expr), col("column1").eq(lit(2))); + } + + { + let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col( + "column1", + ) + .eq(lit(2)) + .and(rand.clone().eq(lit(0)))); + + assert_eq!( + simplify(expr), + col("column1") + .eq(lit(2)) + .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0)))) + ); + } + } + + fn if_not_null(expr: Expr, then: bool) -> Expr { + Expr::Case(Case { + expr: Some(expr.is_not_null().into()), + when_then_expr: vec![(lit(true).into(), lit(then).into())], + else_expr: None, + }) + } } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 8047485c1c7aa..6877269e1d073 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -57,7 +57,7 @@ impl<'a> GuaranteeRewriter<'a> { } } -impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { +impl TreeNodeRewriter for GuaranteeRewriter<'_> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index c0142ae0fc5a6..200f1f159d813 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -208,7 +208,7 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; + let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; assert_optimized_plan_eq(table_scan, expected) } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index f9250fdd09148..001e924c71ca7 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -67,16 +67,21 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// returns true if `needle` is found in a chain of search_op /// expressions. Such as: (A AND B) AND C -pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { +fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { - expr_contains(left, needle, search_op) - || expr_contains(right, needle, search_op) + expr_contains_inner(left, needle, search_op) + || expr_contains_inner(right, needle, search_op) } _ => expr == needle, } } +/// check volatile calls and return if expr contains needle +pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { + expr_contains_inner(expr, needle, search_op) && !needle.is_volatile() +} + /// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor /// expressions. Such as: A ^ (A ^ (B ^ A)) pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr { @@ -218,7 +223,7 @@ pub fn is_false(expr: &Expr) -> bool { /// returns true if `haystack` looks like (needle OP X) or (X OP needle) pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile()) } /// returns true if `not_expr` is !`expr` (not) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 74251e5caad2b..c8f3a4bc7859c 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{internal_err, tree_node::Transformed, DataFusionError, Result}; +use datafusion_common::{ + internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, +}; use datafusion_expr::builder::project; use datafusion_expr::{ col, @@ -31,8 +33,6 @@ use datafusion_expr::{ Expr, }; -use hashbrown::HashSet; - /// single distinct to group by optimizer rule /// ```text /// Before: @@ -279,7 +279,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr::{self, GroupingSet}; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; @@ -288,7 +288,7 @@ mod tests { use datafusion_functions_aggregate::sum::sum_udaf; fn max_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( max_udaf(), vec![expr], true, @@ -569,7 +569,7 @@ mod tests { let table_scan = test_table_scan()?; // sum(a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, @@ -612,7 +612,7 @@ mod tests { let table_scan = test_table_scan()?; // SUM(a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index cabeafd8e7dea..94d07a0791b3b 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -133,20 +133,6 @@ pub fn assert_analyzed_plan_with_config_eq( Ok(()) } -pub fn assert_analyzed_plan_ne( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_ne!(formatted_plan, expected); - - Ok(()) -} - pub fn assert_analyzed_plan_eq_display_indent( rule: Arc, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 779d1e8950b0e..be4d0038f10f0 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -146,7 +146,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; is_supported_type(&left_type) && is_supported_type(&right_type) - && op.is_comparison_operator() + && op.supports_propagation() } => { match (left.as_mut(), right.as_mut()) { @@ -281,7 +281,7 @@ fn is_supported_type(data_type: &DataType) -> bool { || is_supported_dictionary_type(data_type) } -/// Returns true if [[UnwrapCastExprRewriter]] suppors this numeric type +/// Returns true if [[UnwrapCastExprRewriter]] support this numeric type fn is_supported_numeric_type(data_type: &DataType) -> bool { matches!( data_type, @@ -475,12 +475,7 @@ fn try_cast_string_literal( lit_value: &Scalar, target_type: &DataType, ) -> Option { - let string_value = match lit_value.value() { - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { - s.clone() - } - _ => return None, - }; + let string_value = lit_value.value().try_as_str()?.map(|s| s.to_string()); let scalar_value = match target_type { DataType::Utf8 => ScalarValue::Utf8(string_value), DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6972c16c0ddf8..93802212e0a04 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,11 +21,18 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, Result}; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use arrow::array::{new_null_array, Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::{logical_plan::LogicalPlan, Expr}; - +use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr}; +use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; +use std::sync::Arc; /// Re-export of `NamesPreserver` for backwards compatibility, /// as it was initially placed here and then moved elsewhere. @@ -80,23 +87,6 @@ pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> == column_refs.len() } -pub(crate) fn collect_subquery_cols( - exprs: &[Expr], - subquery_schema: &DFSchema, -) -> Result> { - exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { - let mut using_cols: Vec = vec![]; - for col in expr.column_refs().into_iter() { - if subquery_schema.has_column(col) { - using_cols.push(col.clone()); - } - } - - cols.extend(using_cols); - Result::<_>::Ok(cols) - }) -} - pub(crate) fn replace_qualified_name( expr: Expr, cols: &BTreeSet, @@ -117,3 +107,161 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { debug!("{description}:\n{}\n", plan.display_indent()); trace!("{description}::\n{}\n", plan.display_indent_schema()); } + +/// Determine whether a predicate can restrict NULLs. e.g. +/// `c0 > 8` return true; +/// `c0 IS NULL` return false. +pub fn is_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + if matches!(predicate, Expr::Column(_)) { + return Ok(true); + } + + static DUMMY_COL_NAME: &str = "?"; + let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); + let input_schema = DFSchema::try_from(schema.clone())?; + let column = new_null_array(&DataType::Null, 1); + let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?; + let execution_props = ExecutionProps::default(); + let null_column = Column::from_name(DUMMY_COL_NAME); + + let join_cols_to_replace = join_cols_of_predicate + .into_iter() + .map(|column| (column, &null_column)) + .collect::>(); + + let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?; + let coerced_predicate = coerce(replaced_predicate, &input_schema)?; + let phys_expr = + create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?; + + let result_type = phys_expr.data_type(&schema)?; + if !matches!(&result_type, DataType::Boolean) { + return Ok(false); + } + + // If result is single `true`, return false; + // If result is single `NULL` or `false`, return true; + Ok(match phys_expr.evaluate(&input_batch)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar.value(), + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }) +} + +fn coerce(expr: Expr, schema: &DFSchema) -> Result { + let mut expr_rewrite = TypeCoercionRewriter { schema }; + expr.rewrite(&mut expr_rewrite).data() +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator}; + + #[test] + fn expr_is_restrict_null_predicate() -> Result<()> { + let test_cases = vec![ + // a + (col("a"), true), + // a IS NULL + (is_null(col("a")), false), + // a IS NOT NULL + (Expr::IsNotNull(Box::new(col("a"))), true), + // a = NULL + ( + binary_expr(col("a"), Operator::Eq, Expr::from(ScalarValue::Null)), + true, + ), + // a > 8 + (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), + // a <= 8 + (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), + // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .when(lit(0i64), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + true, + ), + // CASE a WHEN 1 THEN true ELSE false END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + true, + ), + // CASE a WHEN 0 THEN false ELSE true END + ( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + false, + ), + // (CASE a WHEN 0 THEN false ELSE true END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + Operator::Or, + lit(false), + ), + false, + ), + // (CASE a WHEN 0 THEN true ELSE false END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + true, + ), + // a IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false), + true, + ), + // a NOT IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true), + true, + ), + // a IN (NULL) + ( + in_list(col("a"), vec![Expr::from(ScalarValue::Null)], false), + true, + ), + // a NOT IN (NULL) + ( + in_list(col("a"), vec![Expr::from(ScalarValue::Null)], true), + true, + ), + ]; + + let column_a = Column::from_name("a"); + for (predicate, expected) in test_cases { + let join_cols_of_predicate = std::iter::once(&column_a); + let actual = + is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + assert_eq!(actual, expected, "{}", predicate); + } + + Ok(()) + } +} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 236167985790d..b9073f5ac881e 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -22,11 +22,13 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{assert_contains, plan_err, Result}; +use datafusion_expr::sqlparser::dialect::PostgreSqlDialect; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_optimizer::analyzer::type_coercion::TypeCoercionRewriter; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -76,8 +78,9 @@ fn subquery_filter_with_cast() -> Result<()> { \n SubqueryAlias: __scalar_sq_1\ \n Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]]\ \n Projection: test.col_int32\ - \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ - \n TableScan: test projection=[col_int32, col_utf8]"; + \n Filter: __common_expr_5 >= Date32(\"2002-05-08\") AND __common_expr_5 <= Date32(\"2002-05-13\")\ + \n Projection: CAST(test.col_utf8 AS Date32) AS __common_expr_5, test.col_int32\ + \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -386,6 +389,32 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() { assert_eq!(expected, format!("{plan}")); } +// The test should return an error +// because the wildcard didn't be expanded before type coercion +#[test] +fn test_union_coercion_with_wildcard() -> Result<()> { + let dialect = PostgreSqlDialect {}; + let context_provider = MyContextProvider::default(); + let sql = "select * from (SELECT col_int32, col_uint32 FROM test) union all select * from(SELECT col_uint32, col_int32 FROM test)"; + let statements = Parser::parse_sql(&dialect, sql)?; + let sql_to_rel = SqlToRel::new(&context_provider); + let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; + + if let LogicalPlan::Union(union) = logical_plan { + let err = TypeCoercionRewriter::coerce_union(union) + .err() + .unwrap() + .to_string(); + assert_contains!( + err, + "Error during planning: Wildcard should be expanded before type coercion" + ); + } else { + panic!("Expected Union plan"); + } + Ok(()) +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 45ccb08e52e91..14d6ca64d15e6 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -41,4 +41,4 @@ arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr-common = { workspace = true } hashbrown = { workspace = true } -rand = { workspace = true } +itertools = { workspace = true } diff --git a/datafusion/physical-expr-common/LICENSE.txt b/datafusion/physical-expr-common/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/physical-expr-common/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/physical-expr-common/NOTICE.txt b/datafusion/physical-expr-common/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/physical-expr-common/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index f320ebcc06b56..8febbdd5b1f90 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -28,10 +28,10 @@ use arrow::array::{ use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; -use std::mem; +use std::mem::{size_of, swap}; use std::ops::Range; use std::sync::Arc; @@ -104,8 +104,9 @@ impl ArrowBytesSet { /// `Binary`, and `LargeBinary`) values that can produce the set of keys on /// output as `GenericBinaryArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringArray` / `BinaryArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// @@ -214,7 +215,7 @@ where /// Should the output be String or Binary? output_type: OutputType, /// Underlying hash set for each distinct value - map: hashbrown::raw::RawTable>, + map: hashbrown::hash_table::HashTable>, /// Total size of the map in bytes map_size: usize, /// In progress arrow `Buffer` containing all values @@ -245,7 +246,7 @@ where pub fn new(output_type: OutputType) -> Self { Self { output_type, - map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), + map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], // first offset is always 0 @@ -259,7 +260,7 @@ where /// the same output type pub fn take(&mut self) -> Self { let mut new_self = Self::new(self.output_type); - mem::swap(self, &mut new_self); + swap(self, &mut new_self); new_self } @@ -348,7 +349,7 @@ where let batch_hashes = &mut self.hashes_buffer; batch_hashes.clear(); batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) + create_hashes(&[Arc::clone(values)], &self.random_state, batch_hashes) // hash is supported for all types and create_hashes only // returns errors for unsupported types .unwrap(); @@ -386,7 +387,7 @@ where let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); // is value is already present in the set? - let entry = self.map.get_mut(hash, |header| { + let entry = self.map.find_mut(hash, |header| { // compare value if hashes match if header.len != value_len { return false; @@ -424,7 +425,7 @@ where // value is not "small" else { // Check if the value is already present in the set - let entry = self.map.get_mut(hash, |header| { + let entry = self.map.find_mut(hash, |header| { // compare value if hashes match if header.len != value_len { return false; @@ -544,7 +545,7 @@ where /// this set, not including `self` pub fn size(&self) -> usize { self.map_size - + self.buffer.capacity() * mem::size_of::() + + self.buffer.capacity() * size_of::() + self.offsets.allocated_size() + self.hashes_buffer.allocated_size() } @@ -574,7 +575,7 @@ where } /// Maximum size of a value that can be inlined in the hash table -const SHORT_VALUE_LEN: usize = mem::size_of::(); +const SHORT_VALUE_LEN: usize = size_of::(); /// Entry in the hash table -- see [`ArrowBytesMap`] for more details #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index bdcf7bbacc696..7ce943030a453 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray; use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::fmt::Debug; use std::sync::Arc; @@ -88,8 +88,9 @@ impl ArrowBytesViewSet { /// values that can produce the set of keys on /// output as `GenericBinaryViewArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringViewArray` / `BinaryViewArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// @@ -113,7 +114,6 @@ impl ArrowBytesViewSet { /// This map is used by the special `COUNT DISTINCT` aggregate function to /// store the distinct values, and by the `GROUP BY` operator to store /// group values when they are a single string array. - pub struct ArrowBytesViewMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, @@ -121,7 +121,7 @@ where /// Should the output be StringView or BinaryView? output_type: OutputType, /// Underlying hash set for each distinct value - map: hashbrown::raw::RawTable>, + map: hashbrown::hash_table::HashTable>, /// Total size of the map in bytes map_size: usize, @@ -147,7 +147,7 @@ where pub fn new(output_type: OutputType) -> Self { Self { output_type, - map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), + map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, builder: GenericByteViewBuilder::new(), random_state: RandomState::new(), @@ -243,7 +243,7 @@ where let batch_hashes = &mut self.hashes_buffer; batch_hashes.clear(); batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) + create_hashes(&[Arc::clone(values)], &self.random_state, batch_hashes) // hash is supported for all types and create_hashes only // returns errors for unsupported types .unwrap(); @@ -273,7 +273,7 @@ where // get the value as bytes let value: &[u8] = value.as_ref(); - let entry = self.map.get_mut(hash, |header| { + let entry = self.map.find_mut(hash, |header| { let v = self.builder.get_value(header.view_idx); if v.len() != value.len() { @@ -392,7 +392,7 @@ where #[cfg(test)] mod tests { use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; - use hashbrown::HashMap; + use datafusion_common::HashMap; use super::*; diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 6ecd47839546d..721c9e76e3dc8 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -8,7 +8,7 @@ // // http://www.apache.org/licenses/LICENSE-2.0 // -// UnLt required by applicable law or agreed to in writing, +// Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 7e2ea0c49397f..a05f1c96306f2 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + //! Physical Expr Common packages for [DataFusion] //! This package contains high level PhysicalExpr trait //! diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index cc725cf2cefbb..b1b889136b35f 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -31,6 +31,9 @@ use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; +/// Shared [`PhysicalExpr`]. +pub type PhysicalExprRef = Arc; + /// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`. /// /// `PhysicalExpr` knows its type, nullability and can be evaluated directly on @@ -52,7 +55,7 @@ use datafusion_expr_common::sort_properties::ExprProperties; /// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html -pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { +pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -141,38 +144,6 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { Ok(Some(vec![])) } - /// Update the hash `state` with this expression requirements from - /// [`Hash`]. - /// - /// This method is required to support hashing [`PhysicalExpr`]s. To - /// implement it, typically the type implementing - /// [`PhysicalExpr`] implements [`Hash`] and - /// then the following boiler plate is used: - /// - /// # Example: - /// ``` - /// // User defined expression that derives Hash - /// #[derive(Hash, Debug, PartialEq, Eq)] - /// struct MyExpr { - /// val: u64 - /// } - /// - /// // impl PhysicalExpr { - /// // ... - /// # impl MyExpr { - /// // Boiler plate to call the derived Hash impl - /// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { - /// use std::hash::Hash; - /// let mut s = state; - /// self.hash(&mut s); - /// } - /// // } - /// # } - /// ``` - /// Note: [`PhysicalExpr`] is not constrained by [`Hash`] - /// directly because it must remain object safe. - fn dyn_hash(&self, _state: &mut dyn Hasher); - /// Calculates the properties of this [`PhysicalExpr`] based on its /// children's properties (i.e. order and range), recursively aggregating /// the information from its children. In cases where the [`PhysicalExpr`] @@ -183,6 +154,40 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { } } +/// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object +/// safe. To ease implementation blanket implementation is provided for [`Eq`] types. +pub trait DynEq { + fn dyn_eq(&self, other: &dyn Any) -> bool; +} + +impl DynEq for T { + fn dyn_eq(&self, other: &dyn Any) -> bool { + other.downcast_ref::() == Some(self) + } +} + +impl PartialEq for dyn PhysicalExpr { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_any()) + } +} + +impl Eq for dyn PhysicalExpr {} + +/// [`PhysicalExpr`] can't be constrained by [`Hash`] directly because it must remain +/// object safe. To ease implementation blanket implementation is provided for [`Hash`] +/// types. +pub trait DynHash { + fn dyn_hash(&self, _state: &mut dyn Hasher); +} + +impl DynHash for T { + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.type_id().hash(&mut state); + self.hash(&mut state) + } +} + impl Hash for dyn PhysicalExpr { fn hash(&self, state: &mut H) { self.dyn_hash(state); @@ -210,6 +215,7 @@ pub fn with_new_children_if_necessary( } } +#[deprecated(since = "44.0.0")] pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { if any.is::>() { any.downcast_ref::>() @@ -227,11 +233,24 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { /// Returns [`Display`] able a list of [`PhysicalExpr`] /// /// Example output: `[a + 1, b]` -pub fn format_physical_expr_list(exprs: &[Arc]) -> impl Display + '_ { - struct DisplayWrapper<'a>(&'a [Arc]); - impl<'a> Display for DisplayWrapper<'a> { +pub fn format_physical_expr_list(exprs: T) -> impl Display +where + T: IntoIterator, + T::Item: Display, + T::IntoIter: Clone, +{ + struct DisplayWrapper(I) + where + I: Iterator + Clone, + I::Item: Display; + + impl Display for DisplayWrapper + where + I: Iterator + Clone, + I::Item: Display, + { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut iter = self.0.iter(); + let mut iter = self.0.clone(); write!(f, "[")?; if let Some(expr) = iter.next() { write!(f, "{}", expr)?; @@ -243,5 +262,6 @@ pub fn format_physical_expr_list(exprs: &[Arc]) -> impl Displa Ok(()) } } - DisplayWrapper(exprs) + + DisplayWrapper(exprs.into_iter()) } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 6c4bf156ce568..b150d3dc9bd38 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -17,18 +17,20 @@ //! Sort expressions +use crate::physical_expr::PhysicalExpr; +use std::fmt; use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::ops::Deref; -use std::sync::Arc; - -use crate::physical_expr::PhysicalExpr; +use std::ops::{Deref, Index, Range, RangeFrom, RangeTo}; +use std::sync::{Arc, LazyLock}; +use std::vec::IntoIter; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_expr_common::columnar_value::ColumnarValue; +use itertools::Itertools; /// Represents Sort operation for a column in a RecordBatch /// @@ -56,14 +58,10 @@ use datafusion_expr_common::columnar_value::ColumnarValue; /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} -/// # fn dyn_hash(&self, _state: &mut dyn Hasher) {todo!()} /// # } /// # impl Display for MyPhysicalExpr { /// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") } /// # } -/// # impl PartialEq for MyPhysicalExpr { -/// # fn eq(&self, _other: &dyn Any) -> bool { true } -/// # } /// # fn col(name: &str) -> Arc { Arc::new(MyPhysicalExpr) } /// // Sort by a ASC /// let options = SortOptions::default(); @@ -143,7 +141,7 @@ impl Hash for PhysicalSortExpr { } impl Display for PhysicalSortExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{} {}", self.expr, to_str(&self.options)) } } @@ -183,26 +181,6 @@ impl PhysicalSortExpr { .map_or(true, |opts| self.options.descending == opts.descending) } } - - /// Returns a [`Display`]able list of `PhysicalSortExpr`. - pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { - struct DisplayableList<'a>(&'a [PhysicalSortExpr]); - impl<'a> Display for DisplayableList<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let mut first = true; - for sort_expr in self.0 { - if first { - first = false; - } else { - write!(f, ",")?; - } - write!(f, "{}", sort_expr)?; - } - Ok(()) - } - } - DisplayableList(input) - } } /// Represents sort requirement associated with a plan @@ -237,7 +215,7 @@ impl From for PhysicalSortExpr { /// If options is `None`, the default sort options `ASC, NULLS LAST` is used. /// /// The default is picked to be consistent with - /// PostgreSQL: + /// PostgreSQL: fn from(value: PhysicalSortRequirement) -> Self { let options = value.options.unwrap_or(SortOptions { descending: false, @@ -260,7 +238,7 @@ impl PartialEq for PhysicalSortRequirement { } impl Display for PhysicalSortRequirement { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let opts_string = self.options.as_ref().map_or("NA", to_str); write!(f, "{} {}", self.expr, opts_string) } @@ -273,8 +251,8 @@ pub fn format_physical_sort_requirement_list( exprs: &[PhysicalSortRequirement], ) -> impl Display + '_ { struct DisplayWrapper<'a>(&'a [PhysicalSortRequirement]); - impl<'a> Display for DisplayWrapper<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + impl Display for DisplayWrapper<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut iter = self.0.iter(); write!(f, "[")?; if let Some(expr) = iter.next() { @@ -313,43 +291,24 @@ impl PhysicalSortRequirement { /// Returns whether this requirement is equal or more specific than `other`. pub fn compatible(&self, other: &PhysicalSortRequirement) -> bool { self.expr.eq(&other.expr) - && other.options.map_or(true, |other_opts| { - self.options.map_or(false, |opts| opts == other_opts) - }) + && other + .options + .map_or(true, |other_opts| self.options == Some(other_opts)) } - /// Returns [`PhysicalSortRequirement`] that requires the exact - /// sort of the [`PhysicalSortExpr`]s in `ordering` - /// - /// This method takes `&'a PhysicalSortExpr` to make it easy to - /// use implementing [`ExecutionPlan::required_input_ordering`]. - /// - /// [`ExecutionPlan::required_input_ordering`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#method.required_input_ordering + #[deprecated(since = "43.0.0", note = "use LexRequirement::from_lex_ordering")] pub fn from_sort_exprs<'a>( ordering: impl IntoIterator, ) -> LexRequirement { - LexRequirement::new( - ordering - .into_iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect(), - ) + let ordering = ordering.into_iter().cloned().collect(); + LexRequirement::from_lex_ordering(ordering) } - - /// Converts an iterator of [`PhysicalSortRequirement`] into a Vec - /// of [`PhysicalSortExpr`]s. - /// - /// This function converts `PhysicalSortRequirement` to `PhysicalSortExpr` - /// for each entry in the input. If required ordering is None for an entry - /// default ordering `ASC, NULLS LAST` if given (see the `PhysicalSortExpr::from`). + #[deprecated(since = "43.0.0", note = "use LexOrdering::from_lex_requirement")] pub fn to_sort_exprs( requirements: impl IntoIterator, - ) -> Vec { - requirements - .into_iter() - .map(PhysicalSortExpr::from) - .collect() + ) -> LexOrdering { + let requirements = requirements.into_iter().collect(); + LexOrdering::from_lex_requirement(requirements) } } @@ -364,12 +323,239 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` is an alias for the type `Vec`, which represents +///`LexOrdering` contains a `Vec`, which represents /// a lexicographical ordering. -pub type LexOrdering = Vec; +/// +/// For example, `vec![a ASC, b DESC]` represents a lexicographical ordering +/// that first sorts by column `a` in ascending order, then by column `b` in +/// descending order. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +pub struct LexOrdering { + inner: Vec, +} + +impl AsRef for LexOrdering { + fn as_ref(&self) -> &LexOrdering { + self + } +} + +impl LexOrdering { + /// Creates a new [`LexOrdering`] from a vector + pub fn new(inner: Vec) -> Self { + Self { inner } + } + + /// Return an empty LexOrdering (no expressions) + pub fn empty() -> &'static LexOrdering { + static EMPTY_ORDER: LazyLock = LazyLock::new(LexOrdering::default); + &EMPTY_ORDER + } + + /// Returns the number of elements that can be stored in the LexOrdering + /// without reallocating. + pub fn capacity(&self) -> usize { + self.inner.capacity() + } + + /// Clears the LexOrdering, removing all elements. + pub fn clear(&mut self) { + self.inner.clear() + } + + /// Returns `true` if the LexOrdering contains `expr` + pub fn contains(&self, expr: &PhysicalSortExpr) -> bool { + self.inner.contains(expr) + } + + /// Add all elements from `iter` to the LexOrdering. + pub fn extend>(&mut self, iter: I) { + self.inner.extend(iter) + } + + /// Remove all elements from the LexOrdering where `f` evaluates to `false`. + pub fn retain(&mut self, f: F) + where + F: FnMut(&PhysicalSortExpr) -> bool, + { + self.inner.retain(f) + } + + /// Returns `true` if the LexOrdering contains no elements. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns an iterator over each `&PhysicalSortExpr` in the LexOrdering. + pub fn iter(&self) -> core::slice::Iter { + self.inner.iter() + } + + /// Returns the number of elements in the LexOrdering. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Removes the last element from the LexOrdering and returns it, or `None` if it is empty. + pub fn pop(&mut self) -> Option { + self.inner.pop() + } + + /// Appends an element to the back of the LexOrdering. + pub fn push(&mut self, physical_sort_expr: PhysicalSortExpr) { + self.inner.push(physical_sort_expr) + } + + /// Truncates the LexOrdering, keeping only the first `len` elements. + pub fn truncate(&mut self, len: usize) { + self.inner.truncate(len) + } + + /// Merge the contents of `other` into `self`, removing duplicates. + pub fn merge(mut self, other: LexOrdering) -> Self { + self.inner = self.inner.into_iter().chain(other).unique().collect(); + self + } + + /// Converts a `LexRequirement` into a `LexOrdering`. + /// + /// This function converts [`PhysicalSortRequirement`] to [`PhysicalSortExpr`] + /// for each entry in the input. + /// + /// If the required ordering is `None` for an entry in `requirement`, the + /// default ordering `ASC, NULLS LAST` is used (see + /// [`PhysicalSortExpr::from`]). + pub fn from_lex_requirement(requirement: LexRequirement) -> LexOrdering { + requirement + .into_iter() + .map(PhysicalSortExpr::from) + .collect() + } + + /// Collapse a `LexOrdering` into a new duplicate-free `LexOrdering` based on expression. + /// + /// This function filters duplicate entries that have same physical + /// expression inside, ignoring [`SortOptions`]. For example: + /// + /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. + pub fn collapse(self) -> Self { + let mut output = LexOrdering::default(); + for item in self { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output + } + + /// Transforms each `PhysicalSortExpr` in the `LexOrdering` + /// in place using the provided closure `f`. + pub fn transform(&mut self, f: F) + where + F: FnMut(&mut PhysicalSortExpr), + { + self.inner.iter_mut().for_each(f); + } +} + +impl From> for LexOrdering { + fn from(value: Vec) -> Self { + Self::new(value) + } +} + +impl From for LexOrdering { + fn from(value: LexRequirement) -> Self { + Self::from_lex_requirement(value) + } +} + +/// Convert a `LexOrdering` into a `Arc[]` for fast copies +impl From for Arc<[PhysicalSortExpr]> { + fn from(value: LexOrdering) -> Self { + value.inner.into() + } +} + +impl Deref for LexOrdering { + type Target = [PhysicalSortExpr]; + + fn deref(&self) -> &Self::Target { + self.inner.as_slice() + } +} + +impl Display for LexOrdering { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let mut first = true; + for sort_expr in &self.inner { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{}", sort_expr)?; + } + Ok(()) + } +} + +impl FromIterator for LexOrdering { + fn from_iter>(iter: T) -> Self { + let mut lex_ordering = LexOrdering::default(); + + for i in iter { + lex_ordering.push(i); + } + + lex_ordering + } +} + +impl Index for LexOrdering { + type Output = PhysicalSortExpr; + + fn index(&self, index: usize) -> &Self::Output { + &self.inner[index] + } +} + +impl Index> for LexOrdering { + type Output = [PhysicalSortExpr]; + + fn index(&self, range: Range) -> &Self::Output { + &self.inner[range] + } +} + +impl Index> for LexOrdering { + type Output = [PhysicalSortExpr]; + + fn index(&self, range_from: RangeFrom) -> &Self::Output { + &self.inner[range_from] + } +} + +impl Index> for LexOrdering { + type Output = [PhysicalSortExpr]; + + fn index(&self, range_to: RangeTo) -> &Self::Output { + &self.inner[range_to] + } +} + +impl IntoIterator for LexOrdering { + type Item = PhysicalSortExpr; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} ///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents /// a reference to a lexicographical ordering. +#[deprecated(since = "43.0.0", note = "use &LexOrdering instead")] pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; ///`LexRequirement` is an struct containing a `Vec`, which @@ -384,6 +570,10 @@ impl LexRequirement { Self { inner } } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn iter(&self) -> impl Iterator { self.inner.iter() } @@ -391,6 +581,40 @@ impl LexRequirement { pub fn push(&mut self, physical_sort_requirement: PhysicalSortRequirement) { self.inner.push(physical_sort_requirement) } + + /// Create a new [`LexRequirement`] from a [`LexOrdering`] + /// + /// Returns [`LexRequirement`] that requires the exact + /// sort of the [`PhysicalSortExpr`]s in `ordering` + pub fn from_lex_ordering(ordering: LexOrdering) -> Self { + Self::new( + ordering + .into_iter() + .map(PhysicalSortRequirement::from) + .collect(), + ) + } + + /// Constructs a duplicate-free `LexOrderingReq` by filtering out + /// duplicate entries that have same physical expression inside. + /// + /// For example, `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a + /// Some(ASC)]`. + pub fn collapse(self) -> Self { + let mut output = Vec::::new(); + for item in self { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + LexRequirement::new(output) + } +} + +impl From for LexRequirement { + fn from(value: LexOrdering) -> Self { + Self::from_lex_ordering(value) + } } impl Deref for LexRequirement { @@ -415,13 +639,23 @@ impl FromIterator for LexRequirement { impl IntoIterator for LexRequirement { type Item = PhysicalSortRequirement; - type IntoIter = std::vec::IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.inner.into_iter() } } +impl<'a> IntoIterator for &'a LexOrdering { + type Item = &'a PhysicalSortExpr; + type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; + + fn into_iter(self) -> Self::IntoIter { + self.inner.iter() + } +} + ///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which /// represents a reference to a lexicographical ordering requirement. +/// #[deprecated(since = "43.0.0", note = "use &LexRequirement instead")] pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs index d9892ce555098..c37e67575bf00 100644 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ b/datafusion/physical-expr-common/src/tree_node.rs @@ -62,7 +62,7 @@ impl ExprContext { } pub fn update_expr_from_children(mut self) -> Result { - let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); + let children_expr = self.children.iter().map(|c| Arc::clone(&c.expr)).collect(); self.expr = with_new_children_if_necessary(self.expr, children_expr)?; Ok(self) } diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index d2c9bf1a24085..114007bfa6afb 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -24,7 +24,7 @@ use datafusion_common::Result; use datafusion_expr_common::sort_properties::ExprProperties; use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::PhysicalSortExpr; +use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::tree_node::ExprContext; /// Represents a [`PhysicalExpr`] node with associated properties (order and @@ -96,10 +96,10 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { /// Reverses the ORDER BY expression, which is useful during equivalent window /// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into /// 'ORDER BY a DESC, NULLS FIRST'. -pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { +pub fn reverse_order_bys(order_bys: &LexOrdering) -> LexOrdering { order_bys .iter() - .map(|e| PhysicalSortExpr::new(e.expr.clone(), !e.options)) + .map(|e| PhysicalSortExpr::new(Arc::clone(&e.expr), !e.options)) .collect() } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index c53f7a6c47715..5e0832673697d 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -35,46 +35,31 @@ workspace = true name = "datafusion_physical_expr" path = "src/lib.rs" -[features] -default = [ - "regex_expressions", - "encoding_expressions", -] -encoding_expressions = ["base64", "hex"] -regex_expressions = ["regex"] - [dependencies] ahash = { workspace = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } -arrow-ord = { workspace = true } arrow-schema = { workspace = true } -arrow-string = { workspace = true } -base64 = { version = "0.22", optional = true } -chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } -hex = { version = "0.4", optional = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" -petgraph = "0.6.2" -regex = { workspace = true, optional = true } +petgraph = "0.7.1" [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } criterion = "0.5" +datafusion-functions = { workspace = true } rand = { workspace = true } rstest = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread"] } [[bench]] harness = false diff --git a/datafusion/physical-expr/LICENSE.txt b/datafusion/physical-expr/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/physical-expr/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/physical-expr/NOTICE.txt b/datafusion/physical-expr/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/physical-expr/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 866596d0b6901..4eaabace7257a 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -28,6 +28,7 @@ pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; } pub mod utils { + #[allow(deprecated)] // allow adjust_output_array pub use datafusion_functions_aggregate_common::utils::{ adjust_output_array, get_accum_scalar_values_as_arrays, get_sort_options, ordering_fields, DecimalAverager, Hashable, @@ -45,7 +46,7 @@ use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs; use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_expr_common::utils::reverse_order_bys; use datafusion_expr_common::groups_accumulator::GroupsAccumulator; @@ -81,7 +82,7 @@ impl AggregateExprBuilder { args, alias: None, schema: Arc::new(Schema::empty()), - ordering_req: vec![], + ordering_req: LexOrdering::default(), ignore_nulls: false, is_distinct: false, is_reversed: false, @@ -111,7 +112,8 @@ impl AggregateExprBuilder { .map(|e| e.expr.data_type(&schema)) .collect::>>()?; - ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); + ordering_fields = + utils::ordering_fields(ordering_req.as_ref(), &ordering_types); } let input_exprs_types = args @@ -265,7 +267,7 @@ impl AggregateFunctionExpr { return_type: &self.data_type, schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: &self.ordering_req, + ordering_req: self.ordering_req.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -291,13 +293,13 @@ impl AggregateFunctionExpr { /// Order by requirements for the aggregate function /// By default it is `None` (there is no requirement) /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this - pub fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + pub fn order_bys(&self) -> Option<&LexOrdering> { if self.ordering_req.is_empty() { return None; } if !self.order_sensitivity().is_insensitive() { - return Some(&self.ordering_req); + return Some(self.ordering_req.as_ref()); } None @@ -328,7 +330,7 @@ impl AggregateFunctionExpr { /// not implement the method, returns an error. Order insensitive and hard /// requirement aggregators return `Ok(None)`. pub fn with_beneficial_ordering( - self, + self: Arc, beneficial_ordering: bool, ) -> Result> { let Some(updated_fn) = self @@ -340,7 +342,7 @@ impl AggregateFunctionExpr { }; AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) - .order_by(self.ordering_req.to_vec()) + .order_by(self.ordering_req.clone()) .schema(Arc::new(self.schema.clone())) .alias(self.name().to_string()) .with_ignore_nulls(self.ignore_nulls) @@ -356,7 +358,7 @@ impl AggregateFunctionExpr { return_type: &self.data_type, schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: &self.ordering_req, + ordering_req: self.ordering_req.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -425,7 +427,7 @@ impl AggregateFunctionExpr { return_type: &self.data_type, schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: &self.ordering_req, + ordering_req: self.ordering_req.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -444,7 +446,7 @@ impl AggregateFunctionExpr { return_type: &self.data_type, schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: &self.ordering_req, + ordering_req: self.ordering_req.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -462,7 +464,7 @@ impl AggregateFunctionExpr { ReversedUDAF::NotSupported => None, ReversedUDAF::Identical => Some(self.clone()), ReversedUDAF::Reversed(reverse_udf) => { - let reverse_ordering_req = reverse_order_bys(&self.ordering_req); + let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref()); let mut name = self.name().to_string(); // If the function is changed, we need to reverse order_by clause as well // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) @@ -473,7 +475,7 @@ impl AggregateFunctionExpr { replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) - .order_by(reverse_ordering_req.to_vec()) + .order_by(reverse_ordering_req) .schema(Arc::new(self.schema.clone())) .alias(name) .with_ignore_nulls(self.ignore_nulls) @@ -489,7 +491,10 @@ impl AggregateFunctionExpr { /// These expressions are (1)function arguments, (2) order by expressions. pub fn all_expressions(&self) -> AggregatePhysicalExpressions { let args = self.expressions(); - let order_bys = self.order_bys().unwrap_or(&[]); + let order_bys = self + .order_bys() + .cloned() + .unwrap_or_else(LexOrdering::default); let order_by_exprs = order_bys .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 8dc40c41fe9b9..96925e5b128c1 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::Display; -use std::sync::Arc; - -use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; +use super::{add_offset_to_expr, ProjectionMapping}; use crate::{ - expressions::Column, physical_expr::deduplicate_physical_exprs, - physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, - LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, - PhysicalSortRequirement, + expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, PhysicalSortRequirement, }; - use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::JoinType; +use datafusion_common::{JoinType, ScalarValue}; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; +use std::fmt::Display; +use std::sync::Arc; +use std::vec::IntoIter; + +use indexmap::{IndexMap, IndexSet}; /// A structure representing a expression known to be constant in a physical execution plan. /// @@ -56,19 +55,50 @@ use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; /// // create a constant expression from a physical expression /// let const_expr = ConstExpr::from(col); /// ``` +// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum: +// +// ``` +// enum PartitionValues { +// Uniform(Option), // Same value across all partitions +// Heterogeneous(Vec>) // Different values per partition +// } +// ``` +// +// This would provide more flexible representation of partition values. +// Note: This is a breaking change for the equivalence API and should be +// addressed in a separate issue/PR. #[derive(Debug, Clone)] pub struct ConstExpr { /// The expression that is known to be constant (e.g. a `Column`) expr: Arc, /// Does the constant have the same value across all partitions? See /// struct docs for more details - across_partitions: bool, + across_partitions: AcrossPartitions, +} + +#[derive(PartialEq, Clone, Debug)] +/// Represents whether a constant expression's value is uniform or varies across partitions. +/// +/// The `AcrossPartitions` enum is used to describe the nature of a constant expression +/// in a physical execution plan: +/// +/// - `Heterogeneous`: The constant expression may have different values for different partitions. +/// - `Uniform(Option)`: The constant expression has the same value across all partitions, +/// or is `None` if the value is not specified. +pub enum AcrossPartitions { + Heterogeneous, + Uniform(Option), +} + +impl Default for AcrossPartitions { + fn default() -> Self { + Self::Heterogeneous + } } impl PartialEq for ConstExpr { fn eq(&self, other: &Self) -> bool { - self.across_partitions == other.across_partitions - && self.expr.eq(other.expr.as_any()) + self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) } } @@ -81,14 +111,14 @@ impl ConstExpr { Self { expr, // By default, assume constant expressions are not same across partitions. - across_partitions: false, + across_partitions: Default::default(), } } /// Set the `across_partitions` flag /// /// See struct docs for more details - pub fn with_across_partitions(mut self, across_partitions: bool) -> Self { + pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self { self.across_partitions = across_partitions; self } @@ -96,8 +126,8 @@ impl ConstExpr { /// Is the expression the same across all partitions? /// /// See struct docs for more details - pub fn across_partitions(&self) -> bool { - self.across_partitions + pub fn across_partitions(&self) -> AcrossPartitions { + self.across_partitions.clone() } pub fn expr(&self) -> &Arc { @@ -115,19 +145,19 @@ impl ConstExpr { let maybe_expr = f(&self.expr); maybe_expr.map(|expr| Self { expr, - across_partitions: self.across_partitions, + across_partitions: self.across_partitions.clone(), }) } /// Returns true if this constant expression is equal to the given expression pub fn eq_expr(&self, other: impl AsRef) -> bool { - self.expr.eq(other.as_ref().as_any()) + self.expr.as_ref() == other.as_ref() } /// Returns a [`Display`]able list of `ConstExpr`. pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ { struct DisplayableList<'a>(&'a [ConstExpr]); - impl<'a> Display for DisplayableList<'a> { + impl Display for DisplayableList<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let mut first = true; for const_expr in self.0 { @@ -145,14 +175,20 @@ impl ConstExpr { } } -/// Display implementation for `ConstExpr` -/// -/// Example `c` or `c(across_partitions)` impl Display for ConstExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.expr)?; - if self.across_partitions { - write!(f, "(across_partitions)")?; + match &self.across_partitions { + AcrossPartitions::Heterogeneous => { + write!(f, "(heterogeneous)")?; + } + AcrossPartitions::Uniform(value) => { + if let Some(val) = value { + write!(f, "(uniform: {})", val)?; + } else { + write!(f, "(uniform: unknown)")?; + } + } } Ok(()) } @@ -192,47 +228,47 @@ pub struct EquivalenceClass { /// The expressions in this equivalence class. The order doesn't /// matter for equivalence purposes /// - /// TODO: use a HashSet for this instead of a Vec - exprs: Vec>, + exprs: IndexSet>, } impl PartialEq for EquivalenceClass { /// Returns true if other is equal in the sense /// of bags (multi-sets), disregarding their orderings. fn eq(&self, other: &Self) -> bool { - physical_exprs_bag_equal(&self.exprs, &other.exprs) + self.exprs.eq(&other.exprs) } } impl EquivalenceClass { /// Create a new empty equivalence class pub fn new_empty() -> Self { - Self { exprs: vec![] } + Self { + exprs: IndexSet::new(), + } } // Create a new equivalence class from a pre-existing `Vec` - pub fn new(mut exprs: Vec>) -> Self { - deduplicate_physical_exprs(&mut exprs); - Self { exprs } + pub fn new(exprs: Vec>) -> Self { + Self { + exprs: exprs.into_iter().collect(), + } } /// Return the inner vector of expressions pub fn into_vec(self) -> Vec> { - self.exprs + self.exprs.into_iter().collect() } /// Return the "canonical" expression for this class (the first element) /// if any fn canonical_expr(&self) -> Option> { - self.exprs.first().cloned() + self.exprs.iter().next().cloned() } /// Insert the expression into this class, meaning it is known to be equal to /// all other expressions in this class pub fn push(&mut self, expr: Arc) { - if !self.contains(&expr) { - self.exprs.push(expr); - } + self.exprs.insert(expr); } /// Inserts all the expressions from other into this class @@ -245,7 +281,7 @@ impl EquivalenceClass { /// Returns true if this equivalence class contains t expression pub fn contains(&self, expr: &Arc) -> bool { - physical_exprs_contains(&self.exprs, expr) + self.exprs.contains(expr) } /// Returns true if this equivalence class has any entries in common with `other` @@ -287,11 +323,10 @@ impl Display for EquivalenceClass { } } -/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each -/// class represents a distinct equivalence class in a relation. +/// A collection of distinct `EquivalenceClass`es #[derive(Debug, Clone)] pub struct EquivalenceGroup { - pub classes: Vec, + classes: Vec, } impl EquivalenceGroup { @@ -475,13 +510,13 @@ impl EquivalenceGroup { /// This function applies the `normalize_sort_expr` function for all sort /// expressions in `sort_exprs` and returns the corresponding normalized /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + let sort_reqs = LexRequirement::from(sort_exprs.clone()); // Normalize the requirements: let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs.inner) + LexOrdering::from(normalized_sort_reqs) } /// This function applies the `normalize_sort_requirement` function for all @@ -489,14 +524,15 @@ impl EquivalenceGroup { /// sort requirements. pub fn normalize_sort_requirements( &self, - sort_reqs: LexRequirementRef, + sort_reqs: &LexRequirement, ) -> LexRequirement { - collapse_lex_req(LexRequirement::new( + LexRequirement::new( sort_reqs .iter() .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) .collect(), - )) + ) + .collapse() } /// Projects `expr` according to the given projection mapping. @@ -520,7 +556,7 @@ impl EquivalenceGroup { // and the equivalence class `(a, b)`, expression `b` projects to `a1`. if self .get_equivalence_class(source) - .map_or(false, |group| group.contains(expr)) + .is_some_and(|group| group.contains(expr)) { return Some(Arc::clone(target)); } @@ -548,28 +584,20 @@ impl EquivalenceGroup { .collect::>(); (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![Arc::clone(target)])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| key.eq(source)) - { - if !physical_exprs_contains(values, target) { - values.push(Arc::clone(target)); - } - } - } + // the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression. + let mut new_classes: IndexMap, EquivalenceClass> = + IndexMap::new(); + mapping.iter().for_each(|(source, target)| { + new_classes + .entry(Arc::clone(source)) + .or_insert_with(EquivalenceClass::new_empty) + .push(Arc::clone(target)); + }); // Only add equivalence classes with at least two members as singleton // equivalence classes are meaningless. let new_classes = new_classes .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)) - .map(EquivalenceClass::new); + .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); let classes = projected_classes.chain(new_classes).collect(); Self::new(classes) @@ -632,10 +660,77 @@ impl EquivalenceGroup { } result } - JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), } } + + /// Checks if two expressions are equal either directly or through equivalence classes. + /// For complex expressions (e.g. a + b), checks that the expression trees are structurally + /// identical and their leaf nodes are equivalent either directly or through equivalence classes. + pub fn exprs_equal( + &self, + left: &Arc, + right: &Arc, + ) -> bool { + // Direct equality check + if left.eq(right) { + return true; + } + + // Check if expressions are equivalent through equivalence classes + // We need to check both directions since expressions might be in different classes + if let Some(left_class) = self.get_equivalence_class(left) { + if left_class.contains(right) { + return true; + } + } + if let Some(right_class) = self.get_equivalence_class(right) { + if right_class.contains(left) { + return true; + } + } + + // For non-leaf nodes, check structural equality + let left_children = left.children(); + let right_children = right.children(); + + // If either expression is a leaf node and we haven't found equality yet, + // they must be different + if left_children.is_empty() || right_children.is_empty() { + return false; + } + + // Type equality check through reflection + if left.as_any().type_id() != right.as_any().type_id() { + return false; + } + + // Check if the number of children is the same + if left_children.len() != right_children.len() { + return false; + } + + // Check if all children are equal + left_children + .into_iter() + .zip(right_children) + .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) + } + + /// Return the inner classes of this equivalence group. + pub fn into_inner(self) -> Vec { + self.classes + } +} + +impl IntoIterator for EquivalenceGroup { + type Item = EquivalenceClass; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.classes.into_iter() + } } impl Display for EquivalenceGroup { @@ -657,9 +752,10 @@ mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{lit, Literal}; + use crate::expressions::{lit, BinaryExpr, Literal}; use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::Operator; #[test] fn test_bridge_groups() -> Result<()> { @@ -787,4 +883,159 @@ mod tests { assert!(!cls1.contains_any(&cls3)); assert!(!cls2.contains_any(&cls3)); } + + #[test] + fn test_exprs_equal() -> Result<()> { + struct TestCase { + left: Arc, + right: Arc, + expected: bool, + description: &'static str, + } + + // Create test columns + let col_a = Arc::new(Column::new("a", 0)) as Arc; + let col_b = Arc::new(Column::new("b", 1)) as Arc; + let col_x = Arc::new(Column::new("x", 2)) as Arc; + let col_y = Arc::new(Column::new("y", 3)) as Arc; + + // Create test literals + let lit_1 = + Arc::new(Literal::from(ScalarValue::Int32(Some(1)))) as Arc; + let lit_2 = + Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) as Arc; + + // Create equivalence group with classes (a = x) and (b = y) + let eq_group = EquivalenceGroup::new(vec![ + EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]), + EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]), + ]); + + let test_cases = vec![ + // Basic equality tests + TestCase { + left: Arc::clone(&col_a), + right: Arc::clone(&col_a), + expected: true, + description: "Same column should be equal", + }, + // Equivalence class tests + TestCase { + left: Arc::clone(&col_a), + right: Arc::clone(&col_x), + expected: true, + description: "Columns in same equivalence class should be equal", + }, + TestCase { + left: Arc::clone(&col_b), + right: Arc::clone(&col_y), + expected: true, + description: "Columns in same equivalence class should be equal", + }, + TestCase { + left: Arc::clone(&col_a), + right: Arc::clone(&col_b), + expected: false, + description: + "Columns in different equivalence classes should not be equal", + }, + // Literal tests + TestCase { + left: Arc::clone(&lit_1), + right: Arc::clone(&lit_1), + expected: true, + description: "Same literal should be equal", + }, + TestCase { + left: Arc::clone(&lit_1), + right: Arc::clone(&lit_2), + expected: false, + description: "Different literals should not be equal", + }, + // Complex expression tests + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&col_b), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&col_y), + )) as Arc, + expected: true, + description: + "Binary expressions with equivalent operands should be equal", + }, + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&col_b), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&col_a), + )) as Arc, + expected: false, + description: + "Binary expressions with non-equivalent operands should not be equal", + }, + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&lit_1), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&lit_1), + )) as Arc, + expected: true, + description: "Binary expressions with equivalent column and same literal should be equal", + }, + TestCase { + left: Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Plus, + Arc::clone(&col_b), + )), + Operator::Multiply, + Arc::clone(&lit_1), + )) as Arc, + right: Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_x), + Operator::Plus, + Arc::clone(&col_y), + )), + Operator::Multiply, + Arc::clone(&lit_1), + )) as Arc, + expected: true, + description: "Nested binary expressions with equivalent operands should be equal", + }, + ]; + + for TestCase { + left, + right, + expected, + description, + } in test_cases + { + let actual = eq_group.exprs_equal(&left, &right); + assert_eq!( + actual, expected, + "{}: Failed comparing {:?} and {:?}, expected {}, got {}", + description, left, right, expected, actual + ); + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 38647f7ca1d4b..a5b85064e6252 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; +use crate::{LexRequirement, PhysicalExpr}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -27,7 +27,7 @@ mod ordering; mod projection; mod properties; -pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup}; +pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; pub use properties::{ @@ -41,14 +41,9 @@ pub use properties::{ /// It will also filter out entries that are ordered if the next entry is; /// for instance, `vec![floor(a) Some(ASC), a Some(ASC)]` will be collapsed to /// `vec![a Some(ASC)]`. +#[deprecated(since = "45.0.0", note = "Use LexRequirement::collapse")] pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - LexRequirement::new(output) + input.collapse() } /// Adds the `offset` value to `Column` indices inside `expr`. This function is @@ -72,20 +67,17 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { + use super::*; use crate::expressions::col; use crate::PhysicalSortExpr; - use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; - - use itertools::izip; - use rand::rngs::StdRng; - use rand::seq::SliceRandom; - use rand::{Rng, SeedableRng}; + use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, PhysicalSortRequirement, + }; pub fn output_schema( mapping: &ProjectionMapping, @@ -175,67 +167,6 @@ mod tests { Ok((test_schema, eq_properties)) } - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Float64, true); - let b = Field::new("b", DataType::Float64, true); - let c = Field::new("c", DataType::Float64, true); - let d = Field::new("d", DataType::Float64, true); - let e = Field::new("e", DataType::Float64, true); - let f = Field::new("f", DataType::Float64, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; - // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - // Convert each tuple to PhysicalSortRequirement pub fn convert_to_sort_reqs( in_data: &[(&Arc, Option)], @@ -251,7 +182,7 @@ mod tests { // Convert each tuple to PhysicalSortExpr pub fn convert_to_sort_exprs( in_data: &[(&Arc, SortOptions)], - ) -> Vec { + ) -> LexOrdering { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { @@ -264,7 +195,7 @@ mod tests { // Convert each inner tuple to PhysicalSortExpr pub fn convert_to_orderings( orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec> { + ) -> Vec { orderings .iter() .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) @@ -274,53 +205,28 @@ mod tests { // Convert each tuple to PhysicalSortExpr pub fn convert_to_sort_exprs_owned( in_data: &[(Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options: *options, - }) - .collect() + ) -> LexOrdering { + LexOrdering::new( + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(expr), + options: *options, + }) + .collect(), + ) } // Convert each inner tuple to PhysicalSortExpr pub fn convert_to_orderings_owned( orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec> { + ) -> Vec { orderings .iter() .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) .collect() } - // Apply projection to the input_data, return projected equivalence properties and record batch - pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, - input_data: &RecordBatch, - input_eq_properties: &EquivalenceProperties, - ) -> Result<(RecordBatch, EquivalenceProperties)> { - let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let output_schema = output_schema(&projection_mapping, &input_schema)?; - let num_rows = input_data.num_rows(); - // Apply projection to the input record batch. - let projected_values = projection_mapping - .iter() - .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) - .collect::>>()?; - let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(Arc::clone(&output_schema)) - } else { - RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? - }; - - let projected_eq = - input_eq_properties.project(&projection_mapping, output_schema); - Ok((projected_batch, projected_eq)) - } - #[test] fn add_equal_conditions_test() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -345,16 +251,16 @@ mod tests { // This new entry is redundant, size shouldn't increase eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); - // b and c are aliases. Exising equivalence class should expand, + // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -368,7 +274,7 @@ mod tests { // Hence equivalent class count should decrease from 2 to 1. eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -378,168 +284,4 @@ mod tests { Ok(()) } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - pub fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows(); - let vals: Vec = (0..n_row).collect::>(); - let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); - let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(Arc::clone(&unique_col)); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = - Arc::new(Field::new(unique_col_name, DataType::Float64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns = required_ordering - .iter() - .map(|order_expr| { - let expr_result = order_expr.expr.evaluate(&new_batch)?; - let values = expr_result.into_array(new_batch.num_rows())?; - Ok(SortColumn { - values, - options: Some(order_expr.options), - }) - }) - .collect::>>()?; - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &EquivalenceClass, - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(Arc::clone(res)); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - pub fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - }; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.expr().as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) - as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(Arc::clone(&representative_array)); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index bb3e9218bc418..4e324663dcd19 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::vec::IntoIter; use crate::equivalence::add_offset_to_expr; -use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use crate::{LexOrdering, PhysicalExpr}; use arrow_schema::SortOptions; /// An `OrderingEquivalenceClass` object keeps track of different alternative @@ -39,7 +39,7 @@ use arrow_schema::SortOptions; /// ordering. In this case, we say that these orderings are equivalent. #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct OrderingEquivalenceClass { - pub orderings: Vec, + orderings: Vec, } impl OrderingEquivalenceClass { @@ -53,13 +53,20 @@ impl OrderingEquivalenceClass { self.orderings.clear(); } - /// Creates new ordering equivalence class from the given orderings. + /// Creates new ordering equivalence class from the given orderings + /// + /// Any redundant entries are removed pub fn new(orderings: Vec) -> Self { let mut result = Self { orderings }; result.remove_redundant_entries(); result } + /// Converts this OrderingEquivalenceClass to a vector of orderings. + pub fn into_inner(self) -> Vec { + self.orderings + } + /// Checks whether `ordering` is a member of this equivalence class. pub fn contains(&self, ordering: &LexOrdering) -> bool { self.orderings.contains(ordering) @@ -67,10 +74,12 @@ impl OrderingEquivalenceClass { /// Adds `ordering` to this equivalence class. #[allow(dead_code)] + #[deprecated( + since = "45.0.0", + note = "use OrderingEquivalenceClass::add_new_ordering instead" + )] fn push(&mut self, ordering: LexOrdering) { - self.orderings.push(ordering); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); + self.add_new_ordering(ordering) } /// Checks whether this ordering equivalence class is empty. @@ -79,6 +88,9 @@ impl OrderingEquivalenceClass { } /// Returns an iterator over the equivalent orderings in this class. + /// + /// Note this class also implements [`IntoIterator`] to return an iterator + /// over owned [`LexOrdering`]s. pub fn iter(&self) -> impl Iterator { self.orderings.iter() } @@ -95,7 +107,7 @@ impl OrderingEquivalenceClass { self.remove_redundant_entries(); } - /// Adds new orderings into this ordering equivalence class. + /// Adds new orderings into this ordering equivalence class pub fn add_new_orderings( &mut self, orderings: impl IntoIterator, @@ -110,9 +122,10 @@ impl OrderingEquivalenceClass { self.add_new_orderings([ordering]); } - /// Removes redundant orderings from this equivalence class. For instance, - /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is - /// no need to keep ordering `[a ASC, b ASC]` in the state. + /// Removes redundant orderings from this equivalence class. + /// + /// For instance, if we already have the ordering `[a ASC, b ASC, c DESC]`, + /// then there is no need to keep ordering `[a ASC, b ASC]` in the state. fn remove_redundant_entries(&mut self) { let mut work = true; while work { @@ -122,12 +135,12 @@ impl OrderingEquivalenceClass { let mut ordering_idx = idx + 1; let mut removal = self.orderings[idx].is_empty(); while ordering_idx < self.orderings.len() { - work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + work |= self.resolve_overlap(idx, ordering_idx); if self.orderings[idx].is_empty() { removal = true; break; } - work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + work |= self.resolve_overlap(ordering_idx, idx); if self.orderings[ordering_idx].is_empty() { self.orderings.swap_remove(ordering_idx); } else { @@ -143,11 +156,36 @@ impl OrderingEquivalenceClass { } } + /// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of + /// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. + /// + /// For example, if `orderings[idx]` is `[a ASC, b ASC, c DESC]` and + /// `orderings[pre_idx]` is `[b ASC, c DESC]`, then the function will trim + /// `orderings[idx]` to `[a ASC]`. + fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> bool { + let length = self.orderings[idx].len(); + let other_length = self.orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if self.orderings[idx][length - overlap..] + == self.orderings[pre_idx][..overlap] + { + self.orderings[idx].truncate(length - overlap); + return true; + } + } + false + } + /// Returns the concatenation of all the orderings. This enables merge /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - let output_ordering = self.orderings.iter().flatten().cloned().collect(); - let output_ordering = collapse_lex_ordering(output_ordering); + let output_ordering = self + .orderings + .iter() + .flatten() + .cloned() + .collect::() + .collapse(); (!output_ordering.is_empty()).then_some(output_ordering) } @@ -179,9 +217,9 @@ impl OrderingEquivalenceClass { /// ordering equivalence class. pub fn add_offset(&mut self, offset: usize) { for ordering in self.orderings.iter_mut() { - for sort_expr in ordering { + ordering.transform(|sort_expr| { sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); - } + }) } } @@ -198,6 +236,7 @@ impl OrderingEquivalenceClass { } } +/// Convert the `OrderingEquivalenceClass` into an iterator of LexOrderings impl IntoIterator for OrderingEquivalenceClass { type Item = LexOrdering; type IntoIter = IntoIter; @@ -207,42 +246,15 @@ impl IntoIterator for OrderingEquivalenceClass { } } -/// This function constructs a duplicate-free `LexOrdering` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. -pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - output -} - -/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of -/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. -fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { - let length = orderings[idx].len(); - let other_length = orderings[pre_idx].len(); - for overlap in 1..=length.min(other_length) { - if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { - orderings[idx].truncate(length - overlap); - return true; - } - } - false -} - impl Display for OrderingEquivalenceClass { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; let mut iter = self.orderings.iter(); if let Some(ordering) = iter.next() { - write!(f, "[{}]", PhysicalSortExpr::format_list(ordering))?; + write!(f, "[{}]", ordering)?; } for ordering in iter { - write!(f, ", [{}]", PhysicalSortExpr::format_list(ordering))?; + write!(f, ", [{}]", ordering)?; } write!(f, "]")?; Ok(()) @@ -254,9 +266,7 @@ mod tests { use std::sync::Arc; use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_random_schema, - create_test_params, create_test_schema, generate_table_for_eq_properties, - is_table_same_after_sort, + convert_to_orderings, convert_to_sort_exprs, create_test_schema, }; use crate::equivalence::{ EquivalenceClass, EquivalenceGroup, EquivalenceProperties, @@ -264,14 +274,13 @@ mod tests { }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; - use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr}; + use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SortOptions; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{Operator, ScalarUDF}; - - use itertools::Itertools; + use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -279,11 +288,11 @@ mod tests { Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), ])); - let crude = vec![PhysicalSortExpr { + let crude = LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), - }]; - let finer = vec![ + }]); + let finer = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), @@ -292,18 +301,20 @@ mod tests { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), }, - ]; + ]); // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = - EquivalenceProperties::new(Arc::clone(&input_schema)); - eq_properties_finer.oeq_class.push(finer.clone()); - assert!(eq_properties_finer.ordering_satisfy(&crude)); + let eq_properties_finer = EquivalenceProperties::new_with_orderings( + Arc::clone(&input_schema), + &[finer.clone()], + ); + assert!(eq_properties_finer.ordering_satisfy(crude.as_ref())); // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = - EquivalenceProperties::new(Arc::clone(&input_schema)); - eq_properties_crude.oeq_class.push(crude); - assert!(!eq_properties_crude.ordering_satisfy(&finer)); + let eq_properties_crude = EquivalenceProperties::new_with_orderings( + Arc::clone(&input_schema), + &[crude.clone()], + ); + assert!(!eq_properties_crude.ordering_satisfy(finer.as_ref())); Ok(()) } @@ -586,14 +597,15 @@ mod tests { let eq_group = EquivalenceGroup::new(eq_group); eq_properties.add_equivalence_group(eq_group); - let constants = constants - .into_iter() - .map(|expr| ConstExpr::from(expr).with_across_partitions(true)); + let constants = constants.into_iter().map(|expr| { + ConstExpr::from(expr) + .with_across_partitions(AcrossPartitions::Uniform(None)) + }); eq_properties = eq_properties.with_constants(constants); let reqs = convert_to_sort_exprs(&reqs); assert_eq!( - eq_properties.ordering_satisfy(&reqs), + eq_properties.ordering_satisfy(reqs.as_ref()), expected, "{}", err_msg @@ -603,305 +615,6 @@ mod tests { Ok(()) } - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - (expected | false), - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - #[test] fn test_ordering_satisfy_different_lengths() -> Result<()> { let test_schema = create_test_schema()?; @@ -952,7 +665,7 @@ mod tests { format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); let reqs = convert_to_sort_exprs(&reqs); assert_eq!( - eq_properties.ordering_satisfy(&reqs), + eq_properties.ordering_satisfy(reqs.as_ref()), expected, "{}", err_msg diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index ebf26d3262aa2..681484fd6bff3 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -139,23 +139,18 @@ fn project_index_to_exprs( mod tests { use super::*; use crate::equivalence::tests::{ - apply_projection, convert_to_orderings, convert_to_orderings_owned, - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - output_schema, + convert_to_orderings, convert_to_orderings_owned, output_schema, }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; use crate::udf::create_physical_expr; use crate::utils::tests::TestScalarUDF; - use crate::PhysicalSortExpr; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::DFSchema; use datafusion_expr::{Operator, ScalarUDF}; - use itertools::Itertools; - #[test] fn project_orderings() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -641,7 +636,7 @@ mod tests { let err_msg = format!( "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping + idx, orderings, expected, projection_mapping ); assert_eq!(orderings.len(), expected.len(), "{}", err_msg); @@ -830,7 +825,7 @@ mod tests { let err_msg = format!( "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping + idx, orderings, expected, projection_mapping ); assert_eq!(orderings.len(), expected.len(), "{}", err_msg); @@ -976,7 +971,7 @@ mod tests { let err_msg = format!( "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings.orderings, expected, projection_mapping + orderings, expected, projection_mapping ); assert_eq!(orderings.len(), expected.len(), "{}", err_msg); @@ -987,174 +982,4 @@ mod tests { Ok(()) } - - #[test] - fn project_orderings_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - // Make sure each ordering after projection is valid. - for ordering in projected_eq.oeq_class().iter() { - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs - ); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, - "{}", - err_msg - ); - } - } - } - } - - Ok(()) - } - - #[test] - fn ordering_satisfy_after_projection_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; - - let projected_exprs = projection_mapping - .iter() - .map(|(_source, target)| Arc::clone(target)) - .collect::>(); - - for n_req in 0..=projected_exprs.len() { - for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - projected_eq.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - } - } - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs old mode 100644 new mode 100755 index 5f18ffcda6e92..2b831bc766b7c --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -20,23 +20,23 @@ use std::hash::{Hash, Hasher}; use std::iter::Peekable; use std::slice::Iter; use std::sync::Arc; +use std::{fmt, mem}; -use super::ordering::collapse_lex_ordering; -use crate::equivalence::class::const_exprs_contains; +use crate::equivalence::class::{const_exprs_contains, AcrossPartitions}; use crate::equivalence::{ - collapse_lex_req, EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, - ProjectionMapping, + EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; use crate::{ - physical_exprs_contains, ConstExpr, LexOrdering, LexOrderingRef, LexRequirement, - LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, - PhysicalSortRequirement, + physical_exprs_contains, ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, + PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result}; +use datafusion_common::{ + internal_err, plan_err, Constraint, Constraints, HashMap, JoinSide, JoinType, Result, +}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_physical_expr_common::utils::ExprPropertiesNode; @@ -102,7 +102,7 @@ use itertools::Itertools; /// # use arrow_schema::{Schema, Field, DataType, SchemaRef}; /// # use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; /// # use datafusion_physical_expr::expressions::col; -/// use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +/// use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; /// # let schema: SchemaRef = Arc::new(Schema::new(vec![ /// # Field::new("a", DataType::Int32, false), /// # Field::new("b", DataType::Int32, false), @@ -115,24 +115,26 @@ use itertools::Itertools; /// // with a single constant value of b /// let mut eq_properties = EquivalenceProperties::new(schema) /// .with_constants(vec![ConstExpr::from(col_b)]); -/// eq_properties.add_new_ordering(vec![ +/// eq_properties.add_new_ordering(LexOrdering::new(vec![ /// PhysicalSortExpr::new_default(col_a).asc(), /// PhysicalSortExpr::new_default(col_c).desc(), -/// ]); +/// ])); /// -/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC,c@2 DESC]], const: [b@1]") +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], const: [b@1(heterogeneous)]") /// ``` #[derive(Debug, Clone)] pub struct EquivalenceProperties { - /// Collection of equivalence classes that store expressions with the same - /// value. - pub eq_group: EquivalenceGroup, - /// Equivalent sort expressions for this table. - pub oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant throughout the table. + /// Distinct equivalence classes (exprs known to have the same expressions) + eq_group: EquivalenceGroup, + /// Equivalent sort expressions + oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant + /// /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_groups` as `Literal` expressions. - pub constants: Vec, + /// inside `eq_group` as `Literal` expressions. + constants: Vec, + /// Table constraints + constraints: Constraints, /// Schema associated with this object. schema: SchemaRef, } @@ -144,16 +146,24 @@ impl EquivalenceProperties { eq_group: EquivalenceGroup::empty(), oeq_class: OrderingEquivalenceClass::empty(), constants: vec![], + constraints: Constraints::empty(), schema, } } + /// Adds constraints to the properties. + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + /// Creates a new `EquivalenceProperties` object with the given orderings. pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { Self { eq_group: EquivalenceGroup::empty(), oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), constants: vec![], + constraints: Constraints::empty(), schema, } } @@ -168,6 +178,11 @@ impl EquivalenceProperties { &self.oeq_class } + /// Return the inner OrderingEquivalenceClass, consuming self + pub fn into_oeq_class(self) -> OrderingEquivalenceClass { + self.oeq_class + } + /// Returns a reference to the equivalence group within. pub fn eq_group(&self) -> &EquivalenceGroup { &self.eq_group @@ -178,6 +193,10 @@ impl EquivalenceProperties { &self.constants } + pub fn constraints(&self) -> &Constraints { + &self.constraints + } + /// Returns the output ordering of the properties. pub fn output_ordering(&self) -> Option { let constants = self.constants(); @@ -216,7 +235,9 @@ impl EquivalenceProperties { /// Removes constant expressions that may change across partitions. /// This method should be used when data from different partitions are merged. pub fn clear_per_partition_constants(&mut self) { - self.constants.retain(|item| item.across_partitions()); + self.constants.retain(|item| { + matches!(item.across_partitions(), AcrossPartitions::Uniform(_)) + }) } /// Extends this `EquivalenceProperties` by adding the orderings inside the @@ -256,14 +277,16 @@ impl EquivalenceProperties { if self.is_expr_constant(left) { // Left expression is constant, add right as constant if !const_exprs_contains(&self.constants, right) { - self.constants - .push(ConstExpr::from(right).with_across_partitions(true)); + let const_expr = ConstExpr::from(right) + .with_across_partitions(self.get_expr_constant_value(left)); + self.constants.push(const_expr); } } else if self.is_expr_constant(right) { // Right expression is constant, add left as constant if !const_exprs_contains(&self.constants, left) { - self.constants - .push(ConstExpr::from(left).with_across_partitions(true)); + let const_expr = ConstExpr::from(left) + .with_across_partitions(self.get_expr_constant_value(right)); + self.constants.push(const_expr); } } @@ -292,30 +315,28 @@ impl EquivalenceProperties { mut self, constants: impl IntoIterator, ) -> Self { - let (const_exprs, across_partition_flags): ( - Vec>, - Vec, - ) = constants + let normalized_constants = constants .into_iter() - .map(|const_expr| { - let across_partitions = const_expr.across_partitions(); - let expr = const_expr.owned_expr(); - (expr, across_partitions) + .filter_map(|c| { + let across_partitions = c.across_partitions(); + let expr = c.owned_expr(); + let normalized_expr = self.eq_group.normalize_expr(expr); + + if const_exprs_contains(&self.constants, &normalized_expr) { + return None; + } + + let const_expr = ConstExpr::from(normalized_expr) + .with_across_partitions(across_partitions); + + Some(const_expr) }) - .unzip(); - for (expr, across_partitions) in self - .eq_group - .normalize_exprs(const_exprs) - .into_iter() - .zip(across_partition_flags) - { - if !const_exprs_contains(&self.constants, &expr) { - let const_expr = - ConstExpr::from(expr).with_across_partitions(across_partitions); - self.constants.push(const_expr); - } - } + .collect::>(); + + // Add all new normalized constants + self.constants.extend(normalized_constants); + // Discover any new orderings based on the constants for ordering in self.normalized_oeq_class().iter() { if let Err(e) = self.discover_new_orderings(&ordering[0].expr) { log::debug!("error discovering new orderings: {e}"); @@ -335,7 +356,6 @@ impl EquivalenceProperties { let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); let eq_class = self .eq_group - .classes .iter() .find_map(|class| { class @@ -345,40 +365,61 @@ impl EquivalenceProperties { .unwrap_or_else(|| vec![Arc::clone(&normalized_expr)]); let mut new_orderings: Vec = vec![]; - for (ordering, next_expr) in self - .normalized_oeq_class() - .iter() - .filter(|ordering| ordering[0].expr.eq(&normalized_expr)) - // First expression after leading ordering - .filter_map(|ordering| Some(ordering).zip(ordering.get(1))) - { - let leading_ordering = ordering[0].options; - // Currently, we only handle expressions with a single child. - // TODO: It should be possible to handle expressions orderings like - // f(a, b, c), a, b, c if f is monotonic in all arguments. + for ordering in self.normalized_oeq_class().iter() { + if !ordering[0].expr.eq(&normalized_expr) { + continue; + } + + let leading_ordering_options = ordering[0].options; + for equivalent_expr in &eq_class { let children = equivalent_expr.children(); - if children.len() == 1 - && children[0].eq(&next_expr.expr) - && SortProperties::Ordered(leading_ordering) - == equivalent_expr - .get_properties(&[ExprProperties { - sort_properties: SortProperties::Ordered( - leading_ordering, - ), - range: Interval::make_unbounded( - &equivalent_expr.data_type(&self.schema)?, - )?, - }])? - .sort_properties - { - // Assume existing ordering is [a ASC, b ASC] - // When equality a = f(b) is given, If we know that given ordering `[b ASC]`, ordering `[f(b) ASC]` is valid, - // then we can deduce that ordering `[b ASC]` is also valid. - // Hence, ordering `[b ASC]` can be added to the state as valid ordering. - // (e.g. existing ordering where leading ordering is removed) - new_orderings.push(ordering[1..].to_vec()); - break; + if children.is_empty() { + continue; + } + + // Check if all children match the next expressions in the ordering + let mut all_children_match = true; + let mut child_properties = vec![]; + + // Build properties for each child based on the next expressions + for (i, child) in children.iter().enumerate() { + if let Some(next) = ordering.get(i + 1) { + if !child.as_ref().eq(next.expr.as_ref()) { + all_children_match = false; + break; + } + child_properties.push(ExprProperties { + sort_properties: SortProperties::Ordered(next.options), + range: Interval::make_unbounded( + &child.data_type(&self.schema)?, + )?, + preserves_lex_ordering: true, + }); + } else { + all_children_match = false; + break; + } + } + + if all_children_match { + // Check if the expression is monotonic in all arguments + if let Ok(expr_properties) = + equivalent_expr.get_properties(&child_properties) + { + if expr_properties.preserves_lex_ordering + && SortProperties::Ordered(leading_ordering_options) + == expr_properties.sort_properties + { + // Assume existing ordering is [c ASC, a ASC, b ASC] + // When equality c = f(a,b) is given, if we know that given ordering `[a ASC, b ASC]`, + // ordering `[f(a,b) ASC]` is valid, then we can deduce that ordering `[a ASC, b ASC]` is also valid. + // Hence, ordering `[a ASC, b ASC]` can be added to the state as a valid ordering. + // (e.g. existing ordering where leading ordering is removed) + new_orderings.push(LexOrdering::new(ordering[1..].to_vec())); + break; + } + } } } } @@ -390,12 +431,51 @@ impl EquivalenceProperties { /// Updates the ordering equivalence group within assuming that the table /// is re-sorted according to the argument `sort_exprs`. Note that constants /// and equivalence classes are unchanged as they are unaffected by a re-sort. - pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { - // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. - self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + /// If the given ordering is already satisfied, the function does nothing. + pub fn with_reorder(mut self, sort_exprs: LexOrdering) -> Self { + // Filter out constant expressions as they don't affect ordering + let filtered_exprs = LexOrdering::new( + sort_exprs + .into_iter() + .filter(|expr| !self.is_expr_constant(&expr.expr)) + .collect(), + ); + + if filtered_exprs.is_empty() { + return self; + } + + let mut new_orderings = vec![filtered_exprs.clone()]; + + // Preserve valid suffixes from existing orderings + let oeq_class = mem::take(&mut self.oeq_class); + for existing in oeq_class { + if self.is_prefix_of(&filtered_exprs, &existing) { + let mut extended = filtered_exprs.clone(); + extended.extend(existing.into_iter().skip(filtered_exprs.len())); + new_orderings.push(extended); + } + } + + self.oeq_class = OrderingEquivalenceClass::new(new_orderings); self } + /// Checks if the new ordering matches a prefix of the existing ordering + /// (considering expression equivalences) + fn is_prefix_of(&self, new_order: &LexOrdering, existing: &LexOrdering) -> bool { + // Check if new order is longer than existing - can't be a prefix + if new_order.len() > existing.len() { + return false; + } + + // Check if new order matches existing prefix (considering equivalences) + new_order.iter().zip(existing).all(|(new, existing)| { + self.eq_group.exprs_equal(&new.expr, &existing.expr) + && new.options == existing.options + }) + } + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the /// equivalence group and the ordering equivalence class within. /// @@ -406,13 +486,13 @@ impl EquivalenceProperties { /// function would return `vec![a ASC, c ASC]`. Internally, it would first /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + let sort_reqs = LexRequirement::from(sort_exprs.clone()); // Normalize the requirements: let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + LexOrdering::from(normalized_sort_reqs) } /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the @@ -428,10 +508,7 @@ impl EquivalenceProperties { /// function would return `vec![a ASC, c ASC]`. Internally, it would first /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result /// after deduplication. - fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { + fn normalize_sort_requirements(&self, sort_reqs: &LexRequirement) -> LexRequirement { let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); let mut constant_exprs = vec![]; constant_exprs.extend( @@ -441,31 +518,34 @@ impl EquivalenceProperties { ); let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); // Prune redundant sections in the requirement: - collapse_lex_req( - normalized_sort_reqs - .iter() - .filter(|&order| { - !physical_exprs_contains(&constants_normalized, &order.expr) - }) - .cloned() - .collect(), - ) + normalized_sort_reqs + .iter() + .filter(|&order| !physical_exprs_contains(&constants_normalized, &order.expr)) + .cloned() + .collect::() + .collapse() } /// Checks whether the given ordering is satisfied by any of the existing /// orderings. - pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { + pub fn ordering_satisfy(&self, given: &LexOrdering) -> bool { // Convert the given sort expressions to sort requirements: - let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); + let sort_requirements = LexRequirement::from(given.clone()); self.ordering_satisfy_requirement(&sort_requirements) } /// Checks whether the given sort requirements are satisfied by any of the /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + pub fn ordering_satisfy_requirement(&self, reqs: &LexRequirement) -> bool { let mut eq_properties = self.clone(); // First, standardize the given requirement: let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + + // Check whether given ordering is satisfied by constraints first + if self.satisfied_by_constraints(&normalized_reqs) { + return true; + } + for normalized_req in normalized_reqs { // Check whether given ordering is satisfied if !eq_properties.ordering_satisfy_single(&normalized_req) { @@ -489,11 +569,87 @@ impl EquivalenceProperties { true } + /// Checks if the sort requirements are satisfied by any of the table constraints (primary key or unique). + /// Returns true if any constraint fully satisfies the requirements. + fn satisfied_by_constraints( + &self, + normalized_reqs: &[PhysicalSortRequirement], + ) -> bool { + self.constraints.iter().any(|constraint| match constraint { + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => self + .satisfied_by_constraint( + normalized_reqs, + indices, + matches!(constraint, Constraint::Unique(_)), + ), + }) + } + + /// Checks if sort requirements are satisfied by a constraint (primary key or unique). + /// Returns true if the constraint indices form a valid prefix of an existing ordering + /// that matches the requirements. For unique constraints, also verifies nullable columns. + fn satisfied_by_constraint( + &self, + normalized_reqs: &[PhysicalSortRequirement], + indices: &[usize], + check_null: bool, + ) -> bool { + // Requirements must contain indices + if indices.len() > normalized_reqs.len() { + return false; + } + + // Iterate over all orderings + self.oeq_class.iter().any(|ordering| { + if indices.len() > ordering.len() { + return false; + } + + // Build a map of column positions in the ordering + let mut col_positions = HashMap::with_capacity(ordering.len()); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() { + col_positions.insert( + col.index(), + (pos, col.nullable(&self.schema).unwrap_or(true)), + ); + } + } + + // Check if all constraint indices appear in valid positions + if !indices.iter().all(|&idx| { + col_positions + .get(&idx) + .map(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last + !check_null + || (pos != 0 && pos != ordering.len() - 1) + || !nullable + }) + .unwrap_or(false) + }) { + return false; + } + + // Check if this ordering matches requirements prefix + let ordering_len = ordering.len(); + normalized_reqs.len() >= ordering_len + && normalized_reqs[..ordering_len].iter().zip(ordering).all( + |(req, existing)| { + req.expr.eq(&existing.expr) + && req + .options + .map_or(true, |req_opts| req_opts == existing.options) + }, + ) + }) + } + /// Determines whether the ordering specified by the given sort requirement /// is satisfied based on the orderings within, equivalence classes, and /// constant expressions. /// - /// # Arguments + /// # Parameters /// /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering /// satisfaction check will be done. @@ -519,12 +675,12 @@ impl EquivalenceProperties { } } - /// Checks whether the `given`` sort requirements are equal or more specific + /// Checks whether the `given` sort requirements are equal or more specific /// than the `reference` sort requirements. pub fn requirements_compatible( &self, - given: LexRequirementRef, - reference: LexRequirementRef, + given: &LexRequirement, + reference: &LexRequirement, ) -> bool { let normalized_given = self.normalize_sort_requirements(given); let normalized_reference = self.normalize_sort_requirements(reference); @@ -546,15 +702,15 @@ impl EquivalenceProperties { /// the latter. pub fn get_finer_ordering( &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, + lhs: &LexOrdering, + rhs: &LexOrdering, ) -> Option { // Convert the given sort expressions to sort requirements: - let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); - let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); + let lhs = LexRequirement::from(lhs.clone()); + let rhs = LexRequirement::from(rhs.clone()); let finer = self.get_finer_requirement(&lhs, &rhs); // Convert the chosen sort requirements back to sort expressions: - finer.map(PhysicalSortRequirement::to_sort_exprs) + finer.map(LexOrdering::from) } /// Returns the finer ordering among the requirements `lhs` and `rhs`, @@ -567,8 +723,8 @@ impl EquivalenceProperties { /// is the latter. pub fn get_finer_requirement( &self, - req1: LexRequirementRef, - req2: LexRequirementRef, + req1: &LexRequirement, + req2: &LexRequirement, ) -> Option { let mut lhs = self.normalize_sort_requirements(req1); let mut rhs = self.normalize_sort_requirements(req2); @@ -604,8 +760,8 @@ impl EquivalenceProperties { pub fn substitute_ordering_component( &self, mapping: &ProjectionMapping, - sort_expr: &[PhysicalSortExpr], - ) -> Result>> { + sort_expr: &LexOrdering, + ) -> Result> { let new_orderings = sort_expr .iter() .map(|sort_expr| { @@ -615,7 +771,7 @@ impl EquivalenceProperties { .filter(|source| expr_refers(source, &sort_expr.expr)) .cloned() .collect(); - let mut res = vec![sort_expr.clone()]; + let mut res = LexOrdering::new(vec![sort_expr.clone()]); // TODO: Add one-to-ones analysis for ScalarFunctions. for r_expr in referring_exprs { // we check whether this expression is substitutable or not @@ -639,6 +795,7 @@ impl EquivalenceProperties { let res = new_orderings .into_iter() .multi_cartesian_product() + .map(LexOrdering::new) .collect::>(); Ok(res) } @@ -649,8 +806,8 @@ impl EquivalenceProperties { /// Since it would cause bug in dependency constructions, we should substitute the input order in order to get correct /// dependency map, happen in issue 8838: pub fn substitute_oeq_class(&mut self, mapping: &ProjectionMapping) -> Result<()> { - let orderings = &self.oeq_class.orderings; - let new_order = orderings + let new_order = self + .oeq_class .iter() .map(|order| self.substitute_ordering_component(mapping, order)) .collect::>>()?; @@ -709,7 +866,7 @@ impl EquivalenceProperties { /// c ASC: Node {None, HashSet{a ASC}} /// ``` fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = IndexMap::new(); + let mut dependency_map = DependencyMap::new(); for ordering in self.normalized_oeq_class().iter() { for (idx, sort_expr) in ordering.iter().enumerate() { let target_sort_expr = @@ -731,13 +888,11 @@ impl EquivalenceProperties { let dependency = idx.checked_sub(1).map(|a| &ordering[a]); // Add sort expressions that can be projected or referred to // by any of the projection expressions to the dependency map: - dependency_map - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.clone(), - dependencies: IndexSet::new(), - }) - .insert_dependency(dependency); + dependency_map.insert( + sort_expr, + target_sort_expr.as_ref(), + dependency, + ); } if !is_projected { // If we can not project, stop constructing the dependency @@ -837,7 +992,7 @@ impl EquivalenceProperties { if prefixes.is_empty() { // If prefix is empty, there is no dependency. Insert // empty ordering: - prefixes = vec![vec![]]; + prefixes = vec![LexOrdering::default()]; } // Append current ordering on top its dependencies: for ordering in prefixes.iter_mut() { @@ -851,7 +1006,7 @@ impl EquivalenceProperties { // Simplify each ordering by removing redundant sections: orderings .chain(projected_orderings) - .map(collapse_lex_ordering) + .map(|lex_ordering| lex_ordering.collapse()) .collect() } @@ -861,7 +1016,7 @@ impl EquivalenceProperties { /// constants based on the existing constants and the mapping. It ensures /// that constants are appropriately propagated through the projection. /// - /// # Arguments + /// # Parameters /// /// - `mapping`: A reference to a `ProjectionMapping` representing the /// mapping of source expressions to target expressions in the projection. @@ -877,37 +1032,76 @@ impl EquivalenceProperties { .constants .iter() .flat_map(|const_expr| { - const_expr.map(|expr| self.eq_group.project_expr(mapping, expr)) + const_expr + .map(|expr| self.eq_group.project_expr(mapping, expr)) + .map(|projected_expr| { + projected_expr + .with_across_partitions(const_expr.across_partitions()) + }) }) .collect::>(); + // Add projection expressions that are known to be constant: for (source, target) in mapping.iter() { if self.is_expr_constant(source) && !const_exprs_contains(&projected_constants, target) { - // Expression evaluates to single value - projected_constants - .push(ConstExpr::from(target).with_across_partitions(true)); + if self.is_expr_constant_accross_partitions(source) { + projected_constants.push( + ConstExpr::from(target) + .with_across_partitions(self.get_expr_constant_value(source)), + ) + } else { + projected_constants.push( + ConstExpr::from(target) + .with_across_partitions(AcrossPartitions::Heterogeneous), + ) + } } } projected_constants } - /// Projects the equivalences within according to `projection_mapping` + /// Projects constraints according to the given projection mapping. + /// + /// This function takes a projection mapping and extracts the column indices of the target columns. + /// It then projects the constraints to only include relationships between + /// columns that exist in the projected output. + /// + /// # Arguments + /// + /// * `mapping` - A reference to `ProjectionMapping` that defines how expressions are mapped + /// in the projection operation + /// + /// # Returns + /// + /// Returns a new `Constraints` object containing only the constraints + /// that are valid for the projected columns. + fn projected_constraints(&self, mapping: &ProjectionMapping) -> Option { + let indices = mapping + .iter() + .filter_map(|(_, target)| target.as_any().downcast_ref::()) + .map(|col| col.index()) + .collect::>(); + debug_assert_eq!(mapping.map.len(), indices.len()); + self.constraints.project(&indices) + } + + /// Projects the equivalences within according to `mapping` /// and `output_schema`. - pub fn project( - &self, - projection_mapping: &ProjectionMapping, - output_schema: SchemaRef, - ) -> Self { - let projected_constants = self.projected_constants(projection_mapping); - let projected_eq_group = self.eq_group.project(projection_mapping); - let projected_orderings = self.projected_orderings(projection_mapping); + pub fn project(&self, mapping: &ProjectionMapping, output_schema: SchemaRef) -> Self { + let eq_group = self.eq_group.project(mapping); + let oeq_class = OrderingEquivalenceClass::new(self.projected_orderings(mapping)); + let constants = self.projected_constants(mapping); + let constraints = self + .projected_constraints(mapping) + .unwrap_or_else(Constraints::empty); Self { - eq_group: projected_eq_group, - oeq_class: OrderingEquivalenceClass::new(projected_orderings), - constants: projected_constants, schema: output_schema, + eq_group, + oeq_class, + constants, + constraints, } } @@ -987,13 +1181,14 @@ impl EquivalenceProperties { // Add new ordered section to the state. result.extend(ordered_exprs); } - result.into_iter().unzip() + let (left, right) = result.into_iter().unzip(); + (LexOrdering::new(left), right) } /// This function determines whether the provided expression is constant /// based on the known constants. /// - /// # Arguments + /// # Parameters /// /// - `expr`: A reference to a `Arc` representing the /// expression to be checked. @@ -1015,6 +1210,76 @@ impl EquivalenceProperties { is_constant_recurse(&normalized_constants, &normalized_expr) } + /// This function determines whether the provided expression is constant + /// across partitions based on the known constants. + /// + /// # Parameters + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant across all partitions according + /// to equivalence group, `false` otherwise. + pub fn is_expr_constant_accross_partitions( + &self, + expr: &Arc, + ) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let const_exprs = self + .constants + .iter() + .filter_map(|const_expr| { + if matches!( + const_expr.across_partitions(), + AcrossPartitions::Uniform { .. } + ) { + Some(Arc::clone(const_expr.expr())) + } else { + None + } + }) + .collect::>(); + let normalized_constants = self.eq_group.normalize_exprs(const_exprs); + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the constant value of a given physical expression, if it exists. + /// + /// Normalizes the input expression and checks if it matches any known constants + /// in the current context. Returns whether the expression has a uniform value, + /// varies across partitions, or is not constant. + /// + /// # Parameters + /// - `expr`: A reference to the physical expression to evaluate. + /// + /// # Returns + /// - `AcrossPartitions::Uniform(value)`: If the expression has the same value across partitions. + /// - `AcrossPartitions::Heterogeneous`: If the expression varies across partitions. + /// - `None`: If the expression is not recognized as constant. + pub fn get_expr_constant_value( + &self, + expr: &Arc, + ) -> AcrossPartitions { + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); + + if let Some(lit) = normalized_expr.as_any().downcast_ref::() { + return AcrossPartitions::Uniform(Some(lit.scalar().value().clone())); + } + + for const_expr in self.constants.iter() { + if normalized_expr.eq(const_expr.expr()) { + return const_expr.across_partitions(); + } + } + + AcrossPartitions::Heterogeneous + } + /// Retrieves the properties for a given physical expression. /// /// This function constructs an [`ExprProperties`] object for the given @@ -1075,7 +1340,7 @@ impl EquivalenceProperties { // Rewrite orderings according to new schema: let mut new_orderings = vec![]; - for ordering in self.oeq_class.orderings { + for ordering in self.oeq_class { let new_ordering = ordering .into_iter() .map(|mut sort_expr| { @@ -1088,7 +1353,7 @@ impl EquivalenceProperties { // Rewrite equivalence classes according to the new schema: let mut eq_classes = vec![]; - for eq_class in self.eq_group.classes { + for eq_class in self.eq_group { let new_eq_exprs = eq_class .into_vec() .into_iter() @@ -1114,7 +1379,7 @@ impl EquivalenceProperties { /// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] /// ``` impl Display for EquivalenceProperties { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.eq_group.is_empty() && self.oeq_class.is_empty() && self.constants.is_empty() @@ -1189,7 +1454,7 @@ fn update_properties( /// This function determines whether the provided expression is constant /// based on the known constants. /// -/// # Arguments +/// # Parameters /// /// - `constants`: A `&[Arc]` containing expressions known to /// be a constant. @@ -1257,7 +1522,7 @@ fn referred_dependencies( // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: let mut expr_to_sort_exprs = IndexMap::::new(); for sort_expr in dependency_map - .keys() + .sort_exprs() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { let key = ExprWrapper(Arc::clone(&sort_expr.expr)); @@ -1270,10 +1535,16 @@ fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - expr_to_sort_exprs - .values() + let dependencies = expr_to_sort_exprs + .into_values() + .map(Dependencies::into_inner) + .collect::>(); + dependencies + .iter() .multi_cartesian_product() - .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .map(|referred_deps| { + Dependencies::new_from_iter(referred_deps.into_iter().cloned()) + }) .collect() } @@ -1296,7 +1567,9 @@ fn construct_prefix_orderings( dependency_map: &DependencyMap, ) -> Vec { let mut dep_enumerator = DependencyEnumerator::new(); - dependency_map[relevant_sort_expr] + dependency_map + .get(relevant_sort_expr) + .expect("no relevant sort expr found") .dependencies .iter() .flat_map(|dep| dep_enumerator.construct_orderings(dep, dependency_map)) @@ -1306,8 +1579,8 @@ fn construct_prefix_orderings( /// Generates all possible orderings where dependencies are satisfied for the /// current projection expression. /// -/// # Examaple -/// If `dependences` is `a + b ASC` and the dependency map holds dependencies +/// # Example +/// If `dependencies` is `a + b ASC` and the dependency map holds dependencies /// * `a ASC` --> `[c ASC]` /// * `b ASC` --> `[d DESC]`, /// @@ -1341,7 +1614,7 @@ fn generate_dependency_orderings( // No dependency, dependent is a leading ordering. if relevant_prefixes.is_empty() { // Return an empty ordering: - return vec![vec![]]; + return vec![LexOrdering::default()]; } relevant_prefixes @@ -1351,7 +1624,12 @@ fn generate_dependency_orderings( prefix_orderings .iter() .permutations(prefix_orderings.len()) - .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .map(|prefixes| { + prefixes + .into_iter() + .flat_map(|ordering| ordering.clone()) + .collect() + }) .collect::>() }) .collect() @@ -1382,11 +1660,13 @@ fn get_expr_properties( Ok(ExprProperties { sort_properties: SortProperties::Ordered(column_order.options), range: Interval::make_unbounded(&expr.data_type(schema)?)?, + preserves_lex_ordering: false, }) } else if expr.as_any().downcast_ref::().is_some() { Ok(ExprProperties { sort_properties: SortProperties::Unordered, range: Interval::make_unbounded(&expr.data_type(schema)?)?, + preserves_lex_ordering: false, }) } else if let Some(literal) = expr.as_any().downcast_ref::() { Ok(ExprProperties { @@ -1395,6 +1675,7 @@ fn get_expr_properties( literal.scalar().value().clone(), literal.scalar().value().clone(), )?, + preserves_lex_ordering: true, }) } else { // Find orderings of its children @@ -1436,13 +1717,161 @@ impl DependencyNode { } } -// Using `IndexMap` and `IndexSet` makes sure to generate consistent results across different executions for the same query. -// We could have used `HashSet`, `HashMap` in place of them without any loss of functionality. -// As an example, if existing orderings are `[a ASC, b ASC]`, `[c ASC]` for output ordering -// both `[a ASC, b ASC, c ASC]` and `[c ASC, a ASC, b ASC]` are valid (e.g. concatenated version of the alternative orderings). -// When using `HashSet`, `HashMap` it is not guaranteed to generate consistent result, among the possible 2 results in the example above. -type DependencyMap = IndexMap; -type Dependencies = IndexSet; +impl Display for DependencyNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(target) = &self.target_sort_expr { + write!(f, "(target: {}, ", target)?; + } else { + write!(f, "(")?; + } + write!(f, "dependencies: [{}])", self.dependencies) + } +} + +/// Maps an expression --> DependencyNode +/// +/// # Debugging / deplaying `DependencyMap` +/// +/// This structure implements `Display` to assist debugging. For example: +/// +/// ```text +/// DependencyMap: { +/// a@0 ASC --> (target: a@0 ASC, dependencies: [[]]) +/// b@1 ASC --> (target: b@1 ASC, dependencies: [[a@0 ASC, c@2 ASC]]) +/// c@2 ASC --> (target: c@2 ASC, dependencies: [[b@1 ASC, a@0 ASC]]) +/// d@3 ASC --> (target: d@3 ASC, dependencies: [[c@2 ASC, b@1 ASC]]) +/// } +/// ``` +/// +/// # Note on IndexMap Rationale +/// +/// Using `IndexMap` (which preserves insert order) to ensure consistent results +/// across different executions for the same query. We could have used +/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// +/// As an example, if existing orderings are +/// 1. `[a ASC, b ASC]` +/// 2. `[c ASC]` for +/// +/// Then both the following output orderings are valid +/// 1. `[a ASC, b ASC, c ASC]` +/// 2. `[c ASC, a ASC, b ASC]` +/// +/// (this are both valid as they are concatenated versions of the alternative +/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate +/// consistent result, among the possible 2 results in the example above. +#[derive(Debug)] +struct DependencyMap { + inner: IndexMap, +} + +impl DependencyMap { + fn new() -> Self { + Self { + inner: IndexMap::new(), + } + } + + /// Insert a new dependency `sort_expr` --> `dependency` into the map. + /// + /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + fn insert( + &mut self, + sort_expr: &PhysicalSortExpr, + target_sort_expr: Option<&PhysicalSortExpr>, + dependency: Option<&PhysicalSortExpr>, + ) { + self.inner + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.cloned(), + dependencies: Dependencies::new(), + }) + .insert_dependency(dependency) + } + + /// Iterator over (sort_expr, DependencyNode) pairs + fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + /// iterator over all sort exprs + fn sort_exprs(&self) -> impl Iterator { + self.inner.keys() + } + + /// Return the dependency node for the given sort expression, if any + fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { + self.inner.get(sort_expr) + } +} + +impl Display for DependencyMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "DependencyMap: {{")?; + for (sort_expr, node) in self.inner.iter() { + writeln!(f, " {sort_expr} --> {node}")?; + } + writeln!(f, "}}") + } +} + +/// A list of sort expressions that can be calculated from a known set of +/// dependencies. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct Dependencies { + inner: IndexSet, +} + +impl Display for Dependencies { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + let mut iter = self.inner.iter(); + if let Some(dep) = iter.next() { + write!(f, "{}", dep)?; + } + for dep in iter { + write!(f, ", {}", dep)?; + } + write!(f, "]") + } +} + +impl Dependencies { + /// Create a new empty `Dependencies` instance. + fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. + fn new_from_iter(iter: impl IntoIterator) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } + + /// Insert a new dependency into the set. + fn insert(&mut self, sort_expr: PhysicalSortExpr) { + self.inner.insert(sort_expr); + } + + /// Iterator over dependencies in the set + fn iter(&self) -> impl Iterator + Clone { + self.inner.iter() + } + + /// Return the inner set of dependencies + fn into_inner(self) -> IndexSet { + self.inner + } + + /// Returns true if there are no dependencies + fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} /// Contains a mapping of all dependencies we have processed for each sort expr struct DependencyEnumerator<'a> { @@ -1490,17 +1919,18 @@ impl<'a> DependencyEnumerator<'a> { referred_sort_expr: &'a PhysicalSortExpr, dependency_map: &'a DependencyMap, ) -> Vec { - // We are sure that `referred_sort_expr` is inside `dependency_map`. - let node = &dependency_map[referred_sort_expr]; - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. + let node = dependency_map + .get(referred_sort_expr) + .expect("`referred_sort_expr` should be inside `dependency_map`"); // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); // An empty dependency means the referred_sort_expr represents a global ordering. // Return its projected version, which is the target_expression. + // An empty dependency means the referred_sort_expr represents a global ordering. + // Return its projected version, which is the target_expression. if node.dependencies.is_empty() { - return vec![vec![target_sort_expr.clone()]]; + return vec![LexOrdering::new(vec![target_sort_expr.clone()])]; }; - node.dependencies .iter() .flat_map(|dep| { @@ -1668,20 +2098,34 @@ fn calculate_union_binary( } // First, calculate valid constants for the union. An expression is constant - // at the output of the union if it is constant in both sides. + // at the output of the union if it is constant in both sides with matching values. let constants: Vec<_> = lhs .constants() .iter() - .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) - .map(|const_expr| { - // TODO: When both sides have a constant column, and the actual - // constant value is the same, then the output properties could - // reflect the constant is valid across all partitions. However we - // don't track the actual value that the ConstExpr takes on, so we - // can't determine that yet - ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) + .filter_map(|lhs_const| { + // Find matching constant expression in RHS + rhs.constants() + .iter() + .find(|rhs_const| rhs_const.expr().eq(lhs_const.expr())) + .map(|rhs_const| { + let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr())); + + // If both sides have matching constant values, preserve the value and set across_partitions=true + if let ( + AcrossPartitions::Uniform(Some(lhs_val)), + AcrossPartitions::Uniform(Some(rhs_val)), + ) = (lhs_const.across_partitions(), rhs_const.across_partitions()) + { + if lhs_val == rhs_val { + const_expr = const_expr.with_across_partitions( + AcrossPartitions::Uniform(Some(lhs_val)), + ) + } + } + const_expr + }) }) - .collect(); + .collect::>(); // remove any constants that are shared in both outputs (avoid double counting them) for c in &constants { @@ -1692,16 +2136,8 @@ fn calculate_union_binary( // Next, calculate valid orderings for the union by searching for prefixes // in both sides. let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings( - lhs.normalized_oeq_class().orderings, - lhs.constants(), - &rhs, - ); - orderings.add_satisfied_orderings( - rhs.normalized_oeq_class().orderings, - rhs.constants(), - &lhs, - ); + orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs); + orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs); let orderings = orderings.build(); let mut eq_properties = @@ -1807,7 +2243,7 @@ impl UnionEquivalentOrderingBuilder { ) -> AddedOrdering { if ordering.is_empty() { AddedOrdering::Yes - } else if constants.is_empty() && properties.ordering_satisfy(&ordering) { + } else if constants.is_empty() && properties.ordering_satisfy(ordering.as_ref()) { // If the ordering satisfies the target properties, no need to // augment it with constants. self.orderings.push(ordering); @@ -1840,7 +2276,7 @@ impl UnionEquivalentOrderingBuilder { // for each equivalent ordering in properties, try and augment // `ordering` it with the constants to match - for existing_ordering in &properties.oeq_class.orderings { + for existing_ordering in properties.oeq_class.iter() { if let Some(augmented_ordering) = self.augment_ordering( ordering, constants, @@ -1848,7 +2284,7 @@ impl UnionEquivalentOrderingBuilder { &properties.constants, ) { if !augmented_ordering.is_empty() { - assert!(properties.ordering_satisfy(&augmented_ordering)); + assert!(properties.ordering_satisfy(augmented_ordering.as_ref())); self.orderings.push(augmented_ordering); } } @@ -1868,7 +2304,7 @@ impl UnionEquivalentOrderingBuilder { existing_ordering: &LexOrdering, existing_constants: &[ConstExpr], ) -> Option { - let mut augmented_ordering = vec![]; + let mut augmented_ordering = LexOrdering::default(); let mut sort_expr_iter = ordering.iter().peekable(); let mut existing_sort_expr_iter = existing_ordering.iter().peekable(); @@ -1947,16 +2383,17 @@ mod tests { use crate::equivalence::add_offset_to_expr; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, - create_random_schema, create_test_params, create_test_schema, - generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + create_test_params, create_test_schema, output_schema, }; use crate::expressions::{col, BinaryExpr, Column}; - use crate::utils::tests::TestScalarUDF; + use crate::ScalarFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, TimeUnit}; - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; + use datafusion_common::{Constraint, ScalarValue}; + use datafusion_expr::Operator; + + use datafusion_functions::string::concat; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -1997,7 +2434,7 @@ mod tests { // At the output a1=a2=a3=a4 assert_eq!(out_properties.eq_group().len(), 1); - let eq_class = &out_properties.eq_group().classes[0]; + let eq_class = out_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_class.len(), 4); assert!(eq_class.contains(col_a1)); assert!(eq_class.contains(col_a2)); @@ -2019,20 +2456,20 @@ mod tests { let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); // add equivalent ordering [a, b, c, d] - input_properties.add_new_ordering(vec![ + input_properties.add_new_ordering(LexOrdering::new(vec![ parse_sort_expr("a", &input_schema), parse_sort_expr("b", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("d", &input_schema), - ]); + ])); // add equivalent ordering [a, c, b, d] - input_properties.add_new_ordering(vec![ + input_properties.add_new_ordering(LexOrdering::new(vec![ parse_sort_expr("a", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("b", &input_schema), // NB b and c are swapped parse_sort_expr("d", &input_schema), - ]); + ])); // simply project all the columns in order let proj_exprs = vec![ @@ -2046,7 +2483,7 @@ mod tests { assert_eq!( out_properties.to_string(), - "order: [[a@0 ASC,c@2 ASC,b@1 ASC,d@3 ASC], [a@0 ASC,b@1 ASC,c@2 ASC,d@3 ASC]]" + "order: [[a@0 ASC, c@2 ASC, b@1 ASC, d@3 ASC], [a@0 ASC, b@1 ASC, c@2 ASC, d@3 ASC]]" ); Ok(()) @@ -2120,17 +2557,12 @@ mod tests { Some(JoinSide::Left), &[], ); - let orderings = &join_eq.oeq_class.orderings; - let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); - assert_eq!( - join_eq.oeq_class.orderings.len(), - expected.len(), - "{}", - err_msg - ); - for ordering in orderings { + let err_msg = + format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class); + assert_eq!(join_eq.oeq_class.len(), expected.len(), "{}", err_msg); + for ordering in join_eq.oeq_class { assert!( - expected.contains(ordering), + expected.contains(&ordering), "{}, ordering: {:?}", err_msg, ordering @@ -2252,27 +2684,27 @@ mod tests { eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; let others = vec![ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::clone(&col_b_expr), options: sort_options, - }], - vec![PhysicalSortExpr { + }]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::clone(&col_c_expr), options: sort_options, - }], + }]), ]; eq_properties.add_new_orderings(others); let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); expected_eqs.add_new_orderings([ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::clone(&col_b_expr), options: sort_options, - }], - vec![PhysicalSortExpr { + }]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::clone(&col_c_expr), options: sort_options, - }], + }]), ]); let oeq_class = eq_properties.oeq_class().clone(); @@ -2295,7 +2727,7 @@ mod tests { let col_b = &col("b", &schema)?; let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([vec![ + eq_properties.add_new_orderings([LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: sort_options_not, @@ -2304,12 +2736,12 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: sort_options, }, - ]]); + ])]); let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::clone(col_b), options: sort_options_not @@ -2318,7 +2750,7 @@ mod tests { expr: Arc::clone(col_a), options: sort_options } - ] + ]) ); let schema = Schema::new(vec![ @@ -2331,11 +2763,11 @@ mod tests { let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); eq_properties.add_new_orderings([ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("c", 2)), options: sort_options, - }], - vec![ + }]), + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: sort_options_not, @@ -2344,13 +2776,13 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: sort_options, }, - ], + ]), ]); let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::clone(col_b), options: sort_options_not @@ -2359,7 +2791,7 @@ mod tests { expr: Arc::clone(col_a), options: sort_options } - ] + ]) ); let required_columns = [ @@ -2374,7 +2806,7 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); // not satisfied orders - eq_properties.add_new_orderings([vec![ + eq_properties.add_new_orderings([LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: sort_options_not, @@ -2387,7 +2819,7 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: sort_options, }, - ]]); + ])]); let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); assert_eq!(idxs, vec![0]); @@ -2416,14 +2848,14 @@ mod tests { eq_properties.add_equal_conditions(col_b, col_a)?; // [b ASC], [d ASC] eq_properties.add_new_orderings(vec![ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::clone(col_b), options: option_asc, - }], - vec![PhysicalSortExpr { + }]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::clone(col_d), options: option_asc, - }], + }]), ]); let test_cases = vec![ @@ -2467,83 +2899,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_longest_permutation_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = - eq_properties.find_longest_permutation(&exprs); - // Make sure that find_longest_permutation return values are consistent - let ordering2 = indices - .iter() - .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, - }) - .collect::>(); - assert_eq!( - ordering, ordering2, - "indices and lexicographical ordering do not match" - ); - - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } #[test] fn test_find_longest_permutation() -> Result<()> { // Schema satisfies following orderings: @@ -2575,7 +2930,7 @@ mod tests { nulls_first: true, }; // [d ASC, h DESC] also satisfies schema. - eq_properties.add_new_orderings([vec![ + eq_properties.add_new_orderings([LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::clone(col_d), options: option_asc, @@ -2584,7 +2939,7 @@ mod tests { expr: Arc::clone(col_h), options: option_desc, }, - ]]); + ])]); let test_cases = vec![ // TEST CASE 1 (vec![col_a], vec![(col_a, option_asc)]), @@ -2866,7 +3221,7 @@ mod tests { Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) - .with_reorder( + .with_reorder(LexOrdering::new( ["a", "b", "c"] .into_iter() .map(|c| { @@ -2879,7 +3234,7 @@ mod tests { }) }) .collect::>>()?, - ); + )); struct TestCase { name: &'static str, @@ -2968,10 +3323,10 @@ mod tests { options: SortOptions::default(), }) }) - .collect::>>()?; + .collect::>()?; assert_eq!( - properties.ordering_satisfy(&sort), + properties.ordering_satisfy(sort.as_ref()), case.should_satisfy_ordering, "failed test '{}'", case.name @@ -3490,7 +3845,7 @@ mod tests { ordering .iter() .map(|name| parse_sort_expr(name, schema)) - .collect::>() + .collect::() }) .collect::>(); @@ -3526,8 +3881,8 @@ mod tests { // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); - let rhs_orderings = &rhs.oeq_class.orderings; - for rhs_ordering in rhs_orderings { + let rhs_orderings = rhs.oeq_class(); + for rhs_ordering in rhs_orderings.iter() { assert!( lhs_orderings.contains(rhs_ordering), "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" @@ -3571,4 +3926,604 @@ mod tests { sort_expr } + + #[test] + fn test_ordering_equivalence_with_lex_monotonic_concat() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let a_concat_b: Arc = Arc::new(ScalarFunctionExpr::new( + "concat", + concat(), + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + DataType::Utf8, + )); + + // Assume existing ordering is [c ASC, a ASC, b ASC] + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + eq_properties.add_new_ordering(LexOrdering::from(vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), + ])); + + // Add equality condition c = concat(a, b) + eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + + let orderings = eq_properties.oeq_class(); + + let expected_ordering1 = + LexOrdering::from(vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc() + ]); + let expected_ordering2 = LexOrdering::from(vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), + ]); + + // The ordering should be [c ASC] and [a ASC, b ASC] + assert_eq!(orderings.len(), 2); + assert!(orderings.contains(&expected_ordering1)); + assert!(orderings.contains(&expected_ordering2)); + + Ok(()) + } + + #[test] + fn test_ordering_equivalence_with_non_lex_monotonic_multiply() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let a_times_b: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::Multiply, + Arc::clone(&col_b), + )); + + // Assume existing ordering is [c ASC, a ASC, b ASC] + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let initial_ordering = LexOrdering::from(vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), + ]); + + eq_properties.add_new_ordering(initial_ordering.clone()); + + // Add equality condition c = a * b + eq_properties.add_equal_conditions(&col_c, &a_times_b)?; + + let orderings = eq_properties.oeq_class(); + + // The ordering should remain unchanged since multiplication is not lex-monotonic + assert_eq!(orderings.len(), 1); + assert!(orderings.contains(&initial_ordering)); + + Ok(()) + } + + #[test] + fn test_ordering_equivalence_with_concat_equality() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let a_concat_b: Arc = Arc::new(ScalarFunctionExpr::new( + "concat", + concat(), + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + DataType::Utf8, + )); + + // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + eq_properties.add_new_ordering(LexOrdering::from(vec![ + PhysicalSortExpr::new_default(Arc::clone(&a_concat_b)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), + ])); + + // Add equality condition c = concat(a, b) + eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + + let orderings = eq_properties.oeq_class(); + + let expected_ordering1 = LexOrdering::from(vec![PhysicalSortExpr::new_default( + Arc::clone(&a_concat_b), + ) + .asc()]); + let expected_ordering2 = LexOrdering::from(vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), + PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), + ]); + + // The ordering should be [concat(a, b) ASC] and [a ASC, b ASC] + assert_eq!(orderings.len(), 2); + assert!(orderings.contains(&expected_ordering1)); + assert!(orderings.contains(&expected_ordering2)); + + Ok(()) + } + + #[test] + fn test_with_reorder_constant_filtering() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + // Setup constant columns + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + eq_properties = eq_properties.with_constants([ConstExpr::from(&col_a)]); + + let sort_exprs = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: SortOptions::default(), + }, + ]); + + let result = eq_properties.with_reorder(sort_exprs); + + // Should only contain b since a is constant + assert_eq!(result.oeq_class().len(), 1); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 1); + assert!(ordering[0].expr.eq(&col_b)); + + Ok(()) + } + + #[test] + fn test_with_reorder_preserve_suffix() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let asc = SortOptions::default(); + let desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // Initial ordering: [a ASC, b DESC, c ASC] + eq_properties.add_new_orderings([LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: desc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_c), + options: asc, + }, + ])]); + + // New ordering: [a ASC] + let new_order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }]); + + let result = eq_properties.with_reorder(new_order); + + // Should only contain [a ASC, b DESC, c ASC] + assert_eq!(result.oeq_class().len(), 1); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 3); + assert!(ordering[0].expr.eq(&col_a)); + assert!(ordering[0].options.eq(&asc)); + assert!(ordering[1].expr.eq(&col_b)); + assert!(ordering[1].options.eq(&desc)); + assert!(ordering[2].expr.eq(&col_c)); + assert!(ordering[2].options.eq(&asc)); + + Ok(()) + } + + #[test] + fn test_with_reorder_equivalent_expressions() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + // Make a and b equivalent + eq_properties.add_equal_conditions(&col_a, &col_b)?; + + let asc = SortOptions::default(); + + // Initial ordering: [a ASC, c ASC] + eq_properties.add_new_orderings([LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_c), + options: asc, + }, + ])]); + + // New ordering: [b ASC] + let new_order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: asc, + }]); + + let result = eq_properties.with_reorder(new_order); + + // Should only contain [b ASC, c ASC] + assert_eq!(result.oeq_class().len(), 1); + + // Verify orderings + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 2); + assert!(ordering[0].expr.eq(&col_b)); + assert!(ordering[0].options.eq(&asc)); + assert!(ordering[1].expr.eq(&col_c)); + assert!(ordering[1].options.eq(&asc)); + + Ok(()) + } + + #[test] + fn test_with_reorder_incompatible_prefix() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + + let asc = SortOptions::default(); + let desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // Initial ordering: [a ASC, b DESC] + eq_properties.add_new_orderings([LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: desc, + }, + ])]); + + // New ordering: [a DESC] + let new_order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: desc, + }]); + + let result = eq_properties.with_reorder(new_order.clone()); + + // Should only contain the new ordering since options don't match + assert_eq!(result.oeq_class().len(), 1); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering, &new_order); + + Ok(()) + } + + #[test] + fn test_with_reorder_comprehensive() -> Result<()> { + let schema = create_test_schema()?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + let col_d = col("d", &schema)?; + let col_e = col("e", &schema)?; + + let asc = SortOptions::default(); + + // Constants: c is constant + eq_properties = eq_properties.with_constants([ConstExpr::from(&col_c)]); + + // Equality: b = d + eq_properties.add_equal_conditions(&col_b, &col_d)?; + + // Orderings: [d ASC, a ASC], [e ASC] + eq_properties.add_new_orderings([ + LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_d), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: asc, + }, + ]), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_e), + options: asc, + }]), + ]); + + // Initial ordering: [b ASC, c ASC] + let new_order = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: Arc::clone(&col_b), + options: asc, + }, + PhysicalSortExpr { + expr: Arc::clone(&col_c), + options: asc, + }, + ]); + + let result = eq_properties.with_reorder(new_order); + + // Should preserve the original [d ASC, a ASC] ordering + assert_eq!(result.oeq_class().len(), 1); + let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 2); + + // First expression should be either b or d (they're equivalent) + assert!( + ordering[0].expr.eq(&col_b) || ordering[0].expr.eq(&col_d), + "Expected b or d as first expression, got {:?}", + ordering[0].expr + ); + assert!(ordering[0].options.eq(&asc)); + + // Second expression should be a + assert!(ordering[1].expr.eq(&col_a)); + assert!(ordering[1].options.eq(&asc)); + + Ok(()) + } + + #[test] + fn test_union_constant_value_preservation() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let col_a = col("a", &schema)?; + let literal_10 = ScalarValue::Int32(Some(10)); + + // Create first input with a=10 + let const_expr1 = ConstExpr::new(Arc::clone(&col_a)) + .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); + let input1 = EquivalenceProperties::new(Arc::clone(&schema)) + .with_constants(vec![const_expr1]); + + // Create second input with a=10 + let const_expr2 = ConstExpr::new(Arc::clone(&col_a)) + .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); + let input2 = EquivalenceProperties::new(Arc::clone(&schema)) + .with_constants(vec![const_expr2]); + + // Calculate union properties + let union_props = calculate_union(vec![input1, input2], schema)?; + + // Verify column 'a' remains constant with value 10 + let const_a = &union_props.constants()[0]; + assert!(const_a.expr().eq(&col_a)); + assert_eq!( + const_a.across_partitions(), + AcrossPartitions::Uniform(Some(literal_10)) + ); + + Ok(()) + } + + #[test] + fn test_ordering_satisfaction_with_key_constraints() -> Result<()> { + let pk_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ])); + + let unique_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ])); + + // Test cases to run + let test_cases = vec![ + // (name, schema, constraint, base_ordering, satisfied_orderings, unsatisfied_orderings) + ( + "single column primary key", + &pk_schema, + vec![Constraint::PrimaryKey(vec![0])], + vec!["a"], // base ordering + vec![vec!["a", "b"], vec!["a", "c", "d"]], + vec![vec!["b", "a"], vec!["c", "a"]], + ), + ( + "single column unique", + &unique_schema, + vec![Constraint::Unique(vec![0])], + vec!["a"], // base ordering + vec![vec!["a", "b"], vec!["a", "c", "d"]], + vec![vec!["b", "a"], vec!["c", "a"]], + ), + ( + "multi-column primary key", + &pk_schema, + vec![Constraint::PrimaryKey(vec![0, 1])], + vec!["a", "b"], // base ordering + vec![vec!["a", "b", "c"], vec!["a", "b", "d"]], + vec![vec!["b", "a"], vec!["a", "c", "b"]], + ), + ( + "multi-column unique", + &unique_schema, + vec![Constraint::Unique(vec![0, 1])], + vec!["a", "b"], // base ordering + vec![vec!["a", "b", "c"], vec!["a", "b", "d"]], + vec![vec!["b", "a"], vec!["c", "a", "b"]], + ), + ( + "nullable unique", + &unique_schema, + vec![Constraint::Unique(vec![2, 3])], + vec!["c", "d"], // base ordering + vec![], + vec![vec!["c", "d", "a"]], + ), + ( + "ordering with arbitrary column unique", + &unique_schema, + vec![Constraint::Unique(vec![0, 1])], + vec!["a", "c", "b"], // base ordering + vec![vec!["a", "c", "b", "d"]], + vec![vec!["a", "b", "d"]], + ), + ( + "ordering with arbitrary column pk", + &pk_schema, + vec![Constraint::PrimaryKey(vec![0, 1])], + vec!["a", "c", "b"], // base ordering + vec![vec!["a", "c", "b", "d"]], + vec![vec!["a", "b", "d"]], + ), + ( + "ordering with arbitrary column pk complex", + &pk_schema, + vec![Constraint::PrimaryKey(vec![3, 1])], + vec!["b", "a", "d"], // base ordering + vec![vec!["b", "a", "d", "c"]], + vec![vec!["b", "c", "d", "a"], vec!["b", "a", "c", "d"]], + ), + ]; + + for ( + name, + schema, + constraints, + base_order, + satisfied_orders, + unsatisfied_orders, + ) in test_cases + { + let mut eq_properties = EquivalenceProperties::new(Arc::clone(schema)); + + // Convert base ordering + let base_ordering = LexOrdering::new( + base_order + .iter() + .map(|col_name| PhysicalSortExpr { + expr: col(col_name, schema).unwrap(), + options: SortOptions::default(), + }) + .collect(), + ); + + // Convert string column names to orderings + let satisfied_orderings: Vec = satisfied_orders + .iter() + .map(|cols| { + LexOrdering::new( + cols.iter() + .map(|col_name| PhysicalSortExpr { + expr: col(col_name, schema).unwrap(), + options: SortOptions::default(), + }) + .collect(), + ) + }) + .collect(); + + let unsatisfied_orderings: Vec = unsatisfied_orders + .iter() + .map(|cols| { + LexOrdering::new( + cols.iter() + .map(|col_name| PhysicalSortExpr { + expr: col(col_name, schema).unwrap(), + options: SortOptions::default(), + }) + .collect(), + ) + }) + .collect(); + + // Test that orderings are not satisfied before adding constraints + for ordering in &satisfied_orderings { + assert!( + !eq_properties.ordering_satisfy(ordering), + "{}: ordering {:?} should not be satisfied before adding constraints", + name, + ordering + ); + } + + // Add base ordering + eq_properties.add_new_ordering(base_ordering); + + // Add constraints + eq_properties = + eq_properties.with_constraints(Constraints::new_unverified(constraints)); + + // Test that expected orderings are now satisfied + for ordering in &satisfied_orderings { + assert!( + eq_properties.ordering_satisfy(ordering), + "{}: ordering {:?} should be satisfied after adding constraints", + name, + ordering + ); + } + + // Test that unsatisfied orderings remain unsatisfied + for ordering in &unsatisfied_orderings { + assert!( + !eq_properties.ordering_satisfy(ordering), + "{}: ordering {:?} should not be satisfied after adding constraints", + name, + ordering + ); + } + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index be58178511bbf..d4cfd90530b12 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,11 +17,10 @@ mod kernels; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::{any::Any, sync::Arc}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::*; @@ -48,7 +47,7 @@ use kernels::{ }; /// Binary expression -#[derive(Debug, Hash, Clone)] +#[derive(Debug, Clone, Eq)] pub struct BinaryExpr { left: Arc, op: Operator, @@ -57,6 +56,24 @@ pub struct BinaryExpr { fail_on_overflow: bool, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for BinaryExpr { + fn eq(&self, other: &Self) -> bool { + self.left.eq(&other.left) + && self.op.eq(&other.op) + && self.right.eq(&other.right) + && self.fail_on_overflow.eq(&other.fail_on_overflow) + } +} +impl Hash for BinaryExpr { + fn hash(&self, state: &mut H) { + self.left.hash(state); + self.op.hash(state); + self.right.hash(state); + self.fail_on_overflow.hash(state); + } +} + impl BinaryExpr { /// Create new binary expression pub fn new( @@ -186,7 +203,9 @@ macro_rules! compute_utf8_flag_op { } macro_rules! binary_string_array_flag_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value, + // the query can be optimized in such a way that operands will be dicts, so we need to support it here let result: Result> = match $LEFT.data_type() { DataType::Utf8View | DataType::Utf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) @@ -194,6 +213,27 @@ macro_rules! binary_string_array_flag_op_scalar { DataType::LargeUtf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) }, + DataType::Dictionary(_, _) => { + let values = $LEFT.as_any_dictionary().values(); + + match values.data_type() { + DataType::Utf8View | DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), + DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG), + other => internal_err!( + "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array", + other, stringify!($OP) + ), + }.map( + // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before + |evaluated_values| downcast_dictionary_array! { + $LEFT => { + let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::(); + Arc::new(unpacked_dict) as _ + }, + _ => unreachable!(), + } + ) + }, other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array", other, stringify!($OP) @@ -211,20 +251,23 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - if let ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT { - let flag = $FLAG.then_some("i"); - let mut array = - paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; - if $NOT { - array = not(&array).unwrap(); - } - Ok(Arc::new(array)) - } else { - internal_err!( - "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", - $RIGHT, stringify!($OP) - ) + let string_value = match $RIGHT.try_as_str() { + Some(Some(string_value)) => string_value, + // null literal or non string + _ => return internal_err!( + "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", + $RIGHT, stringify!($OP) + ) + }; + + let flag = $FLAG.then_some("i"); + let mut array = + paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; + if $NOT { + array = not(&array).unwrap(); } + + Ok(Arc::new(array)) }}; } @@ -355,7 +398,7 @@ impl PhysicalExpr for BinaryExpr { if self.op.eq(&Operator::And) { if interval.eq(&Interval::CERTAINLY_TRUE) { // A certainly true logical conjunction can only derive from possibly - // true operands. Otherwise, we prove infeasability. + // true operands. Otherwise, we prove infeasibility. Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE) && !right_interval.eq(&Interval::CERTAINLY_FALSE)) .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE])) @@ -395,7 +438,7 @@ impl PhysicalExpr for BinaryExpr { } else if self.op.eq(&Operator::Or) { if interval.eq(&Interval::CERTAINLY_FALSE) { // A certainly false logical conjunction can only derive from certainly - // false operands. Otherwise, we prove infeasability. + // false operands. Otherwise, we prove infeasibility. Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) && !right_interval.eq(&Interval::CERTAINLY_TRUE)) .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE])) @@ -432,7 +475,7 @@ impl PhysicalExpr for BinaryExpr { // end-points of its children. Ok(Some(vec![])) } - } else if self.op.is_comparison_operator() { + } else if self.op.supports_propagation() { Ok( propagate_comparison(&self.op, interval, left_interval, right_interval)? .map(|(left, right)| vec![left, right]), @@ -445,11 +488,6 @@ impl PhysicalExpr for BinaryExpr { } } - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } - /// For each operator, [`BinaryExpr`] has distinct rules. /// TODO: There may be rules specific to some data types and expression ranges. fn get_properties(&self, children: &[ExprProperties]) -> Result { @@ -459,54 +497,48 @@ impl PhysicalExpr for BinaryExpr { Operator::Plus => Ok(ExprProperties { sort_properties: l_order.add(&r_order), range: l_range.add(r_range)?, + preserves_lex_ordering: false, }), Operator::Minus => Ok(ExprProperties { sort_properties: l_order.sub(&r_order), range: l_range.sub(r_range)?, + preserves_lex_ordering: false, }), Operator::Gt => Ok(ExprProperties { sort_properties: l_order.gt_or_gteq(&r_order), range: l_range.gt(r_range)?, + preserves_lex_ordering: false, }), Operator::GtEq => Ok(ExprProperties { sort_properties: l_order.gt_or_gteq(&r_order), range: l_range.gt_eq(r_range)?, + preserves_lex_ordering: false, }), Operator::Lt => Ok(ExprProperties { sort_properties: r_order.gt_or_gteq(&l_order), range: l_range.lt(r_range)?, + preserves_lex_ordering: false, }), Operator::LtEq => Ok(ExprProperties { sort_properties: r_order.gt_or_gteq(&l_order), range: l_range.lt_eq(r_range)?, + preserves_lex_ordering: false, }), Operator::And => Ok(ExprProperties { sort_properties: r_order.and_or(&l_order), range: l_range.and(r_range)?, + preserves_lex_ordering: false, }), Operator::Or => Ok(ExprProperties { sort_properties: r_order.and_or(&l_order), range: l_range.or(r_range)?, + preserves_lex_ordering: false, }), _ => Ok(ExprProperties::new_unknown()), } } } -impl PartialEq for BinaryExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.left.eq(&x.left) - && self.op == x.op - && self.right.eq(&x.right) - && self.fail_on_overflow.eq(&x.fail_on_overflow) - }) - .unwrap_or(false) - } -} - /// Casts dictionary array to result type for binary numerical operators. Such operators /// between array and scalar produce a dictionary array other than primitive array of the /// same operators between array and array. This leads to inconsistent result types causing diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 89e9d990e6a69..19dbb979db716 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,11 +16,10 @@ // under the License. use std::borrow::Cow; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::{any::Any, sync::Arc}; use crate::expressions::try_cast; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::*; @@ -28,7 +27,9 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::ColumnarValue; use super::{Column, Literal}; @@ -37,7 +38,7 @@ use itertools::Itertools; type WhenThen = (Arc, Arc); -#[derive(Debug, Hash)] +#[derive(Debug, Hash, PartialEq, Eq)] enum EvalMethod { /// CASE WHEN condition THEN result /// [WHEN ...] @@ -61,6 +62,11 @@ enum EvalMethod { /// are literal values /// CASE WHEN condition THEN literal ELSE literal END ScalarOrScalar, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` are expressions + /// + /// CASE WHEN condition THEN expression ELSE expression END + ExpressionOrExpression, } /// The CASE expression is similar to a series of nested if/else and there are two forms that @@ -80,7 +86,7 @@ enum EvalMethod { /// [WHEN ...] /// [ELSE result] /// END -#[derive(Debug, Hash)] +#[derive(Debug, Hash, PartialEq, Eq)] pub struct CaseExpr { /// Optional base expression that can be compared to literal values in the "when" expressions expr: Option>, @@ -150,6 +156,8 @@ impl CaseExpr { && else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar + } else if when_then_expr.len() == 1 && else_expr.is_some() { + EvalMethod::ExpressionOrExpression } else { EvalMethod::NoExpression }; @@ -241,10 +249,9 @@ impl CaseExpr { remainder = and_not(&remainder, &when_match)?; } - if let Some(e) = &self.else_expr { + if let Some(e) = self.else_expr() { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; let else_ = expr @@ -274,11 +281,8 @@ impl CaseExpr { .0 .evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|e| { - DataFusionError::Context( - "WHEN expression did not return a BooleanArray".to_string(), - Box::new(e), - ) + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; // Treat 'NULL' as false value let when_value = match when_value.null_count() { @@ -312,10 +316,9 @@ impl CaseExpr { remainder = and_not(&remainder, &when_value)?; } - if let Some(e) = &self.else_expr { + if let Some(e) = self.else_expr() { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_ = expr .evaluate_selection(batch, &remainder)? .into_array(batch.num_rows())?; @@ -343,7 +346,10 @@ impl CaseExpr { .downcast_ref::() .expect("predicate should evaluate to a boolean array"); // invert the bitmask - let bit_mask = not(bit_mask)?; + let bit_mask = match bit_mask.null_count() { + 0 => not(bit_mask)?, + _ => not(&prep_null_mask_filter(bit_mask))?, + }; match then_expr.evaluate(batch)? { ColumnarValue::Array(array) => { Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) @@ -363,6 +369,35 @@ impl CaseExpr { // evaluate when expression let when_value = self.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // evaluate then_value + let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = Scalar::new(then_value.into_array(1)?); + + let Some(e) = self.else_expr() else { + return internal_err!("expression did not evaluate to an array"); + }; + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?; + let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + } + + fn expr_or_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evalute when condition on batch + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( "WHEN expression did not return a BooleanArray".to_string(), @@ -376,17 +411,22 @@ impl CaseExpr { _ => Cow::Owned(prep_null_mask_filter(when_value)), }; - // evaluate then_value - let then_value = self.when_then_expr[0].1.evaluate(batch)?; - let then_value = Scalar::new(then_value.into_array(1)?); + let then_value = self.when_then_expr[0] + .1 + .evaluate_selection(batch, &when_value)? + .into_array(batch.num_rows())?; - // keep `else_expr`'s data type and return type consistent + // evaluate else expression on the values not covered by when_value + let remainder = not(&when_value)?; let e = self.else_expr.as_ref().unwrap(); - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type) + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) .unwrap_or_else(|_| Arc::clone(e)); - let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; - Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } } @@ -451,6 +491,7 @@ impl PhysicalExpr for CaseExpr { self.case_column_or_null(batch) } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), + EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), } } @@ -502,39 +543,6 @@ impl PhysicalExpr for CaseExpr { )?)) } } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for CaseExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - let expr_eq = match (&self.expr, &x.expr) { - (Some(expr1), Some(expr2)) => expr1.eq(expr2), - (None, None) => true, - _ => false, - }; - let else_expr_eq = match (&self.else_expr, &x.else_expr) { - (Some(expr1), Some(expr2)) => expr1.eq(expr2), - (None, None) => true, - _ => false, - }; - expr_eq - && else_expr_eq - && self.when_then_expr.len() == x.when_then_expr.len() - && self.when_then_expr.iter().zip(x.when_then_expr.iter()).all( - |((when1, then1), (when2, then2))| { - when1.eq(when2) && then1.eq(then2) - }, - ) - }) - .unwrap_or(false) - } } /// Create a CASE expression @@ -915,6 +923,32 @@ mod tests { Ok(()) } + #[test] + fn test_when_null_and_some_cond_else_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + let when = binary( + Arc::new(Literal::from(ScalarValue::Boolean(None))), + Operator::And, + binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?, + &schema, + )?; + let then = col("a", &schema)?; + + // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END + let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_string_array(&result); + + // all result values should be null + assert_eq!(result.logical_null_count(), batch.num_rows()); + Ok(()) + } + fn case_test_batch() -> Result { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); @@ -1092,16 +1126,15 @@ mod tests { let expr2 = Arc::clone(&expr) .transform(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.scalar().value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.scalar().value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { @@ -1113,16 +1146,15 @@ mod tests { let expr3 = Arc::clone(&expr) .transform_down(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.scalar().value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.scalar().value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { @@ -1177,6 +1209,45 @@ mod tests { Ok(()) } + #[test] + fn test_expr_or_expr_specialization() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + let when = binary( + col("a", &schema)?, + Operator::LtEq, + lit(2i32), + &batch.schema(), + )?; + let then = binary( + col("a", &schema)?, + Operator::Plus, + lit(1i32), + &batch.schema(), + )?; + let else_expr = binary( + col("a", &schema)?, + Operator::Minus, + lit(1i32), + &batch.schema(), + )?; + let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; + assert!(matches!( + expr.eval_method, + EvalMethod::ExpressionOrExpression + )); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result).expect("failed to downcast to Int32Array"); + + let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]); + + assert_eq!(expected, result); + Ok(()) + } + fn make_col(name: &str, index: usize) -> Arc { Arc::new(Column::new(name, index)) } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 5621473c4fdb1..7eda5fb4beaa8 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -17,10 +17,10 @@ use std::any::Any; use std::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; @@ -42,7 +42,7 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { }; /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq)] pub struct CastExpr { /// The expression to cast pub expr: Arc, @@ -52,6 +52,23 @@ pub struct CastExpr { cast_options: CastOptions<'static>, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for CastExpr { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) + && self.cast_type.eq(&other.cast_type) + && self.cast_options.eq(&other.cast_options) + } +} + +impl Hash for CastExpr { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.cast_type.hash(state); + self.cast_options.hash(state); + } +} + impl CastExpr { /// Create a new CastExpr pub fn new( @@ -160,13 +177,6 @@ impl PhysicalExpr for CastExpr { ])) } - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.expr.hash(&mut s); - self.cast_type.hash(&mut s); - self.cast_options.hash(&mut s); - } - /// A [`CastExpr`] preserves the ordering of its child if the cast is done /// under the same datatype family. fn get_properties(&self, children: &[ExprProperties]) -> Result { @@ -186,19 +196,6 @@ impl PhysicalExpr for CastExpr { } } -impl PartialEq for CastExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.expr.eq(&x.expr) - && self.cast_type == x.cast_type - && self.cast_options == x.cast_options - }) - .unwrap_or(false) - } -} - /// Return a PhysicalExpression representing `expr` casted to /// `cast_type`, if any casting is needed. /// @@ -693,7 +690,7 @@ mod tests { let result = cast( col("a", &schema).unwrap(), &schema, - DataType::Interval(IntervalUnit::MonthDayNano), + Interval(IntervalUnit::MonthDayNano), ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 4aad959584ac4..0649cbd65d34d 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,9 +18,10 @@ //! Physical column reference: [`Column`] use std::any::Any; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; +use crate::physical_expr::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -30,8 +31,6 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; - /// Represents the column at a given index in a RecordBatch /// /// This is a physical expression that represents a column at a given index in an @@ -43,7 +42,7 @@ use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; /// /// # Example: /// If the schema is `a`, `b`, `c` the `Column` for `b` would be represented by -/// index 1, since `b` is the second colum in the schema. +/// index 1, since `b` is the second column in the schema. /// /// ``` /// # use datafusion_physical_expr::expressions::Column; @@ -107,7 +106,7 @@ impl std::fmt::Display for Column { impl PhysicalExpr for Column { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -139,20 +138,6 @@ impl PhysicalExpr for Column { ) -> Result> { Ok(self) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for Column { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) - } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index b44afe40b3c54..a16f42beacb3e 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -22,7 +22,7 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, physical_exprs_bag_equal}; +use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; use arrow::array::*; @@ -44,8 +44,8 @@ use datafusion_expr::{ColumnarValue, Scalar}; use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; +use datafusion_common::HashMap; use hashbrown::hash_map::RawEntryMut; -use hashbrown::HashMap; /// InList pub struct InListExpr { @@ -246,7 +246,7 @@ trait IsEqual: HashValue { fn is_equal(&self, other: &Self) -> bool; } -impl<'a, T: IsEqual + ?Sized> IsEqual for &'a T { +impl IsEqual for &T { fn is_equal(&self, other: &Self) -> bool { T::is_equal(self, other) } @@ -400,26 +400,24 @@ impl PhysicalExpr for InListExpr { self.static_filter.clone(), ))) } +} - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.expr.hash(&mut s); - self.negated.hash(&mut s); - self.list.hash(&mut s); - // Add `self.static_filter` when hash is available +impl PartialEq for InListExpr { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) + && physical_exprs_bag_equal(&self.list, &other.list) + && self.negated == other.negated } } -impl PartialEq for InListExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.expr.eq(&x.expr) - && physical_exprs_bag_equal(&self.list, &x.list) - && self.negated == x.negated - }) - .unwrap_or(false) +impl Eq for InListExpr {} + +impl Hash for InListExpr { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.negated.hash(state); + self.list.hash(state); + // Add `self.static_filter` when hash is available } } @@ -1104,7 +1102,7 @@ mod tests { let mut phy_exprs = vec![ lit(1i64), expressions::cast(lit(2i32), &schema, DataType::Int64)?, - expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?, + try_cast(lit(3.13f32), &schema, DataType::Int64)?, ]; let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); @@ -1132,7 +1130,7 @@ mod tests { try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); // column - phy_exprs.push(expressions::col("a", &schema)?); + phy_exprs.push(col("a", &schema)?); assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); Ok(()) diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index c16db7e8d4561..e68910bd1afc0 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,10 +17,9 @@ //! IS NOT NULL expression -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, @@ -31,12 +30,25 @@ use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; /// IS NOT NULL expression -#[derive(Debug, Hash)] +#[derive(Debug, Eq)] pub struct IsNotNullExpr { /// The input expression arg: Arc, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for IsNotNullExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) + } +} + +impl Hash for IsNotNullExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + } +} + impl IsNotNullExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { @@ -92,21 +104,8 @@ impl PhysicalExpr for IsNotNullExpr { ) -> Result> { Ok(Arc::new(IsNotNullExpr::new(Arc::clone(&children[0])))) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } } -impl PartialEq for IsNotNullExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.arg.eq(&x.arg)) - .unwrap_or(false) - } -} /// Create an IS NOT NULL expression pub fn is_not_null(arg: Arc) -> Result> { Ok(Arc::new(IsNotNullExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 65dc04fdfab0b..dff45e51f8b47 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,27 +17,38 @@ //! IS NULL expression -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::{any::Any, sync::Arc}; +use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; /// IS NULL expression -#[derive(Debug, Hash)] +#[derive(Debug, Eq)] pub struct IsNullExpr { /// Input expression arg: Arc, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for IsNullExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) + } +} + +impl Hash for IsNullExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + } +} + impl IsNullExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { @@ -92,20 +103,6 @@ impl PhysicalExpr for IsNullExpr { ) -> Result> { Ok(Arc::new(IsNullExpr::new(Arc::clone(&children[0])))) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for IsNullExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.arg.eq(&x.arg)) - .unwrap_or(false) - } } /// Create an IS NULL expression diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index b84ba82b642dc..d61cd63c35b1e 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::{physical_expr::down_cast_any_ref, PhysicalExpr}; - +use crate::PhysicalExpr; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; use datafusion_common::{internal_err, Result}; @@ -27,7 +26,7 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; // Like expression -#[derive(Debug, Hash)] +#[derive(Debug, Eq)] pub struct LikeExpr { negated: bool, case_insensitive: bool, @@ -35,6 +34,25 @@ pub struct LikeExpr { pattern: Arc, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for LikeExpr { + fn eq(&self, other: &Self) -> bool { + self.negated == other.negated + && self.case_insensitive == other.case_insensitive + && self.expr.eq(&other.expr) + && self.pattern.eq(&other.pattern) + } +} + +impl Hash for LikeExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + impl LikeExpr { pub fn new( negated: bool, @@ -127,25 +145,6 @@ impl PhysicalExpr for LikeExpr { Arc::clone(&children[1]), ))) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for LikeExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.negated == x.negated - && self.case_insensitive == x.case_insensitive - && self.expr.eq(&x.expr) - && self.pattern.eq(&x.pattern) - }) - .unwrap_or(false) - } } /// used for optimize Dictionary like diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 42c592958125f..c6465aaf10e79 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,10 +18,10 @@ //! Literal expressions for physical operations use std::any::Any; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use crate::physical_expr::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, @@ -92,11 +92,6 @@ impl PhysicalExpr for Literal { Ok(self) } - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } - fn get_properties(&self, _children: &[ExprProperties]) -> Result { Ok(ExprProperties { sort_properties: SortProperties::Singleton, @@ -104,19 +99,11 @@ impl PhysicalExpr for Literal { self.scalar.value().clone(), self.scalar.value().clone(), )?, + preserves_lex_ordering: true, }) } } -impl PartialEq for Literal { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) - } -} - /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { @@ -135,7 +122,7 @@ mod tests { #[test] fn literal_i32() -> Result<()> { - // create an arbitrary record bacth + // create an arbitrary record batch let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 177fd799ae792..f00b49f503141 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,11 +35,6 @@ mod unknown_column; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; -pub use crate::window::cume_dist::{cume_dist, CumeDist}; -pub use crate::window::lead_lag::{lag, lead, WindowShift}; -pub use crate::window::nth_value::NthValue; -pub use crate::window::ntile::Ntile; -pub use crate::window::rank::{dense_rank, percent_rank, rank, Rank, RankType}; pub use crate::PhysicalSortExpr; pub use binary::{binary, similar_to, BinaryExpr}; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index fd6584cccbe0a..44d4a66368618 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,10 +18,9 @@ //! Negation (-) expression use std::any::Any; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::{ @@ -38,12 +37,25 @@ use datafusion_expr::{ }; /// Negative expression -#[derive(Debug, Hash)] +#[derive(Debug, Eq)] pub struct NegativeExpr { /// Input expression arg: Arc, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for NegativeExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) + } +} + +impl Hash for NegativeExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + } +} + impl NegativeExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { @@ -100,11 +112,6 @@ impl PhysicalExpr for NegativeExpr { Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0])))) } - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } - /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. /// It replaces the upper and lower bounds after multiplying them with -1. /// Ex: `(a, b]` => `[-b, -a)` @@ -138,19 +145,11 @@ impl PhysicalExpr for NegativeExpr { Ok(ExprProperties { sort_properties: -children[0].sort_properties, range: children[0].range.clone().arithmetic_negate()?, + preserves_lex_ordering: false, }) } } -impl PartialEq for NegativeExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.arg.eq(&x.arg)) - .unwrap_or(false) - } -} - /// Creates a unary expression NEGATIVE /// /// # Errors @@ -224,9 +223,7 @@ mod tests { #[test] fn test_evaluate_bounds() -> Result<()> { - let negative_expr = NegativeExpr { - arg: Arc::new(Column::new("a", 0)), - }; + let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); let child_interval = Interval::make(Some(-2), Some(1))?; let negative_expr_interval = Interval::make(Some(-1), Some(2))?; assert_eq!( @@ -238,9 +235,7 @@ mod tests { #[test] fn test_propagate_constraints() -> Result<()> { - let negative_expr = NegativeExpr { - arg: Arc::new(Column::new("a", 0)), - }; + let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); let original_child_interval = Interval::make(Some(-2), Some(3))?; let negative_expr_interval = Interval::make(Some(0), Some(4))?; let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]); @@ -257,7 +252,7 @@ mod tests { #[test] fn test_negation_valid_types() -> Result<()> { let negatable_types = [ - DataType::Int8, + Int8, DataType::Timestamp(TimeUnit::Second, None), DataType::Interval(IntervalUnit::YearMonth), ]; diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 9148cb7c1c1de..c17b52f5cdfff 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,7 +18,7 @@ //! NoOp placeholder for physical operations use std::any::Any; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; use arrow::{ @@ -26,7 +26,6 @@ use arrow::{ record_batch::RecordBatch, }; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -78,18 +77,4 @@ impl PhysicalExpr for NoOp { ) -> Result> { Ok(self) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for NoOp { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) - } } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 7a0afaa1a637a..ee886e5a1562f 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -19,23 +19,36 @@ use std::any::Any; use std::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; /// Not expression -#[derive(Debug, Hash)] +#[derive(Debug, Eq)] pub struct NotExpr { /// Input expression arg: Arc, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for NotExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) + } +} + +impl Hash for NotExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + } +} + impl NotExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { @@ -97,19 +110,8 @@ impl PhysicalExpr for NotExpr { ) -> Result> { Ok(Arc::new(NotExpr::new(Arc::clone(&children[0])))) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for NotExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.arg.eq(&x.arg)) - .unwrap_or(false) + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + children[0].not() } } @@ -123,10 +125,11 @@ mod tests { use super::*; use crate::expressions::col; use arrow::{array::BooleanArray, datatypes::*}; + use std::sync::LazyLock; #[test] fn neg_op() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let schema = schema(); let expr = not(col("a", &schema)?)?; assert_eq!(expr.data_type(&schema)?, DataType::Boolean); @@ -135,8 +138,7 @@ mod tests { let input = BooleanArray::from(vec![Some(true), None, Some(false)]); let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; + let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?; let result = expr .evaluate(&batch)? @@ -148,4 +150,47 @@ mod tests { Ok(()) } + + #[test] + fn test_evaluate_bounds() -> Result<()> { + // Note that `None` for boolean intervals is converted to `Some(false)` + // / `Some(true)` by `Interval::make`, so it is not explicitly tested + // here + + // if the bounds are all booleans (false, true) so is the negation + assert_evaluate_bounds( + Interval::make(Some(false), Some(true))?, + Interval::make(Some(false), Some(true))?, + )?; + // (true, false) is not tested because it is not a valid interval (lower + // bound is greater than upper bound) + assert_evaluate_bounds( + Interval::make(Some(true), Some(true))?, + Interval::make(Some(false), Some(false))?, + )?; + assert_evaluate_bounds( + Interval::make(Some(false), Some(false))?, + Interval::make(Some(true), Some(true))?, + )?; + Ok(()) + } + + fn assert_evaluate_bounds( + interval: Interval, + expected_interval: Interval, + ) -> Result<()> { + let not_expr = not(col("a", &schema())?)?; + assert_eq!( + not_expr.evaluate_bounds(&[&interval]).unwrap(), + expected_interval + ); + Ok(()) + } + + fn schema() -> SchemaRef { + static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])) + }); + Arc::clone(&SCHEMA) + } } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 8490f78b1fe52..9723acbbfe302 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -17,10 +17,9 @@ use std::any::Any; use std::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::{cast_with_options, CastOptions}; @@ -32,7 +31,7 @@ use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; /// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast -#[derive(Debug, Hash)] +#[derive(Debug, Eq)] pub struct TryCastExpr { /// The expression to cast expr: Arc, @@ -40,6 +39,20 @@ pub struct TryCastExpr { cast_type: DataType, } +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 +impl PartialEq for TryCastExpr { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) && self.cast_type == other.cast_type + } +} + +impl Hash for TryCastExpr { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.cast_type.hash(state); + } +} + impl TryCastExpr { /// Create a new CastExpr pub fn new(expr: Arc, cast_type: DataType) -> Self { @@ -110,20 +123,6 @@ impl PhysicalExpr for TryCastExpr { self.cast_type.clone(), ))) } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for TryCastExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.expr.eq(&x.expr) && self.cast_type == x.cast_type) - .unwrap_or(false) - } } /// Return a PhysicalExpression representing `expr` casted to diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index cb7221e7fa151..a63caf7e13056 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -30,7 +30,7 @@ use arrow::{ use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; -#[derive(Debug, Hash, PartialEq, Eq, Clone)] +#[derive(Debug, Clone, Eq)] pub struct UnKnownColumn { name: String, } @@ -57,7 +57,7 @@ impl std::fmt::Display for UnKnownColumn { impl PhysicalExpr for UnKnownColumn { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -86,15 +86,16 @@ impl PhysicalExpr for UnKnownColumn { ) -> Result> { Ok(self) } +} - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); +impl Hash for UnKnownColumn { + fn hash(&self, state: &mut H) { + self.name.hash(state); } } -impl PartialEq for UnKnownColumn { - fn eq(&self, _other: &dyn Any) -> bool { +impl PartialEq for UnKnownColumn { + fn eq(&self, _other: &Self) -> bool { // UnknownColumn is not a valid expression, so it should not be equal to any other expression. // See https://github.com/apache/datafusion/pull/11536 false diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 3f861013b345d..d2b89b7e4b234 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -6,7 +6,7 @@ // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // -//http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an @@ -19,6 +19,7 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use super::utils::{ @@ -128,12 +129,11 @@ impl ExprIntervalGraph { /// Estimate size of bytes including `Self`. pub fn size(&self) -> usize { let node_memory_usage = self.graph.node_count() - * (std::mem::size_of::() - + std::mem::size_of::()); - let edge_memory_usage = self.graph.edge_count() - * (std::mem::size_of::() + std::mem::size_of::() * 2); + * (size_of::() + size_of::()); + let edge_memory_usage = + self.graph.edge_count() * (size_of::() + size_of::() * 2); - std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage + size_of_val(self) + node_memory_usage + edge_memory_usage } } diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index 49163fba60aff..e63fe43d17b49 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -6,7 +6,7 @@ // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // -//http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index b426a656fba9e..496db7b454df6 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -6,7 +6,7 @@ // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // -//http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 46185712413ef..4c55f4ddba93b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -27,7 +27,6 @@ pub mod binary_map { pub mod equivalence; pub mod expressions; pub mod intervals; -pub mod math_expressions; mod partitioning; mod physical_expr; pub mod planner; @@ -46,7 +45,9 @@ pub mod execution_props { pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; -pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; +pub use equivalence::{ + calculate_union, AcrossPartitions, ConstExpr, EquivalenceProperties, +}; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, @@ -55,8 +56,7 @@ pub use physical_expr::{ pub use datafusion_physical_expr_common::physical_expr::PhysicalExpr; pub use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, - PhysicalSortRequirement, + LexOrdering, LexRequirement, PhysicalSortExpr, PhysicalSortRequirement, }; pub use planner::{create_physical_expr, create_physical_exprs}; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs deleted file mode 100644 index 503565b1e2613..0000000000000 --- a/datafusion/physical-expr/src/math_expressions.rs +++ /dev/null @@ -1,126 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Math expressions - -use std::any::type_name; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::array::{BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow_array::Array; - -use datafusion_common::exec_err; -use datafusion_common::{DataFusionError, Result}; - -macro_rules! downcast_arg { - ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast {} from {} to {}", - $NAME, - $ARG.data_type(), - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - -/// Isnan SQL function -pub fn isnan(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { f64::is_nan } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { f32::is_nan } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function isnan"), - } -} - -#[cfg(test)] -mod tests { - - use datafusion_common::cast::as_boolean_array; - - use super::*; - - #[test] - fn test_isnan_f64() { - let args: Vec = vec![Arc::new(Float64Array::from(vec![ - 1.0, - f64::NAN, - 3.0, - -f64::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_isnan_f32() { - let args: Vec = vec![Arc::new(Float32Array::from(vec![ - 1.0, - f32::NAN, - 3.0, - f32::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 01f72a8efd9a5..eb7e1ea6282bb 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -97,7 +97,7 @@ use std::sync::Arc; /// # Additional Examples /// /// A simple `FileScanExec` might produce one output stream (partition) for each -/// file (note the actual DataFusion file scaners can read individual files in +/// file (note the actual DataFusion file scanners can read individual files in /// parallel, potentially producing multiple partitions per file) /// /// Plans such as `SortPreservingMerge` produce a single output stream @@ -121,8 +121,8 @@ pub enum Partitioning { UnknownPartitioning(usize), } -impl fmt::Display for Partitioning { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for Partitioning { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Partitioning::RoundRobinBatch(size) => write!(f, "RoundRobinBatch({size})"), Partitioning::Hash(phy_exprs, size) => { diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index d4c91e8ef3b11..5d8bd5f385f55 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,14 +17,11 @@ use std::sync::Arc; +use datafusion_common::HashMap; pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use itertools::izip; -pub use datafusion_physical_expr_common::physical_expr::down_cast_any_ref; - -/// Shared [`PhysicalExpr`]. -pub type PhysicalExprRef = Arc; - /// This function is similar to the `contains` method of `Vec`. It finds /// whether `expr` is among `physical_exprs`. pub fn physical_exprs_contains( @@ -50,51 +47,24 @@ pub fn physical_exprs_bag_equal( lhs: &[Arc], rhs: &[Arc], ) -> bool { - // TODO: Once we can use `HashMap`s with `Arc`, this - // function should use a `HashMap` to reduce computational complexity. - if lhs.len() == rhs.len() { - let mut rhs_vec = rhs.to_vec(); - for expr in lhs { - if let Some(idx) = rhs_vec.iter().position(|e| expr.eq(e)) { - rhs_vec.swap_remove(idx); - } else { - return false; - } - } - true - } else { - false + let mut multi_set_lhs: HashMap<_, usize> = HashMap::new(); + let mut multi_set_rhs: HashMap<_, usize> = HashMap::new(); + for expr in lhs { + *multi_set_lhs.entry(expr).or_insert(0) += 1; } -} - -/// This utility function removes duplicates from the given `exprs` vector. -/// Note that this function does not necessarily preserve its input ordering. -pub fn deduplicate_physical_exprs(exprs: &mut Vec>) { - // TODO: Once we can use `HashSet`s with `Arc`, this - // function should use a `HashSet` to reduce computational complexity. - // See issue: https://github.com/apache/datafusion/issues/8027 - let mut idx = 0; - while idx < exprs.len() { - let mut rest_idx = idx + 1; - while rest_idx < exprs.len() { - if exprs[idx].eq(&exprs[rest_idx]) { - exprs.swap_remove(rest_idx); - } else { - rest_idx += 1; - } - } - idx += 1; + for expr in rhs { + *multi_set_rhs.entry(expr).or_insert(0) += 1; } + multi_set_lhs == multi_set_rhs } #[cfg(test)] mod tests { - use std::sync::Arc; + use super::*; use crate::expressions::{Column, Literal}; use crate::physical_expr::{ - deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, - physical_exprs_equal, PhysicalExpr, + physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, }; use datafusion_common::ScalarValue; @@ -210,41 +180,4 @@ mod tests { assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice())); assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice())); } - - #[test] - fn test_deduplicate_physical_exprs() { - let lit_true = &(Arc::new(Literal::from(ScalarValue::Boolean(Some(true)))) - as Arc); - let lit_false = &(Arc::new(Literal::from(ScalarValue::Boolean(Some(false)))) - as Arc); - let lit4 = &(Arc::new(Literal::from(ScalarValue::Int32(Some(4)))) - as Arc); - let lit2 = &(Arc::new(Literal::from(ScalarValue::Int32(Some(2)))) - as Arc); - let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc); - let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc); - - // First vector in the tuple is arguments, second one is the expected value. - let test_cases = vec![ - // ---------- TEST CASE 1----------// - ( - vec![ - lit_true, lit_false, lit4, lit2, col_a_expr, col_a_expr, col_b_expr, - lit_true, lit2, - ], - vec![lit_true, lit_false, lit4, lit2, col_a_expr, col_b_expr], - ), - // ---------- TEST CASE 2----------// - ( - vec![lit_true, lit_true, lit_false, lit4], - vec![lit_true, lit4, lit_false], - ), - ]; - for (exprs, expected) in test_cases { - let mut exprs = exprs.into_iter().cloned().collect::>(); - let expected = expected.into_iter().cloned().collect::>(); - deduplicate_physical_exprs(&mut exprs); - assert!(physical_exprs_equal(&exprs, &expected)); - } - } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7e606fd5e8679..a1a0be5598a89 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -28,7 +28,7 @@ use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; +use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -200,8 +200,11 @@ pub fn create_physical_expr( escape_char, case_insensitive, }) => { - if escape_char.is_some() { - return exec_err!("LIKE does not support escape_char"); + // `\` is the implicit escape, see https://github.com/apache/datafusion/issues/13291 + if escape_char.unwrap_or('\\') != '\\' { + return exec_err!( + "LIKE does not support escape_char other than the backslash (\\)" + ); } let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; @@ -358,6 +361,9 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::Placeholder(Placeholder { id, .. }) => { + exec_err!("Placeholder '{id}' was not provided a value for execution.") + } other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 130c335d1c95e..37a8e0258103c 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -31,21 +31,22 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, DFSchema, Result}; +use arrow_array::Array; +use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; -use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF}; +use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF}; /// Physical expression of a scalar function +#[derive(Eq, PartialEq, Hash)] pub struct ScalarFunctionExpr { fun: Arc, name: String, @@ -133,22 +134,36 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let inputs = self + let args = self .args .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; + let input_empty = args.is_empty(); + let input_all_scalar = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + // evaluate the function - let output = match self.args.is_empty() { - true => self.fun.invoke_no_args(batch.num_rows()), - false => self.fun.invoke(&inputs), - }?; + let output = self.fun.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: batch.num_rows(), + return_type: &self.return_type, + })?; if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { - return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", - batch.num_rows(), array.len()); + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = + array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::from) + } else { + internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, batch.num_rows(), array.len()) + }; } } Ok(output) @@ -185,16 +200,9 @@ impl PhysicalExpr for ScalarFunctionExpr { self.fun.propagate_constraints(interval, children) } - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.name.hash(&mut s); - self.args.hash(&mut s); - self.return_type.hash(&mut s); - // Add `self.fun` when hash is available - } - fn get_properties(&self, children: &[ExprProperties]) -> Result { let sort_properties = self.fun.output_ordering(children)?; + let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?; let children_range = children .iter() .map(|props| &props.range) @@ -204,24 +212,11 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(ExprProperties { sort_properties, range, + preserves_lex_ordering, }) } } -impl PartialEq for ScalarFunctionExpr { - /// Comparing name, args and return_type - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && physical_exprs_equal(&self.args, &x.args) - && self.return_type == x.return_type - }) - .unwrap_or(false) - } -} - /// Create a physical expression for the UDF. pub fn create_physical_expr( fun: &ScalarUDF, @@ -238,7 +233,7 @@ pub fn create_physical_expr( // verify that input data types is consistent with function's `TypeSignature` data_types_with_scalar_udf(&input_expr_types, fun)?; - // Since we have arg_types, we dont need args and schema. + // Since we have arg_types, we don't need args and schema. let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 1a92b06654449..3ac604cbd3908 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -20,9 +20,9 @@ use crate::utils::split_disjunction; use crate::{split_conjunction, PhysicalExpr}; -use datafusion_common::{Column, ScalarValue}; +use datafusion_common::{Column, HashMap, ScalarValue}; use datafusion_expr::Operator; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fmt::{self, Display, Formatter}; use std::sync::Arc; @@ -124,7 +124,7 @@ impl LiteralGuarantee { // for an `AND` conjunction to be true, all terms individually must be true .fold(GuaranteeBuilder::new(), |builder, expr| { if let Some(cel) = ColOpLit::try_new(expr) { - return builder.aggregate_conjunct(cel); + builder.aggregate_conjunct(cel) } else if let Some(inlist) = expr .as_any() .downcast_ref::() @@ -412,7 +412,7 @@ impl<'a> ColOpLit<'a> { #[cfg(test)] mod test { - use std::sync::OnceLock; + use std::sync::LazyLock; use super::*; use crate::planner::logical2physical; @@ -808,7 +808,7 @@ mod test { vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], ); // b IN (1, 2, 3) OR b = 2 - // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to anylize this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. + // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to analyze this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. test_analyze( col("b") .in_list(vec![lit(1), lit(2), lit(3)], false) @@ -863,13 +863,12 @@ mod test { // Schema for testing fn schema() -> SchemaRef { - Arc::clone(SCHEMA.get_or_init(|| { + static SCHEMA: LazyLock = LazyLock::new(|| { Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), ])) - })) + }); + Arc::clone(&SCHEMA) } - - static SCHEMA: OnceLock = OnceLock::new(); } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7f..c06efd5540985 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -17,10 +17,8 @@ mod guarantee; pub use guarantee::{Guarantee, LiteralGuarantee}; -use hashbrown::HashSet; use std::borrow::Borrow; -use std::collections::HashMap; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; @@ -32,9 +30,10 @@ use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::Result; +use datafusion_common::{HashMap, HashSet, Result}; use datafusion_expr::Operator; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; @@ -86,6 +85,10 @@ pub fn map_columns_before_projection( parent_required: &[Arc], proj_exprs: &[(Arc, String)], ) -> Vec> { + if parent_required.is_empty() { + // No need to build mapping. + return vec![]; + } let column_mapping = proj_exprs .iter() .filter_map(|(expr, name)| { @@ -143,9 +146,7 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> - PhysicalExprDAEGBuilder<'a, T, F> -{ +impl) -> Result> PhysicalExprDAEGBuilder<'_, T, F> { // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. fn mutate( @@ -241,10 +242,7 @@ pub fn reassign_predicate_columns( } /// Merge left and right sort expressions, checking for duplicates. -pub fn merge_vectors( - left: &[PhysicalSortExpr], - right: &[PhysicalSortExpr], -) -> Vec { +pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering { left.iter() .cloned() .chain(right.iter().cloned()) @@ -311,7 +309,11 @@ pub(crate) mod tests { Ok(input[0].sort_properties) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + _number_rows: usize, + ) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { @@ -525,7 +527,7 @@ pub(crate) mod tests { ) .unwrap(); - assert_eq!(actual.as_ref(), expected.as_any()); + assert_eq!(actual.as_ref(), expected.as_ref()); } #[test] diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index d012fef93b675..0c56bdc929857 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -25,40 +25,40 @@ use arrow::array::Array; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, WindowFrame}; - use crate::aggregate::AggregateFunctionExpr; use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, }; -use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; +use crate::{reverse_order_bys, PhysicalExpr}; +use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; /// A window expr that takes the form of an aggregate function. /// /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct PlainAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, - order_by: Vec, + order_by: LexOrdering, window_frame: Arc, } impl PlainAggregateWindowExpr { /// Create a new aggregate window function expression pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: &LexOrdering, window_frame: Arc, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.to_vec(), + order_by: order_by.clone(), window_frame, } } @@ -124,8 +124,8 @@ impl WindowExpr for PlainAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by + fn order_by(&self) -> &LexOrdering { + self.order_by.as_ref() } fn get_window_frame(&self) -> &Arc { @@ -137,16 +137,16 @@ impl WindowExpr for PlainAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 2aeb053331027..bc7c716783bdc 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -16,21 +16,21 @@ // under the License. mod aggregate; -mod built_in; -mod built_in_window_function_expr; -pub(crate) mod cume_dist; -pub(crate) mod lead_lag; -pub(crate) mod nth_value; -pub(crate) mod ntile; -pub(crate) mod rank; mod sliding_aggregate; +mod standard; +mod standard_window_function_expr; mod window_expr; +#[deprecated(since = "44.0.0", note = "use StandardWindowExpr")] +pub type BuiltInWindowExpr = StandardWindowExpr; + +#[deprecated(since = "44.0.0", note = "use StandardWindowFunctionExpr")] +pub type BuiltInWindowFunctionExpr = dyn StandardWindowFunctionExpr; + pub use aggregate::PlainAggregateWindowExpr; -pub use built_in::BuiltInWindowExpr; -pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; pub use sliding_aggregate::SlidingAggregateWindowExpr; -pub use window_expr::NthValueKind; +pub use standard::StandardWindowExpr; +pub use standard_window_function_expr::StandardWindowFunctionExpr; pub use window_expr::PartitionBatches; pub use window_expr::PartitionKey; pub use window_expr::PartitionWindowAggStates; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs deleted file mode 100644 index d94983c5adf74..0000000000000 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ /dev/null @@ -1,415 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` -//! functions that can be evaluated at run time during query execution. - -use std::any::Any; -use std::cmp::Ordering; -use std::ops::Range; -use std::sync::Arc; - -use crate::window::window_expr::{NthValueKind, NthValueState}; -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; - -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::window_state::WindowAggState; -use datafusion_expr::PartitionEvaluator; - -/// nth_value expression -#[derive(Debug)] -pub struct NthValue { - name: String, - expr: Arc, - /// Output data type - data_type: DataType, - kind: NthValueKind, - ignore_nulls: bool, -} - -impl NthValue { - /// Create a new FIRST_VALUE window aggregate function - pub fn first( - name: impl Into, - expr: Arc, - data_type: DataType, - ignore_nulls: bool, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::First, - ignore_nulls, - } - } - - /// Create a new LAST_VALUE window aggregate function - pub fn last( - name: impl Into, - expr: Arc, - data_type: DataType, - ignore_nulls: bool, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Last, - ignore_nulls, - } - } - - /// Create a new NTH_VALUE window aggregate function - pub fn nth( - name: impl Into, - expr: Arc, - data_type: DataType, - n: i64, - ignore_nulls: bool, - ) -> Result { - Ok(Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Nth(n), - ignore_nulls, - }) - } - - /// Get the NTH_VALUE kind - pub fn get_kind(&self) -> NthValueKind { - self.kind - } -} - -impl BuiltInWindowFunctionExpr for NthValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - let state = NthValueState { - finalized_result: None, - kind: self.kind, - }; - Ok(Box::new(NthValueEvaluator { - state, - ignore_nulls: self.ignore_nulls, - })) - } - - fn reverse_expr(&self) -> Option> { - let reversed_kind = match self.kind { - NthValueKind::First => NthValueKind::Last, - NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), - }; - Some(Arc::new(Self { - name: self.name.clone(), - expr: Arc::clone(&self.expr), - data_type: self.data_type.clone(), - kind: reversed_kind, - ignore_nulls: self.ignore_nulls, - })) - } -} - -/// Value evaluator for nth_value functions -#[derive(Debug)] -pub(crate) struct NthValueEvaluator { - state: NthValueState, - ignore_nulls: bool, -} - -impl PartitionEvaluator for NthValueEvaluator { - /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), - /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we - /// can memoize the result. Once result is calculated, it will always stay - /// same. Hence, we do not need to keep past data as we process the entire - /// dataset. - fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { - let out = &state.out_col; - let size = out.len(); - let mut buffer_size = 1; - // Decide if we arrived at a final result yet: - let (is_prunable, is_reverse_direction) = match self.state.kind { - NthValueKind::First => { - let n_range = - state.window_frame_range.end - state.window_frame_range.start; - (n_range > 0 && size > 0, false) - } - NthValueKind::Last => (true, true), - NthValueKind::Nth(n) => { - let n_range = - state.window_frame_range.end - state.window_frame_range.start; - match n.cmp(&0) { - Ordering::Greater => { - (n_range >= (n as usize) && size > (n as usize), false) - } - Ordering::Less => { - let reverse_index = (-n) as usize; - buffer_size = reverse_index; - // Negative index represents reverse direction. - (n_range >= reverse_index, true) - } - Ordering::Equal => (true, false), - } - } - }; - // Do not memoize results when nulls are ignored. - if is_prunable && !self.ignore_nulls { - if self.state.finalized_result.is_none() && !is_reverse_direction { - let result = ScalarValue::try_from_array(out, size - 1)?; - self.state.finalized_result = Some(result); - } - state.window_frame_range.start = - state.window_frame_range.end.saturating_sub(buffer_size); - } - Ok(()) - } - - fn evaluate( - &mut self, - values: &[ArrayRef], - range: &Range, - ) -> Result { - if let Some(ref result) = self.state.finalized_result { - Ok(result.clone()) - } else { - // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1. - let arr = &values[0]; - let n_range = range.end - range.start; - if n_range == 0 { - // We produce None if the window is empty. - return ScalarValue::try_from(arr.data_type()); - } - - // Extract valid indices if ignoring nulls. - let valid_indices = if self.ignore_nulls { - // Calculate valid indices, inside the window frame boundaries - let slice = arr.slice(range.start, n_range); - let valid_indices = slice - .nulls() - .map(|nulls| { - nulls - .valid_indices() - // Add offset `range.start` to valid indices, to point correct index in the original arr. - .map(|idx| idx + range.start) - .collect::>() - }) - .unwrap_or_default(); - if valid_indices.is_empty() { - return ScalarValue::try_from(arr.data_type()); - } - Some(valid_indices) - } else { - None - }; - match self.state.kind { - NthValueKind::First => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array(arr, valid_indices[0]) - } else { - ScalarValue::try_from_array(arr, range.start) - } - } - NthValueKind::Last => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array( - arr, - valid_indices[valid_indices.len() - 1], - ) - } else { - ScalarValue::try_from_array(arr, range.end - 1) - } - } - NthValueKind::Nth(n) => { - match n.cmp(&0) { - Ordering::Greater => { - // SQL indices are not 0-based. - let index = (n as usize) - 1; - if index >= n_range { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if index >= valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - ScalarValue::try_from_array(&arr, valid_indices[index]) - } else { - ScalarValue::try_from_array(arr, range.start + index) - } - } - Ordering::Less => { - let reverse_index = (-n) as usize; - if n_range < reverse_index { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if reverse_index > valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - let new_index = - valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&arr, new_index) - } else { - ScalarValue::try_from_array( - arr, - range.start + n_range - reverse_index, - ) - } - } - Ordering::Equal => ScalarValue::try_from(arr.data_type()), - } - } - } - } - } - - fn supports_bounded_execution(&self) -> bool { - true - } - - fn uses_window_frame(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; - use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> { - let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); - let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let mut ranges: Vec> = vec![]; - for i in 0..8 { - ranges.push(Range { - start: 0, - end: i + 1, - }) - } - let mut evaluator = expr.create_evaluator()?; - let values = expr.evaluate_args(&batch)?; - let result = ranges - .iter() - .map(|range| evaluator.evaluate(&values, range)) - .collect::>>()?; - let result = ScalarValue::iter_to_array(result.into_iter())?; - let result = as_int32_array(&result)?; - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - fn first_value() -> Result<()> { - let first_value = NthValue::first( - "first_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - false, - ); - test_i32_result(first_value, Int32Array::from(vec![1; 8]))?; - Ok(()) - } - - #[test] - fn last_value() -> Result<()> { - let last_value = NthValue::last( - "last_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - false, - ); - test_i32_result( - last_value, - Int32Array::from(vec![ - Some(1), - Some(-2), - Some(3), - Some(-4), - Some(5), - Some(-6), - Some(7), - Some(8), - ]), - )?; - Ok(()) - } - - #[test] - fn nth_value_1() -> Result<()> { - let nth_value = NthValue::nth( - "nth_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - 1, - false, - )?; - test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?; - Ok(()) - } - - #[test] - fn nth_value_2() -> Result<()> { - let nth_value = NthValue::nth( - "nth_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - 2, - false, - )?; - test_i32_result( - nth_value, - Int32Array::from(vec![ - None, - Some(-2), - Some(-2), - Some(-2), - Some(-2), - Some(-2), - Some(-2), - Some(-2), - ]), - )?; - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs deleted file mode 100644 index fb7a7ad84fb70..0000000000000 --- a/datafusion/physical-expr/src/window/ntile.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `ntile` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::{ArrayRef, UInt64Array}; -use arrow::datatypes::Field; -use arrow_schema::{DataType, SchemaRef, SortOptions}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; - -use std::any::Any; -use std::sync::Arc; - -#[derive(Debug)] -pub struct Ntile { - name: String, - n: u64, - /// Output data type - data_type: DataType, -} - -impl Ntile { - pub fn new(name: String, n: u64, data_type: &DataType) -> Self { - Self { - name, - n, - data_type: data_type.clone(), - } - } - - pub fn get_n(&self) -> u64 { - self.n - } -} - -impl BuiltInWindowFunctionExpr for Ntile { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(NtileEvaluator { n: self.n })) - } - - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in NTILE window function introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } - }) - } -} - -#[derive(Debug)] -pub(crate) struct NtileEvaluator { - n: u64, -} - -impl PartitionEvaluator for NtileEvaluator { - fn evaluate_all( - &mut self, - _values: &[ArrayRef], - num_rows: usize, - ) -> Result { - let num_rows = num_rows as u64; - let mut vec: Vec = Vec::new(); - let n = u64::min(self.n, num_rows); - for i in 0..num_rows { - let res = i * n / num_rows; - vec.push(res + 1) - } - Ok(Arc::new(UInt64Array::from(vec))) - } -} diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 143d59eb44953..572eb8866a44e 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -25,15 +25,15 @@ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame}; - use crate::aggregate::AggregateFunctionExpr; use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, }; use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; /// A window expr that takes the form of an aggregate function that /// can be incrementally computed over sliding windows. @@ -41,24 +41,24 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct SlidingAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, - order_by: Vec, + order_by: LexOrdering, window_frame: Arc, } impl SlidingAggregateWindowExpr { /// Create a new (sliding) aggregate window function expression. pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: &LexOrdering, window_frame: Arc, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.to_vec(), + order_by: order_by.clone(), window_frame, } } @@ -108,8 +108,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by + fn order_by(&self) -> &LexOrdering { + self.order_by.as_ref() } fn get_window_frame(&self) -> &Arc { @@ -121,16 +121,16 @@ impl WindowExpr for SlidingAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } @@ -157,9 +157,12 @@ impl WindowExpr for SlidingAggregateWindowExpr { expr: new_expr, options: req.options, }) - .collect::>(); + .collect::(); Some(Arc::new(SlidingAggregateWindowExpr { - aggregate: self.aggregate.with_new_expressions(args, vec![])?, + aggregate: self + .aggregate + .with_new_expressions(args, vec![]) + .map(Arc::new)?, partition_by: partition_bys, order_by: new_order_by, window_frame: Arc::clone(&self.window_frame), diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/standard.rs similarity index 90% rename from datafusion/physical-expr/src/window/built_in.rs rename to datafusion/physical-expr/src/window/standard.rs index 8ff277db37dfd..82e48a5f93382 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! Physical exec for built-in window function expressions. +//! Physical exec for standard window function expressions. use std::any::Any; use std::ops::Range; use std::sync::Arc; -use super::{BuiltInWindowFunctionExpr, WindowExpr}; -use crate::expressions::PhysicalSortExpr; +use super::{StandardWindowFunctionExpr, WindowExpr}; use crate::window::window_expr::{get_orderby_values, WindowFn}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; @@ -34,34 +33,35 @@ use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; +use datafusion_physical_expr_common::sort_expr::LexOrdering; -/// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`]. +/// A window expr that takes the form of a [`StandardWindowFunctionExpr`]. #[derive(Debug)] -pub struct BuiltInWindowExpr { - expr: Arc, +pub struct StandardWindowExpr { + expr: Arc, partition_by: Vec>, - order_by: Vec, + order_by: LexOrdering, window_frame: Arc, } -impl BuiltInWindowExpr { - /// create a new built-in window function expression +impl StandardWindowExpr { + /// create a new standard window function expression pub fn new( - expr: Arc, + expr: Arc, partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: &LexOrdering, window_frame: Arc, ) -> Self { Self { expr, partition_by: partition_by.to_vec(), - order_by: order_by.to_vec(), + order_by: order_by.clone(), window_frame, } } - /// Get BuiltInWindowFunction expr of BuiltInWindowExpr - pub fn get_built_in_func_expr(&self) -> &Arc { + /// Get StandardWindowFunction expr of StandardWindowExpr + pub fn get_standard_func_expr(&self) -> &Arc { &self.expr } @@ -76,9 +76,10 @@ impl BuiltInWindowExpr { if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { if self.partition_by.is_empty() { // In the absence of a PARTITION BY, ordering of `self.expr` is global: - eq_properties.add_new_orderings([vec![fn_res_ordering]]); + eq_properties + .add_new_orderings([LexOrdering::new(vec![fn_res_ordering])]); } else { - // If we have a PARTITION BY, built-in functions can not introduce + // If we have a PARTITION BY, standard functions can not introduce // a global ordering unless the existing ordering is compatible // with PARTITION BY expressions. To elaborate, when PARTITION BY // expressions and existing ordering expressions are equal (w.r.t. @@ -95,7 +96,7 @@ impl BuiltInWindowExpr { } } -impl WindowExpr for BuiltInWindowExpr { +impl WindowExpr for StandardWindowExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -117,8 +118,8 @@ impl WindowExpr for BuiltInWindowExpr { &self.partition_by } - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by + fn order_by(&self) -> &LexOrdering { + self.order_by.as_ref() } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -263,10 +264,10 @@ impl WindowExpr for BuiltInWindowExpr { fn get_reverse_expr(&self) -> Option> { self.expr.reverse_expr().map(|reverse_expr| { - Arc::new(BuiltInWindowExpr::new( + Arc::new(StandardWindowExpr::new( reverse_expr, &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ }) diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs similarity index 92% rename from datafusion/physical-expr/src/window/built_in_window_function_expr.rs rename to datafusion/physical-expr/src/window/standard_window_function_expr.rs index 7aa4f6536a6e4..d308812a0e351 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -36,7 +36,7 @@ use std::sync::Arc; /// but others such as `first_value`, `last_value`, and /// `nth_value` need the value. #[allow(rustdoc::private_intra_doc_links)] -pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { +pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Returns the aggregate expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -50,7 +50,7 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. fn name(&self) -> &str { - "BuiltInWindowFunctionExpr: default name" + "StandardWindowFunctionExpr: default name" } /// Evaluate window function's arguments against the input window @@ -71,7 +71,7 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// a particular partition. fn create_evaluator(&self) -> Result>; - /// Construct a new [`BuiltInWindowFunctionExpr`] that produces + /// Construct a new [`StandardWindowFunctionExpr`] that produces /// the same result as this function on a window with reverse /// order. The return value of this function is used by the /// DataFusion optimizer to avoid re-sorting the data when @@ -80,7 +80,7 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Returns `None` (the default) if no reverse is known (or possible). /// /// For example, the reverse of `lead(10)` is `lag(10)`. - fn reverse_expr(&self) -> Option> { + fn reverse_expr(&self) -> Option> { None } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 8f6f78df8cb85..8b130506cdea7 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::{LexOrderingRef, PhysicalExpr, PhysicalSortExpr}; +use crate::{LexOrdering, PhysicalExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; @@ -109,7 +109,7 @@ pub trait WindowExpr: Send + Sync + Debug { fn partition_by(&self) -> &[Arc]; /// Expressions that's from the window function's order by clause, empty if absent - fn order_by(&self) -> &[PhysicalSortExpr]; + fn order_by(&self) -> &LexOrdering; /// Get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { @@ -344,7 +344,7 @@ pub(crate) fn is_end_bound_safe( window_frame_ctx: &WindowFrameContext, order_bys: &[ArrayRef], most_recent_order_bys: Option<&[ArrayRef]>, - sort_exprs: LexOrderingRef, + sort_exprs: &LexOrdering, idx: usize, ) -> Result { if sort_exprs.is_empty() { @@ -530,41 +530,6 @@ pub enum WindowFn { Aggregate(Box), } -/// State for the RANK(percent_rank, rank, dense_rank) built-in window function. -#[derive(Debug, Clone, Default)] -pub struct RankState { - /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Option>, - /// The index where last_rank_boundary is started - pub last_rank_boundary: usize, - /// Keep the number of entries in current rank - pub current_group_count: usize, - /// Rank number kept from the start - pub n_rank: usize, -} - -/// Tag to differentiate special use cases of the NTH_VALUE built-in window function. -#[derive(Debug, Copy, Clone)] -pub enum NthValueKind { - First, - Last, - Nth(i64), -} - -#[derive(Debug, Clone)] -pub struct NthValueState { - // In certain cases, we can finalize the result early. Consider this usage: - // ``` - // FIRST_VALUE(increasing_col) OVER window AS my_first_value - // WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window - // ``` - // The result will always be the first entry in the table. We can store such - // early-finalizing results and then just reuse them as necessary. This opens - // opportunities to prune our datasets. - pub finalized_result: Option, - pub kind: NthValueKind, -} - /// Key for IndexMap for each unique partition /// /// For instance, if window frame is `OVER(PARTITION BY a,b)`, diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index acf3eee105d4d..40074e8eecd88 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -31,10 +31,27 @@ rust-version = { workspace = true } [lints] workspace = true +[features] +recursive_protection = ["dep:recursive"] + [dependencies] +arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } +futures = { workspace = true } itertools = { workspace = true } +log = { workspace = true } +recursive = { workspace = true, optional = true } + +[dev-dependencies] +datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-nested = { workspace = true } +rstest = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/LICENSE.txt b/datafusion/physical-optimizer/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/physical-optimizer/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/physical-optimizer/NOTICE.txt b/datafusion/physical-optimizer/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/physical-optimizer/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index a11b498b955ca..a00bc4b1d5714 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -16,19 +16,18 @@ // under the License. //! Utilizing exact statistics from sources to avoid scanning data -use std::sync::Arc; - use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; use datafusion_physical_plan::{expressions, ExecutionPlan}; +use std::sync::Arc; use crate::PhysicalOptimizerRule; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; /// Optimizer that uses available statistics for aggregate functions #[derive(Default, Debug)] @@ -42,6 +41,7 @@ impl AggregateStatistics { } impl PhysicalOptimizerRule for AggregateStatistics { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn optimize( &self, plan: Arc, @@ -146,3 +146,373 @@ fn take_optimizable_value_from_statistics( let value = agg_expr.fun().value_from_stats(statistics_args); value.map(|val| (val, agg_expr.name().to_string())) } + +#[cfg(test)] +mod tests { + use crate::aggregate_statistics::AggregateStatistics; + use crate::PhysicalOptimizerRule; + use datafusion_common::config::ConfigOptions; + use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; + use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::aggregates::AggregateExec; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::udaf::AggregateFunctionExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::sync::Arc; + + use datafusion_common::Result; + use datafusion_expr_common::operator::Operator; + + use datafusion_physical_plan::aggregates::PhysicalGroupBy; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::common; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::memory::MemoryExec; + + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_int64_array; + use datafusion_physical_expr::expressions::{self, cast}; + use datafusion_physical_plan::aggregates::AggregateMode; + + /// Describe the type of aggregate being tested + pub enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), + } + + impl TestAggregate { + /// Create a new COUNT(*) aggregate + pub fn new_count_star() -> Self { + Self::CountStar + } + + /// Create a new COUNT(column) aggregate + pub fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(Arc::clone(schema)) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr { + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .alias(self.column_name()) + .build() + .unwrap() + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), + } + } + + /// What name would this aggregate produce in a plan? + pub fn column_name(&self) -> &'static str { + match self { + Self::CountStar => "COUNT(*)", + Self::ColumnA(_) => "COUNT(a)", + } + } + + /// What is the expected count? + pub fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, + } + } + } + + /// Mock data using a MemoryExec which has an exact count statistic + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + /// Checks that the count optimization was applied and we still get the right result + async fn assert_count_optim_success( + plan: AggregateExec, + agg: TestAggregate, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let plan: Arc = Arc::new(plan); + + let config = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?; + + // A ProjectionExec is a sign that the count optimization was applied + assert!(optimized.as_any().is::()); + + // run both the optimized and nonoptimized plan + let optimized_result = + common::collect(optimized.execute(0, Arc::clone(&task_ctx))?).await?; + let nonoptimized_result = common::collect(plan.execute(0, task_ctx)?).await?; + assert_eq!(optimized_result.len(), nonoptimized_result.len()); + + // and validate the results are the same and expected + assert_eq!(optimized_result.len(), 1); + check_batch(optimized_result.into_iter().next().unwrap(), &agg); + // check the non optimized one too to ensure types and names remain the same + assert_eq!(nonoptimized_result.len(), 1); + check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); + + Ok(()) + } + + fn check_batch(batch: RecordBatch, agg: &TestAggregate) { + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 1); + + let field = &fields[0]; + assert_eq!(field.name(), agg.column_name()); + assert_eq!(field.data_type(), &DataType::Int64); + // note that nullability differs + + assert_eq!( + as_int64_array(batch.column(0)).unwrap().values(), + &[agg.expected_count()] + ); + } + + #[tokio::test] + async fn test_count_partial_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } + + #[tokio::test] + async fn test_count_with_nulls_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs similarity index 88% rename from datafusion/core/src/physical_optimizer/coalesce_batches.rs rename to datafusion/physical-optimizer/src/coalesce_batches.rs index 2f834813ede91..5cf2c877c61a4 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -18,19 +18,19 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches +use crate::PhysicalOptimizerRule; + use std::sync::Arc; -use crate::{ - config::ConfigOptions, - error::Result, - physical_plan::{ - coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, Partitioning, - }, +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_plan::{ + coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, + repartition::RepartitionExec, ExecutionPlan, }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_optimizer::PhysicalOptimizerRule; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters @@ -46,9 +46,9 @@ impl CoalesceBatches { impl PhysicalOptimizerRule for CoalesceBatches { fn optimize( &self, - plan: Arc, + plan: Arc, config: &ConfigOptions, - ) -> Result> { + ) -> Result> { if !config.execution.coalesce_batches { return Ok(plan); } diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 4e352e25b52c9..86f7e73e9e359 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -125,7 +125,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { type GroupExprsRef<'a> = ( &'a PhysicalGroupBy, - &'a [AggregateFunctionExpr], + &'a [Arc], &'a [Option>], ); diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs similarity index 79% rename from datafusion/core/src/physical_optimizer/join_selection.rs rename to datafusion/physical-optimizer/src/join_selection.rs index 499fb9cbbcf03..5f7f1f396a168 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -25,23 +25,22 @@ use std::sync::Arc; -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use crate::physical_plan::joins::{ - CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, - StreamJoinPartitionMode, SymmetricHashJoinExec, -}; -use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use crate::PhysicalOptimizerRule; -use arrow_schema::Schema; +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; -use datafusion_expr::sort_properties::SortProperties; +use datafusion_expr_common::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_expr::LexOrdering; +use datafusion_physical_plan::execution_plan::EmissionType; +use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use datafusion_physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + StreamJoinPartitionMode, SymmetricHashJoinExec, +}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; /// The [`JoinSelection`] rule tries to modify a given plan so that it can /// accommodate infinite sources and optimize joins in the plan according to @@ -59,7 +58,7 @@ impl JoinSelection { // TODO: We need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. // TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is 8 times. /// Checks statistics for join swap. -fn should_swap_join_order( +pub(crate) fn should_swap_join_order( left: &dyn ExecutionPlan, right: &dyn ExecutionPlan, ) -> Result { @@ -106,179 +105,49 @@ fn supports_collect_by_thresholds( } /// Predicate that checks whether the given join type supports input swapping. -fn supports_swap(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) +#[deprecated(since = "45.0.0", note = "use JoinType::supports_swap instead")] +#[allow(dead_code)] +pub(crate) fn supports_swap(join_type: JoinType) -> bool { + join_type.supports_swap() } /// This function returns the new join type we get after swapping the given /// join's inputs. -fn swap_join_type(join_type: JoinType) -> JoinType { - match join_type { - JoinType::Inner => JoinType::Inner, - JoinType::Full => JoinType::Full, - JoinType::Left => JoinType::Right, - JoinType::Right => JoinType::Left, - JoinType::LeftSemi => JoinType::RightSemi, - JoinType::RightSemi => JoinType::LeftSemi, - JoinType::LeftAnti => JoinType::RightAnti, - JoinType::RightAnti => JoinType::LeftAnti, - } -} - -/// This function swaps the given join's projection. -fn swap_join_projection( - left_schema_len: usize, - right_schema_len: usize, - projection: Option<&Vec>, -) -> Option> { - projection.map(|p| { - p.iter() - .map(|i| { - // If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it. - // Otherwise, it is from the right schema, so we subtract the left schema length from it. - if *i < left_schema_len { - *i + right_schema_len - } else { - *i - left_schema_len - } - }) - .collect() - }) +#[deprecated(since = "45.0.0", note = "use datafusion-functions-nested instead")] +#[allow(dead_code)] +pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType { + join_type.swap() } /// This function swaps the inputs of the given join operator. /// This function is public so other downstream projects can use it /// to construct `HashJoinExec` with right side as the build side. +#[deprecated(since = "45.0.0", note = "use HashJoinExec::swap_inputs instead")] pub fn swap_hash_join( hash_join: &HashJoinExec, partition_mode: PartitionMode, ) -> Result> { - let left = hash_join.left(); - let right = hash_join.right(); - let new_join = HashJoinExec::try_new( - Arc::clone(right), - Arc::clone(left), - hash_join - .on() - .iter() - .map(|(l, r)| (r.clone(), l.clone())) - .collect(), - swap_join_filter(hash_join.filter()), - &swap_join_type(*hash_join.join_type()), - swap_join_projection( - left.schema().fields().len(), - right.schema().fields().len(), - hash_join.projection.as_ref(), - ), - partition_mode, - hash_join.null_equals_null(), - )?; - if matches!( - hash_join.join_type(), - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) { - Ok(Arc::new(new_join)) - } else { - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj = ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?; - Ok(Arc::new(proj)) - } + hash_join.swap_inputs(partition_mode) } /// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required -fn swap_nl_join(join: &NestedLoopJoinExec) -> Result> { - let new_filter = swap_join_filter(join.filter()); - let new_join_type = &swap_join_type(*join.join_type()); - - let new_join = NestedLoopJoinExec::try_new( - Arc::clone(join.right()), - Arc::clone(join.left()), - new_filter, - new_join_type, - )?; - - // For Semi/Anti joins, swap result will produce same output schema, - // no need to wrap them into additional projection - let plan: Arc = if matches!( - join.join_type(), - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) { - Arc::new(new_join) - } else { - let projection = - swap_reverting_projection(&join.left().schema(), &join.right().schema()); - - Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?) - }; - - Ok(plan) -} - -/// When the order of the join is changed by the optimizer, the columns in -/// the output should not be impacted. This function creates the expressions -/// that will allow to swap back the values from the original left as the first -/// columns and those on the right next. -fn swap_reverting_projection( - left_schema: &Schema, - right_schema: &Schema, -) -> Vec<(Arc, String)> { - let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| { - ( - Arc::new(Column::new(f.name(), i)) as Arc, - f.name().to_owned(), - ) - }); - let right_len = right_cols.len(); - let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { - ( - Arc::new(Column::new(f.name(), right_len + i)) as Arc, - f.name().to_owned(), - ) - }); - - left_cols.chain(right_cols).collect() -} - -/// Swaps join sides for filter column indices and produces new JoinFilter -fn swap_filter(filter: &JoinFilter) -> JoinFilter { - let column_indices = filter - .column_indices() - .iter() - .map(|idx| ColumnIndex { - index: idx.index, - side: idx.side.negate(), - }) - .collect(); - - JoinFilter::new( - filter.expression().clone(), - column_indices, - filter.schema().clone(), - ) +#[deprecated(since = "45.0.0", note = "use NestedLoopJoinExec::swap_inputs")] +#[allow(dead_code)] +pub(crate) fn swap_nl_join(join: &NestedLoopJoinExec) -> Result> { + join.swap_inputs() } /// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). +#[deprecated(since = "45.0.0", note = "use filter.map(JoinFilter::swap) instead")] +#[allow(dead_code)] fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { - filter.map(swap_filter) + filter.map(JoinFilter::swap) +} + +#[deprecated(since = "45.0.0", note = "use JoinFilter::swap instead")] +#[allow(dead_code)] +pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter { + filter.swap() } impl PhysicalOptimizerRule for JoinSelection { @@ -339,7 +208,7 @@ impl PhysicalOptimizerRule for JoinSelection { /// `CollectLeft` mode is applicable. Otherwise, it will try to swap the join sides. /// When the `ignore_threshold` is false, this function will also check left /// and right sizes in bytes or rows. -fn try_collect_left( +pub(crate) fn try_collect_left( hash_join: &HashJoinExec, ignore_threshold: bool, threshold_byte_size: usize, @@ -363,10 +232,10 @@ fn try_collect_left( match (left_can_collect, right_can_collect) { (true, true) => { - if should_swap_join_order(&**left, &**right)? - && supports_swap(*hash_join.join_type()) + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? { - Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?)) + Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) } else { Ok(Some(Arc::new(HashJoinExec::try_new( Arc::clone(left), @@ -391,8 +260,8 @@ fn try_collect_left( hash_join.null_equals_null(), )?))), (false, true) => { - if supports_swap(*hash_join.join_type()) { - swap_hash_join(hash_join, PartitionMode::CollectLeft).map(Some) + if hash_join.join_type().supports_swap() { + hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) } else { Ok(None) } @@ -401,12 +270,19 @@ fn try_collect_left( } } -fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result> { +/// Creates a partitioned hash join execution plan, swapping inputs if beneficial. +/// +/// Checks if the join order should be swapped based on the join type and input statistics. +/// If swapping is optimal and supported, creates a swapped partitioned hash join; otherwise, +/// creates a standard partitioned hash join. +pub(crate) fn partitioned_hash_join( + hash_join: &HashJoinExec, +) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) + if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - swap_hash_join(hash_join, PartitionMode::Partitioned) + hash_join.swap_inputs(PartitionMode::Partitioned) } else { Ok(Arc::new(HashJoinExec::try_new( Arc::clone(left), @@ -449,10 +325,12 @@ fn statistical_join_selection_subrule( PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right)? - && supports_swap(*hash_join.join_type()) + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? { - swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? + hash_join + .swap_inputs(PartitionMode::Partitioned) + .map(Some)? } else { None } @@ -462,21 +340,17 @@ fn statistical_join_selection_subrule( let left = cross_join.left(); let right = cross_join.right(); if should_swap_join_order(&**left, &**right)? { - let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj: Arc = Arc::new(ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?); - Some(proj) + cross_join.swap_inputs().map(Some)? } else { None } } else if let Some(nl_join) = plan.as_any().downcast_ref::() { let left = nl_join.left(); let right = nl_join.right(); - if should_swap_join_order(&**left, &**right)? { - swap_nl_join(nl_join).map(Some)? + if nl_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + nl_join.swap_inputs().map(Some)? } else { None } @@ -495,7 +369,8 @@ fn statistical_join_selection_subrule( pub type PipelineFixerSubrule = dyn Fn(Arc, &ConfigOptions) -> Result>; -/// Converts a hash join to a symmetric hash join in the case of infinite inputs on both sides. +/// Converts a hash join to a symmetric hash join if both its inputs are +/// unbounded and incremental. /// /// This subrule checks if a hash join can be replaced with a symmetric hash join when dealing /// with unbounded (infinite) inputs on both sides. This replacement avoids pipeline breaking and @@ -516,10 +391,18 @@ fn hash_join_convert_symmetric_subrule( ) -> Result> { // Check if the current plan node is a HashJoinExec. if let Some(hash_join) = input.as_any().downcast_ref::() { - let left_unbounded = hash_join.left.execution_mode().is_unbounded(); - let right_unbounded = hash_join.right.execution_mode().is_unbounded(); - // Process only if both left and right sides are unbounded. - if left_unbounded && right_unbounded { + let left_unbounded = hash_join.left.boundedness().is_unbounded(); + let left_incremental = matches!( + hash_join.left.pipeline_behavior(), + EmissionType::Incremental | EmissionType::Both + ); + let right_unbounded = hash_join.right.boundedness().is_unbounded(); + let right_incremental = matches!( + hash_join.right.pipeline_behavior(), + EmissionType::Incremental | EmissionType::Both + ); + // Process only if both left and right sides are unbounded and incrementally emit. + if left_unbounded && right_unbounded & left_incremental & right_incremental { // Determine the partition mode based on configuration. let mode = if config_options.optimizer.repartition_joins { StreamJoinPartitionMode::Partitioned @@ -535,7 +418,7 @@ fn hash_join_convert_symmetric_subrule( // the function concludes that no specific order is required for the SymmetricHashJoinExec. This approach // ensures that the symmetric hash join operation only imposes ordering constraints when necessary, // based on the properties of the child nodes and the filter condition. - let determine_order = |side: JoinSide| -> Option> { + let determine_order = |side: JoinSide| -> Option { hash_join .filter() .map(|filter| { @@ -558,6 +441,7 @@ fn hash_join_convert_symmetric_subrule( hash_join.right().equivalence_properties(), hash_join.right().schema(), ), + JoinSide::None => return false, }; let name = schema.field(*index).name(); @@ -573,8 +457,9 @@ fn hash_join_convert_symmetric_subrule( match side { JoinSide::Left => hash_join.left().output_ordering(), JoinSide::Right => hash_join.right().output_ordering(), + JoinSide::None => unreachable!(), } - .map(|p| p.to_vec()) + .map(|p| LexOrdering::new(p.to_vec())) }) .flatten() }; @@ -584,8 +469,8 @@ fn hash_join_convert_symmetric_subrule( let right_order = determine_order(JoinSide::Right); return SymmetricHashJoinExec::try_new( - hash_join.left().clone(), - hash_join.right().clone(), + Arc::clone(hash_join.left()), + Arc::clone(hash_join.right()), hash_join.on().to_vec(), hash_join.filter().cloned(), hash_join.join_type(), @@ -646,8 +531,8 @@ fn hash_join_swap_subrule( _config_options: &ConfigOptions, ) -> Result> { if let Some(hash_join) = input.as_any().downcast_ref::() { - if hash_join.left.execution_mode().is_unbounded() - && !hash_join.right.execution_mode().is_unbounded() + if hash_join.left.boundedness().is_unbounded() + && !hash_join.right.boundedness().is_unbounded() && matches!( *hash_join.join_type(), JoinType::Inner @@ -667,7 +552,9 @@ fn hash_join_swap_subrule( /// [`JoinType::Full`], [`JoinType::Right`], [`JoinType::RightAnti`] and /// [`JoinType::RightSemi`] can not run with an unbounded left side, even if /// we swap join sides. Therefore, we do not consider them here. -fn swap_join_according_to_unboundedness( +/// This function is crate public as it is useful for downstream projects +/// to implement, or experiment with, their own join selection rules. +pub(crate) fn swap_join_according_to_unboundedness( hash_join: &HashJoinExec, ) -> Result> { let partition_mode = hash_join.partition_mode(); @@ -678,10 +565,10 @@ fn swap_join_according_to_unboundedness( JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, ) => internal_err!("{join_type} join cannot be swapped for unbounded input."), (PartitionMode::Partitioned, _) => { - swap_hash_join(hash_join, PartitionMode::Partitioned) + hash_join.swap_inputs(PartitionMode::Partitioned) } (PartitionMode::CollectLeft, _) => { - swap_hash_join(hash_join, PartitionMode::CollectLeft) + hash_join.swap_inputs(PartitionMode::CollectLeft) } (PartitionMode::Auto, _) => { internal_err!("Auto is not acceptable for unbounded input here.") @@ -704,19 +591,19 @@ fn apply_subrules( #[cfg(test)] mod tests_statistical { - use super::*; - use crate::{ - physical_plan::{displayable, ColumnStatistics, Statistics}, - test::StatisticsExec, - }; + use util_tests::StatisticsExec; - use arrow::datatypes::{DataType, Field}; - use datafusion_common::{stats::Precision, JoinType, ScalarValue}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{ + stats::Precision, ColumnStatistics, JoinType, ScalarValue, Statistics, + }; use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::expressions::BinaryExpr; - use datafusion_physical_expr::PhysicalExprRef; - + use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::projection::ProjectionExec; use rstest::rstest; /// Return statistics for empty table @@ -988,32 +875,27 @@ mod tests_statistical { for join_type in join_types { let (big, small) = create_big_and_small(); - let join = Arc::new( - HashJoinExec::try_new( - Arc::clone(&big), - Arc::clone(&small), - vec![( - Arc::new( - Column::new_with_schema("big_col", &big.schema()).unwrap(), - ), - Arc::new( - Column::new_with_schema("small_col", &small.schema()) - .unwrap(), - ), - )], - None, - &join_type, - None, - PartitionMode::Partitioned, - false, - ) - .unwrap(), - ); + let join = HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + Arc::new( + Column::new_with_schema("small_col", &small.schema()).unwrap(), + ), + )], + None, + &join_type, + None, + PartitionMode::Partitioned, + false, + ) + .unwrap(); let original_schema = join.schema(); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize(Arc::new(join), &ConfigOptions::new()) .unwrap(); let swapped_join = optimized_join @@ -1067,8 +949,8 @@ mod tests_statistical { Arc::clone(&big), Arc::clone(&small), vec![( - Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), - Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + col("big_col", &big.schema()).unwrap(), + col("small_col", &small.schema()).unwrap(), )], None, &JoinType::Inner, @@ -1084,10 +966,8 @@ mod tests_statistical { Arc::clone(&medium), Arc::new(child_join), vec![( - Arc::new( - Column::new_with_schema("medium_col", &medium.schema()).unwrap(), - ), - Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()), + col("medium_col", &medium.schema()).unwrap(), + col("small_col", &child_schema).unwrap(), )], None, &JoinType::Left, @@ -1175,6 +1055,7 @@ mod tests_statistical { Arc::clone(&small), nl_join_filter(), &join_type, + None, ) .unwrap(), ); @@ -1243,12 +1124,16 @@ mod tests_statistical { Arc::clone(&small), nl_join_filter(), &join_type, + None, ) .unwrap(), ); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize( + Arc::::clone(&join), + &ConfigOptions::new(), + ) .unwrap(); let swapped_join = optimized_join @@ -1287,30 +1172,64 @@ mod tests_statistical { ); } + #[rstest( + join_type, projection, small_on_right, + case::inner(JoinType::Inner, vec![1], true), + case::left(JoinType::Left, vec![1], true), + case::right(JoinType::Right, vec![1], true), + case::full(JoinType::Full, vec![1], true), + case::left_anti(JoinType::LeftAnti, vec![0], false), + case::left_semi(JoinType::LeftSemi, vec![0], false), + case::right_anti(JoinType::RightAnti, vec![0], true), + case::right_semi(JoinType::RightSemi, vec![0], true), + )] #[tokio::test] - async fn test_swap_reverting_projection() { - let left_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - - let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]); + async fn test_hash_join_swap_on_joins_with_projections( + join_type: JoinType, + projection: Vec, + small_on_right: bool, + ) -> Result<()> { + let (big, small) = create_big_and_small(); - let proj = swap_reverting_projection(&left_schema, &right_schema); + let left = if small_on_right { &big } else { &small }; + let right = if small_on_right { &small } else { &big }; - assert_eq!(proj.len(), 3); + let left_on = if small_on_right { + "big_col" + } else { + "small_col" + }; + let right_on = if small_on_right { + "small_col" + } else { + "big_col" + }; - let (col, name) = &proj[0]; - assert_eq!(name, "a"); - assert_col_expr(col, "a", 1); + let join = Arc::new(HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + vec![( + Arc::new(Column::new_with_schema(left_on, &left.schema())?), + Arc::new(Column::new_with_schema(right_on, &right.schema())?), + )], + None, + &join_type, + Some(projection), + PartitionMode::Partitioned, + false, + )?); - let (col, name) = &proj[1]; - assert_eq!(name, "b"); - assert_col_expr(col, "b", 2); + let swapped = join + .swap_inputs(PartitionMode::Partitioned) + .expect("swap_hash_join must support joins with projections"); + let swapped_join = swapped.as_any().downcast_ref::().expect( + "ProjectionExec won't be added above if HashJoinExec contains embedded projection", + ); - let (col, name) = &proj[2]; - assert_eq!(name, "c"); - assert_col_expr(col, "c", 0); + assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped.schema().fields.len(), 1); + assert_eq!(swapped.schema().fields[0].name(), "small_col"); + Ok(()) } fn assert_col_expr(expr: &Arc, name: &str, index: usize) { @@ -1340,44 +1259,44 @@ mod tests_statistical { )); let join_on = vec![( - Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, + col("small_col", &small.schema()).unwrap(), + col("big_col", &big.schema()).unwrap(), )]; check_join_partition_mode( - small.clone(), - big.clone(), + Arc::::clone(&small), + Arc::::clone(&big), join_on, false, PartitionMode::CollectLeft, ); let join_on = vec![( - Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, + col("big_col", &big.schema()).unwrap(), + col("small_col", &small.schema()).unwrap(), )]; check_join_partition_mode( big, - small.clone(), + Arc::::clone(&small), join_on, true, PartitionMode::CollectLeft, ); let join_on = vec![( - Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _, + col("small_col", &small.schema()).unwrap(), + col("empty_col", &empty.schema()).unwrap(), )]; check_join_partition_mode( - small.clone(), - empty.clone(), + Arc::::clone(&small), + Arc::::clone(&empty), join_on, false, PartitionMode::CollectLeft, ); let join_on = vec![( - Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, + col("empty_col", &empty.schema()).unwrap(), + col("small_col", &small.schema()).unwrap(), )]; check_join_partition_mode( empty, @@ -1411,8 +1330,8 @@ mod tests_statistical { as _, )]; check_join_partition_mode( - big.clone(), - bigger.clone(), + Arc::::clone(&big), + Arc::::clone(&bigger), join_on, false, PartitionMode::Partitioned, @@ -1425,7 +1344,7 @@ mod tests_statistical { )]; check_join_partition_mode( bigger, - big.clone(), + Arc::::clone(&big), join_on, true, PartitionMode::Partitioned, @@ -1436,8 +1355,8 @@ mod tests_statistical { Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, )]; check_join_partition_mode( - empty.clone(), - big.clone(), + Arc::::clone(&empty), + Arc::::clone(&big), join_on, false, PartitionMode::Partitioned, @@ -1501,13 +1420,258 @@ mod tests_statistical { #[cfg(test)] mod util_tests { - use std::sync::Arc; + use std::{ + any::Any, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + }; - use arrow_schema::{DataType, Field, Schema}; + use arrow::{ + array::RecordBatch, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + use datafusion_common::{Result, Statistics}; + use datafusion_execution::{ + RecordBatchStream, SendableRecordBatchStream, TaskContext, + }; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; use datafusion_physical_expr::intervals::utils::check_support; - use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; + use datafusion_physical_plan::{ + execution_plan::{Boundedness, EmissionType}, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + }; + use futures::Stream; + + #[derive(Debug)] + struct UnboundedStream { + batch_produce: Option, + count: usize, + batch: RecordBatch, + } + + impl Stream for UnboundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if let Some(val) = self.batch_produce { + if val <= self.count { + return Poll::Ready(None); + } + } + self.count += 1; + Poll::Ready(Some(Ok(self.batch.clone()))) + } + } + + impl RecordBatchStream for UnboundedStream { + fn schema(&self) -> SchemaRef { + self.batch.schema() + } + } + + /// A mock execution plan that simply returns the provided data source characteristic + #[derive(Debug, Clone)] + pub struct UnboundedExec { + batch_produce: Option, + batch: RecordBatch, + cache: PlanProperties, + } + + impl UnboundedExec { + /// Create new exec that clones the given record batch to its output. + /// + /// Set `batch_produce` to `Some(n)` to emit exactly `n` batches per partition. + pub fn new( + batch_produce: Option, + batch: RecordBatch, + partitions: usize, + ) -> Self { + let cache = + Self::compute_properties(batch.schema(), batch_produce, partitions); + Self { + batch_produce, + batch, + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + schema: SchemaRef, + batch_produce: Option, + n_partitions: usize, + ) -> PlanProperties { + let boundedness = if batch_produce.is_none() { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + }; + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(n_partitions), + EmissionType::Incremental, + boundedness, + ) + } + } + + impl DisplayAs for UnboundedExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "UnboundedExec: unbounded={}", + self.batch_produce.is_none(), + ) + } + } + } + } + + impl ExecutionPlan for UnboundedExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(UnboundedStream { + batch_produce: self.batch_produce, + count: 0, + batch: self.batch.clone(), + })) + } + } + + #[derive(Eq, PartialEq, Debug)] + pub enum SourceType { + Unbounded, + Bounded, + } + + /// A mock execution plan that simply returns the provided statistics + #[derive(Debug, Clone)] + pub struct StatisticsExec { + stats: Statistics, + schema: Arc, + cache: PlanProperties, + } + + impl StatisticsExec { + pub fn new(stats: Statistics, schema: Schema) -> Self { + assert_eq!( + stats.column_statistics.len(), schema.fields().len(), + "if defined, the column statistics vector length should be the number of fields" + ); + let cache = Self::compute_properties(Arc::new(schema.clone())); + Self { + stats, + schema: Arc::new(schema), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(2), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } + } + + impl DisplayAs for StatisticsExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "StatisticsExec: col_count={}, row_count={:?}", + self.schema.fields().len(), + self.stats.num_rows, + ) + } + } + } + } + + impl ExecutionPlan for StatisticsExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("This plan only serves for testing statistics") + } + + fn statistics(&self) -> Result { + Ok(self.stats.clone()) + } + } #[test] fn check_expr_supported() { @@ -1541,11 +1705,12 @@ mod util_tests { #[cfg(test)] mod hash_join_tests { use super::*; - use crate::physical_optimizer::test_utils::SourceType; - use crate::test_util::UnboundedExec; + use util_tests::{SourceType, UnboundedExec}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_plan::projection::ProjectionExec; struct TestCase { case: String, @@ -1625,7 +1790,7 @@ mod hash_join_tests { initial_join_type: join_type, initial_mode: PartitionMode::CollectLeft, expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), + expected_join_type: join_type.swap(), expected_mode: PartitionMode::CollectLeft, expecting_swap: true, }); @@ -1668,7 +1833,7 @@ mod hash_join_tests { initial_join_type: join_type, initial_mode: PartitionMode::Partitioned, expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), + expected_join_type: join_type.swap(), expected_mode: PartitionMode::Partitioned, expecting_swap: true, }); @@ -1726,7 +1891,7 @@ mod hash_join_tests { initial_join_type: join_type, initial_mode: PartitionMode::Partitioned, expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), + expected_join_type: join_type.swap(), expected_mode: PartitionMode::Partitioned, expecting_swap: true, }); @@ -1888,7 +2053,7 @@ mod hash_join_tests { false, )]))), 2, - )) as Arc; + )) as _; let right_exec = Arc::new(UnboundedExec::new( (!right_unbounded).then_some(1), RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( @@ -1897,21 +2062,21 @@ mod hash_join_tests { false, )]))), 2, - )) as Arc; + )) as _; let join = Arc::new(HashJoinExec::try_new( Arc::clone(&left_exec), Arc::clone(&right_exec), vec![( - Arc::new(Column::new_with_schema("a", &left_exec.schema())?), - Arc::new(Column::new_with_schema("b", &right_exec.schema())?), + col("a", &left_exec.schema())?, + col("b", &right_exec.schema())?, )], None, &t.initial_join_type, None, t.initial_mode, false, - )?); + )?) as _; let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?; @@ -1924,7 +2089,7 @@ mod hash_join_tests { .expect( "A proj is required to swap columns back to their original order", ); - proj.input().clone() + Arc::::clone(proj.input()) } else { optimized_join_plan }; @@ -1944,12 +2109,12 @@ mod hash_join_tests { assert_eq!( ( t.case.as_str(), - if left.execution_mode().is_unbounded() { + if left.boundedness().is_unbounded() { SourceType::Unbounded } else { SourceType::Bounded }, - if right.execution_mode().is_unbounded() { + if right.boundedness().is_unbounded() { SourceType::Unbounded } else { SourceType::Bounded diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 439f1dc873d1e..ccb18f6791711 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -14,15 +14,22 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] pub mod aggregate_statistics; +pub mod coalesce_batches; pub mod combine_partial_final_agg; +pub mod join_selection; pub mod limit_pushdown; pub mod limited_distinct_aggregation; mod optimizer; pub mod output_requirements; +pub mod pruning; +pub mod sanity_checker; +pub mod test_utils; pub mod topk_aggregation; +pub mod update_aggr_exprs; pub use optimizer::PhysicalOptimizerRule; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 8f392b683077f..7a44b2e90dde7 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -339,3 +339,480 @@ fn add_global_limit( } // See tests in datafusion/core/tests/physical_optimizer + +#[cfg(test)] +mod test { + use super::*; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::config::ConfigOptions; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; + use datafusion_physical_expr::expressions::{col, lit}; + use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; + use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::repartition::RepartitionExec; + use datafusion_physical_plan::sorts::sort::SortExec; + use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; + use datafusion_physical_plan::{ + get_plan_string, ExecutionPlan, ExecutionPlanProperties, + }; + use std::sync::Arc; + + #[derive(Debug)] + struct DummyStreamPartition { + schema: SchemaRef, + } + impl PartitionStream for DummyStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } + } + + #[test] + fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero( + ) -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema)?; + let global_limit = global_limit_exec(streaming_table, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( + ) -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema)?; + let global_limit = global_limit_exec(streaming_table, 2, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=2, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "GlobalLimitExec: skip=2, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( + ) -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let repartition = repartition_exec(streaming_table)?; + let filter = filter_exec(schema, repartition)?; + let coalesce_batches = coalesce_batches_exec(filter); + let local_limit = local_limit_exec(coalesce_batches, 5); + let coalesce_partitions = coalesce_partitions_exec(local_limit); + let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " LocalLimitExec: fetch=5", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192, fetch=5", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let filter = filter_exec(Arc::clone(&schema), streaming_table)?; + let projection = projection_exec(schema, filter)?; + let global_limit = global_limit_exec(projection, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " FilterExec: c3@2 > 0", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " GlobalLimitExec: skip=0, fetch=5", + " FilterExec: c3@2 > 0", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( + ) -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); + let coalesce_batches = coalesce_batches_exec(streaming_table); + let projection = projection_exec(schema, coalesce_batches)?; + let global_limit = global_limit_exec(projection, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); + let coalesce_batches = coalesce_batches_exec(streaming_table); + let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; + let repartition = repartition_exec(projection)?; + let sort = sort_exec( + vec![PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }], + repartition, + ); + let spm = + sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); + let global_limit = global_limit_exec(spm, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " SortPreservingMergeExec: [c1@0 ASC]", + " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "SortPreservingMergeExec: [c1@0 ASC], fetch=5", + " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions( + ) -> Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let repartition = repartition_exec(streaming_table)?; + let filter = filter_exec(schema, repartition)?; + let coalesce_partitions = coalesce_partitions_exec(filter); + let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn merges_local_limit_with_local_limit() -> Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let child_local_limit = local_limit_exec(empty_exec, 10); + let parent_local_limit = local_limit_exec(child_local_limit, 20); + + let initial = get_plan_string(&parent_local_limit); + let expected_initial = [ + "LocalLimitExec: fetch=20", + " LocalLimitExec: fetch=10", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn merges_global_limit_with_global_limit() -> Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); + let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); + + let initial = get_plan_string(&parent_global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=10, fetch=20", + " GlobalLimitExec: skip=10, fetch=30", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn merges_global_limit_with_local_limit() -> Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let local_limit = local_limit_exec(empty_exec, 40); + let global_limit = global_limit_exec(local_limit, 20, Some(30)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=20, fetch=30", + " LocalLimitExec: fetch=40", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn merges_local_limit_with_global_limit() -> Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let global_limit = global_limit_exec(empty_exec, 20, Some(30)); + let local_limit = local_limit_exec(global_limit, 20); + + let initial = get_plan_string(&local_limit); + let expected_initial = [ + "LocalLimitExec: fetch=20", + " GlobalLimitExec: skip=20, fetch=30", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + fn create_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + Field::new("c3", DataType::Int32, true), + ])) + } + + fn streaming_table_exec(schema: SchemaRef) -> Result> { + Ok(Arc::new(StreamingTableExec::try_new( + Arc::clone(&schema), + vec![Arc::new(DummyStreamPartition { schema }) as _], + None, + None, + true, + None, + )?)) + } + + fn global_limit_exec( + input: Arc, + skip: usize, + fetch: Option, + ) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) + } + + fn local_limit_exec( + input: Arc, + fetch: usize, + ) -> Arc { + Arc::new(LocalLimitExec::new(input, fetch)) + } + + fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortExec::new(sort_exprs, input)) + } + + fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + } + + fn projection_exec( + schema: SchemaRef, + input: Arc, + ) -> Result> { + Ok(Arc::new(ProjectionExec::try_new( + vec![ + (col("c1", schema.as_ref()).unwrap(), "c1".to_string()), + (col("c2", schema.as_ref()).unwrap(), "c2".to_string()), + (col("c3", schema.as_ref()).unwrap(), "c3".to_string()), + ], + input, + )?)) + } + + fn filter_exec( + schema: SchemaRef, + input: Arc, + ) -> Result> { + Ok(Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("c3", schema.as_ref()).unwrap(), + Operator::Gt, + lit(0), + )), + input, + )?)) + } + + fn coalesce_batches_exec(input: Arc) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, 8192)) + } + + fn coalesce_partitions_exec( + local_limit: Arc, + ) -> Arc { + Arc::new(CoalescePartitionsExec::new(local_limit)) + } + + fn repartition_exec( + streaming_table: Arc, + ) -> Result> { + Ok(Arc::new(RepartitionExec::try_new( + streaming_table, + Partitioning::RoundRobinBatch(8), + )?)) + } + + fn empty_exec(schema: SchemaRef) -> Arc { + Arc::new(EmptyExec::new(schema)) + } +} diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 4f6f91a2348f4..e107bb85d7b8a 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -33,7 +33,7 @@ use datafusion_physical_plan::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; -use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; +use datafusion_physical_expr::{Distribution, LexRequirement}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; @@ -44,7 +44,7 @@ use crate::PhysicalOptimizerRule; /// `new_add_mode` and `new_remove_mode`. With this rule, we can keep track of /// the global requirements (ordering and distribution) across rules. /// -/// The primary usecase of this node and rule is to specify and preserve the desired output +/// The primary use case of this node and rule is to specify and preserve the desired output /// ordering and distribution the entire plan. When sending to a single client, a single partition may /// be desirable, but when sending to a multi-partitioned writer, keeping multiple partitions may be /// better. @@ -121,7 +121,8 @@ impl OutputRequirementExec { PlanProperties::new( input.equivalence_properties().clone(), // Equivalence Properties input.output_partitioning().clone(), // Output Partitioning - input.execution_mode(), // Execution Mode + input.pipeline_behavior(), // Pipeline Behavior + input.boundedness(), // Boundedness ) } } @@ -256,13 +257,13 @@ fn require_top_ordering_helper( // Therefore; we check the sort expression field of the SortExec to assign the requirements. let req_ordering = sort_exec.expr(); let req_dist = sort_exec.required_input_distribution()[0].clone(); - let reqs = PhysicalSortRequirement::from_sort_exprs(req_ordering); + let reqs = LexRequirement::from(req_ordering.clone()); Ok(( Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _, true, )) } else if let Some(spm) = plan.as_any().downcast_ref::() { - let reqs = PhysicalSortRequirement::from_sort_exprs(spm.expr()); + let reqs = LexRequirement::from(spm.expr().clone()); Ok(( Arc::new(OutputRequirementExec::new( plan, diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs similarity index 76% rename from datafusion/core/src/physical_optimizer/pruning.rs rename to datafusion/physical-optimizer/src/pruning.rs index e88b25cc23635..b331bc88b8de3 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -18,33 +18,30 @@ //! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers" //! based on statistics (e.g. Parquet Row Groups) //! -//! [`Expr`]: crate::prelude::Expr +//! [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html use std::collections::HashSet; use std::sync::Arc; -use crate::{ - common::{Column, DFSchema}, - error::{DataFusionError, Result}, - logical_expr::Operator, - physical_plan::{ColumnarValue, PhysicalExpr}, -}; - +use arrow::array::AsArray; use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use arrow_array::cast::AsArray; +use log::trace; + +use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, ScalarValue, }; +use datafusion_common::{Column, DFSchema}; +use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; - -use log::trace; +use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; /// A source of runtime statistical information to [`PruningPredicate`]s. /// @@ -290,7 +287,12 @@ pub trait PruningStatistics { /// predicate can never possibly be true). The container can be pruned (skipped) /// entirely. /// -/// Note that in order to be correct, `PruningPredicate` must return false +/// While `PruningPredicate` will never return a `NULL` value, the +/// rewritten predicate (as returned by `build_predicate_expression` and used internally +/// by `PruningPredicate`) may evaluate to `NULL` when some of the min/max values +/// or null / row counts are not known. +/// +/// In order to be correct, `PruningPredicate` must return false /// **only** if it can determine that for all rows in the container, the /// predicate could never evaluate to `true` (always evaluates to either `NULL` /// or `false`). @@ -330,12 +332,12 @@ pub trait PruningStatistics { /// /// Original Predicate | Rewritten Predicate /// ------------------ | -------------------- -/// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END` -/// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END` -/// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END` +/// `x = 5` | `x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max)` +/// `x < 5` | `x_null_count != x_row_count THEN false (x_max < 5)` +/// `x = 5 AND y = 10` | `x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max) AND y_null_count != y_row_count (y_min <= 10 AND 10 <= y_max)` /// `x IS NULL` | `x_null_count > 0` /// `x IS NOT NULL` | `x_null_count != row_count` -/// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END` +/// `CAST(x as int) = 5` | `x_null_count != x_row_count (CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int))` /// /// ## Predicate Evaluation /// The PruningPredicate works in two passes @@ -355,15 +357,9 @@ pub trait PruningStatistics { /// Given the predicate, `x = 5 AND y = 10`, the rewritten predicate would look like: /// /// ```sql -/// CASE -/// WHEN x_null_count = x_row_count THEN false -/// ELSE x_min <= 5 AND 5 <= x_max -/// END +/// x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max) /// AND -/// CASE -/// WHEN y_null_count = y_row_count THEN false -/// ELSE y_min <= 10 AND 10 <= y_max -/// END +/// y_null_count != y_row_count AND (y_min <= 10 AND 10 <= y_max) /// ``` /// /// If we know that for a given container, `x` is between `1 and 100` and we know that @@ -384,16 +380,22 @@ pub trait PruningStatistics { /// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// -/// * `CASE WHEN null = null THEN false ELSE 1 <= 5 AND 5 <= 100 END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` -/// * `null = null` is `null` which is not true, so the `CASE` expression will use the `ELSE` clause -/// * `1 <= 5 AND 5 <= 100 AND 4 <= 10 AND 10 <= 7` -/// * `true AND true AND true AND false` +/// * `null != null AND (1 <= 5 AND 5 <= 100) AND null != null AND (4 <= 10 AND 10 <= 7)` +/// * `null = null` is `null` which is not true, so the AND moves on to the next clause +/// * `null and (1 <= 5 AND 5 <= 100) AND null AND (4 <= 10 AND 10 <= 7)` +/// * evaluating the clauses further we get: +/// * `null and true and null and false` +/// * `null and false` /// * `false` /// /// Returning `false` means the container can be pruned, which matches the /// intuition that `x = 5 AND y = 10` can’t be true for any row if all values of `y` /// are `7` or less. /// +/// Note that if we had ended up with `null AND true AND null AND true` the result +/// would have been `null`. +/// `null` is treated the same as`true`, because we can't prove that the predicate is `false.` +/// /// If, for some other container, we knew `y` was between the values `4` and /// `15`, then the rewritten predicate evaluates to `true` (verifying this is /// left as an exercise to the reader -- are you still here?), and the container @@ -408,15 +410,9 @@ pub trait PruningStatistics { /// look like the same as example 1: /// /// ```sql -/// CASE -/// WHEN x_null_count = x_row_count THEN false -/// ELSE x_min <= 5 AND 5 <= x_max -/// END +/// x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max) /// AND -/// CASE -/// WHEN y_null_count = y_row_count THEN false -/// ELSE y_min <= 10 AND 10 <= y_max -/// END +/// y_null_count != y_row_count AND (y_min <= 10 AND 10 <= y_max) /// ``` /// /// If we know that for another given container, `x_min` is NULL and `x_max` is @@ -438,14 +434,13 @@ pub trait PruningStatistics { /// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// -/// * `CASE WHEN 100 = 100 THEN false ELSE null <= 5 AND 5 <= null END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` -/// * Since `100 = 100` is `true`, the `CASE` expression will use the `THEN` clause, i.e. `false` -/// * The other `CASE` expression will use the `ELSE` clause, i.e. `4 <= 10 AND 10 <= 7` -/// * `false AND true` +/// * `100 != 100 AND (null <= 5 AND 5 <= null) AND null = null AND (4 <= 10 AND 10 <= 7)` +/// * `false AND null AND null AND false` +/// * `false AND false` /// * `false` /// /// Returning `false` means the container can be pruned, which matches the -/// intuition that `x = 5 AND y = 10` can’t be true for all values in `x` +/// intuition that `x = 5 AND y = 10` can’t be true because all values in `x` /// are known to be NULL. /// /// # Related Work @@ -458,7 +453,7 @@ pub trait PruningStatistics { /// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741 /// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf /// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10 -///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 +/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated @@ -478,6 +473,36 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain +/// complex expressions or predicates that reference columns that are not in the +/// schema. +pub trait UnhandledPredicateHook { + /// Called when a predicate can not be rewritten in terms of statistics or + /// references a column that is not in the schema. + fn handle(&self, expr: &Arc) -> Arc; +} + +/// The default handling for unhandled predicates is to return a constant `true` +/// (meaning don't prune the container) +#[derive(Debug, Clone)] +struct ConstantUnhandledPredicateHook { + default: Arc, +} + +impl Default for ConstantUnhandledPredicateHook { + fn default() -> Self { + Self { + default: Arc::new(phys_expr::Literal::from(ScalarValue::from(true))), + } + } +} + +impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { + fn handle(&self, _expr: &Arc) -> Arc { + Arc::clone(&self.default) + } +} + impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -502,10 +527,16 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + // build predicate expression once let mut required_columns = RequiredColumns::new(); - let predicate_expr = - build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + let predicate_expr = build_predicate_expression( + &expr, + schema.as_ref(), + &mut required_columns, + &unhandled_hook, + ); let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -531,7 +562,7 @@ impl PruningPredicate { /// expressions like `b = false`, but it does handle the /// simplified version `b`. See [`ExprSimplifier`] to simplify expressions. /// - /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier + /// [`ExprSimplifier`]: https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html pub fn prune(&self, statistics: &S) -> Result> { let mut builder = BoolVecBuilder::new(statistics.num_containers()); @@ -617,7 +648,7 @@ impl PruningPredicate { // this is only used by `parquet` feature right now #[allow(dead_code)] - pub(crate) fn required_columns(&self) -> &RequiredColumns { + pub fn required_columns(&self) -> &RequiredColumns { &self.required_columns } @@ -728,7 +759,7 @@ fn is_always_true(expr: &Arc) -> bool { /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredColumns { +pub struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) @@ -752,7 +783,7 @@ impl RequiredColumns { /// * `true` returns None #[allow(dead_code)] // this fn is only used by `parquet` feature right now, thus the `allow(dead_code)` - pub(crate) fn single_column(&self) -> Option<&phys_expr::Column> { + pub fn single_column(&self) -> Option<&phys_expr::Column> { if self.columns.windows(2).all(|w| { // check if all columns are the same (ignoring statistics and field) let c1 = &w[0].0; @@ -823,7 +854,7 @@ impl RequiredColumns { Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } - rewrite_column_expr(column_expr.clone(), column, &stat_column) + rewrite_column_expr(Arc::clone(column_expr), column, &stat_column) } /// rewrite col --> col_min @@ -1096,7 +1127,7 @@ fn rewrite_expr_to_prunable( .is_some() { // `col op lit()` - Ok((column_expr.clone(), op, scalar_expr.clone())) + Ok((Arc::clone(column_expr), op, Arc::clone(scalar_expr))) } else if let Some(cast) = column_expr_any.downcast_ref::() { // `cast(col) op lit()` let arrow_schema: SchemaRef = schema.clone().into(); @@ -1141,8 +1172,8 @@ fn rewrite_expr_to_prunable( .downcast_ref::() .is_some() { - let left = not.arg().clone(); - let right = Arc::new(phys_expr::NotExpr::new(scalar_expr.clone())); + let left = Arc::clone(not.arg()); + let right = Arc::new(phys_expr::NotExpr::new(Arc::clone(scalar_expr))); Ok((left, reverse_operator(op)?, right)) } else { plan_err!("Not with complex expression {column_expr:?} is not supported") @@ -1161,6 +1192,8 @@ fn is_compare_op(op: Operator) -> bool { | Operator::LtEq | Operator::Gt | Operator::GtEq + | Operator::LikeMatch + | Operator::NotLikeMatch ) } @@ -1314,27 +1347,78 @@ fn build_is_null_column_expr( /// an OR chain const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; +/// Rewrite a predicate expression in terms of statistics (min/max/null_counts) +/// for use as a [`PruningPredicate`]. +pub struct PredicateRewriter { + unhandled_hook: Arc, +} + +impl Default for PredicateRewriter { + fn default() -> Self { + Self { + unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()), + } + } +} + +impl PredicateRewriter { + /// Create a new `PredicateRewriter` + pub fn new() -> Self { + Self::default() + } + + /// Set the unhandled hook to be used when a predicate can not be rewritten + pub fn with_unhandled_hook( + self, + unhandled_hook: Arc, + ) -> Self { + Self { unhandled_hook } + } + + /// Translate logical filter expression into pruning predicate + /// expression that will evaluate to FALSE if it can be determined no + /// rows between the min/max values could pass the predicates. + /// + /// Any predicates that can not be translated will be passed to `unhandled_hook`. + /// + /// Returns the pruning predicate as an [`PhysicalExpr`] + /// + /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` + pub fn rewrite_predicate_to_statistics_predicate( + &self, + expr: &Arc, + schema: &Schema, + ) -> Arc { + let mut required_columns = RequiredColumns::new(); + build_predicate_expression( + expr, + schema, + &mut required_columns, + &self.unhandled_hook, + ) + } +} + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. /// +/// Any predicates that can not be translated will be passed to `unhandled_hook`. +/// /// Returns the pruning predicate as an [`PhysicalExpr`] /// -/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` fn build_predicate_expression( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + unhandled_hook: &Arc, ) -> Arc { - // Returned for unsupported expressions. Such expressions are - // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(true)))); - // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(is_not_null) = expr_any.downcast_ref::() { return build_is_null_column_expr( @@ -1343,19 +1427,19 @@ fn build_predicate_expression( required_columns, true, ) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { - return unhandled; + return unhandled_hook.handle(expr); } } if let Some(in_list) = expr_any.downcast_ref::() { @@ -1377,34 +1461,56 @@ fn build_predicate_expression( .iter() .map(|e| { Arc::new(phys_expr::BinaryExpr::new( - in_list.expr().clone(), + Arc::clone(in_list.expr()), eq_op, - e.clone(), + Arc::clone(e), )) as _ }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); - return build_predicate_expression(&change_expr, schema, required_columns); + return build_predicate_expression( + &change_expr, + schema, + required_columns, + unhandled_hook, + ); } else { - return unhandled; + return unhandled_hook.handle(expr); } } let (left, op, right) = { if let Some(bin_expr) = expr_any.downcast_ref::() { ( - bin_expr.left().clone(), + Arc::clone(bin_expr.left()), *bin_expr.op(), - bin_expr.right().clone(), + Arc::clone(bin_expr.right()), + ) + } else if let Some(like_expr) = expr_any.downcast_ref::() { + if like_expr.case_insensitive() { + return unhandled_hook.handle(expr); + } + let op = match (like_expr.negated(), like_expr.case_insensitive()) { + (false, false) => Operator::LikeMatch, + (true, false) => Operator::NotLikeMatch, + (false, true) => Operator::ILikeMatch, + (true, true) => Operator::NotILikeMatch, + }; + ( + Arc::clone(like_expr.expr()), + op, + Arc::clone(like_expr.pattern()), ) } else { - return unhandled; + return unhandled_hook.handle(expr); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(&left, schema, required_columns); - let right_expr = build_predicate_expression(&right, schema, required_columns); + let left_expr = + build_predicate_expression(&left, schema, required_columns, unhandled_hook); + let right_expr = + build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, @@ -1412,7 +1518,7 @@ fn build_predicate_expression( (left, Operator::Or, right) if is_always_true(left) || is_always_true(right) => { - unhandled + Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(true)))) } _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; @@ -1425,12 +1531,11 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => { - return unhandled; - } + Err(_) => return unhandled_hook.handle(expr), }; - build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) + build_statistics_expr(&mut expr_builder) + .unwrap_or_else(|_| unhandled_hook.handle(expr)) } fn build_statistics_expr( @@ -1447,11 +1552,11 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( min_column_expr, Operator::NotEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )), Operator::Or, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), Operator::NotEq, max_column_expr, )), @@ -1466,22 +1571,27 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( min_column_expr, Operator::LtEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )), Operator::And, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), Operator::LtEq, max_column_expr, )), )) } + Operator::LikeMatch => build_like_match(expr_builder).ok_or_else(|| { + plan_datafusion_err!( + "LIKE expression with wildcard at the beginning is not supported" + ) + })?, Operator::Gt => { // column > literal => (min, max) > literal => max > literal Arc::new(phys_expr::BinaryExpr::new( expr_builder.max_column_expr()?, Operator::Gt, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } Operator::GtEq => { @@ -1489,7 +1599,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.max_column_expr()?, Operator::GtEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } Operator::Lt => { @@ -1497,7 +1607,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.min_column_expr()?, Operator::Lt, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } Operator::LtEq => { @@ -1505,7 +1615,7 @@ fn build_statistics_expr( Arc::new(phys_expr::BinaryExpr::new( expr_builder.min_column_expr()?, Operator::LtEq, - expr_builder.scalar_expr().clone(), + Arc::clone(expr_builder.scalar_expr()), )) } // other expressions are not supported @@ -1515,13 +1625,134 @@ fn build_statistics_expr( ); } }; - let statistics_expr = wrap_case_expr(statistics_expr, expr_builder)?; + let statistics_expr = wrap_null_count_check_expr(statistics_expr, expr_builder)?; Ok(statistics_expr) } -/// Wrap the statistics expression in a case expression. -/// This is necessary to handle the case where the column is known -/// to be all nulls. +/// Convert `column LIKE literal` where P is a constant prefix of the literal +/// to a range check on the column: `P <= column && column < P'`, where P' is the +/// lowest string after all P* strings. +fn build_like_match( + expr_builder: &mut PruningExpressionBuilder, +) -> Option> { + // column LIKE literal => (min, max) LIKE literal split at % => min <= split literal && split literal <= max + // column LIKE 'foo%' => min <= 'foo' && 'foo' <= max + // column LIKE '%foo' => min <= '' && '' <= max => true + // column LIKE '%foo%' => min <= '' && '' <= max => true + // column LIKE 'foo' => min <= 'foo' && 'foo' <= max + + /// returns the string literal of the scalar value if it is a string + fn unpack_string(s: &ScalarValue) -> Option<&str> { + s.try_as_str().flatten() + } + + fn extract_string_literal(expr: &Arc) -> Option<&str> { + if let Some(lit) = expr.as_any().downcast_ref::() { + let s = unpack_string(lit.scalar().value())?; + return Some(s); + } + None + } + + // TODO Handle ILIKE perhaps by making the min lowercase and max uppercase + // this may involve building the physical expressions that call lower() and upper() + let min_column_expr = expr_builder.min_column_expr().ok()?; + let max_column_expr = expr_builder.max_column_expr().ok()?; + let scalar_expr = expr_builder.scalar_expr(); + // check that the scalar is a string literal + let s = extract_string_literal(scalar_expr)?; + // ANSI SQL specifies two wildcards: % and _. % matches zero or more characters, _ matches exactly one character. + let first_wildcard_index = s.find(['%', '_']); + if first_wildcard_index == Some(0) { + // there's no filtering we could possibly do, return an error and have this be handled by the unhandled hook + return None; + } + let (lower_bound, upper_bound) = if let Some(wildcard_index) = first_wildcard_index { + let prefix = &s[..wildcard_index]; + let lower_bound_lit = Arc::new(phys_expr::Literal::from(ScalarValue::Utf8( + Some(prefix.to_string()), + ))); + let upper_bound_lit = Arc::new(phys_expr::Literal::from(ScalarValue::Utf8( + Some(increment_utf8(prefix)?), + ))); + (lower_bound_lit, upper_bound_lit) + } else { + // the like expression is a literal and can be converted into a comparison + let bound = Arc::new(phys_expr::Literal::from(ScalarValue::Utf8(Some( + s.to_string(), + )))); + (Arc::clone(&bound), bound) + }; + let lower_bound_expr = Arc::new(phys_expr::BinaryExpr::new( + lower_bound, + Operator::LtEq, + Arc::clone(&max_column_expr), + )); + let upper_bound_expr = Arc::new(phys_expr::BinaryExpr::new( + Arc::clone(&min_column_expr), + Operator::LtEq, + upper_bound, + )); + let combined = Arc::new(phys_expr::BinaryExpr::new( + upper_bound_expr, + Operator::And, + lower_bound_expr, + )); + Some(combined) +} + +/// Increment a UTF8 string by one, returning `None` if it can't be incremented. +/// This makes it so that the returned string will always compare greater than the input string +/// or any other string with the same prefix. +/// This is necessary since the statistics may have been truncated: if we have a min statistic +/// of "fo" that may have originally been "foz" or anything else with the prefix "fo". +/// E.g. `increment_utf8("foo") >= "foo"` and `increment_utf8("foo") >= "fooz"` +/// In this example `increment_utf8("foo") == "fop" +fn increment_utf8(data: &str) -> Option { + // Helper function to check if a character is valid to use + fn is_valid_unicode(c: char) -> bool { + let cp = c as u32; + + // Filter out non-characters (https://www.unicode.org/versions/corrigendum9.html) + if [0xFFFE, 0xFFFF].contains(&cp) || (0xFDD0..=0xFDEF).contains(&cp) { + return false; + } + + // Filter out private use area + if cp >= 0x110000 { + return false; + } + + true + } + + // Convert string to vector of code points + let mut code_points: Vec = data.chars().collect(); + + // Work backwards through code points + for idx in (0..code_points.len()).rev() { + let original = code_points[idx] as u32; + + // Try incrementing the code point + if let Some(next_char) = char::from_u32(original + 1) { + if is_valid_unicode(next_char) { + code_points[idx] = next_char; + // truncate the string to the current index + code_points.truncate(idx + 1); + return Some(code_points.into_iter().collect()); + } + } + } + + None +} + +/// Wrap the statistics expression in a check that skips the expression if the column is all nulls. +/// +/// This is important not only as an optimization but also because statistics may not be +/// accurate for columns that are all nulls. +/// For example, for an `int` column `x` with all nulls, the min/max/null_count statistics +/// might be set to 0 and evaluating `x = 0` would incorrectly include the column. /// /// For example: /// @@ -1530,33 +1761,29 @@ fn build_statistics_expr( /// will become /// /// ```sql -/// CASE -/// WHEN x_null_count = x_row_count THEN false -/// ELSE x_min <= 10 AND 10 <= x_max -/// END +/// x_null_count != x_row_count AND (x_min <= 10 AND 10 <= x_max) /// ```` /// /// If the column is known to be all nulls, then the expression /// `x_null_count = x_row_count` will be true, which will cause the -/// case expression to return false. Therefore, prune out the container. -fn wrap_case_expr( +/// boolean expression to return false. Therefore, prune out the container. +fn wrap_null_count_check_expr( statistics_expr: Arc, expr_builder: &mut PruningExpressionBuilder, ) -> Result> { - // x_null_count = x_row_count - let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( + // x_null_count != x_row_count + let not_when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( expr_builder.null_count_column_expr()?, - Operator::Eq, + Operator::NotEq, expr_builder.row_count_column_expr()?, )); - let then = Arc::new(phys_expr::Literal::from(ScalarValue::Boolean(Some(false)))); - - // CASE WHEN x_null_count = x_row_count THEN false ELSE END - Ok(Arc::new(phys_expr::CaseExpr::try_new( - None, - vec![(when_null_count_eq_row_count, then)], - Some(statistics_expr), - )?)) + + // (x_null_count != x_row_count) AND () + Ok(Arc::new(phys_expr::BinaryExpr::new( + not_when_null_count_eq_row_count, + Operator::And, + statistics_expr, + ))) } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -1573,17 +1800,18 @@ mod tests { use std::ops::{Not, Rem}; use super::*; - use crate::assert_batches_eq; - use crate::logical_expr::{col, lit}; + use datafusion_common::assert_batches_eq; + use datafusion_expr::{col, lit}; use arrow::array::Decimal128Array; use arrow::{ - array::{BinaryArray, Int32Array, Int64Array, StringArray}, + array::{BinaryArray, Int32Array, Int64Array, StringArray, UInt64Array}, datatypes::TimeUnit, }; - use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_functions_nested::expr_fn::{array_has, make_array}; + use datafusion_physical_expr::expressions as phys_expr; use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] @@ -1963,6 +2191,110 @@ mod tests { } } + #[test] + fn prune_all_rows_null_counts() { + // if null_count = row_count then we should prune the container for i = 0 + // regardless of the statistics + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(0)], // min + vec![Some(0)], // max + ) + .with_null_counts(vec![Some(1)]) + .with_row_counts(vec![Some(1)]), + ); + let expected_ret = &[false]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + + // this should be true even if the container stats are missing + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![None]))), + max: Some(Arc::new(Int32Array::from(vec![None]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[false]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + + // If the null counts themselves are missing we should be able to fall back to the stats + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + max: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![None]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[true]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[false]; + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); + + // Same for the row counts + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + max: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![None]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[true]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[false]; + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); + } + + #[test] + fn prune_missing_statistics() { + // If the min or max stats are missing we should not prune + // (unless we know all rows are null, see `prune_all_rows_null_counts`) + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![None, Some(0)]))), + max: Some(Arc::new(Int32Array::from(vec![Some(0), None]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![Some(0), Some(0)]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![Some(1), Some(1)]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[true, true]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[false, true]; + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[true, false]; + prune_with_expr(col("i").lt(lit(0)), &schema, &statistics, expected_ret); + } + + #[test] + fn prune_null_stats() { + // if null_count = row_count then we should prune the container for i = 0 + // regardless of the statistics + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(0)], // min + vec![Some(0)], // max + ) + .with_null_counts(vec![Some(1)]) + .with_row_counts(vec![Some(1)]), + ); + + let expected_ret = &[false]; + + // i = 0 + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + } + #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min @@ -2144,7 +2476,8 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 END"; + let expected_expr = + "c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1"; // test column on the left let expr = col("c1").eq(lit(1)); @@ -2164,7 +2497,8 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != 1 OR 1 != c1_max@1 END"; + let expected_expr = + "c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1)"; // test column on the left let expr = col("c1").not_eq(lit(1)); @@ -2184,8 +2518,7 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 > 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_max@0 > 1"; // test column on the left let expr = col("c1").gt(lit(1)); @@ -2205,7 +2538,7 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 >= 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_max@0 >= 1"; // test column on the left let expr = col("c1").gt_eq(lit(1)); @@ -2224,8 +2557,7 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < 1"; // test column on the left let expr = col("c1").lt(lit(1)); @@ -2245,7 +2577,7 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 <= 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 <= 1"; // test column on the left let expr = col("c1").lt_eq(lit(1)); @@ -2270,8 +2602,7 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = - "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < 1"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2337,7 +2668,7 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < true END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < true"; // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated @@ -2360,20 +2691,11 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "\ - CASE \ - WHEN c1_null_count@1 = c1_row_count@2 THEN false \ - ELSE c1_min@0 < 1 \ - END \ - AND (\ - CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_min@3 <= 2 AND 2 <= c2_max@4 \ - END \ - OR CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_min@3 <= 3 AND 3 <= c2_max@4 \ - END\ + let expected_expr = "c1_null_count@1 != c1_row_count@2 \ + AND c1_min@0 < 1 AND (\ + c2_null_count@5 != c2_row_count@6 \ + AND c2_min@3 <= 2 AND 2 <= c2_max@4 OR \ + c2_null_count@5 != c2_row_count@6 AND c2_min@3 <= 3 AND 3 <= c2_max@4\ )"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); @@ -2465,18 +2787,7 @@ mod tests { vec![lit(1), lit(2), lit(3)], false, )); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 3 AND 3 <= c1_max@1 \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 3 AND 3 <= c1_max@1"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2512,19 +2823,7 @@ mod tests { vec![lit(1), lit(2), lit(3)], true, )); - let expected_expr = "\ - CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 != 1 OR 1 != c1_max@1 \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 != 2 OR 2 != c1_max@1 \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 != 3 OR 3 != c1_max@1 \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1) AND c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 2 OR 2 != c1_max@1) AND c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 3 OR 3 != c1_max@1)"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2570,24 +2869,7 @@ mod tests { // test c1 in(1, 2) and c2 BETWEEN 4 AND 5 let expr3 = expr1.and(expr2); - let expected_expr = "\ - (\ - CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ - END\ - ) AND CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_max@4 >= 4 \ - END \ - AND CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_min@7 <= 5 \ - END"; + let expected_expr = "(c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_null_count@5 != c2_row_count@6 AND c2_max@4 >= 4 AND c2_null_count@5 != c2_row_count@6 AND c2_min@7 <= 5"; let predicate_expr = test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2614,10 +2896,7 @@ mod tests { #[test] fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; // test cast(c1 as int64) = 1 // test column on the left @@ -2632,10 +2911,8 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expected_expr = "CASE \ - WHEN c1_null_count@1 = c1_row_count@2 THEN false \ - ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \ - END"; + let expected_expr = + "c1_null_count@1 != c1_row_count@2 AND TRY_CAST(c1_max@0 AS Int64) > 1"; // test column on the left let expr = @@ -2667,18 +2944,7 @@ mod tests { ], false, )); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2692,18 +2958,7 @@ mod tests { ], true, )); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3350,6 +3605,425 @@ mod tests { ); } + #[test] + fn test_increment_utf8() { + // Basic ASCII + assert_eq!(increment_utf8("abc").unwrap(), "abd"); + assert_eq!(increment_utf8("abz").unwrap(), "ab{"); + + // Test around ASCII 127 (DEL) + assert_eq!(increment_utf8("~").unwrap(), "\u{7f}"); // 126 -> 127 + assert_eq!(increment_utf8("\u{7f}").unwrap(), "\u{80}"); // 127 -> 128 + + // Test 2-byte UTF-8 sequences + assert_eq!(increment_utf8("ß").unwrap(), "à"); // U+00DF -> U+00E0 + + // Test 3-byte UTF-8 sequences + assert_eq!(increment_utf8("℣").unwrap(), "ℤ"); // U+2123 -> U+2124 + + // Test at UTF-8 boundaries + assert_eq!(increment_utf8("\u{7FF}").unwrap(), "\u{800}"); // 2-byte to 3-byte boundary + assert_eq!(increment_utf8("\u{FFFF}").unwrap(), "\u{10000}"); // 3-byte to 4-byte boundary + + // Test that if we can't increment we return None + assert!(increment_utf8("").is_none()); + assert!(increment_utf8("\u{10FFFF}").is_none()); // U+10FFFF is the max code point + + // Test that if we can't increment the last character we do the previous one and truncate + assert_eq!(increment_utf8("a\u{10FFFF}").unwrap(), "b"); + + // Test surrogate pair range (0xD800..=0xDFFF) + assert_eq!(increment_utf8("a\u{D7FF}").unwrap(), "b"); + assert!(increment_utf8("\u{D7FF}").is_none()); + + // Test non-characters range (0xFDD0..=0xFDEF) + assert_eq!(increment_utf8("a\u{FDCF}").unwrap(), "b"); + assert!(increment_utf8("\u{FDCF}").is_none()); + + // Test private use area limit (>= 0x110000) + assert_eq!(increment_utf8("a\u{10FFFF}").unwrap(), "b"); + assert!(increment_utf8("\u{10FFFF}").is_none()); // Can't increment past max valid codepoint + } + + /// Creates a setup for chunk pruning, modeling a utf8 column "s1" + /// with 5 different containers (e.g. RowGroups). They have [min, + /// max]: + /// s1 ["A", "Z"] + /// s1 ["A", "L"] + /// s1 ["N", "Z"] + /// s1 [NULL, NULL] + /// s1 ["A", NULL] + /// s1 ["", "A"] + /// s1 ["", ""] + /// s1 ["AB", "A\u{10ffff}"] + /// s1 ["A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] + fn utf8_setup() -> (SchemaRef, TestStatistics) { + let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); + + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![ + Some("A"), + Some("A"), + Some("N"), + Some("M"), + None, + Some("A"), + Some(""), + Some(""), + Some("AB"), + Some("A\u{10ffff}\u{10ffff}"), + ], // min + vec![ + Some("Z"), + Some("L"), + Some("Z"), + Some("M"), + None, + None, + Some("A"), + Some(""), + Some("A\u{10ffff}\u{10ffff}\u{10ffff}"), + Some("A\u{10ffff}\u{10ffff}"), + ], // max + ), + ); + (schema, statistics) + } + + #[test] + fn prune_utf8_eq() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").eq(lit("A")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["M", "M"] ==> no rows can pass (not keep) + false, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> no rows can pass (not keep) + false, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").eq(lit("")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["A", "L"] ==> no rows can pass (not keep) + false, + // s1 ["N", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["M", "M"] ==> no rows can pass (not keep) + false, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> no rows can pass (not keep) + false, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> all rows must pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + + #[test] + fn prune_utf8_not_eq() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").not_eq(lit("A")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["M", "M"] ==> all rows must pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> all rows must pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_eq(lit("")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["A", "L"] ==> all rows must pass (must keep) + true, + // s1 ["N", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["M", "M"] ==> all rows must pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> no rows can pass (not keep) + false, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + + #[test] + fn prune_utf8_like_one() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").like(lit("A_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["M", "M"] ==> no rows can pass (not keep) + false, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> no rows can pass (not keep) + false, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit("_A_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit("_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["A", "L"] ==> all rows must pass (must keep) + true, + // s1 ["N", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["M", "M"] ==> all rows must pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> all rows must pass (must keep) + true, + // s1 ["", ""] ==> all rows must pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit("")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["A", "L"] ==> no rows can pass (not keep) + false, + // s1 ["N", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["M", "M"] ==> no rows can pass (not keep) + false, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> no rows can pass (not keep) + false, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> all rows must pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + + #[test] + fn prune_utf8_like_many() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").like(lit("A%")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["M", "M"] ==> no rows can pass (not keep) + false, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> no rows can pass (not keep) + false, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit("%A%")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit("%")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["A", "L"] ==> all rows must pass (must keep) + true, + // s1 ["N", "Z"] ==> all rows must pass (must keep) + true, + // s1 ["M", "M"] ==> all rows must pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> unknown (must keep) + true, + // s1 ["", "A"] ==> all rows must pass (must keep) + true, + // s1 ["", ""] ==> all rows must pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> all rows must pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit("")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["A", "L"] ==> no rows can pass (not keep) + false, + // s1 ["N", "Z"] ==> no rows can pass (not keep) + false, + // s1 ["M", "M"] ==> no rows can pass (not keep) + false, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> no rows can pass (not keep) + false, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> all rows must pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + #[test] fn test_rewrite_expr_to_prunable() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); @@ -3399,6 +4073,74 @@ mod tests { // TODO: add test for other case and op } + #[test] + fn test_rewrite_expr_to_prunable_custom_unhandled_hook() { + struct CustomUnhandledHook; + + impl UnhandledPredicateHook for CustomUnhandledHook { + /// This handles an arbitrary case of a column that doesn't exist in the schema + /// by renaming it to yet another column that doesn't exist in the schema + /// (the transformation is arbitrary, the point is that it can do whatever it wants) + fn handle(&self, _expr: &Arc) -> Arc { + Arc::new(phys_expr::Literal::from(ScalarValue::Int32(Some(42)))) + } + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let schema_with_b = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let rewriter = PredicateRewriter::new() + .with_unhandled_hook(Arc::new(CustomUnhandledHook {})); + + let transform_expr = |expr| { + let expr = logical2physical(&expr, &schema_with_b); + rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema) + }; + + // transform an arbitrary valid expression that we know is handled + let known_expression = col("a").eq(lit(12)); + let known_expression_transformed = PredicateRewriter::new() + .rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), + &schema, + ); + + // an expression referencing an unknown column (that is not in the schema) gets passed to the hook + let input = col("b").eq(lit(12)); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown column + let input = known_expression.clone().and(input.clone()); + let expected = phys_expr::BinaryExpr::new( + Arc::::clone(&known_expression_transformed), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // an unknown expression gets passed to the hook + let input = array_has(make_array(vec![lit(1)]), col("a")); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown expression + let input = known_expression.and(input); + let expected = phys_expr::BinaryExpr::new( + Arc::::clone(&known_expression_transformed), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + } + #[test] fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value @@ -3877,7 +4619,7 @@ mod tests { ) { println!("Pruning with expr: {}", expr); let expr = logical2physical(&expr, schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); assert_eq!(result, expected); } @@ -3888,6 +4630,7 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - build_predicate_expression(&expr, schema, required_columns) + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } } diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs new file mode 100644 index 0000000000000..1cf89ed8d8a48 --- /dev/null +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The [SanityCheckPlan] rule ensures that a given plan can +//! accommodate its infinite sources, if there are any. It will reject +//! non-runnable query plans that use pipeline-breaking operators on +//! infinite input(s). In addition, it will check if all order and +//! distribution requirements of a plan are satisfied by its children. + +use std::sync::Arc; + +use datafusion_common::Result; +use datafusion_physical_plan::ExecutionPlan; + +use datafusion_common::config::{ConfigOptions, OptimizerOptions}; +use datafusion_common::plan_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; +use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; + +use crate::PhysicalOptimizerRule; +use datafusion_physical_expr_common::sort_expr::format_physical_sort_requirement_list; +use itertools::izip; + +/// The SanityCheckPlan rule rejects the following query plans: +/// 1. Invalid plans containing nodes whose order and/or distribution requirements +/// are not satisfied by their children. +/// 2. Plans that use pipeline-breaking operators on infinite input(s), +/// it is impossible to execute such queries (they will never generate output nor finish) +#[derive(Default, Debug)] +pub struct SanityCheckPlan {} + +impl SanityCheckPlan { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for SanityCheckPlan { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|p| check_plan_sanity(p, &config.optimizer)) + .data() + } + + fn name(&self) -> &str { + "SanityCheckPlan" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function propagates finiteness information and rejects any plan with +/// pipeline-breaking operators acting on infinite inputs. +pub fn check_finiteness_requirements( + input: Arc, + optimizer_options: &OptimizerOptions, +) -> Result>> { + if let Some(exec) = input.as_any().downcast_ref::() { + if !(optimizer_options.allow_symmetric_joins_without_pruning + || (exec.check_if_order_information_available()? && is_prunable(exec))) + { + return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ + the 'allow_symmetric_joins_without_pruning' configuration flag"); + } + } + + if matches!( + input.boundedness(), + Boundedness::Unbounded { + requires_infinite_memory: true + } + ) || (input.boundedness().is_unbounded() + && input.pipeline_behavior() == EmissionType::Final) + { + plan_err!( + "Cannot execute pipeline breaking queries, operator: {:?}", + input + ) + } else { + Ok(Transformed::no(input)) + } +} + +/// This function returns whether a given symmetric hash join is amenable to +/// data pruning. For this to be possible, it needs to have a filter where +/// all involved [`PhysicalExpr`]s, [`Operator`]s and data types support +/// interval calculations. +/// +/// [`PhysicalExpr`]: datafusion_physical_plan::PhysicalExpr +/// [`Operator`]: datafusion_expr::Operator +fn is_prunable(join: &SymmetricHashJoinExec) -> bool { + join.filter().is_some_and(|filter| { + check_support(filter.expression(), &join.schema()) + && filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())) + }) +} + +/// Ensures that the plan is pipeline friendly and the order and +/// distribution requirements from its children are satisfied. +pub fn check_plan_sanity( + plan: Arc, + optimizer_options: &OptimizerOptions, +) -> Result>> { + check_finiteness_requirements(Arc::clone(&plan), optimizer_options)?; + + for ((idx, child), sort_req, dist_req) in izip!( + plan.children().into_iter().enumerate(), + plan.required_input_ordering(), + plan.required_input_distribution(), + ) { + let child_eq_props = child.equivalence_properties(); + if let Some(sort_req) = sort_req { + if !child_eq_props.ordering_satisfy_requirement(&sort_req) { + let plan_str = get_plan_string(&plan); + return plan_err!( + "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", + plan_str, + format_physical_sort_requirement_list(&sort_req), + idx, + child_eq_props.oeq_class() + ); + } + } + + if !child + .output_partitioning() + .satisfy(&dist_req, child_eq_props) + { + let plan_str = get_plan_string(&plan); + return plan_err!( + "Plan: {:?} does not satisfy distribution requirements: {}. Child-{} output partitioning: {}", + plan_str, + dist_req, + idx, + child.output_partitioning() + ); + } + } + + Ok(Transformed::no(plan)) +} diff --git a/datafusion/physical-optimizer/src/test_utils.rs b/datafusion/physical-optimizer/src/test_utils.rs new file mode 100644 index 0000000000000..9f0b5ddf6f403 --- /dev/null +++ b/datafusion/physical-optimizer/src/test_utils.rs @@ -0,0 +1,336 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Collection of testing utility functions that are leveraged by the query optimizer rules + +use std::sync::Arc; + +use std::any::Any; +use std::fmt::Formatter; + +use arrow_schema::{Schema, SchemaRef, SortOptions}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinType, Result}; +use datafusion_expr::test::function_stub::count_udaf; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::memory::MemoryExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec}; +use datafusion_physical_plan::{InputOrderMode, Partitioning}; + +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::ExecutionPlan; + +use datafusion_physical_plan::tree_node::PlanContext; +use datafusion_physical_plan::{ + displayable, DisplayAs, DisplayFormatType, PlanProperties, +}; + +pub fn sort_merge_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, +) -> Arc { + Arc::new( + SortMergeJoinExec::try_new( + left, + right, + join_on.clone(), + None, + *join_type, + vec![SortOptions::default(); join_on.len()], + false, + ) + .unwrap(), + ) +} + +/// make PhysicalSortExpr with default options +pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { + sort_expr_options(name, schema, SortOptions::default()) +} + +/// PhysicalSortExpr with specified options +pub fn sort_expr_options( + name: &str, + schema: &Schema, + options: SortOptions, +) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options, + } +} + +pub fn coalesce_partitions_exec(input: Arc) -> Arc { + Arc::new(CoalescePartitionsExec::new(input)) +} + +pub fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[vec![]], Arc::clone(schema), None).unwrap()) +} + +pub fn hash_join_exec( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, +) -> Result> { + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + on, + filter, + join_type, + None, + PartitionMode::Partitioned, + true, + )?)) +} + +pub fn bounded_window_exec( + col_name: &str, + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs: LexOrdering = sort_exprs.into_iter().collect(); + let schema = input.schema(); + + Arc::new( + BoundedWindowAggExec::try_new( + vec![create_window_expr( + &WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count".to_owned(), + &[col(col_name, &schema).unwrap()], + &[], + sort_exprs.as_ref(), + Arc::new(WindowFrame::new(Some(false))), + schema.as_ref(), + false, + ) + .unwrap()], + Arc::clone(&input), + vec![], + InputOrderMode::Sorted, + ) + .unwrap(), + ) +} + +pub fn filter_exec( + predicate: Arc, + input: Arc, +) -> Arc { + Arc::new(FilterExec::try_new(predicate, input).unwrap()) +} + +pub fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) +} + +pub fn union_exec(input: Vec>) -> Arc { + Arc::new(UnionExec::new(input)) +} + +pub fn limit_exec(input: Arc) -> Arc { + global_limit_exec(local_limit_exec(input)) +} + +pub fn local_limit_exec(input: Arc) -> Arc { + Arc::new(LocalLimitExec::new(input, 100)) +} + +pub fn global_limit_exec(input: Arc) -> Arc { + Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +} + +pub fn repartition_exec(input: Arc) -> Arc { + Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) +} + +pub fn spr_repartition_exec(input: Arc) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(), + ) +} + +pub fn aggregate_exec(input: Arc) -> Arc { + let schema = input.schema(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +pub fn coalesce_batches_exec(input: Arc) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, 128)) +} + +pub fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortExec::new(sort_exprs, input)) +} + +/// A test [`ExecutionPlan`] whose requirements can be configured. +#[derive(Debug)] +pub struct RequirementsTestExec { + required_input_ordering: LexOrdering, + maintains_input_order: bool, + input: Arc, +} + +impl RequirementsTestExec { + pub fn new(input: Arc) -> Self { + Self { + required_input_ordering: LexOrdering::default(), + maintains_input_order: true, + input, + } + } + + /// sets the required input ordering + pub fn with_required_input_ordering( + mut self, + required_input_ordering: LexOrdering, + ) -> Self { + self.required_input_ordering = required_input_ordering; + self + } + + /// set the maintains_input_order flag + pub fn with_maintains_input_order(mut self, maintains_input_order: bool) -> Self { + self.maintains_input_order = maintains_input_order; + self + } + + /// returns this ExecutionPlan as an `Arc` + pub fn into_arc(self) -> Arc { + Arc::new(self) + } +} + +impl DisplayAs for RequirementsTestExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "RequiredInputOrderingExec") + } +} + +impl ExecutionPlan for RequirementsTestExec { + fn name(&self) -> &str { + "RequiredInputOrderingExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn required_input_ordering(&self) -> Vec> { + let requirement = LexRequirement::from(self.required_input_ordering.clone()); + vec![Some(requirement)] + } + + fn maintains_input_order(&self) -> Vec { + vec![self.maintains_input_order] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + Ok(RequirementsTestExec::new(Arc::clone(&children[0])) + .with_required_input_ordering(self.required_input_ordering.clone()) + .with_maintains_input_order(self.maintains_input_order) + .into_arc()) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("Test exec does not support execution") + } +} + +/// A [`PlanContext`] object is susceptible to being left in an inconsistent state after +/// untested mutable operations. It is crucial that there be no discrepancies between a plan +/// associated with the root node and the plan generated after traversing all nodes +/// within the [`PlanContext`] tree. In addition to verifying the plans resulting from optimizer +/// rules, it is essential to ensure that the overall tree structure corresponds with the plans +/// contained within the node contexts. +/// TODO: Once [`ExecutionPlan`] implements [`PartialEq`], string comparisons should be +/// replaced with direct plan equality checks. +pub fn check_integrity(context: PlanContext) -> Result> { + context + .transform_up(|node| { + let children_plans = node.plan.children(); + assert_eq!(node.children.len(), children_plans.len()); + for (child_plan, child_node) in + children_plans.iter().zip(node.children.iter()) + { + assert_eq!( + displayable(child_plan.as_ref()).one_line().to_string(), + displayable(child_node.plan.as_ref()).one_line().to_string() + ); + } + Ok(Transformed::no(node)) + }) + .data() +} diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 804dd165d335c..0e5fb82d9e93e 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -19,21 +19,18 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::ExecutionPlan; - -use arrow_schema::DataType; +use crate::PhysicalOptimizerRule; +use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalSortExpr; - -use crate::PhysicalOptimizerRule; +use datafusion_physical_expr::LexOrdering; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed @@ -48,12 +45,13 @@ impl TopKAggregation { fn transform_agg( aggr: &AggregateExec, - order: &PhysicalSortExpr, + order_by: &str, + order_desc: bool, limit: usize, ) -> Option> { // ensure the sort direction matches aggregate function let (field, desc) = aggr.get_minmax_desc()?; - if desc != order.options.descending { + if desc != order_desc { return None; } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; @@ -66,8 +64,7 @@ impl TopKAggregation { } // ensure the sort is on the same field as the aggregate output - let col = order.expr.as_any().downcast_ref::()?; - if col.name() != field.name() { + if order_by != field.name() { return None; } @@ -92,16 +89,11 @@ impl TopKAggregation { let child = children.into_iter().exactly_one().ok()?; let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; + let order_desc = order.options.descending; + let order = order.expr.as_any().downcast_ref::()?; + let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; - let is_cardinality_preserving = |plan: Arc| { - plan.as_any() - .downcast_ref::() - .is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - }; - let mut cardinality_preserved = true; let closure = |plan: Arc| { if !cardinality_preserved { @@ -109,20 +101,33 @@ impl TopKAggregation { } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it - match Self::transform_agg(aggr, order, limit) { + match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) { None => cardinality_preserved = false, Some(plan) => return Ok(Transformed::yes(plan)), } + } else if let Some(proj) = plan.as_any().downcast_ref::() { + // track renames due to successive projections + for (src_expr, proj_name) in proj.expr() { + let Some(src_col) = src_expr.as_any().downcast_ref::() else { + continue; + }; + if *proj_name == cur_col_name { + cur_col_name = src_col.name().to_string(); + } + } } else { - // or we continue down whitelisted nodes of other types - if !is_cardinality_preserving(Arc::clone(&plan)) { - cardinality_preserved = false; + // or we continue down through types that don't reduce cardinality + match plan.cardinality_effect() { + CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {} + CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => { + cardinality_preserved = false; + } } } Ok(Transformed::no(plan)) }; let child = Arc::clone(child).transform_down(closure).data().ok()?; - let sort = SortExec::new(sort.expr().to_vec(), child) + let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); Some(Arc::new(sort)) diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs similarity index 78% rename from datafusion/core/src/physical_optimizer/update_aggr_exprs.rs rename to datafusion/physical-optimizer/src/update_aggr_exprs.rs index c0d9140c025e5..6228ed10ec341 100644 --- a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -27,13 +27,15 @@ use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{ reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement, }; -use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_expr::{LexOrdering, LexRequirement}; use datafusion_physical_plan::aggregates::concat_slices; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; use datafusion_physical_plan::{ aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties, }; +use crate::PhysicalOptimizerRule; + /// This optimizer rule checks ordering requirements of aggregate expressions. /// /// There are 3 kinds of aggregators in terms of ordering requirements: @@ -59,6 +61,20 @@ impl OptimizeAggregateOrder { } impl PhysicalOptimizerRule for OptimizeAggregateOrder { + /// Applies the `OptimizeAggregateOrder` rule to the provided execution plan. + /// + /// This function traverses the execution plan tree, identifies `AggregateExec` nodes, + /// and optimizes their aggregate expressions based on existing input orderings. + /// If optimizations are applied, it returns a modified execution plan. + /// + /// # Arguments + /// + /// * `plan` - The root of the execution plan to optimize. + /// * `_config` - Configuration options (currently unused). + /// + /// # Returns + /// + /// A `Result` containing the potentially optimized execution plan or an error. fn optimize( &self, plan: Arc, @@ -84,7 +100,12 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { let requirement = indices .iter() .map(|&idx| { - PhysicalSortRequirement::new(groupby_exprs[idx].clone(), None) + PhysicalSortRequirement::new( + Arc::::clone( + &groupby_exprs[idx], + ), + None, + ) }) .collect::>(); @@ -131,19 +152,17 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { /// successfully. Any errors occurring during the conversion process are /// passed through. fn try_convert_aggregate_if_better( - aggr_exprs: Vec, + aggr_exprs: Vec>, prefix_requirement: &[PhysicalSortRequirement], eq_properties: &EquivalenceProperties, -) -> Result> { +) -> Result>> { aggr_exprs .into_iter() .map(|aggr_expr| { - let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(&[]); + let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(LexOrdering::empty()); let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); - let aggr_sort_reqs = - PhysicalSortRequirement::from_sort_exprs(aggr_sort_exprs); - let reverse_aggr_req = - PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_sort_exprs); + let aggr_sort_reqs = LexRequirement::from(aggr_sort_exprs.clone()); + let reverse_aggr_req = LexRequirement::from(reverse_aggr_sort_exprs); // If the aggregate expression benefits from input ordering, and // there is an actual ordering enabling this, try to update the @@ -151,24 +170,32 @@ fn try_convert_aggregate_if_better( // Otherwise, leave it as is. if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty() { - let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs); + let reqs = LexRequirement { + inner: concat_slices(prefix_requirement, &aggr_sort_reqs), + }; + + let prefix_requirement = LexRequirement { + inner: prefix_requirement.to_vec(), + }; + if eq_properties.ordering_satisfy_requirement(&reqs) { // Existing ordering satisfies the aggregator requirements: - aggr_expr.with_beneficial_ordering(true)? - } else if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &reverse_aggr_req, - )) { + aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) + } else if eq_properties.ordering_satisfy_requirement(&LexRequirement { + inner: concat_slices(&prefix_requirement, &reverse_aggr_req), + }) { // Converting to reverse enables more efficient execution // given the existing ordering (if possible): aggr_expr .reverse_expr() + .map(Arc::new) .unwrap_or(aggr_expr) .with_beneficial_ordering(true)? + .map(Arc::new) } else { // There is no beneficial ordering present -- aggregation // will still work albeit in a less efficient mode. - aggr_expr.with_beneficial_ordering(false)? + aggr_expr.with_beneficial_ordering(false)?.map(Arc::new) } .ok_or_else(|| { plan_datafusion_err!( diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index c3f1b7eb0e95c..83dc9549531de 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -51,8 +51,6 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } @@ -62,13 +60,16 @@ hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } -once_cell = "1.18.0" parking_lot = { workspace = true } pin-project-lite = "^0.2.7" -rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] +criterion = { version = "0.5", features = ["async_futures"] } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-window = { workspace = true } +once_cell = "1.18.0" +rand = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tokio = { workspace = true, features = [ @@ -76,3 +77,7 @@ tokio = { workspace = true, features = [ "fs", "parking_lot", ] } + +[[bench]] +harness = false +name = "spm" diff --git a/datafusion/physical-plan/LICENSE.txt b/datafusion/physical-plan/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/physical-plan/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/physical-plan/NOTICE.txt b/datafusion/physical-plan/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/physical-plan/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-plan/benches/spm.rs b/datafusion/physical-plan/benches/spm.rs new file mode 100644 index 0000000000000..fbbd274091738 --- /dev/null +++ b/datafusion/physical-plan/benches/spm.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::memory::MemoryExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::{collect, ExecutionPlan}; + +use criterion::async_executor::FuturesExecutor; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +fn generate_spm_for_round_robin_tie_breaker( + has_same_value: bool, + enable_round_robin_repartition: bool, + batch_count: usize, + partition_count: usize, +) -> SortPreservingMergeExec { + let row_size = 256; + let rb = if has_same_value { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + } else { + let v = (0i32..row_size as i32).collect::>(); + let a: ArrayRef = Arc::new(Int32Array::from(v)); + + // Use alphanumeric characters + let charset: Vec = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + .chars() + .collect(); + + let mut strings = Vec::new(); + for i in 0..256 { + let mut s = String::new(); + s.push(charset[i % charset.len()]); + s.push(charset[(i / charset.len()) % charset.len()]); + strings.push(Some(s)); + } + + let b: ArrayRef = Arc::new(StringArray::from_iter(strings)); + + let v = (0i64..row_size as i64).collect::>(); + let c: ArrayRef = Arc::new(Int64Array::from_iter(v)); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + }; + + let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); + let partitiones = vec![rbs.clone(); partition_count]; + + let schema = rb.schema(); + let sort = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: Default::default(), + }, + ]); + + let exec = MemoryExec::try_new(&partitiones, schema, None).unwrap(); + SortPreservingMergeExec::new(sort, Arc::new(exec)) + .with_round_robin_repartition(enable_round_robin_repartition) +} + +fn run_bench( + c: &mut Criterion, + has_same_value: bool, + enable_round_robin_repartition: bool, + batch_count: usize, + partition_count: usize, + description: &str, +) { + let task_ctx = TaskContext::default(); + let task_ctx = Arc::new(task_ctx); + + let spm = Arc::new(generate_spm_for_round_robin_tie_breaker( + has_same_value, + enable_round_robin_repartition, + batch_count, + partition_count, + )) as Arc; + + c.bench_function(description, |b| { + b.to_async(FuturesExecutor) + .iter(|| black_box(collect(Arc::clone(&spm), Arc::clone(&task_ctx)))) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let params = [ + (true, false, "low_card_without_tiebreaker"), // low cardinality, no tie breaker + (true, true, "low_card_with_tiebreaker"), // low cardinality, with tie breaker + (false, false, "high_card_without_tiebreaker"), // high cardinality, no tie breaker + (false, true, "high_card_with_tiebreaker"), // high cardinality, with tie breaker + ]; + + let batch_counts = [1, 25, 625]; + let partition_counts = [2, 8, 32]; + + for &(has_same_value, enable_round_robin_repartition, cardinality_label) in ¶ms { + for &batch_count in &batch_counts { + for &partition_count in &partition_counts { + let description = format!( + "{}_batch_count_{}_partition_count_{}", + cardinality_label, batch_count, partition_count + ); + run_bench( + c, + has_same_value, + enable_round_robin_repartition, + batch_count, + partition_count, + &description, + ); + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index fb7b667750924..e4a7eb049e9eb 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -18,26 +18,36 @@ //! [`GroupValues`] trait for storing and interning group keys use arrow::record_batch::RecordBatch; +use arrow_array::types::{ + Date32Type, Date64Type, Decimal128Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; use arrow_array::{downcast_primitive, ArrayRef}; +use arrow_schema::TimeUnit; use arrow_schema::{DataType, SchemaRef}; -use bytes_view::GroupValuesBytesView; use datafusion_common::Result; -pub(crate) mod primitive; use datafusion_expr::EmitTo; -use primitive::GroupValuesPrimitive; -mod column; +pub(crate) mod multi_group_by; + mod row; -use column::GroupValuesColumn; +mod single_group_by; +use datafusion_physical_expr::binary_map::OutputType; +use multi_group_by::GroupValuesColumn; use row::GroupValuesRows; -mod bytes; -mod bytes_view; -use bytes::GroupValuesByes; -use datafusion_physical_expr::binary_map::OutputType; +pub(crate) use single_group_by::primitive::HashValue; + +use crate::aggregates::{ + group_values::single_group_by::{ + bytes::GroupValuesByes, bytes_view::GroupValuesBytesView, + primitive::GroupValuesPrimitive, + }, + order::GroupOrdering, +}; -mod group_column; mod null_builder; /// Stores the group values during hash aggregation. @@ -76,7 +86,7 @@ mod null_builder; /// Each distinct group in a hash aggregation is identified by a unique group id /// (usize) which is assigned by instances of this trait. Group ids are /// continuous without gaps, starting from 0. -pub trait GroupValues: Send { +pub(crate) trait GroupValues: Send { /// Calculates the group id for each input row of `cols`, assigning new /// group ids as necessary. /// @@ -105,7 +115,24 @@ pub trait GroupValues: Send { } /// Return a specialized implementation of [`GroupValues`] for the given schema. -pub fn new_group_values(schema: SchemaRef) -> Result> { +/// +/// [`GroupValues`] implementations choosing logic: +/// +/// - If group by single column, and type of this column has +/// the specific [`GroupValues`] implementation, such implementation +/// will be chosen. +/// +/// - If group by multiple columns, and all column types have the specific +/// [`GroupColumn`] implementations, [`GroupValuesColumn`] will be chosen. +/// +/// - Otherwise, the general implementation [`GroupValuesRows`] will be chosen. +/// +/// [`GroupColumn`]: crate::aggregates::group_values::multi_group_by::GroupColumn +/// +pub(crate) fn new_group_values( + schema: SchemaRef, + group_ordering: &GroupOrdering, +) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); @@ -121,6 +148,31 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { } match d { + DataType::Date32 => { + downcast_helper!(Date32Type, d); + } + DataType::Date64 => { + downcast_helper!(Date64Type, d); + } + DataType::Time32(t) => match t { + TimeUnit::Second => downcast_helper!(Time32SecondType, d), + TimeUnit::Millisecond => downcast_helper!(Time32MillisecondType, d), + _ => {} + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => downcast_helper!(Time64MicrosecondType, d), + TimeUnit::Nanosecond => downcast_helper!(Time64NanosecondType, d), + _ => {} + }, + DataType::Timestamp(t, _tz) => match t { + TimeUnit::Second => downcast_helper!(TimestampSecondType, d), + TimeUnit::Millisecond => downcast_helper!(TimestampMillisecondType, d), + TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), + TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), + }, + DataType::Decimal128(_, _) => { + downcast_helper!(Decimal128Type, d); + } DataType::Utf8 => { return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); } @@ -143,8 +195,12 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { } } - if GroupValuesColumn::supported_schema(schema.as_ref()) { - Ok(Box::new(GroupValuesColumn::try_new(schema)?)) + if multi_group_by::supported_schema(schema.as_ref()) { + if matches!(group_ordering, GroupOrdering::None) { + Ok(Box::new(GroupValuesColumn::::try_new(schema)?)) + } else { + Ok(Box::new(GroupValuesColumn::::try_new(schema)?)) + } } else { Ok(Box::new(GroupValuesRows::try_new(schema)?)) } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs new file mode 100644 index 0000000000000..8e975e10180f1 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs @@ -0,0 +1,633 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::{AsArray, BufferBuilder, GenericBinaryArray, GenericStringArray}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType}; +use arrow_array::types::GenericStringType; +use arrow_array::{Array, ArrayRef, GenericByteArray, OffsetSizeTrait}; +use datafusion_common::utils::proxy::VecAllocExt; +use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; +use itertools::izip; +use std::mem::size_of; +use std::sync::Arc; +use std::vec; + +/// An implementation of [`GroupColumn`] for binary and utf8 types. +/// +/// Stores a collection of binary or utf8 group values in a single buffer +/// in a way that allows: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +pub struct ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + output_type: OutputType, + buffer: BufferBuilder, + /// Offsets into `buffer` for each distinct value. These offsets as used + /// directly to create the final `GenericBinaryArray`. The `i`th string is + /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values + /// are stored as a zero length string. + offsets: Vec, + /// Nulls + nulls: MaybeNullBufferBuilder, +} + +impl ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + pub fn new(output_type: OutputType) -> Self { + Self { + output_type, + buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), + offsets: vec![O::default()], + nulls: MaybeNullBufferBuilder::new(), + } + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool + where + B: ByteArrayType, + { + let array = array.as_bytes::(); + self.do_equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + if arr.is_null(row) { + self.nulls.append(true); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + } else { + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } + } + + fn vectorized_equal_to_inner( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) where + B: ByteArrayType, + { + let array = array.as_bytes::(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to, don't need to check + if !*equal_to_result { + continue; + } + + *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); + } + } + + fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) + where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Some(true) + } else if null_count == num_rows { + Some(false) + } else { + None + }; + + match all_null_or_non_null { + None => { + for &row in rows { + if arr.is_null(row) { + self.nulls.append(true); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + } else { + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } + } + } + + Some(true) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.do_append_val_inner(arr, row); + } + } + + Some(false) => { + self.nulls.append_n(rows.len(), true); + + let new_len = self.offsets.len() + rows.len(); + let offset = self.buffer.len(); + self.offsets.resize(new_len, O::usize_as(offset)); + } + } + } + + fn do_equal_to_inner( + &self, + lhs_row: usize, + array: &GenericByteArray, + rhs_row: usize, + ) -> bool + where + B: ByteArrayType, + { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) + } + + fn do_append_val_inner(&mut self, array: &GenericByteArray, row: usize) + where + B: ByteArrayType, + { + let value: &[u8] = array.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); + } + + /// return the current value of the specified row irrespective of null + pub fn value(&self, row: usize) -> &[u8] { + let l = self.offsets[row].as_usize(); + let r = self.offsets[row + 1].as_usize(); + // Safety: the offsets are constructed correctly and never decrease + unsafe { self.buffer.as_slice().get_unchecked(l..r) } + } +} + +impl GroupColumn for ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn append_val(&mut self, column: &ArrayRef, row: usize) { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.append_val_inner::>(column, row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.append_val_inner::>(column, row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + array.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.vectorized_equal_to_inner::>( + lhs_rows, + array, + rhs_rows, + equal_to_results, + ); + } + OutputType::Utf8 => { + debug_assert!(matches!( + array.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.vectorized_equal_to_inner::>( + lhs_rows, + array, + rhs_rows, + equal_to_results, + ); + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) { + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.vectorized_append_inner::>(column, rows) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.vectorized_append_inner::>(column, rows) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn size(&self) -> usize { + self.buffer.capacity() * size_of::() + + self.offsets.allocated_size() + + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + output_type, + mut buffer, + offsets, + nulls, + } = *self; + + let null_buffer = nulls.build(); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + let values = buffer.finish(); + match output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. the input arrays were all the correct type and thus since + // all the values that went in were valid (e.g. utf8) so are all + // the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + let null_buffer = self.nulls.take_n(n); + let first_remaining_offset = O::as_usize(self.offsets[n]); + + // Given offsets like [0, 2, 4, 5] and n = 1, we expect to get + // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. + // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. + let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); + let offset_n = *self.offsets.first().unwrap(); + self.offsets + .iter_mut() + .for_each(|offset| *offset = offset.sub(offset_n)); + first_n_offsets.push(offset_n); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = + unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) }; + + let mut remaining_buffer = + BufferBuilder::new(self.buffer.len() - first_remaining_offset); + // TODO: Current approach copy the remaining and truncate the original one + // Find out a way to avoid copying buffer but split the original one into two. + remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]); + self.buffer.truncate(first_remaining_offset); + let values = self.buffer.finish(); + self.buffer = remaining_buffer; + + match self.output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. we asserted the input arrays were all the correct type and + // thus since all the values that went in were valid (e.g. utf8) + // so are all the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder; + use arrow_array::{ArrayRef, StringArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + use datafusion_physical_expr::binary_map::OutputType; + + use super::GroupColumn; + + #[test] + fn test_byte_take_n() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; + // a, null, null + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (a, null) remaining: null + let output = builder.take_n(2); + assert_eq!(&output, &array); + + // null, a, null, a + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 0); + + // (null, a) remaining: (null, a) + let output = builder.take_n(2); + let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef; + assert_eq!(&output, &array); + + let array = Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("longstringfortest"), + ])) as ArrayRef; + + // null, a, longstringfortest, null, null + builder.append_val(&array, 2); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (null, a, longstringfortest, null) remaining: (null) + let output = builder.take_n(4); + let array = Arc::new(StringArray::from(vec![ + None, + Some("a"), + Some("longstringfortest"), + None, + ])) as ArrayRef; + assert_eq!(&output, &array); + } + + #[test] + fn test_byte_equal_to() { + let append = |builder: &mut ByteGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &ByteGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_byte_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_vectorized_equal_to() { + let append = |builder: &mut ByteGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &ByteGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_byte_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + + // All nulls input array + let all_nulls_input_array = Arc::new(StringArray::from(vec![ + Option::<&str>::None, + None, + None, + None, + None, + ])) as _; + builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(StringArray::from(vec![ + Some("string1"), + Some("string2"), + Some("string3"), + Some("string4"), + Some("string5"), + ])) as _; + builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } + + fn test_byte_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut ByteGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &ByteGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let builder_array = Arc::new(StringArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bar"), + Some("baz"), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + + // Define input array + let (offsets, buffer, _nulls) = StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + None, + Some("foo"), + Some("baz"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5], + &input_array, + &[0, 1, 2, 3, 4, 5], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(equal_to_results[5]); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs new file mode 100644 index 0000000000000..811790f4e5885 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs @@ -0,0 +1,911 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::{make_view, AsArray, ByteView}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::ByteViewType; +use arrow_array::{Array, ArrayRef, GenericByteViewArray}; +use arrow_buffer::Buffer; +use itertools::izip; +use std::marker::PhantomData; +use std::mem::{replace, size_of}; +use std::sync::Arc; + +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; + +/// An implementation of [`GroupColumn`] for binary view and utf8 view types. +/// +/// Stores a collection of binary view or utf8 view group values in a buffer +/// whose structure is similar to `GenericByteViewArray`, and we can get benefits: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +/// 3. Efficient to perform `take_n` comparing to use `GenericByteViewBuilder` +pub struct ByteViewGroupValueBuilder { + /// The views of string values + /// + /// If string len <= 12, the view's format will be: + /// string(12B) | len(4B) + /// + /// If string len > 12, its format will be: + /// offset(4B) | buffer_index(4B) | prefix(4B) | len(4B) + views: Vec, + + /// The progressing block + /// + /// New values will be inserted into it until its capacity + /// is not enough(detail can see `max_block_size`). + in_progress: Vec, + + /// The completed blocks + completed: Vec, + + /// The max size of `in_progress` + /// + /// `in_progress` will be flushed into `completed`, and create new `in_progress` + /// when found its remaining capacity(`max_block_size` - `len(in_progress)`), + /// is no enough to store the appended value. + /// + /// Currently it is fixed at 2MB. + max_block_size: usize, + + /// Nulls + nulls: MaybeNullBufferBuilder, + + /// phantom data so the type requires `` + _phantom: PhantomData, +} + +impl ByteViewGroupValueBuilder { + pub fn new() -> Self { + Self { + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + max_block_size: BYTE_VIEW_MAX_BLOCK_SIZE, + nulls: MaybeNullBufferBuilder::new(), + _phantom: PhantomData {}, + } + } + + /// Set the max block size + fn with_max_block_size(mut self, max_block_size: usize) -> Self { + self.max_block_size = max_block_size; + self + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + let array = array.as_byte_view::(); + self.do_equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) { + let arr = array.as_byte_view::(); + + // Null row case, set and return + if arr.is_null(row) { + self.nulls.append(true); + self.views.push(0); + return; + } + + // Not null row case + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } + + fn vectorized_equal_to_inner( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + let array = array.as_byte_view::(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to, don't need to check + if !*equal_to_result { + continue; + } + + *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); + } + } + + fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) { + let arr = array.as_byte_view::(); + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Some(true) + } else if null_count == num_rows { + Some(false) + } else { + None + }; + + match all_null_or_non_null { + None => { + for &row in rows { + // Null row case, set and return + if arr.is_valid(row) { + self.nulls.append(false); + self.do_append_val_inner(arr, row); + } else { + self.nulls.append(true); + self.views.push(0); + } + } + } + + Some(true) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.do_append_val_inner(arr, row); + } + } + + Some(false) => { + self.nulls.append_n(rows.len(), true); + let new_len = self.views.len() + rows.len(); + self.views.resize(new_len, 0); + } + } + } + + fn do_append_val_inner(&mut self, array: &GenericByteViewArray, row: usize) + where + B: ByteViewType, + { + let value: &[u8] = array.value(row).as_ref(); + + let value_len = value.len(); + let view = if value_len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure big enough block to hold the value firstly + self.ensure_in_progress_big_enough(value_len); + + // Append value + let buffer_index = self.completed.len(); + let offset = self.in_progress.len(); + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index as u32, offset as u32) + }; + + // Append view + self.views.push(view); + } + + fn ensure_in_progress_big_enough(&mut self, value_len: usize) { + debug_assert!(value_len > 12); + let require_cap = self.in_progress.len() + value_len; + + // If current block isn't big enough, flush it and create a new in progress block + if require_cap > self.max_block_size { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } + } + + fn do_equal_to_inner( + &self, + lhs_row: usize, + array: &GenericByteViewArray, + rhs_row: usize, + ) -> bool { + // Check if nulls equal firstly + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + + // Otherwise, we need to check their values + let exist_view = self.views[lhs_row]; + let exist_view_len = exist_view as u32; + + let input_view = array.views()[rhs_row]; + let input_view_len = input_view as u32; + + // The check logic + // - Check len equality + // - If inlined, check inlined value + // - If non-inlined, check prefix and then check value in buffer + // when needed + if exist_view_len != input_view_len { + return false; + } + + if exist_view_len <= 12 { + let exist_inline = unsafe { + GenericByteViewArray::::inline_value( + &exist_view, + exist_view_len as usize, + ) + }; + let input_inline = unsafe { + GenericByteViewArray::::inline_value( + &input_view, + input_view_len as usize, + ) + }; + exist_inline == input_inline + } else { + let exist_prefix = + unsafe { GenericByteViewArray::::inline_value(&exist_view, 4) }; + let input_prefix = + unsafe { GenericByteViewArray::::inline_value(&input_view, 4) }; + + if exist_prefix != input_prefix { + return false; + } + + let exist_full = { + let byte_view = ByteView::from(exist_view); + self.value( + byte_view.buffer_index as usize, + byte_view.offset as usize, + byte_view.length as usize, + ) + }; + let input_full: &[u8] = unsafe { array.value_unchecked(rhs_row).as_ref() }; + exist_full == input_full + } + } + + fn value(&self, buffer_index: usize, offset: usize, length: usize) -> &[u8] { + debug_assert!(buffer_index <= self.completed.len()); + + if buffer_index < self.completed.len() { + let block = &self.completed[buffer_index]; + &block[offset..offset + length] + } else { + &self.in_progress[offset..offset + length] + } + } + + fn build_inner(self) -> ArrayRef { + let Self { + views, + in_progress, + mut completed, + nulls, + .. + } = self; + + // Build nulls + let null_buffer = nulls.build(); + + // Build values + // Flush `in_process` firstly + if !in_progress.is_empty() { + let buffer = Buffer::from(in_progress); + completed.push(buffer); + } + + let views = ScalarBuffer::from(views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + completed, + null_buffer, + )) + } + } + + fn take_n_inner(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + + // The `n == len` case, we need to take all + if self.len() == n { + let new_builder = Self::new().with_max_block_size(self.max_block_size); + let cur_builder = replace(self, new_builder); + return cur_builder.build_inner(); + } + + // The `n < len` case + // Take n for nulls + let null_buffer = self.nulls.take_n(n); + + // Take n for values: + // - Take first n `view`s from `views` + // + // - Find the last non-inlined `view`, if all inlined, + // we can build array and return happily, otherwise we + // we need to continue to process related buffers + // + // - Get the last related `buffer index`(let's name it `buffer index n`) + // from last non-inlined `view` + // + // - Take buffers, the key is that we need to know if we need to take + // the whole last related buffer. The logic is a bit complex, you can + // detail in `take_buffers_with_whole_last`, `take_buffers_with_partial_last` + // and other related steps in following + // + // - Shift the `buffer index` of remaining non-inlined `views` + // + let first_n_views = self.views.drain(0..n).collect::>(); + + let last_non_inlined_view = first_n_views + .iter() + .rev() + .find(|view| ((**view) as u32) > 12); + + // All taken views inlined + let Some(view) = last_non_inlined_view else { + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + return Arc::new(GenericByteViewArray::::new_unchecked( + views, + Vec::new(), + null_buffer, + )); + } + }; + + // Unfortunately, some taken views non-inlined + let view = ByteView::from(*view); + let last_remaining_buffer_index = view.buffer_index as usize; + + // Check should we take the whole `last_remaining_buffer_index` buffer + let take_whole_last_buffer = self.should_take_whole_buffer( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ); + + // Take related buffers + let buffers = if take_whole_last_buffer { + self.take_buffers_with_whole_last(last_remaining_buffer_index) + } else { + self.take_buffers_with_partial_last( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ) + }; + + // Shift `buffer index`s finally + let shifts = if take_whole_last_buffer { + last_remaining_buffer_index + 1 + } else { + last_remaining_buffer_index + }; + + self.views.iter_mut().for_each(|view| { + if (*view as u32) > 12 { + let mut byte_view = ByteView::from(*view); + byte_view.buffer_index -= shifts as u32; + *view = byte_view.as_u128(); + } + }); + + // Build array and return + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + buffers, + null_buffer, + )) + } + } + + fn take_buffers_with_whole_last( + &mut self, + last_remaining_buffer_index: usize, + ) -> Vec { + if last_remaining_buffer_index == self.completed.len() { + self.flush_in_progress(); + } + self.completed + .drain(0..last_remaining_buffer_index + 1) + .collect() + } + + fn take_buffers_with_partial_last( + &mut self, + last_remaining_buffer_index: usize, + last_take_len: usize, + ) -> Vec { + let mut take_buffers = Vec::with_capacity(last_remaining_buffer_index + 1); + + // Take `0 ~ last_remaining_buffer_index - 1` buffers + if !self.completed.is_empty() || last_remaining_buffer_index == 0 { + take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); + } + + // Process the `last_remaining_buffer_index` buffers + let last_buffer = if last_remaining_buffer_index < self.completed.len() { + // If it is in `completed`, simply clone + self.completed[last_remaining_buffer_index].clone() + } else { + // If it is `in_progress`, copied `0 ~ offset` part + let taken_last_buffer = self.in_progress[0..last_take_len].to_vec(); + Buffer::from_vec(taken_last_buffer) + }; + take_buffers.push(last_buffer); + + take_buffers + } + + #[inline] + fn should_take_whole_buffer(&self, buffer_index: usize, take_len: usize) -> bool { + if buffer_index < self.completed.len() { + take_len == self.completed[buffer_index].len() + } else { + take_len == self.in_progress.len() + } + } + + fn flush_in_progress(&mut self) { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } +} + +impl GroupColumn for ByteViewGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + self.equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + self.append_val_inner(array, row) + } + + fn vectorized_equal_to( + &self, + group_indices: &[usize], + array: &ArrayRef, + rows: &[usize], + equal_to_results: &mut [bool], + ) { + self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + self.vectorized_append_inner(array, rows); + } + + fn len(&self) -> usize { + self.views.len() + } + + fn size(&self) -> usize { + let buffers_size = self + .completed + .iter() + .map(|buf| buf.capacity() * size_of::()) + .sum::(); + + self.nulls.allocated_size() + + self.views.capacity() * size_of::() + + self.in_progress.capacity() * size_of::() + + buffers_size + + size_of::() + } + + fn build(self: Box) -> ArrayRef { + Self::build_inner(*self) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + self.take_n_inner(n) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::aggregates::group_values::multi_group_by::bytes_view::ByteViewGroupValueBuilder; + use arrow::array::AsArray; + use arrow::datatypes::StringViewType; + use arrow_array::{ArrayRef, StringViewArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + + use super::GroupColumn; + + #[test] + fn test_byte_view_append_val() { + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = StringViewArray::from(vec![ + Some("this string is quite long"), // in buffer 0 + Some("foo"), + None, + Some("bar"), + Some("this string is also quite long"), // buffer 0 + Some("this string is quite long"), // buffer 1 + Some("bar"), + ]); + let builder_array: ArrayRef = Arc::new(builder_array); + for row in 0..builder_array.len() { + builder.append_val(&builder_array, row); + } + + let output = Box::new(builder).build(); + // should be 2 output buffers to hold all the data + assert_eq!(output.as_string_view().data_buffers().len(), 2); + assert_eq!(&output, &builder_array) + } + + #[test] + fn test_byte_view_equal_to() { + let append = |builder: &mut ByteViewGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &ByteViewGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_byte_view_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_view_vectorized_equal_to() { + let append = |builder: &mut ByteViewGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &ByteViewGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_byte_view_equal_to_internal(append, equal_to); + } + + #[test] + fn test_byte_view_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + + // All nulls input array + let all_nulls_input_array = Arc::new(StringViewArray::from(vec![ + Option::<&str>::None, + None, + None, + None, + None, + ])) as _; + builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(StringViewArray::from(vec![ + Some("stringview1"), + Some("stringview2"), + Some("stringview3"), + Some("stringview4"), + Some("stringview5"), + ])) as _; + builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } + + fn test_byte_view_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut ByteViewGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &ByteViewGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; value lens not equal + // - exist not null, input not null; value not equal(inlined case) + // - exist not null, input not null; value equal(inlined case) + // + // - exist not null, input not null; value not equal + // (non-inlined case + prefix not equal) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `in_progress`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `in_progress`) + + // Set the block size to 40 for ensuring some unlined values are in `in_progress`, + // and some are in `completed`, so both two branches in `value` function can be covered. + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = Arc::new(StringViewArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bazz"), + Some("foo"), + Some("bar"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in progress"), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6, 7, 8]); + + // Define input array + let (views, buffer, _nulls) = StringViewArray::from(vec![ + Some("foo"), + Some("bar"), // set to null + None, + None, + Some("baz"), + Some("oof"), + Some("bar"), + Some("i am a long string for test eq in completed"), + Some("I am a long string for test eq in COMPLETED"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in PROGRESS"), + Some("I am a long string for test eq in progress"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(9); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringViewArray::new(views, buffer, Some(nulls))) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; input_array.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 8], + &input_array, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(!equal_to_results[5]); + assert!(equal_to_results[6]); + assert!(!equal_to_results[7]); + assert!(!equal_to_results[8]); + assert!(equal_to_results[9]); + assert!(!equal_to_results[10]); + assert!(equal_to_results[11]); + } + + #[test] + fn test_byte_view_take_n() { + // ####### Define cases and init ####### + + // `take_n` is really complex, we should consider and test following situations: + // 1. Take nulls + // 2. Take all `inlined`s + // 3. Take non-inlined + partial last buffer in `completed` + // 4. Take non-inlined + whole last buffer in `completed` + // 5. Take non-inlined + partial last `in_progress` + // 6. Take non-inlined + whole last buffer in `in_progress` + // 7. Take all views at once + + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let input_array = StringViewArray::from(vec![ + // Test situation 1 + None, + None, + // Test situation 2 (also test take null together) + None, + Some("foo"), + Some("bar"), + // Test situation 3 (also test take null + inlined) + None, + Some("foo"), + Some("this string is quite long"), + Some("this string is also quite long"), + // Test situation 4 (also test take null + inlined) + None, + Some("bar"), + Some("this string is quite long"), + // Test situation 5 (also test take null + inlined) + None, + Some("foo"), + Some("another string that is is quite long"), + Some("this string not so long"), + // Test situation 6 (also test take null + inlined + insert again after taking) + None, + Some("bar"), + Some("this string is quite long"), + // Insert 4 and just take 3 to ensure it will go the path of situation 6 + None, + // Finally, we create a new builder, insert the whole array and then + // take whole at once for testing situation 7 + ]); + + let input_array: ArrayRef = Arc::new(input_array); + let first_ones_to_append = 16; // For testing situation 1~5 + let second_ones_to_append = 4; // For testing situation 6 + let final_ones_to_append = input_array.len(); // For testing situation 7 + + // ####### Test situation 1~5 ####### + for row in 0..first_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 2); + assert_eq!(builder.in_progress.len(), 59); + + // Situation 1 + let taken_array = builder.take_n(2); + assert_eq!(&taken_array, &input_array.slice(0, 2)); + + // Situation 2 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(2, 3)); + + // Situation 3 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(5, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(8, 1)); + + // Situation 4 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(9, 3)); + + // Situation 5 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(12, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(15, 1)); + + // ####### Test situation 6 ####### + assert!(builder.completed.is_empty()); + assert!(builder.in_progress.is_empty()); + assert!(builder.views.is_empty()); + + for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { + builder.append_val(&input_array, row); + } + + assert!(builder.completed.is_empty()); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(16, 3)); + + // ####### Test situation 7 ####### + // Create a new builder + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + + for row in 0..final_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 3); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(final_ones_to_append); + assert_eq!(&taken_array, &input_array); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs new file mode 100644 index 0000000000000..540f9c3c64804 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -0,0 +1,1796 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `GroupValues` implementations for multi group by cases + +mod bytes; +mod bytes_view; +mod primitive; + +use std::mem::{self, size_of}; + +use crate::aggregates::group_values::multi_group_by::{ + bytes::ByteGroupValueBuilder, bytes_view::ByteViewGroupValueBuilder, + primitive::PrimitiveGroupValueBuilder, +}; +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::compute::cast; +use arrow::datatypes::{ + BinaryViewType, Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, StringViewType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; +use datafusion_expr::EmitTo; +use datafusion_physical_expr::binary_map::OutputType; + +use hashbrown::hash_table::HashTable; + +const NON_INLINED_FLAG: u64 = 0x8000000000000000; +const VALUE_MASK: u64 = 0x7FFFFFFFFFFFFFFF; + +/// Trait for storing a single column of group values in [`GroupValuesColumn`] +/// +/// Implementations of this trait store an in-progress collection of group values +/// (similar to various builders in Arrow-rs) that allow for quick comparison to +/// incoming rows. +/// +/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn +pub trait GroupColumn: Send + Sync { + /// Returns equal if the row stored in this builder at `lhs_row` is equal to + /// the row in `array` at `rhs_row` + /// + /// Note that this comparison returns true if both elements are NULL + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; + + /// Appends the row at `row` in `array` to this builder + fn append_val(&mut self, array: &ArrayRef, row: usize); + + /// The vectorized version equal to + /// + /// When found nth row stored in this builder at `lhs_row` + /// is equal to the row in `array` at `rhs_row`, + /// it will record the `true` result at the corresponding + /// position in `equal_to_results`. + /// + /// And if found nth result in `equal_to_results` is already + /// `false`, the check for nth row will be skipped. + /// + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ); + + /// The vectorized version `append_val` + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]); + + /// Returns the number of rows stored in this builder + fn len(&self) -> usize; + + /// Returns the number of bytes used by this [`GroupColumn`] + fn size(&self) -> usize; + + /// Builds a new array from all of the stored rows + fn build(self: Box) -> ArrayRef; + + /// Builds a new array from the first `n` stored rows, shifting the + /// remaining rows to the start of the builder + fn take_n(&mut self, n: usize) -> ArrayRef; +} + +/// Determines if the nullability of the existing and new input array can be used +/// to short-circuit the comparison of the two values. +/// +/// Returns `Some(result)` if the result of the comparison can be determined +/// from the nullness of the two values, and `None` if the comparison must be +/// done on the values themselves. +pub fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { + match (lhs_null, rhs_null) { + (true, true) => Some(true), + (false, true) | (true, false) => Some(false), + _ => None, + } +} + +/// The view of indices pointing to the actual values in `GroupValues` +/// +/// If only single `group index` represented by view, +/// value of view is just the `group index`, and we call it a `inlined view`. +/// +/// If multiple `group indices` represented by view, +/// value of view is the actually the index pointing to `group indices`, +/// and we call it `non-inlined view`. +/// +/// The view(a u64) format is like: +/// +---------------------+---------------------------------------------+ +/// | inlined flag(1bit) | group index / index to group indices(63bit) | +/// +---------------------+---------------------------------------------+ +/// +/// `inlined flag`: 1 represents `non-inlined`, and 0 represents `inlined` +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct GroupIndexView(u64); + +impl GroupIndexView { + #[inline] + pub fn is_non_inlined(&self) -> bool { + (self.0 & NON_INLINED_FLAG) > 0 + } + + #[inline] + pub fn new_inlined(group_index: u64) -> Self { + Self(group_index) + } + + #[inline] + pub fn new_non_inlined(list_offset: u64) -> Self { + let non_inlined_value = list_offset | NON_INLINED_FLAG; + Self(non_inlined_value) + } + + #[inline] + pub fn value(&self) -> u64 { + self.0 & VALUE_MASK + } +} + +/// A [`GroupValues`] that stores multiple columns of group values, +/// and supports vectorized operators for them +/// +pub struct GroupValuesColumn { + /// The output schema + schema: SchemaRef, + + /// Logically maps group values to a group_index in + /// [`Self::group_values`] and in each accumulator + /// + /// It is a `hashtable` based on `hashbrown`. + /// + /// Key and value in the `hashtable`: + /// - The `key` is `hash value(u64)` of the `group value` + /// - The `value` is the `group values` with the same `hash value` + /// + /// We don't really store the actual `group values` in `hashtable`, + /// instead we store the `group indices` pointing to values in `GroupValues`. + /// And we use [`GroupIndexView`] to represent such `group indices` in table. + /// + /// + map: HashTable<(u64, GroupIndexView)>, + + /// The size of `map` in bytes + map_size: usize, + + /// The lists for group indices with the same hash value + /// + /// It is possible that hash value collision exists, + /// and we will chain the `group indices` with same hash value + /// + /// The chained indices is like: + /// `latest group index -> older group index -> even older group index -> ...` + /// + group_index_lists: Vec>, + + /// When emitting first n, we need to decrease/erase group indices in + /// `map` and `group_index_lists`. + /// + /// This buffer is used to temporarily store the remaining group indices in + /// a specific list in `group_index_lists`. + emit_group_index_list_buffer: Vec, + + /// Buffers for `vectorized_append` and `vectorized_equal_to` + vectorized_operation_buffers: VectorizedOperationBuffers, + + /// The actual group by values, stored column-wise. Compare from + /// the left to right, each column is stored as [`GroupColumn`]. + /// + /// Performance tests showed that this design is faster than using the + /// more general purpose [`GroupValuesRows`]. See the ticket for details: + /// + /// + /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows + group_values: Vec>, + + /// reused buffer to store hashes + hashes_buffer: Vec, + + /// Random state for creating hashes + random_state: RandomState, +} + +/// Buffers to store intermediate results in `vectorized_append` +/// and `vectorized_equal_to`, for reducing memory allocation +#[derive(Default)] +struct VectorizedOperationBuffers { + /// The `vectorized append` row indices buffer + append_row_indices: Vec, + + /// The `vectorized_equal_to` row indices buffer + equal_to_row_indices: Vec, + + /// The `vectorized_equal_to` group indices buffer + equal_to_group_indices: Vec, + + /// The `vectorized_equal_to` result buffer + equal_to_results: Vec, + + /// The buffer for storing row indices found not equal to + /// exist groups in `group_values` in `vectorized_equal_to`. + /// We will perform `scalarized_intern` for such rows. + remaining_row_indices: Vec, +} + +impl VectorizedOperationBuffers { + fn clear(&mut self) { + self.append_row_indices.clear(); + self.equal_to_row_indices.clear(); + self.equal_to_group_indices.clear(); + self.equal_to_results.clear(); + self.remaining_row_indices.clear(); + } +} + +impl GroupValuesColumn { + // ======================================================================== + // Initialization functions + // ======================================================================== + + /// Create a new instance of GroupValuesColumn if supported for the specified schema + pub fn try_new(schema: SchemaRef) -> Result { + let map = HashTable::with_capacity(0); + Ok(Self { + schema, + map, + group_index_lists: Vec::new(), + emit_group_index_list_buffer: Vec::new(), + vectorized_operation_buffers: VectorizedOperationBuffers::default(), + map_size: 0, + group_values: vec![], + hashes_buffer: Default::default(), + random_state: Default::default(), + }) + } + + // ======================================================================== + // Scalarized intern + // ======================================================================== + + /// Scalarized intern + /// + /// This is used only for `streaming aggregation`, because `streaming aggregation` + /// depends on the order between `input rows` and their corresponding `group indices`. + /// + /// For example, assuming `input rows` in `cols` with 4 new rows + /// (not equal to `exist rows` in `group_values`, and need to create + /// new groups for them): + /// + /// ```text + /// row1 (hash collision with the exist rows) + /// row2 + /// row3 (hash collision with the exist rows) + /// row4 + /// ``` + /// + /// # In `scalarized_intern`, their `group indices` will be + /// + /// ```text + /// row1 --> 0 + /// row2 --> 1 + /// row3 --> 2 + /// row4 --> 3 + /// ``` + /// + /// `Group indices` order agrees with their input order, and the `streaming aggregation` + /// depends on this. + /// + /// # However In `vectorized_intern`, their `group indices` will be + /// + /// ```text + /// row1 --> 2 + /// row2 --> 0 + /// row3 --> 3 + /// row4 --> 1 + /// ``` + /// + /// `Group indices` order are against with their input order, and this will lead to error + /// in `streaming aggregation`. + /// + fn scalarized_intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + ) -> Result<()> { + let n_rows = cols[0].len(); + + // tracks to which group each of the input rows belongs + groups.clear(); + + // 1.1 Calculate the group keys for the group values + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &target_hash) in batch_hashes.iter().enumerate() { + let entry = self + .map + .find_mut(target_hash, |(exist_hash, group_idx_view)| { + // It is ensured to be inlined in `scalarized_intern` + debug_assert!(!group_idx_view.is_non_inlined()); + + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 + if target_hash != *exist_hash { + return false; + } + + fn check_row_equal( + array_row: &dyn GroupColumn, + lhs_row: usize, + array: &ArrayRef, + rhs_row: usize, + ) -> bool { + array_row.equal_to(lhs_row, array, rhs_row) + } + + for (i, group_val) in self.group_values.iter().enumerate() { + if !check_row_equal( + group_val.as_ref(), + group_idx_view.value() as usize, + &cols[i], + row, + ) { + return false; + } + } + + true + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx_view)) => group_idx_view.value() as usize, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + // let group_idx = group_values.num_rows(); + // group_values.push(group_rows.row(row)); + + let mut checklen = 0; + let group_idx = self.group_values[0].len(); + for (i, group_value) in self.group_values.iter_mut().enumerate() { + group_value.append_val(&cols[i], row); + let len = group_value.len(); + if i == 0 { + checklen = len; + } else { + debug_assert_eq!(checklen, len); + } + } + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (target_hash, GroupIndexView::new_inlined(group_idx as u64)), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + + Ok(()) + } + + // ======================================================================== + // Vectorized intern + // ======================================================================== + + /// Vectorized intern + /// + /// This is used in `non-streaming aggregation` without requiring the order between + /// rows in `cols` and corresponding groups in `group_values`. + /// + /// The vectorized approach can offer higher performance for avoiding row by row + /// downcast for `cols` and being able to implement even more optimizations(like simd). + /// + fn vectorized_intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + ) -> Result<()> { + let n_rows = cols[0].len(); + + // tracks to which group each of the input rows belongs + groups.clear(); + groups.resize(n_rows, usize::MAX); + + let mut batch_hashes = mem::take(&mut self.hashes_buffer); + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, &mut batch_hashes)?; + + // General steps for one round `vectorized equal_to & append`: + // 1. Collect vectorized context by checking hash values of `cols` in `map`, + // mainly fill `vectorized_append_row_indices`, `vectorized_equal_to_row_indices` + // and `vectorized_equal_to_group_indices` + // + // 2. Perform `vectorized_append` for `vectorized_append_row_indices`. + // `vectorized_append` must be performed before `vectorized_equal_to`, + // because some `group indices` in `vectorized_equal_to_group_indices` + // maybe still point to no actual values in `group_values` before performing append. + // + // 3. Perform `vectorized_equal_to` for `vectorized_equal_to_row_indices` + // and `vectorized_equal_to_group_indices`. If found some rows in input `cols` + // not equal to `exist rows` in `group_values`, place them in `remaining_row_indices` + // and perform `scalarized_intern_remaining` for them similar as `scalarized_intern` + // after. + // + // 4. Perform `scalarized_intern_remaining` for rows mentioned above, about in what situation + // we will process this can see the comments of `scalarized_intern_remaining`. + // + + // 1. Collect vectorized context by checking hash values of `cols` in `map` + self.collect_vectorized_process_context(&batch_hashes, groups); + + // 2. Perform `vectorized_append` + self.vectorized_append(cols); + + // 3. Perform `vectorized_equal_to` + self.vectorized_equal_to(cols, groups); + + // 4. Perform scalarized inter for remaining rows + // (about remaining rows, can see comments for `remaining_row_indices`) + self.scalarized_intern_remaining(cols, &batch_hashes, groups); + + self.hashes_buffer = batch_hashes; + + Ok(()) + } + + /// Collect vectorized context by checking hash values of `cols` in `map` + /// + /// 1. If bucket not found + /// - Build and insert the `new inlined group index view` + /// and its hash value to `map` + /// - Add row index to `vectorized_append_row_indices` + /// - Set group index to row in `groups` + /// + /// 2. bucket found + /// - Add row index to `vectorized_equal_to_row_indices` + /// - Check if the `group index view` is `inlined` or `non_inlined`: + /// If it is inlined, add to `vectorized_equal_to_group_indices` directly. + /// Otherwise get all group indices from `group_index_lists`, and add them. + /// + fn collect_vectorized_process_context( + &mut self, + batch_hashes: &[u64], + groups: &mut [usize], + ) { + self.vectorized_operation_buffers.append_row_indices.clear(); + self.vectorized_operation_buffers + .equal_to_row_indices + .clear(); + self.vectorized_operation_buffers + .equal_to_group_indices + .clear(); + + let mut group_values_len = self.group_values[0].len(); + for (row, &target_hash) in batch_hashes.iter().enumerate() { + let entry = self + .map + .find(target_hash, |(exist_hash, _)| target_hash == *exist_hash); + + let Some((_, group_index_view)) = entry else { + // 1. Bucket not found case + // Build `new inlined group index view` + let current_group_idx = group_values_len; + let group_index_view = + GroupIndexView::new_inlined(current_group_idx as u64); + + // Insert the `group index view` and its hash into `map` + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (target_hash, group_index_view), + |(hash, _)| *hash, + &mut self.map_size, + ); + + // Add row index to `vectorized_append_row_indices` + self.vectorized_operation_buffers + .append_row_indices + .push(row); + + // Set group index to row in `groups` + groups[row] = current_group_idx; + + group_values_len += 1; + continue; + }; + + // 2. bucket found + // Check if the `group index view` is `inlined` or `non_inlined` + if group_index_view.is_non_inlined() { + // Non-inlined case, the value of view is offset in `group_index_lists`. + // We use it to get `group_index_list`, and add related `rows` and `group_indices` + // into `vectorized_equal_to_row_indices` and `vectorized_equal_to_group_indices`. + let list_offset = group_index_view.value() as usize; + let group_index_list = &self.group_index_lists[list_offset]; + for &group_index in group_index_list { + self.vectorized_operation_buffers + .equal_to_row_indices + .push(row); + self.vectorized_operation_buffers + .equal_to_group_indices + .push(group_index); + } + } else { + let group_index = group_index_view.value() as usize; + self.vectorized_operation_buffers + .equal_to_row_indices + .push(row); + self.vectorized_operation_buffers + .equal_to_group_indices + .push(group_index); + } + } + } + + /// Perform `vectorized_append`` for `rows` in `vectorized_append_row_indices` + fn vectorized_append(&mut self, cols: &[ArrayRef]) { + if self + .vectorized_operation_buffers + .append_row_indices + .is_empty() + { + return; + } + + let iter = self.group_values.iter_mut().zip(cols.iter()); + for (group_column, col) in iter { + group_column.vectorized_append( + col, + &self.vectorized_operation_buffers.append_row_indices, + ); + } + } + + /// Perform `vectorized_equal_to` + /// + /// 1. Perform `vectorized_equal_to` for `rows` in `vectorized_equal_to_group_indices` + /// and `group_indices` in `vectorized_equal_to_group_indices`. + /// + /// 2. Check `equal_to_results`: + /// + /// If found equal to `rows`, set the `group_indices` to `rows` in `groups`. + /// + /// If found not equal to `row`s, just add them to `scalarized_indices`, + /// and perform `scalarized_intern` for them after. + /// Usually, such `rows` having same hash but different value with `exists rows` + /// are very few. + fn vectorized_equal_to(&mut self, cols: &[ArrayRef], groups: &mut [usize]) { + assert_eq!( + self.vectorized_operation_buffers + .equal_to_group_indices + .len(), + self.vectorized_operation_buffers.equal_to_row_indices.len() + ); + + self.vectorized_operation_buffers + .remaining_row_indices + .clear(); + + if self + .vectorized_operation_buffers + .equal_to_group_indices + .is_empty() + { + return; + } + + // 1. Perform `vectorized_equal_to` for `rows` in `vectorized_equal_to_group_indices` + // and `group_indices` in `vectorized_equal_to_group_indices` + let mut equal_to_results = + mem::take(&mut self.vectorized_operation_buffers.equal_to_results); + equal_to_results.clear(); + equal_to_results.resize( + self.vectorized_operation_buffers + .equal_to_group_indices + .len(), + true, + ); + + for (col_idx, group_col) in self.group_values.iter().enumerate() { + group_col.vectorized_equal_to( + &self.vectorized_operation_buffers.equal_to_group_indices, + &cols[col_idx], + &self.vectorized_operation_buffers.equal_to_row_indices, + &mut equal_to_results, + ); + } + + // 2. Check `equal_to_results`, if found not equal to `row`s, just add them + // to `scalarized_indices`, and perform `scalarized_intern` for them after. + let mut current_row_equal_to_result = false; + for (idx, &row) in self + .vectorized_operation_buffers + .equal_to_row_indices + .iter() + .enumerate() + { + let equal_to_result = equal_to_results[idx]; + + // Equal to case, set the `group_indices` to `rows` in `groups` + if equal_to_result { + groups[row] = + self.vectorized_operation_buffers.equal_to_group_indices[idx]; + } + current_row_equal_to_result |= equal_to_result; + + // Look forward next one row to check if have checked all results + // of current row + let next_row = self + .vectorized_operation_buffers + .equal_to_row_indices + .get(idx + 1) + .unwrap_or(&usize::MAX); + + // Have checked all results of current row, check the total result + if row != *next_row { + // Not equal to case, add `row` to `scalarized_indices` + if !current_row_equal_to_result { + self.vectorized_operation_buffers + .remaining_row_indices + .push(row); + } + + // Init the total result for checking next row + current_row_equal_to_result = false; + } + } + + self.vectorized_operation_buffers.equal_to_results = equal_to_results; + } + + /// It is possible that some `input rows` have the same + /// hash values with the `exist rows`, but have the different + /// actual values the exists. + /// + /// We can found them in `vectorized_equal_to`, and put them + /// into `scalarized_indices`. And for these `input rows`, + /// we will perform the `scalarized_intern` similar as what in + /// [`GroupValuesColumn`]. + /// + /// This design can make the process simple and still efficient enough: + /// + /// # About making the process simple + /// + /// Some corner cases become really easy to solve, like following cases: + /// + /// ```text + /// input row1 (same hash value with exist rows, but value different) + /// input row1 + /// ... + /// input row1 + /// ``` + /// + /// After performing `vectorized_equal_to`, we will found multiple `input rows` + /// not equal to the `exist rows`. However such `input rows` are repeated, only + /// one new group should be create for them. + /// + /// If we don't fallback to `scalarized_intern`, it is really hard for us to + /// distinguish the such `repeated rows` in `input rows`. And if we just fallback, + /// it is really easy to solve, and the performance is at least not worse than origin. + /// + /// # About performance + /// + /// The hash collision may be not frequent, so the fallback will indeed hardly happen. + /// In most situations, `scalarized_indices` will found to be empty after finishing to + /// preform `vectorized_equal_to`. + /// + fn scalarized_intern_remaining( + &mut self, + cols: &[ArrayRef], + batch_hashes: &[u64], + groups: &mut [usize], + ) { + if self + .vectorized_operation_buffers + .remaining_row_indices + .is_empty() + { + return; + } + + let mut map = mem::take(&mut self.map); + + for &row in &self.vectorized_operation_buffers.remaining_row_indices { + let target_hash = batch_hashes[row]; + let entry = map.find_mut(target_hash, |(exist_hash, _)| { + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 + target_hash == *exist_hash + }); + + // Only `rows` having the same hash value with `exist rows` but different value + // will be process in `scalarized_intern`. + // So related `buckets` in `map` is ensured to be `Some`. + let Some((_, group_index_view)) = entry else { + unreachable!() + }; + + // Perform scalarized equal to + if self.scalarized_equal_to_remaining(group_index_view, cols, row, groups) { + // Found the row actually exists in group values, + // don't need to create new group for it. + continue; + } + + // Insert the `row` to `group_values` before checking `next row` + let group_idx = self.group_values[0].len(); + let mut checklen = 0; + for (i, group_value) in self.group_values.iter_mut().enumerate() { + group_value.append_val(&cols[i], row); + let len = group_value.len(); + if i == 0 { + checklen = len; + } else { + debug_assert_eq!(checklen, len); + } + } + + // Check if the `view` is `inlined` or `non-inlined` + if group_index_view.is_non_inlined() { + // Non-inlined case, get `group_index_list` from `group_index_lists`, + // then add the new `group` with the same hash values into it. + let list_offset = group_index_view.value() as usize; + let group_index_list = &mut self.group_index_lists[list_offset]; + group_index_list.push(group_idx); + } else { + // Inlined case + let list_offset = self.group_index_lists.len(); + + // Create new `group_index_list` including + // `exist group index` + `new group index`. + // Add new `group_index_list` into ``group_index_lists`. + let exist_group_index = group_index_view.value() as usize; + let new_group_index_list = vec![exist_group_index, group_idx]; + self.group_index_lists.push(new_group_index_list); + + // Update the `group_index_view` to non-inlined + let new_group_index_view = + GroupIndexView::new_non_inlined(list_offset as u64); + *group_index_view = new_group_index_view; + } + + groups[row] = group_idx; + } + + self.map = map; + } + + fn scalarized_equal_to_remaining( + &self, + group_index_view: &GroupIndexView, + cols: &[ArrayRef], + row: usize, + groups: &mut [usize], + ) -> bool { + // Check if this row exists in `group_values` + fn check_row_equal( + array_row: &dyn GroupColumn, + lhs_row: usize, + array: &ArrayRef, + rhs_row: usize, + ) -> bool { + array_row.equal_to(lhs_row, array, rhs_row) + } + + if group_index_view.is_non_inlined() { + let list_offset = group_index_view.value() as usize; + let group_index_list = &self.group_index_lists[list_offset]; + + for &group_idx in group_index_list { + let mut check_result = true; + for (i, group_val) in self.group_values.iter().enumerate() { + if !check_row_equal(group_val.as_ref(), group_idx, &cols[i], row) { + check_result = false; + break; + } + } + + if check_result { + groups[row] = group_idx; + return true; + } + } + + // All groups unmatched, return false result + false + } else { + let group_idx = group_index_view.value() as usize; + for (i, group_val) in self.group_values.iter().enumerate() { + if !check_row_equal(group_val.as_ref(), group_idx, &cols[i], row) { + return false; + } + } + + groups[row] = group_idx; + true + } + } + + /// Return group indices of the hash, also if its `group_index_view` is non-inlined + #[cfg(test)] + fn get_indices_by_hash(&self, hash: u64) -> Option<(Vec, GroupIndexView)> { + let entry = self.map.find(hash, |(exist_hash, _)| hash == *exist_hash); + + match entry { + Some((_, group_index_view)) => { + if group_index_view.is_non_inlined() { + let list_offset = group_index_view.value() as usize; + Some(( + self.group_index_lists[list_offset].clone(), + *group_index_view, + )) + } else { + let group_index = group_index_view.value() as usize; + Some((vec![group_index], *group_index_view)) + } + } + None => None, + } + } +} + +/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v +/// +/// Arguments: +/// `$v`: the vector to push the new builder into +/// `$nullable`: whether the input can contains nulls +/// `$t`: the primitive type of the builder +/// +macro_rules! instantiate_primitive { + ($v:expr, $nullable:expr, $t:ty, $data_type:ident) => { + if $nullable { + let b = PrimitiveGroupValueBuilder::<$t, true>::new($data_type.to_owned()); + $v.push(Box::new(b) as _) + } else { + let b = PrimitiveGroupValueBuilder::<$t, false>::new($data_type.to_owned()); + $v.push(Box::new(b) as _) + } + }; +} + +impl GroupValues for GroupValuesColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + if self.group_values.is_empty() { + let mut v = Vec::with_capacity(cols.len()); + + for f in self.schema.fields().iter() { + let nullable = f.is_nullable(); + let data_type = f.data_type(); + match data_type { + &DataType::Int8 => { + instantiate_primitive!(v, nullable, Int8Type, data_type) + } + &DataType::Int16 => { + instantiate_primitive!(v, nullable, Int16Type, data_type) + } + &DataType::Int32 => { + instantiate_primitive!(v, nullable, Int32Type, data_type) + } + &DataType::Int64 => { + instantiate_primitive!(v, nullable, Int64Type, data_type) + } + &DataType::UInt8 => { + instantiate_primitive!(v, nullable, UInt8Type, data_type) + } + &DataType::UInt16 => { + instantiate_primitive!(v, nullable, UInt16Type, data_type) + } + &DataType::UInt32 => { + instantiate_primitive!(v, nullable, UInt32Type, data_type) + } + &DataType::UInt64 => { + instantiate_primitive!(v, nullable, UInt64Type, data_type) + } + &DataType::Float32 => { + instantiate_primitive!(v, nullable, Float32Type, data_type) + } + &DataType::Float64 => { + instantiate_primitive!(v, nullable, Float64Type, data_type) + } + &DataType::Date32 => { + instantiate_primitive!(v, nullable, Date32Type, data_type) + } + &DataType::Date64 => { + instantiate_primitive!(v, nullable, Date64Type, data_type) + } + &DataType::Time32(t) => match t { + TimeUnit::Second => { + instantiate_primitive!( + v, + nullable, + Time32SecondType, + data_type + ) + } + TimeUnit::Millisecond => { + instantiate_primitive!( + v, + nullable, + Time32MillisecondType, + data_type + ) + } + _ => {} + }, + &DataType::Time64(t) => match t { + TimeUnit::Microsecond => { + instantiate_primitive!( + v, + nullable, + Time64MicrosecondType, + data_type + ) + } + TimeUnit::Nanosecond => { + instantiate_primitive!( + v, + nullable, + Time64NanosecondType, + data_type + ) + } + _ => {} + }, + &DataType::Timestamp(t, _) => match t { + TimeUnit::Second => { + instantiate_primitive!( + v, + nullable, + TimestampSecondType, + data_type + ) + } + TimeUnit::Millisecond => { + instantiate_primitive!( + v, + nullable, + TimestampMillisecondType, + data_type + ) + } + TimeUnit::Microsecond => { + instantiate_primitive!( + v, + nullable, + TimestampMicrosecondType, + data_type + ) + } + TimeUnit::Nanosecond => { + instantiate_primitive!( + v, + nullable, + TimestampNanosecondType, + data_type + ) + } + }, + &DataType::Decimal128(_, _) => { + instantiate_primitive! { + v, + nullable, + Decimal128Type, + data_type + } + } + &DataType::Utf8 => { + let b = ByteGroupValueBuilder::::new(OutputType::Utf8); + v.push(Box::new(b) as _) + } + &DataType::LargeUtf8 => { + let b = ByteGroupValueBuilder::::new(OutputType::Utf8); + v.push(Box::new(b) as _) + } + &DataType::Binary => { + let b = ByteGroupValueBuilder::::new(OutputType::Binary); + v.push(Box::new(b) as _) + } + &DataType::LargeBinary => { + let b = ByteGroupValueBuilder::::new(OutputType::Binary); + v.push(Box::new(b) as _) + } + &DataType::Utf8View => { + let b = ByteViewGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } + &DataType::BinaryView => { + let b = ByteViewGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } + dt => { + return not_impl_err!("{dt} not supported in GroupValuesColumn") + } + } + } + self.group_values = v; + } + + if !STREAMING { + self.vectorized_intern(cols, groups) + } else { + self.scalarized_intern(cols, groups) + } + } + + fn size(&self) -> usize { + let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum(); + group_values_size + self.map_size + self.hashes_buffer.allocated_size() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + if self.group_values.is_empty() { + return 0; + } + + self.group_values[0].len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let mut output = match emit_to { + EmitTo::All => { + let group_values = mem::take(&mut self.group_values); + debug_assert!(self.group_values.is_empty()); + + group_values + .into_iter() + .map(|v| v.build()) + .collect::>() + } + EmitTo::First(n) => { + let output = self + .group_values + .iter_mut() + .map(|v| v.take_n(n)) + .collect::>(); + let mut next_new_list_offset = 0; + + self.map.retain(|(_exist_hash, group_idx_view)| { + // In non-streaming case, we need to check if the `group index view` + // is `inlined` or `non-inlined` + if !STREAMING && group_idx_view.is_non_inlined() { + // Non-inlined case + // We take `group_index_list` from `old_group_index_lists` + + // list_offset is incrementally + self.emit_group_index_list_buffer.clear(); + let list_offset = group_idx_view.value() as usize; + for group_index in self.group_index_lists[list_offset].iter() { + if let Some(remaining) = group_index.checked_sub(n) { + self.emit_group_index_list_buffer.push(remaining); + } + } + + // The possible results: + // - `new_group_index_list` is empty, we should erase this bucket + // - only one value in `new_group_index_list`, switch the `view` to `inlined` + // - still multiple values in `new_group_index_list`, build and set the new `unlined view` + if self.emit_group_index_list_buffer.is_empty() { + false + } else if self.emit_group_index_list_buffer.len() == 1 { + let group_index = + self.emit_group_index_list_buffer.first().unwrap(); + *group_idx_view = + GroupIndexView::new_inlined(*group_index as u64); + true + } else { + let group_index_list = + &mut self.group_index_lists[next_new_list_offset]; + group_index_list.clear(); + group_index_list + .extend(self.emit_group_index_list_buffer.iter()); + *group_idx_view = GroupIndexView::new_non_inlined( + next_new_list_offset as u64, + ); + next_new_list_offset += 1; + true + } + } else { + // In `streaming case`, the `group index view` is ensured to be `inlined` + debug_assert!(!group_idx_view.is_non_inlined()); + + // Inlined case, we just decrement group index by n) + let group_index = group_idx_view.value() as usize; + match group_index.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + *group_idx_view = GroupIndexView::new_inlined(sub as u64); + true + } + // Group index was < n, so remove from table + None => false, + } + } + }); + + if !STREAMING { + self.group_index_lists.truncate(next_new_list_offset); + } + + output + } + }; + + // TODO: Materialize dictionaries in group keys (#7647) + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + + Ok(output) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.group_values.clear(); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); + self.hashes_buffer.clear(); + self.hashes_buffer.shrink_to(count); + + // Such structures are only used in `non-streaming` case + if !STREAMING { + self.group_index_lists.clear(); + self.emit_group_index_list_buffer.clear(); + self.vectorized_operation_buffers.clear(); + } + } +} + +/// Returns true if [`GroupValuesColumn`] supported for the specified schema +pub fn supported_schema(schema: &Schema) -> bool { + schema + .fields() + .iter() + .map(|f| f.data_type()) + .all(supported_type) +} + +/// Returns true if the specified data type is supported by [`GroupValuesColumn`] +/// +/// In order to be supported, there must be a specialized implementation of +/// [`GroupColumn`] for the data type, instantiated in [`GroupValuesColumn::intern`] +fn supported_type(data_type: &DataType) -> bool { + matches!( + *data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Timestamp(_, _) + | DataType::Utf8View + | DataType::BinaryView + ) +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use arrow::{compute::concat_batches, util::pretty::pretty_format_batches}; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StringViewArray}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::utils::proxy::HashTableAllocExt; + use datafusion_expr::EmitTo; + + use crate::aggregates::group_values::{ + multi_group_by::GroupValuesColumn, GroupValues, + }; + + use super::GroupIndexView; + + #[test] + fn test_intern_for_vectorized_group_values() { + let data_set = VectorizedTestDataSet::new(); + let mut group_values = + GroupValuesColumn::::try_new(data_set.schema()).unwrap(); + + data_set.load_to_group_values(&mut group_values); + let actual_batch = group_values.emit(EmitTo::All).unwrap(); + let actual_batch = RecordBatch::try_new(data_set.schema(), actual_batch).unwrap(); + + check_result(&actual_batch, &data_set.expected_batch); + } + + #[test] + fn test_emit_first_n_for_vectorized_group_values() { + let data_set = VectorizedTestDataSet::new(); + let mut group_values = + GroupValuesColumn::::try_new(data_set.schema()).unwrap(); + + // 1~num_rows times to emit the groups + let num_rows = data_set.expected_batch.num_rows(); + let schema = data_set.schema(); + for times_to_take in 1..=num_rows { + // Write data after emitting + data_set.load_to_group_values(&mut group_values); + + // Emit `times_to_take` times, collect and concat the sub-results to total result, + // then check it + let suggest_num_emit = data_set.expected_batch.num_rows() / times_to_take; + let mut num_remaining_rows = num_rows; + let mut actual_sub_batches = Vec::new(); + + for nth_time in 0..times_to_take { + let num_emit = if nth_time == times_to_take - 1 { + num_remaining_rows + } else { + suggest_num_emit + }; + + let sub_batch = group_values.emit(EmitTo::First(num_emit)).unwrap(); + let sub_batch = + RecordBatch::try_new(Arc::clone(&schema), sub_batch).unwrap(); + actual_sub_batches.push(sub_batch); + + num_remaining_rows -= num_emit; + } + assert!(num_remaining_rows == 0); + + let actual_batch = concat_batches(&schema, &actual_sub_batches).unwrap(); + check_result(&actual_batch, &data_set.expected_batch); + } + } + + #[test] + fn test_hashtable_modifying_in_emit_first_n() { + // Situations should be covered: + // 1. Erase inlined group index view + // 2. Erase whole non-inlined group index view + // 3. Erase + decrease group indices in non-inlined group index view + // + view still non-inlined after decreasing + // 4. Erase + decrease group indices in non-inlined group index view + // + view switch to inlined after decreasing + // 5. Only decrease group index in inlined group index view + // 6. Only decrease group indices in non-inlined group index view + // 7. Erase all things + + let field = Field::new_list_field(DataType::Int32, true); + let schema = Arc::new(Schema::new_with_metadata(vec![field], HashMap::new())); + let mut group_values = GroupValuesColumn::::try_new(schema).unwrap(); + + // Insert group index views and check if success to insert + insert_inline_group_index_view(&mut group_values, 0, 0); + insert_non_inline_group_index_view(&mut group_values, 1, vec![1, 2]); + insert_non_inline_group_index_view(&mut group_values, 2, vec![3, 4, 5]); + insert_inline_group_index_view(&mut group_values, 3, 6); + insert_non_inline_group_index_view(&mut group_values, 4, vec![7, 8]); + insert_non_inline_group_index_view(&mut group_values, 5, vec![9, 10, 11]); + + assert_eq!( + group_values.get_indices_by_hash(0).unwrap(), + (vec![0], GroupIndexView::new_inlined(0)) + ); + assert_eq!( + group_values.get_indices_by_hash(1).unwrap(), + (vec![1, 2], GroupIndexView::new_non_inlined(0)) + ); + assert_eq!( + group_values.get_indices_by_hash(2).unwrap(), + (vec![3, 4, 5], GroupIndexView::new_non_inlined(1)) + ); + assert_eq!( + group_values.get_indices_by_hash(3).unwrap(), + (vec![6], GroupIndexView::new_inlined(6)) + ); + assert_eq!( + group_values.get_indices_by_hash(4).unwrap(), + (vec![7, 8], GroupIndexView::new_non_inlined(2)) + ); + assert_eq!( + group_values.get_indices_by_hash(5).unwrap(), + (vec![9, 10, 11], GroupIndexView::new_non_inlined(3)) + ); + assert_eq!(group_values.map.len(), 6); + + // Emit first 4 to test cases 1~3, 5~6 + let _ = group_values.emit(EmitTo::First(4)).unwrap(); + assert!(group_values.get_indices_by_hash(0).is_none()); + assert!(group_values.get_indices_by_hash(1).is_none()); + assert_eq!( + group_values.get_indices_by_hash(2).unwrap(), + (vec![0, 1], GroupIndexView::new_non_inlined(0)) + ); + assert_eq!( + group_values.get_indices_by_hash(3).unwrap(), + (vec![2], GroupIndexView::new_inlined(2)) + ); + assert_eq!( + group_values.get_indices_by_hash(4).unwrap(), + (vec![3, 4], GroupIndexView::new_non_inlined(1)) + ); + assert_eq!( + group_values.get_indices_by_hash(5).unwrap(), + (vec![5, 6, 7], GroupIndexView::new_non_inlined(2)) + ); + assert_eq!(group_values.map.len(), 4); + + // Emit first 1 to test case 4, and cases 5~6 again + let _ = group_values.emit(EmitTo::First(1)).unwrap(); + assert_eq!( + group_values.get_indices_by_hash(2).unwrap(), + (vec![0], GroupIndexView::new_inlined(0)) + ); + assert_eq!( + group_values.get_indices_by_hash(3).unwrap(), + (vec![1], GroupIndexView::new_inlined(1)) + ); + assert_eq!( + group_values.get_indices_by_hash(4).unwrap(), + (vec![2, 3], GroupIndexView::new_non_inlined(0)) + ); + assert_eq!( + group_values.get_indices_by_hash(5).unwrap(), + (vec![4, 5, 6], GroupIndexView::new_non_inlined(1)) + ); + assert_eq!(group_values.map.len(), 4); + + // Emit first 5 to test cases 1~3 again + let _ = group_values.emit(EmitTo::First(5)).unwrap(); + assert_eq!( + group_values.get_indices_by_hash(5).unwrap(), + (vec![0, 1], GroupIndexView::new_non_inlined(0)) + ); + assert_eq!(group_values.map.len(), 1); + + // Emit first 1 to test cases 4 again + let _ = group_values.emit(EmitTo::First(1)).unwrap(); + assert_eq!( + group_values.get_indices_by_hash(5).unwrap(), + (vec![0], GroupIndexView::new_inlined(0)) + ); + assert_eq!(group_values.map.len(), 1); + + // Emit first 1 to test cases 7 + let _ = group_values.emit(EmitTo::First(1)).unwrap(); + assert!(group_values.map.is_empty()); + } + + /// Test data set for [`GroupValuesColumn::vectorized_intern`] + /// + /// Define the test data and support loading them into test [`GroupValuesColumn::vectorized_intern`] + /// + /// The covering situations: + /// + /// Array type: + /// - Primitive array + /// - String(byte) array + /// - String view(byte view) array + /// + /// Repeation and nullability in single batch: + /// - All not null rows + /// - Mixed null + not null rows + /// - All null rows + /// - All not null rows(repeated) + /// - Null + not null rows(repeated) + /// - All not null rows(repeated) + /// + /// If group exists in `map`: + /// - Group exists in inlined group view + /// - Group exists in non-inlined group view + /// - Group not exist + bucket not found in `map` + /// - Group not exist + not equal to inlined group view(tested in hash collision) + /// - Group not exist + not equal to non-inlined group view(tested in hash collision) + /// + struct VectorizedTestDataSet { + test_batches: Vec>, + expected_batch: RecordBatch, + } + + impl VectorizedTestDataSet { + fn new() -> Self { + // Intern batch 1 + let col1 = Int64Array::from(vec![ + // Repeated rows in batch + Some(42), // all not nulls + repeated rows + exist in map case + None, // mixed + repeated rows + exist in map case + None, // mixed + repeated rows + not exist in map case + Some(1142), // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some(42), + None, + None, + Some(1142), + None, + // Unique rows in batch + Some(4211), // all not nulls + unique rows + exist in map case + None, // mixed + unique rows + exist in map case + None, // mixed + unique rows + not exist in map case + Some(4212), // mixed + unique rows + not exist in map case + ]); + + let col2 = StringArray::from(vec![ + // Repeated rows in batch + Some("string1"), // all not nulls + repeated rows + exist in map case + None, // mixed + repeated rows + exist in map case + Some("string2"), // mixed + repeated rows + not exist in map case + None, // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some("string1"), + None, + Some("string2"), + None, + None, + // Unique rows in batch + Some("string3"), // all not nulls + unique rows + exist in map case + None, // mixed + unique rows + exist in map case + Some("string4"), // mixed + unique rows + not exist in map case + None, // mixed + unique rows + not exist in map case + ]); + + let col3 = StringViewArray::from(vec![ + // Repeated rows in batch + Some("stringview1"), // all not nulls + repeated rows + exist in map case + Some("stringview2"), // mixed + repeated rows + exist in map case + None, // mixed + repeated rows + not exist in map case + None, // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some("stringview1"), + Some("stringview2"), + None, + None, + None, + // Unique rows in batch + Some("stringview3"), // all not nulls + unique rows + exist in map case + Some("stringview4"), // mixed + unique rows + exist in map case + None, // mixed + unique rows + not exist in map case + None, // mixed + unique rows + not exist in map case + ]); + let batch1 = vec![ + Arc::new(col1) as _, + Arc::new(col2) as _, + Arc::new(col3) as _, + ]; + + // Intern batch 2 + let col1 = Int64Array::from(vec![ + // Repeated rows in batch + Some(42), // all not nulls + repeated rows + exist in map case + None, // mixed + repeated rows + exist in map case + None, // mixed + repeated rows + not exist in map case + Some(21142), // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some(42), + None, + None, + Some(21142), + None, + // Unique rows in batch + Some(4211), // all not nulls + unique rows + exist in map case + None, // mixed + unique rows + exist in map case + None, // mixed + unique rows + not exist in map case + Some(24212), // mixed + unique rows + not exist in map case + ]); + + let col2 = StringArray::from(vec![ + // Repeated rows in batch + Some("string1"), // all not nulls + repeated rows + exist in map case + None, // mixed + repeated rows + exist in map case + Some("2string2"), // mixed + repeated rows + not exist in map case + None, // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some("string1"), + None, + Some("2string2"), + None, + None, + // Unique rows in batch + Some("string3"), // all not nulls + unique rows + exist in map case + None, // mixed + unique rows + exist in map case + Some("2string4"), // mixed + unique rows + not exist in map case + None, // mixed + unique rows + not exist in map case + ]); + + let col3 = StringViewArray::from(vec![ + // Repeated rows in batch + Some("stringview1"), // all not nulls + repeated rows + exist in map case + Some("stringview2"), // mixed + repeated rows + exist in map case + None, // mixed + repeated rows + not exist in map case + None, // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some("stringview1"), + Some("stringview2"), + None, + None, + None, + // Unique rows in batch + Some("stringview3"), // all not nulls + unique rows + exist in map case + Some("stringview4"), // mixed + unique rows + exist in map case + None, // mixed + unique rows + not exist in map case + None, // mixed + unique rows + not exist in map case + ]); + let batch2 = vec![ + Arc::new(col1) as _, + Arc::new(col2) as _, + Arc::new(col3) as _, + ]; + + // Intern batch 3 + let col1 = Int64Array::from(vec![ + // Repeated rows in batch + Some(42), // all not nulls + repeated rows + exist in map case + None, // mixed + repeated rows + exist in map case + None, // mixed + repeated rows + not exist in map case + Some(31142), // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some(42), + None, + None, + Some(31142), + None, + // Unique rows in batch + Some(4211), // all not nulls + unique rows + exist in map case + None, // mixed + unique rows + exist in map case + None, // mixed + unique rows + not exist in map case + Some(34212), // mixed + unique rows + not exist in map case + ]); + + let col2 = StringArray::from(vec![ + // Repeated rows in batch + Some("string1"), // all not nulls + repeated rows + exist in map case + None, // mixed + repeated rows + exist in map case + Some("3string2"), // mixed + repeated rows + not exist in map case + None, // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some("string1"), + None, + Some("3string2"), + None, + None, + // Unique rows in batch + Some("string3"), // all not nulls + unique rows + exist in map case + None, // mixed + unique rows + exist in map case + Some("3string4"), // mixed + unique rows + not exist in map case + None, // mixed + unique rows + not exist in map case + ]); + + let col3 = StringViewArray::from(vec![ + // Repeated rows in batch + Some("stringview1"), // all not nulls + repeated rows + exist in map case + Some("stringview2"), // mixed + repeated rows + exist in map case + None, // mixed + repeated rows + not exist in map case + None, // mixed + repeated rows + not exist in map case + None, // all nulls + repeated rows + exist in map case + Some("stringview1"), + Some("stringview2"), + None, + None, + None, + // Unique rows in batch + Some("stringview3"), // all not nulls + unique rows + exist in map case + Some("stringview4"), // mixed + unique rows + exist in map case + None, // mixed + unique rows + not exist in map case + None, // mixed + unique rows + not exist in map case + ]); + let batch3 = vec![ + Arc::new(col1) as _, + Arc::new(col2) as _, + Arc::new(col3) as _, + ]; + + // Expected batch + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Utf8View, true), + ])); + + let col1 = Int64Array::from(vec![ + // Repeated rows in batch + Some(42), + None, + None, + Some(1142), + None, + Some(21142), + None, + Some(31142), + None, + // Unique rows in batch + Some(4211), + None, + None, + Some(4212), + None, + Some(24212), + None, + Some(34212), + ]); + + let col2 = StringArray::from(vec![ + // Repeated rows in batch + Some("string1"), + None, + Some("string2"), + None, + Some("2string2"), + None, + Some("3string2"), + None, + None, + // Unique rows in batch + Some("string3"), + None, + Some("string4"), + None, + Some("2string4"), + None, + Some("3string4"), + None, + ]); + + let col3 = StringViewArray::from(vec![ + // Repeated rows in batch + Some("stringview1"), + Some("stringview2"), + None, + None, + None, + None, + None, + None, + None, + // Unique rows in batch + Some("stringview3"), + Some("stringview4"), + None, + None, + None, + None, + None, + None, + ]); + let expected_batch = vec![ + Arc::new(col1) as _, + Arc::new(col2) as _, + Arc::new(col3) as _, + ]; + let expected_batch = RecordBatch::try_new(schema, expected_batch).unwrap(); + + Self { + test_batches: vec![batch1, batch2, batch3], + expected_batch, + } + } + + fn load_to_group_values(&self, group_values: &mut impl GroupValues) { + for batch in self.test_batches.iter() { + group_values.intern(batch, &mut vec![]).unwrap(); + } + } + + fn schema(&self) -> SchemaRef { + self.expected_batch.schema() + } + } + + fn check_result(actual_batch: &RecordBatch, expected_batch: &RecordBatch) { + let formatted_actual_batch = pretty_format_batches(&[actual_batch.clone()]) + .unwrap() + .to_string(); + let mut formatted_actual_batch_sorted: Vec<&str> = + formatted_actual_batch.trim().lines().collect(); + formatted_actual_batch_sorted.sort_unstable(); + + let formatted_expected_batch = pretty_format_batches(&[expected_batch.clone()]) + .unwrap() + .to_string(); + let mut formatted_expected_batch_sorted: Vec<&str> = + formatted_expected_batch.trim().lines().collect(); + formatted_expected_batch_sorted.sort_unstable(); + + for (i, (actual_line, expected_line)) in formatted_actual_batch_sorted + .iter() + .zip(&formatted_expected_batch_sorted) + .enumerate() + { + assert_eq!( + (i, actual_line), + (i, expected_line), + "Inconsistent result\n\n\ + Actual batch:\n{}\n\ + Expected batch:\n{}\n\ + ", + formatted_actual_batch, + formatted_expected_batch, + ); + } + } + + fn insert_inline_group_index_view( + group_values: &mut GroupValuesColumn, + hash_key: u64, + group_index: u64, + ) { + let group_index_view = GroupIndexView::new_inlined(group_index); + group_values.map.insert_accounted( + (hash_key, group_index_view), + |(hash, _)| *hash, + &mut group_values.map_size, + ); + } + + fn insert_non_inline_group_index_view( + group_values: &mut GroupValuesColumn, + hash_key: u64, + group_indices: Vec, + ) { + let list_offset = group_values.group_index_lists.len(); + let group_index_view = GroupIndexView::new_non_inlined(list_offset as u64); + group_values.group_index_lists.push(group_indices); + group_values.map.insert_accounted( + (hash_key, group_index_view), + |(hash, _)| *hash, + &mut group_values.map_size, + ); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs new file mode 100644 index 0000000000000..4ceeb634bad2e --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::buffer::ScalarBuffer; +use arrow_array::cast::AsArray; +use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_schema::DataType; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use itertools::izip; +use std::iter; +use std::sync::Arc; + +/// An implementation of [`GroupColumn`] for primitive values +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `T`: the native Rust type that stores the data +/// `NULLABLE`: if the data can contain any nulls +#[derive(Debug)] +pub struct PrimitiveGroupValueBuilder { + data_type: DataType, + group_values: Vec, + nulls: MaybeNullBufferBuilder, +} + +impl PrimitiveGroupValueBuilder +where + T: ArrowPrimitiveType, +{ + /// Create a new `PrimitiveGroupValueBuilder` + pub fn new(data_type: DataType) -> Self { + Self { + data_type, + group_values: vec![], + nulls: MaybeNullBufferBuilder::new(), + } + } +} + +impl GroupColumn + for PrimitiveGroupValueBuilder +{ + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + } + + self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + // Perf: skip null check if input can't have nulls + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(array.as_primitive::().value(row)); + } + } else { + self.group_values.push(array.as_primitive::().value(row)); + } + } + + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + let array = array.as_primitive::(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to in previous column, don't need to check + if !*equal_to_result { + continue; + } + + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + *equal_to_result = result; + continue; + } + // Otherwise, we need to check their values + } + + *equal_to_result = self.group_values[lhs_row] == array.value(rhs_row); + } + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + let arr = array.as_primitive::(); + + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Some(true) + } else if null_count == num_rows { + Some(false) + } else { + None + }; + + match (NULLABLE, all_null_or_non_null) { + (true, None) => { + for &row in rows { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(arr.value(row)); + } + } + } + + (true, Some(true)) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.group_values.push(arr.value(row)); + } + } + + (true, Some(false)) => { + self.nulls.append_n(rows.len(), true); + self.group_values + .extend(iter::repeat(T::default_value()).take(rows.len())); + } + + (false, _) => { + for &row in rows { + self.group_values.push(arr.value(row)); + } + } + } + } + + fn len(&self) -> usize { + self.group_values.len() + } + + fn size(&self) -> usize { + self.group_values.allocated_size() + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + data_type, + group_values, + nulls, + } = *self; + + let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); + } + + let arr = PrimitiveArray::::new(ScalarBuffer::from(group_values), nulls); + // Set timezone information for timestamp + Arc::new(arr.with_data_type(data_type)) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + let first_n = self.group_values.drain(0..n).collect::>(); + + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; + + Arc::new( + PrimitiveArray::::new(ScalarBuffer::from(first_n), first_n_nulls) + .with_data_type(self.data_type.clone()), + ) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; + use arrow::datatypes::Int64Type; + use arrow_array::{ArrayRef, Int64Array}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + use arrow_schema::DataType; + + use super::GroupColumn; + + #[test] + fn test_nullable_primitive_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_nullable_primitive_equal_to_internal(append, equal_to); + } + + #[test] + fn test_nullable_primitive_vectorized_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_nullable_primitive_equal_to_internal(append, equal_to); + } + + fn test_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &PrimitiveGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); + let builder_array = Arc::new(Int64Array::from(vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + + // Define input array + let (_nulls, values, _) = + Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some(2) to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5], + &input_array, + &[0, 1, 2, 3, 4, 5], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(equal_to_results[5]); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index); + } + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_not_nullable_primitive_equal_to_internal(append, equal_to); + } + + #[test] + fn test_not_nullable_primitive_vectorized_equal_to() { + let append = |builder: &mut PrimitiveGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder.vectorized_append(builder_array, append_rows); + }; + + let equal_to = |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_not_nullable_primitive_equal_to_internal(append, equal_to); + } + + fn test_not_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &PrimitiveGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); + let builder_array = + Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1]); + + // Define input array + let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1], + &input_array, + &[0, 1], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(!equal_to_results[1]); + } + + #[test] + fn test_nullable_primitive_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); + + // All nulls input array + let all_nulls_input_array = Arc::new(Int64Array::from(vec![ + Option::::None, + None, + None, + None, + None, + ])) as _; + builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])) as _; + builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs index 0249390f38cdd..a584cf58e50a0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs @@ -70,6 +70,24 @@ impl MaybeNullBufferBuilder { } } + pub fn append_n(&mut self, n: usize, is_null: bool) { + match self { + Self::NoNulls { row_count } if is_null => { + // have seen no nulls so far, this is the first null, + // need to create the nulls buffer for all currently valid values + // alloc 2x the need given we push a new but immediately + let mut nulls = BooleanBufferBuilder::new(*row_count * 2); + nulls.append_n(*row_count, true); + nulls.append_n(n, false); + *self = Self::Nulls(nulls); + } + Self::NoNulls { row_count } => { + *row_count += n; + } + Self::Nulls(builder) => builder.append_n(n, !is_null), + } + } + /// return the number of heap allocated bytes used by this structure to store boolean values pub fn allocated_size(&self) -> usize { match self { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 8ca88257bf1a7..edc3f909bbd61 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -24,9 +24,11 @@ use arrow_array::{Array, ArrayRef, ListArray, StructArray}; use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::Result; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; +use log::debug; +use std::mem::size_of; use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] @@ -52,7 +54,7 @@ pub struct GroupValuesRows { /// /// keys: u64 hashes of the GroupValue /// values: (hash, group_index) - map: RawTable<(u64, usize)>, + map: HashTable<(u64, usize)>, /// The size of `map` in bytes map_size: usize, @@ -79,6 +81,9 @@ pub struct GroupValuesRows { impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { + // Print a debugging message, so it is clear when the (slower) fallback + // GroupValuesRows is used. + debug!("Creating GroupValuesRows for schema: {}", schema); let row_converter = RowConverter::new( schema .fields() @@ -87,7 +92,7 @@ impl GroupValuesRows { .collect(), )?; - let map = RawTable::with_capacity(0); + let map = HashTable::with_capacity(0); let starting_rows_capacity = 1000; @@ -130,7 +135,7 @@ impl GroupValues for GroupValuesRows { create_hashes(cols, &self.random_state, batch_hashes)?; for (row, &target_hash) in batch_hashes.iter().enumerate() { - let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { + let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| { // Somewhat surprisingly, this closure can be called even if the // hash doesn't match, so check the hash first with an integer // comparison first avoid the more expensive comparison with @@ -211,18 +216,18 @@ impl GroupValues for GroupValuesRows { } std::mem::swap(&mut new_group_values, &mut group_values); - // SAFETY: self.map outlives iterator and is not modified concurrently - unsafe { - for bucket in self.map.iter() { - // Decrement group index by n - match bucket.as_ref().1.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => bucket.as_mut().1 = sub, - // Group index was < n, so remove from table - None => self.map.erase(bucket), + self.map.retain(|(_exists_hash, group_idx)| { + // Decrement group index by n + match group_idx.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + *group_idx = sub; + true } + // Group index was < n, so remove from table + None => false, } - } + }); output } }; @@ -231,10 +236,8 @@ impl GroupValues for GroupValuesRows { // https://github.com/apache/datafusion/issues/7647 for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); - *array = dictionary_encode_if_necessary( - Arc::::clone(array), - expected, - )?; + *array = + dictionary_encode_if_necessary(Arc::::clone(array), expected)?; } self.group_values = Some(group_values); @@ -249,7 +252,7 @@ impl GroupValues for GroupValuesRows { }); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared - self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); } @@ -267,7 +270,7 @@ fn dictionary_encode_if_necessary( .zip(struct_array.columns()) .map(|(expected_field, column)| { dictionary_encode_if_necessary( - Arc::::clone(column), + Arc::::clone(column), expected_field.data_type(), ) }) @@ -286,13 +289,13 @@ fn dictionary_encode_if_necessary( Arc::::clone(expected_field), list.offsets().clone(), dictionary_encode_if_necessary( - Arc::::clone(list.values()), + Arc::::clone(list.values()), expected_field.data_type(), )?, list.nulls().cloned(), )?)) } (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?), - (_, _) => Ok(Arc::::clone(&array)), + (_, _) => Ok(Arc::::clone(&array)), } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs similarity index 98% rename from datafusion/physical-plan/src/aggregates/group_values/bytes.rs rename to datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index f789af8b8a024..013c027e7306c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -19,6 +19,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// @@ -73,7 +74,7 @@ impl GroupValues for GroupValuesByes { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs similarity index 98% rename from datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs rename to datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 1a0cb90a16d47..7379b7a538b49 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -20,6 +20,7 @@ use arrow_array::{Array, ArrayRef, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8View/BinaryView values /// @@ -74,7 +75,7 @@ impl GroupValues for GroupValuesBytesView { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs new file mode 100644 index 0000000000000..417618ba66af4 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `GroupValues` implementations for single group by cases + +pub(crate) mod bytes; +pub(crate) mod bytes_view; +pub(crate) mod primitive; diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs similarity index 83% rename from datafusion/physical-plan/src/aggregates/group_values/primitive.rs rename to datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index d5b7f1b11ac55..6b69c00bca740 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -29,7 +29,8 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; +use std::mem::size_of; use std::sync::Arc; /// A trait to allow hashing of floating point numbers @@ -85,7 +86,7 @@ pub struct GroupValuesPrimitive { /// /// We don't store the hashes as hashing fixed width primitives /// is fast enough for this not to benefit performance - map: RawTable, + map: HashTable, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -99,7 +100,7 @@ impl GroupValuesPrimitive { assert!(PrimitiveArray::::is_compatible(&data_type)); Self { data_type, - map: RawTable::with_capacity(128), + map: HashTable::with_capacity(128), values: Vec::with_capacity(128), null_group: None, random_state: Default::default(), @@ -125,22 +126,19 @@ where Some(key) => { let state = &self.random_state; let hash = key.hash(state); - let insert = self.map.find_or_find_insert_slot( + let insert = self.map.entry( hash, |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, |g| unsafe { self.values.get_unchecked(*g).hash(state) }, ); - // SAFETY: No mutation occurred since find_or_find_insert_slot - unsafe { - match insert { - Ok(v) => *v.as_ref(), - Err(slot) => { - let g = self.values.len(); - self.map.insert_in_slot(hash, slot, g); - self.values.push(key); - g - } + match insert { + hashbrown::hash_table::Entry::Occupied(o) => *o.get(), + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert(g); + self.values.push(key); + g } } } @@ -151,7 +149,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * std::mem::size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::() + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -182,18 +180,18 @@ where build_primitive(std::mem::take(&mut self.values), self.null_group.take()) } EmitTo::First(n) => { - // SAFETY: self.map outlives iterator and is not modified concurrently - unsafe { - for bucket in self.map.iter() { - // Decrement group index by n - match bucket.as_ref().checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => *bucket.as_mut() = sub, - // Group index was < n, so remove from table - None => self.map.erase(bucket), + self.map.retain(|group_idx| { + // Decrement group index by n + match group_idx.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + *group_idx = sub; + true } + // Group index was < n, so remove from table + None => false, } - } + }); let null_group = match &mut self.null_group { Some(v) if *v >= n => { *v -= n; @@ -207,6 +205,7 @@ where build_primitive(split, null_group) } }; + Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index daa5d7b81c587..360f4f88cad29 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -20,11 +20,12 @@ use std::any::Any; use std::sync::Arc; -use super::{DisplayAs, ExecutionMode, ExecutionPlanProperties, PlanProperties}; +use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::aggregates::{ no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; +use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; @@ -38,20 +39,19 @@ use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{ - equivalence::{collapse_lex_req, ProjectionMapping}, - expressions::Column, - physical_exprs_contains, EquivalenceProperties, LexOrdering, LexRequirement, - PhysicalExpr, PhysicalSortRequirement, + equivalence::ProjectionMapping, expressions::Column, physical_exprs_contains, + EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, + PhysicalSortRequirement, }; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use itertools::Itertools; -pub mod group_values; +pub(crate) mod group_values; mod no_grouping; pub mod order; mod row_hash; @@ -250,6 +250,10 @@ impl PhysicalGroupBy { } } + pub fn group_schema(&self, schema: &Schema) -> Result { + Ok(Arc::new(Schema::new(self.group_fields(schema)?))) + } + /// Returns the fields that are used as the grouping keys. fn group_fields(&self, input_schema: &Schema) -> Result> { let mut fields = Vec::with_capacity(self.num_group_exprs()); @@ -343,14 +347,14 @@ impl From for SendableRecordBatchStream { } /// Hash aggregate execution plan -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AggregateExec { /// Aggregation mode (full, partial) mode: AggregateMode, /// Group by expressions group_by: PhysicalGroupBy, /// Aggregate expressions - aggr_expr: Vec, + aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause @@ -377,7 +381,10 @@ impl AggregateExec { /// Function used in `OptimizeAggregateOrder` optimizer rule, /// where we need parts of the new value, others cloned from the old one /// Rewrites aggregate exec with new aggregate expressions. - pub fn with_new_aggr_exprs(&self, aggr_expr: Vec) -> Self { + pub fn with_new_aggr_exprs( + &self, + aggr_expr: Vec>, + ) -> Self { Self { aggr_expr, // clone the rest of the fields @@ -403,7 +410,7 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -434,7 +441,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -469,7 +476,7 @@ impl AggregateExec { &mode, )?; new_requirement.inner.extend(req); - new_requirement = collapse_lex_req(new_requirement); + new_requirement = new_requirement.collapse(); // If our aggregation has grouping sets then our base grouping exprs will // be expanded based on the flags in `group_by.groups` where for each @@ -493,7 +500,7 @@ impl AggregateExec { }; // construct a map from the input expression to the output expression of the Aggregation group by - let projection_mapping = + let group_expr_mapping = ProjectionMapping::try_new(&group_by.expr, &input.schema())?; let required_input_ordering = @@ -502,7 +509,7 @@ impl AggregateExec { let cache = Self::compute_properties( &input, Arc::clone(&schema), - &projection_mapping, + &group_expr_mapping, &mode, &input_order_mode, ); @@ -544,7 +551,7 @@ impl AggregateExec { } /// Aggregate expressions - pub fn aggr_expr(&self) -> &[AggregateFunctionExpr] { + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -638,14 +645,33 @@ impl AggregateExec { pub fn compute_properties( input: &Arc, schema: SchemaRef, - projection_mapping: &ProjectionMapping, + group_expr_mapping: &ProjectionMapping, mode: &AggregateMode, input_order_mode: &InputOrderMode, ) -> PlanProperties { // Construct equivalence properties: - let eq_properties = input + let mut eq_properties = input .equivalence_properties() - .project(projection_mapping, schema); + .project(group_expr_mapping, schema); + + // Group by expression will be a distinct value after the aggregation. + // Add it into the constraint set. + let mut constraints = eq_properties.constraints().to_vec(); + let new_constraint = Constraint::Unique( + group_expr_mapping + .map + .iter() + .filter_map(|(_, target_col)| { + target_col + .as_any() + .downcast_ref::() + .map(|c| c.index()) + }) + .collect(), + ); + constraints.push(new_constraint); + eq_properties = + eq_properties.with_constraints(Constraints::new_unverified(constraints)); // Get output partitioning: let input_partitioning = input.output_partitioning().clone(); @@ -654,21 +680,24 @@ impl AggregateExec { // but needs to respect aliases (e.g. mapping in the GROUP BY // expression). let input_eq_properties = input.equivalence_properties(); - input_partitioning.project(projection_mapping, input_eq_properties) + input_partitioning.project(group_expr_mapping, input_eq_properties) } else { input_partitioning.clone() }; - // Determine execution mode: - let mut exec_mode = input.execution_mode(); - if exec_mode == ExecutionMode::Unbounded - && *input_order_mode == InputOrderMode::Linear - { - // Cannot run without breaking the pipeline - exec_mode = ExecutionMode::PipelineBreaking; - } + // TODO: Emission type and boundedness information can be enhanced here + let emission_type = if *input_order_mode == InputOrderMode::Linear { + EmissionType::Final + } else { + input.pipeline_behavior() + }; - PlanProperties::new(eq_properties, output_partitioning, exec_mode) + PlanProperties::new( + eq_properties, + output_partitioning, + emission_type, + input.boundedness(), + ) } pub fn input_order_mode(&self) -> &InputOrderMode { @@ -785,6 +814,19 @@ impl ExecutionPlan for AggregateExec { vec![self.required_input_ordering.clone()] } + /// The output ordering of [`AggregateExec`] is determined by its `group_by` + /// columns. Although this method is not explicitly used by any optimizer + /// rules yet, overriding the default implementation ensures that it + /// accurately reflects the actual behavior. + /// + /// If the [`InputOrderMode`] is `Linear`, the `group_by` columns don't have + /// an ordering, which means the results do not either. However, in the + /// `Ordered` and `PartiallyOrdered` cases, the `group_by` columns do have + /// an ordering, which is preserved in the output. + fn maintains_input_order(&self) -> Vec { + vec![self.input_order_mode != InputOrderMode::Linear] + } + fn children(&self) -> Vec<&Arc> { vec![&self.input] } @@ -866,12 +908,16 @@ impl ExecutionPlan for AggregateExec { } } } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } fn create_schema( input_schema: &Schema, group_by: &PhysicalGroupBy, - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: AggregateMode, ) -> Result { let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); @@ -901,10 +947,6 @@ fn create_schema( )) } -fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result { - Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?))) -} - /// Determines the lexical ordering requirement for an aggregate expression. /// /// # Parameters @@ -929,10 +971,10 @@ fn get_aggregate_expr_req( // necessary, or the aggregation is performing a "second stage" calculation, // then ignore the ordering requirement. if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() { - return vec![]; + return LexOrdering::default(); } - let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); + let mut req = aggr_expr.order_bys().cloned().unwrap_or_default(); // In non-first stage modes, we accumulate data (using `merge_batch`) from // different partitions (i.e. merge partial results). During this merge, we @@ -975,7 +1017,7 @@ fn finer_ordering( agg_mode: &AggregateMode, ) -> Option { let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); - eq_properties.get_finer_ordering(existing_req, &aggr_req) + eq_properties.get_finer_ordering(existing_req, aggr_req.as_ref()) } /// Concatenates the given slices. @@ -1001,17 +1043,17 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( - aggr_exprs: &mut [AggregateFunctionExpr], + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, ) -> Result { - let mut requirement = vec![]; + let mut requirement = LexOrdering::default(); for aggr_expr in aggr_exprs.iter_mut() { if let Some(finer_ordering) = finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) { - if eq_properties.ordering_satisfy(&finer_ordering) { + if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { // Requirement is satisfied by existing ordering requirement = finer_ordering; continue; @@ -1025,11 +1067,11 @@ pub fn get_finer_aggregate_exprs_requirement( eq_properties, agg_mode, ) { - if eq_properties.ordering_satisfy(&finer_ordering) { + if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { // Reverse requirement is satisfied by exiting ordering. // Hence reverse the aggregator requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1053,7 +1095,7 @@ pub fn get_finer_aggregate_exprs_requirement( // There is a requirement that both satisfies existing requirement and reverse // aggregate requirement. Use updated requirement requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1066,7 +1108,7 @@ pub fn get_finer_aggregate_exprs_requirement( ); } - Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) + Ok(LexRequirement::from(requirement)) } /// Returns physical expressions for arguments to evaluate against a batch. @@ -1075,7 +1117,7 @@ pub fn get_finer_aggregate_exprs_requirement( /// * Partial: AggregateFunctionExpr::expressions /// * Final: columns of `AggregateFunctionExpr::state_fields()` pub fn aggregate_expressions( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { @@ -1130,7 +1172,7 @@ fn merge_expressions( pub type AccumulatorItem = Box; pub fn create_accumulators( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], ) -> Result> { aggr_expr .iter() @@ -1290,8 +1332,10 @@ mod tests { use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; + use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::memory::MemoryExec; + use crate::metrics::MetricValue; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::RecordBatchStream; @@ -1421,7 +1465,7 @@ mod tests { fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc { let session_config = SessionConfig::new().with_batch_size(batch_size); - let runtime = RuntimeEnvBuilder::default() + let runtime = RuntimeEnvBuilder::new() .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))) .build_arc() .unwrap(); @@ -1453,10 +1497,12 @@ mod tests { ], ); - let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .schema(Arc::clone(&input_schema)) - .alias("COUNT(1)") - .build()?]; + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(1)") + .build()?, + )]; let task_ctx = if spill { // adjust the max memory size to have the partial aggregate result for spill mode. @@ -1475,7 +1521,7 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { // In spill mode, we test with the limited memory, if the mem usage exceeds, @@ -1547,8 +1593,7 @@ mod tests { input_schema, )?); - let result = - common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 4); assert_eq!(batch.num_rows(), 12); @@ -1591,13 +1636,12 @@ mod tests { vec![vec![false]], ); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1616,7 +1660,7 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { vec![ @@ -1662,7 +1706,7 @@ mod tests { } else { Arc::clone(&task_ctx) }; - let result = common::collect(merged_aggregate.execute(0, task_ctx)?).await?; + let result = collect(merged_aggregate.execute(0, task_ctx)?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 3); @@ -1681,12 +1725,24 @@ mod tests { let metrics = merged_aggregate.metrics().unwrap(); let output_rows = metrics.output_rows().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + let spilled_bytes = metrics.spilled_bytes().unwrap(); + let spilled_rows = metrics.spilled_rows().unwrap(); + if spill { // When spilling, the output rows metrics become partial output size + final output size // This is because final aggregation starts while partial aggregation is still emitting assert_eq!(8, output_rows); + + assert!(spill_count > 0); + assert!(spilled_bytes > 0); + assert!(spilled_rows > 0); } else { assert_eq!(3, output_rows); + + assert_eq!(0, spill_count); + assert_eq!(0, spilled_bytes); + assert_eq!(0, spilled_rows); } Ok(()) @@ -1710,13 +1766,11 @@ mod tests { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); PlanProperties::new( - eq_properties, - // Output Partitioning + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(1), - // Execution Mode - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ) } } @@ -1894,7 +1948,7 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec::new(true)); let input_schema = input.schema(); - let runtime = RuntimeEnvBuilder::default() + let runtime = RuntimeEnvBuilder::new() .with_memory_limit(1, 1.0) .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); @@ -1908,17 +1962,16 @@ mod tests { ); // something that allocates within the aggregator - let aggregates_v0: Vec = - vec![test_median_agg_expr(Arc::clone(&input_schema))?]; + let aggregates_v0: Vec> = + vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates_v2: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1951,7 +2004,7 @@ mod tests { } let stream: SendableRecordBatchStream = stream.into(); - let err = common::collect(stream).await.unwrap_err(); + let err = collect(stream).await.unwrap_err(); // error root cause traversal is a bit complicated, see #4172. let err = err.find_root(); @@ -1972,13 +2025,12 @@ mod tests { let groups = PhysicalGroupBy::default(); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(a)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2012,13 +2064,12 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2063,7 +2114,7 @@ mod tests { fn test_first_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -2071,27 +2122,29 @@ mod tests { let args = [col("b", schema)?]; AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) - .order_by(ordering_req.to_vec()) + .order_by(LexOrdering::new(ordering_req.to_vec())) .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) - .order_by(ordering_req.to_vec()) + .order_by(LexOrdering::new(ordering_req.to_vec())) .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // This function either constructs the physical plan below, @@ -2136,7 +2189,7 @@ mod tests { descending: false, nulls_first: false, }; - let aggregates: Vec = if is_first_acc { + let aggregates: Vec> = if is_first_acc { vec![test_first_value_agg_expr(&schema, sort_options)?] } else { vec![test_last_value_agg_expr(&schema, sort_options)?] @@ -2253,7 +2306,7 @@ mod tests { ]), ]; - let common_requirement = vec![ + let common_requirement = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::clone(col_a), options: options1, @@ -2262,16 +2315,17 @@ mod tests { expr: Arc::clone(col_c), options: options1, }, - ]; + ]); let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { let ordering_req = order_by_expr.unwrap_or_default(); AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) .alias("a") - .order_by(ordering_req.to_vec()) + .order_by(LexOrdering::new(ordering_req.to_vec())) .schema(Arc::clone(&test_schema)) .build() + .map(Arc::new) .unwrap() }) .collect::>(); @@ -2282,7 +2336,7 @@ mod tests { &eq_properties, &AggregateMode::Partial, )?; - let res = PhysicalSortRequirement::to_sort_exprs(res); + let res = LexOrdering::from(res); assert_eq!(res, common_requirement); Ok(()) } @@ -2301,7 +2355,7 @@ mod tests { }; let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); - let aggregates: Vec = vec![ + let aggregates: Vec> = vec![ test_first_value_agg_expr(&schema, option_desc)?, test_last_value_agg_expr(&schema, option_desc)?, ]; @@ -2359,11 +2413,12 @@ mod tests { ], ); - let aggregates: Vec = + let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) .schema(Arc::clone(&schema)) .alias("1") - .build()?]; + .build() + .map(Arc::new)?]; let input_batches = (0..4) .map(|_| { @@ -2415,25 +2470,21 @@ mod tests { "labels".to_string(), DataType::Struct( vec![ - Field::new_dict( + Field::new( "a".to_string(), DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), ), true, - 0, - false, ), - Field::new_dict( + Field::new( "b".to_string(), DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), ), true, - 0, - false, ), ] .into(), @@ -2445,15 +2496,13 @@ mod tests { vec![ Arc::new(StructArray::from(vec![ ( - Arc::new(Field::new_dict( + Arc::new(Field::new( "a".to_string(), DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), ), true, - 0, - false, )), Arc::new( vec![Some("a"), None, Some("a")] @@ -2462,15 +2511,13 @@ mod tests { ) as ArrayRef, ), ( - Arc::new(Field::new_dict( + Arc::new(Field::new( "b".to_string(), DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), ), true, - 0, - false, )), Arc::new( vec![Some("b"), Some("c"), Some("b")] @@ -2495,11 +2542,12 @@ mod tests { ) .schema(Arc::clone(&batch.schema())) .alias(String::from("SUM(value)")) - .build()?]; + .build() + .map(Arc::new)?]; let input = Arc::new(MemoryExec::try_new( &[vec![batch.clone()]], - Arc::::clone(&batch.schema()), + Arc::::clone(&batch.schema()), None, )?); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2543,7 +2591,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2624,7 +2673,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2711,7 +2761,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) .schema(Arc::clone(&input_schema)) .alias("COUNT(a)") - .build()?, + .build() + .map(Arc::new)?, ]; let grouping_set = PhysicalGroupBy::new( @@ -2743,4 +2794,137 @@ mod tests { assert_eq!(aggr_schema, expected_schema); Ok(()) } + + // test for https://github.com/apache/datafusion/issues/13949 + async fn run_test_with_spill_pool_if_necessary( + pool_size: usize, + expect_spill: bool, + ) -> Result<()> { + fn create_record_batch( + schema: &Arc, + data: (Vec, Vec), + ) -> Result { + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(UInt32Array::from(data.0)), + Arc::new(Float64Array::from(data.1)), + ], + )?) + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + let batches = vec![ + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + ]; + let plan: Arc = + Arc::new(MemoryExec::try_new(&[batches], Arc::clone(&schema), None)?); + + let grouping_set = PhysicalGroupBy::new( + vec![(col("a", &schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); + + // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). + let aggregates: Vec> = vec![ + Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::min_max::min_udaf(), + vec![col("b", &schema)?], + ) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + ]; + + let single_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + grouping_set, + aggregates, + vec![None, None], + plan, + Arc::clone(&schema), + )?); + + let batch_size = 2; + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )), + ); + + let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + + assert_spill_count_metric(expect_spill, single_aggregate); + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+---+--------+--------+", + "| a | MIN(b) | AVG(b) |", + "+---+--------+--------+", + "| 2 | 1.0 | 1.0 |", + "| 3 | 2.0 | 2.0 |", + "| 4 | 3.0 | 3.5 |", + "+---+--------+--------+", + ], + &result + ); + + Ok(()) + } + + fn assert_spill_count_metric( + expect_spill: bool, + single_aggregate: Arc, + ) { + if let Some(metrics_set) = single_aggregate.metrics() { + let mut spill_count = 0; + + // Inspect metrics for SpillCount + for metric in metrics_set.iter() { + if let MetricValue::SpillCount(count) = metric.value() { + spill_count = count.value(); + break; + } + } + + if expect_spill && spill_count == 0 { + panic!( + "Expected spill but SpillCount metric not found or SpillCount was 0." + ); + } else if !expect_spill && spill_count > 0 { + panic!("Expected no spill but found SpillCount metric with value greater than 0."); + } + } else { + panic!("No metrics returned from the operator; cannot verify spilling."); + } + } + + #[tokio::test] + async fn test_aggregate_with_spill_if_necessary() -> Result<()> { + // test with spill + run_test_with_spill_pool_if_necessary(2_000, true).await?; + // test without spill + run_test_with_spill_pool_if_necessary(20_000, false).await?; + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index d64c99ba1bee3..218855459b1e2 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -16,6 +16,7 @@ // under the License. use datafusion_expr::EmitTo; +use std::mem::size_of; /// Tracks grouping state when the data is ordered entirely by its /// group keys @@ -139,7 +140,7 @@ impl GroupOrderingFull { } pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 483150ee61af6..7d9a50e20ae05 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -19,7 +19,8 @@ use arrow_array::ArrayRef; use arrow_schema::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::mem::size_of; mod full; mod partial; @@ -44,7 +45,7 @@ impl GroupOrdering { pub fn try_new( input_schema: &Schema, mode: &InputOrderMode, - ordering: &[PhysicalSortExpr], + ordering: &LexOrdering, ) -> Result { match mode { InputOrderMode::Linear => Ok(GroupOrdering::None), @@ -118,7 +119,7 @@ impl GroupOrdering { /// Return the size of memory used by the ordering state, in bytes pub fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + match self { GroupOrdering::None => 0, GroupOrdering::Partial(partial) => partial.size(), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index 2cbe3bbb784ec..5a05b88798eff 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -21,7 +21,8 @@ use arrow_schema::Schema; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::mem::size_of; use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of @@ -106,7 +107,7 @@ impl GroupOrderingPartial { pub fn try_new( input_schema: &Schema, order_indices: &[usize], - ordering: &[PhysicalSortExpr], + ordering: &LexOrdering, ) -> Result { assert!(!order_indices.is_empty()); assert!(order_indices.len() <= ordering.len()); @@ -244,7 +245,7 @@ impl GroupOrderingPartial { /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.order_indices.allocated_size() + self.row_converter.size() } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 5121e6cc3b354..cc95ce51c15b3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -24,7 +24,7 @@ use std::vec; use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ - evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, + create_schema, evaluate_group_by, evaluate_many, evaluate_optional, AggregateMode, PhysicalGroupBy, }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; @@ -48,14 +48,14 @@ use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; +use super::order::GroupOrdering; +use super::AggregateExec; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; -use super::order::GroupOrdering; -use super::AggregateExec; - #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) pub(crate) enum ExecutionState { @@ -80,7 +80,7 @@ struct SpillState { // the execution. // ======================================================================== /// Sorting expression for spilling batches - spill_expr: Vec, + spill_expr: LexOrdering, /// Schema for spilling batches spill_schema: SchemaRef, @@ -102,6 +102,19 @@ struct SpillState { /// true when streaming merge is in progress is_stream_merging: bool, + + // ======================================================================== + // METRICS: + // ======================================================================== + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: metrics::Gauge, + /// count of spill files during the execution of the operator + spill_count: metrics::Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: metrics::Count, + /// total spilled rows during the execution of the operator + spilled_rows: metrics::Count, } /// Tracks if the aggregate should skip partial aggregations @@ -124,7 +137,7 @@ struct SkipAggregationProbe { // ======================================================================== // STATES: // Fields changes during execution. Can be buffer, or state flags that - // influence the exeuction in parent `GroupedHashAggregateStream` + // influence the execution in parent `GroupedHashAggregateStream` // ======================================================================== /// Number of processed input rows (updated during probing) input_rows: usize, @@ -138,6 +151,9 @@ struct SkipAggregationProbe { /// make any effect (set either while probing or on probing completion) is_locked: bool, + // ======================================================================== + // METRICS: + // ======================================================================== /// Number of rows where state was output without aggregation. /// /// * If 0, all input rows were aggregated (should_skip was always false) @@ -473,7 +489,32 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; + let group_schema = agg_group_by.group_schema(&agg.input().schema())?; + + // fix https://github.com/apache/datafusion/issues/13949 + // Builds a **partial aggregation** schema by combining the group columns and + // the accumulator state columns produced by each aggregate expression. + // + // # Why Partial Aggregation Schema Is Needed + // + // In a multi-stage (partial/final) aggregation strategy, each partial-aggregate + // operator produces *intermediate* states (e.g., partial sums, counts) rather + // than final scalar values. These extra columns do **not** exist in the original + // input schema (which may be something like `[colA, colB, ...]`). Instead, + // each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`). + // + // Therefore, when we spill these intermediate states or pass them to another + // aggregation operator, we must use a schema that includes both the group + // columns **and** the partial-state columns. + let partial_agg_schema = create_schema( + &agg.input().schema(), + &agg_group_by, + &aggregate_exprs, + AggregateMode::Partial, + )?; + + let partial_agg_schema = Arc::new(partial_agg_schema); + let spill_expr = group_schema .fields .into_iter() @@ -495,10 +536,10 @@ impl GroupedHashAggregateStream { let group_ordering = GroupOrdering::try_new( &group_schema, &agg.input_order_mode, - ordering.as_slice(), + ordering.as_ref(), )?; - let group_values = new_group_values(group_schema)?; + let group_values = new_group_values(group_schema, &group_ordering)?; timer.done(); let exec_state = ExecutionState::ReadingInput; @@ -506,10 +547,15 @@ impl GroupedHashAggregateStream { let spill_state = SpillState { spills: vec![], spill_expr, - spill_schema: Arc::clone(&agg_schema), + spill_schema: partial_agg_schema, is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + peak_mem_used: MetricBuilder::new(&agg.metrics) + .gauge("peak_mem_used", partition), + spill_count: MetricBuilder::new(&agg.metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(&agg.metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(&agg.metrics).spilled_rows(partition), }; // Skip aggregation is supported if: @@ -570,7 +616,7 @@ impl GroupedHashAggregateStream { /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( - agg_expr: &AggregateFunctionExpr, + agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { agg_expr.create_groups_accumulator() @@ -580,7 +626,7 @@ pub(crate) fn create_group_accumulator( "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", agg_expr.name() ); - let agg_expr_captured = agg_expr.clone(); + let agg_expr_captured = Arc::clone(agg_expr); let factory = move || agg_expr_captured.create_accumulator(); Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) } @@ -633,9 +679,13 @@ impl Stream for GroupedHashAggregateStream { } if let Some(to_emit) = self.group_ordering.emit_to() { - let batch = extract_ok!(self.emit(to_emit, false)); - self.exec_state = ExecutionState::ProducingOutput(batch); timer.done(); + if let Some(batch) = + extract_ok!(self.emit(to_emit, false)) + { + self.exec_state = + ExecutionState::ProducingOutput(batch); + }; // make sure the exec_state just set is not overwritten below break 'reading_input; } @@ -672,9 +722,13 @@ impl Stream for GroupedHashAggregateStream { } if let Some(to_emit) = self.group_ordering.emit_to() { - let batch = extract_ok!(self.emit(to_emit, false)); - self.exec_state = ExecutionState::ProducingOutput(batch); timer.done(); + if let Some(batch) = + extract_ok!(self.emit(to_emit, false)) + { + self.exec_state = + ExecutionState::ProducingOutput(batch); + }; // make sure the exec_state just set is not overwritten below break 'reading_input; } @@ -747,6 +801,9 @@ impl Stream for GroupedHashAggregateStream { let output = batch.slice(0, size); (ExecutionState::ProducingOutput(remaining), output) }; + // Empty record batches should not be emitted. + // They need to be treated as [`Option`]es and handled separately + debug_assert!(output_batch.num_rows() > 0); return Poll::Ready(Some(Ok( output_batch.record_output(&self.baseline_metrics) ))); @@ -838,14 +895,13 @@ impl GroupedHashAggregateStream { )?; } _ => { + if opt_filter.is_some() { + return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage"); + } + // if aggregation is over intermediate states, // use merge - acc.merge_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; + acc.merge_batch(values, group_indices, None, total_num_groups)?; } } } @@ -865,23 +921,31 @@ impl GroupedHashAggregateStream { fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - self.reservation.try_resize( + let reservation_result = self.reservation.try_resize( acc + self.group_values.size() + self.group_ordering.size() + self.current_group_indices.allocated_size(), - ) + ); + + if reservation_result.is_ok() { + self.spill_state + .peak_mem_used + .set_max(self.reservation.size()); + } + + reservation_result } /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to - fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { + fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { let schema = if spilling { Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; if self.group_values.is_empty() { - return Ok(RecordBatch::new_empty(schema)); + return Ok(None); } let mut output = self.group_values.emit(emit_to)?; @@ -909,7 +973,8 @@ impl GroupedHashAggregateStream { // over the target memory size after emission, we can emit again rather than returning Err. let _ = self.update_memory_reservation(); let batch = RecordBatch::try_new(schema, output)?; - Ok(batch) + debug_assert!(batch.num_rows() > 0); + Ok(Some(batch)) } /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly @@ -924,9 +989,6 @@ impl GroupedHashAggregateStream { && self.update_memory_reservation().is_err() { assert_ne!(self.mode, AggregateMode::Partial); - // Use input batch (Partial mode) schema for spilling because - // the spilled data will be merged and re-evaluated later. - self.spill_state.spill_schema = batch.schema(); self.spill()?; self.clear_shrink(batch); } @@ -935,8 +997,10 @@ impl GroupedHashAggregateStream { /// Emit all rows, sort them, and store them on disk. fn spill(&mut self) -> Result<()> { - let emit = self.emit(EmitTo::All, true)?; - let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; + let Some(emit) = self.emit(EmitTo::All, true)? else { + return Ok(()); + }; + let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?; let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; // TODO: slice large `sorted` and write to multiple files in parallel spill_record_batch_by_size( @@ -946,6 +1010,14 @@ impl GroupedHashAggregateStream { self.batch_size, )?; self.spill_state.spills.push(spillfile); + + // Update metrics + self.spill_state.spill_count.add(1); + self.spill_state + .spilled_bytes + .add(sorted.get_array_memory_size()); + self.spill_state.spilled_rows.add(sorted.num_rows()); + Ok(()) } @@ -972,8 +1044,9 @@ impl GroupedHashAggregateStream { { assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; - let batch = self.emit(EmitTo::First(n), false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + if let Some(batch) = self.emit(EmitTo::First(n), false)? { + self.exec_state = ExecutionState::ProducingOutput(batch); + }; } Ok(()) } @@ -983,7 +1056,9 @@ impl GroupedHashAggregateStream { /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. fn update_merged_stream(&mut self) -> Result<()> { - let batch = self.emit(EmitTo::All, true)?; + let Some(batch) = self.emit(EmitTo::All, true)? else { + return Ok(()); + }; // clear up memory for streaming_merge self.clear_all(); self.update_memory_reservation()?; @@ -993,7 +1068,7 @@ impl GroupedHashAggregateStream { streams.push(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, &expr, None) + sort_batch(&batch, expr.as_ref(), None) })), ))); for spill in self.spill_state.spills.drain(..) { @@ -1004,7 +1079,7 @@ impl GroupedHashAggregateStream { self.input = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(schema) - .with_expressions(&self.spill_state.spill_expr) + .with_expressions(self.spill_state.spill_expr.as_ref()) .with_metrics(self.baseline_metrics.clone()) .with_batch_size(self.batch_size) .with_reservation(self.reservation.new_empty()) @@ -1031,7 +1106,7 @@ impl GroupedHashAggregateStream { let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { let batch = self.emit(EmitTo::All, false)?; - ExecutionState::ProducingOutput(batch) + batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) } else { // If spill files exist, stream-merge them. self.update_merged_stream()?; @@ -1060,8 +1135,9 @@ impl GroupedHashAggregateStream { fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { - let batch = self.emit(EmitTo::All, false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + if let Some(batch) = self.emit(EmitTo::All, false)? { + self.exec_state = ExecutionState::ProducingOutput(batch); + }; } } diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 232b87de32314..23a07ebec305f 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -17,7 +17,7 @@ //! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index -use crate::aggregates::group_values::primitive::HashValue; +use crate::aggregates::group_values::HashValue; use crate::aggregates::topk::heap::Comparable; use ahash::RandomState; use arrow::datatypes::i256; @@ -109,7 +109,7 @@ impl StringHashTable { Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } @@ -181,7 +181,7 @@ where Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index e694422e443da..ec1277f8fd558 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -20,11 +20,11 @@ use arrow::datatypes::i256; use arrow_array::cast::AsArray; use arrow_array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; -use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use arrow_schema::DataType; use datafusion_common::DataFusionError; use datafusion_common::Result; -use datafusion_physical_expr::aggregate::utils::adjust_output_array; + use half::f16; use std::cmp::Ordering; use std::fmt::{Debug, Display, Formatter}; @@ -151,10 +151,11 @@ where } fn drain(&mut self) -> (ArrayRef, Vec) { + let nulls = None; let (vals, map_idxs) = self.heap.drain(); - let vals = Arc::new(PrimitiveArray::::from_iter_values(vals)); - let vals = adjust_output_array(&self.data_type, vals).expect("Type is incorrect"); - (vals, map_idxs) + let arr = PrimitiveArray::::new(ScalarBuffer::from(vals), nulls) + .with_data_type(self.data_type.clone()); + (Arc::new(arr), map_idxs) } } @@ -366,7 +367,7 @@ impl TopKHeap { impl Display for TopKHeap { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut output = String::new(); - if self.heap.first().is_some() { + if !self.heap.is_empty() { self._tree_print(0, String::new(), true, &mut output); } write!(f, "{}", output) diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 075d8c5f28833..5d18f40d13bc7 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -123,7 +123,7 @@ impl Stream for GroupedTopKAggregateStream { batch.num_rows() ); if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 { - print_batches(&[batch.clone()])?; + print_batches(std::slice::from_ref(&batch))?; } self.row_count += batch.num_rows(); let batches = &[batch]; @@ -165,7 +165,7 @@ impl Stream for GroupedTopKAggregateStream { batch.num_rows() ); if log::log_enabled!(Level::Trace) { - print_batches(&[batch.clone()])?; + print_batches(std::slice::from_ref(&batch))?; } return Poll::Ready(Some(Ok(batch))); } diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 287446328f8de..708f006b0d390 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -40,9 +40,9 @@ use futures::StreamExt; /// discards the results, and then prints out an annotated plan with metrics #[derive(Debug, Clone)] pub struct AnalyzeExec { - /// control how much extra to print + /// Control how much extra to print verbose: bool, - /// if statistics should be displayed + /// If statistics should be displayed show_statistics: bool, /// The input plan (the plan being analyzed) pub(crate) input: Arc, @@ -69,12 +69,12 @@ impl AnalyzeExec { } } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } - /// access to show_statistics + /// Access to show_statistics pub fn show_statistics(&self) -> bool { self.show_statistics } @@ -89,10 +89,12 @@ impl AnalyzeExec { input: &Arc, schema: SchemaRef, ) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - let output_partitioning = Partitioning::UnknownPartitioning(1); - let exec_mode = input.execution_mode(); - PlanProperties::new(eq_properties, output_partitioning, exec_mode) + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + input.pipeline_behavior(), + input.boundedness(), + ) } } @@ -171,7 +173,7 @@ impl ExecutionPlan for AnalyzeExec { ); } - // Create future that computes thefinal output + // Create future that computes the final output let start = Instant::now(); let captured_input = Arc::clone(&self.input); let captured_schema = Arc::clone(&self.schema); diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index 46875fae94fc3..f38876d93ec11 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -180,7 +180,7 @@ impl BatchCoalescer { /// Indicates the state of the [`BatchCoalescer`] buffer after the /// [`BatchCoalescer::push_batch()`] operation. /// -/// The caller should take diferent actions, depending on the variant returned. +/// The caller should take different actions, depending on the variant returned. pub enum CoalescerState { /// Neither the limit nor the target batch size is reached. /// diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 7caf5b8ab65a3..fa8d125d62d1f 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -34,6 +34,7 @@ use datafusion_common::Result; use datafusion_execution::TaskContext; use crate::coalesce::{BatchCoalescer, CoalescerState}; +use crate::execution_plan::CardinalityEffect; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -47,11 +48,11 @@ use futures::stream::{Stream, StreamExt}; /// reaches the `fetch` value. /// /// See [`BatchCoalescer`] for more information -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CoalesceBatchesExec { /// The input plan input: Arc, - /// Minimum number of rows for coalesces batches + /// Minimum number of rows for coalescing batches target_batch_size: usize, /// Maximum number of rows to fetch, `None` means fetching all rows fetch: Option, @@ -96,7 +97,8 @@ impl CoalesceBatchesExec { PlanProperties::new( input.equivalence_properties().clone(), // Equivalence Properties input.output_partitioning().clone(), // Output Partitioning - input.execution_mode(), // Execution Mode + input.pipeline_behavior(), + input.boundedness(), ) } } @@ -199,6 +201,10 @@ impl ExecutionPlan for CoalesceBatchesExec { fn fetch(&self) -> Option { self.fetch } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 486ae41901db3..7c1bdba2f339c 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -30,12 +30,13 @@ use super::{ use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::execution_plan::CardinalityEffect; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CoalescePartitionsExec { /// Input execution plan input: Arc, @@ -69,7 +70,8 @@ impl CoalescePartitionsExec { PlanProperties::new( eq_properties, // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning - input.execution_mode(), // Execution Mode + input.pipeline_behavior(), + input.boundedness(), ) } } @@ -178,6 +180,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } #[cfg(test)] @@ -231,10 +237,10 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); - let coaelesce_partitions_exec = + let coalesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec, task_ctx); + let fut = collect(coalesce_partitions_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 4b5eea6b760df..aefb90d1d1b71 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -109,7 +109,7 @@ pub(crate) fn spawn_buffered( builder.spawn(async move { while let Some(item) = input.next().await { if sender.send(item).await.is_err() { - // receiver dropped when query is shutdown early (e.g., limit) or error, + // Receiver dropped when query is shutdown early (e.g., limit) or error, // no need to return propagate the send error. return Ok(()); } @@ -156,7 +156,11 @@ pub fn compute_record_batch_statistics( for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - null_counts[stat_index] += batch.column(*col_index).null_count(); + null_counts[stat_index] += batch + .column(*col_index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default(); } } } @@ -178,15 +182,15 @@ pub fn compute_record_batch_statistics( /// Write in Arrow IPC format. pub struct IPCWriter { - /// path + /// Path pub path: PathBuf, - /// inner writer + /// Inner writer pub writer: FileWriter, - /// batches written + /// Batches written pub num_batches: usize, - /// rows written + /// Rows written pub num_rows: usize, - /// bytes written + /// Bytes written pub num_bytes: usize, } @@ -257,7 +261,7 @@ pub fn can_project( if columns .iter() .max() - .map_or(false, |&i| i >= schema.fields().len()) + .is_some_and(|&i| i >= schema.fields().len()) { Err(arrow_schema::ArrowError::SchemaError(format!( "project index {} out of bounds, max field {}", @@ -311,7 +315,7 @@ mod tests { ], )?; - // just select f32,f64 + // Just select f32,f64 let select_projection = Some(vec![0, 1]); let byte_size = batch .project(&select_projection.clone().unwrap()) diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 0d2653c5c7753..961d2f639897c 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -25,7 +25,7 @@ use arrow_schema::SchemaRef; use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; use datafusion_expr::display_schema; -use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::LexOrdering; use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; @@ -38,7 +38,39 @@ pub enum DisplayFormatType { Verbose, } -/// Wraps an `ExecutionPlan` with various ways to display this plan +/// Wraps an `ExecutionPlan` with various methods for formatting +/// +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_schema::{Field, Schema, DataType}; +/// # use datafusion_expr::Operator; +/// # use datafusion_physical_expr::expressions::{binary, col, lit}; +/// # use datafusion_physical_plan::{displayable, ExecutionPlan}; +/// # use datafusion_physical_plan::empty::EmptyExec; +/// # use datafusion_physical_plan::filter::FilterExec; +/// # let schema = Schema::new(vec![Field::new("i", DataType::Int32, false)]); +/// # let plan = EmptyExec::new(Arc::new(schema)); +/// # let i = col("i", &plan.schema()).unwrap(); +/// # let predicate = binary(i, Operator::Eq, lit(1), &plan.schema()).unwrap(); +/// # let plan: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(plan)).unwrap()); +/// // Get a one line description (Displayable) +/// let display_plan = displayable(plan.as_ref()); +/// +/// // you can use the returned objects to format plans +/// // where you can use `Display` such as format! or println! +/// assert_eq!( +/// &format!("The plan is: {}", display_plan.one_line()), +/// "The plan is: FilterExec: i@0 = 1\n" +/// ); +/// // You can also print out the plan and its children in indented mode +/// assert_eq!(display_plan.indent(false).to_string(), +/// "FilterExec: i@0 = 1\ +/// \n EmptyExec\ +/// \n" +/// ); +/// ``` #[derive(Debug, Clone)] pub struct DisplayableExecutionPlan<'a> { inner: &'a dyn ExecutionPlan, @@ -124,8 +156,8 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: bool, show_schema: bool, } - impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + impl fmt::Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { t: self.format_type, f, @@ -163,8 +195,8 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics, show_statistics: bool, } - impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + impl fmt::Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let t = DisplayFormatType::Default; let mut visitor = GraphvizVisitor { @@ -202,8 +234,8 @@ impl<'a> DisplayableExecutionPlan<'a> { show_schema: bool, } - impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + impl fmt::Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { f, t: DisplayFormatType::Default, @@ -231,6 +263,7 @@ impl<'a> DisplayableExecutionPlan<'a> { } } +/// Enum representing the different levels of metrics to display #[derive(Debug, Clone, Copy)] enum ShowMetrics { /// Do not show any metrics @@ -256,7 +289,7 @@ struct IndentVisitor<'a, 'b> { /// How to format each node t: DisplayFormatType, /// Write to this formatter - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// Indent size indent: usize, /// How to show metrics @@ -267,7 +300,7 @@ struct IndentVisitor<'a, 'b> { show_schema: bool, } -impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { +impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { type Error = fmt::Error; fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { write!(self.f, "{:indent$}", "", indent = self.indent * 2)?; @@ -317,7 +350,7 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } struct GraphvizVisitor<'a, 'b> { - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// How to format each node t: DisplayFormatType, /// How to show metrics @@ -348,8 +381,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { struct Wrapper<'a>(&'a dyn ExecutionPlan, DisplayFormatType); - impl<'a> std::fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + impl fmt::Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(self.1, f) } } @@ -421,23 +454,23 @@ pub trait DisplayAs { /// different from the default one /// /// Should not include a newline - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result; + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; } -/// A newtype wrapper to display `T` implementing`DisplayAs` using the `Default` mode +/// A new type wrapper to display `T` implementing`DisplayAs` using the `Default` mode pub struct DefaultDisplay(pub T); impl fmt::Display for DefaultDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Default, f) } } -/// A newtype wrapper to display `T` implementing `DisplayAs` using the `Verbose` mode +/// A new type wrapper to display `T` implementing `DisplayAs` using the `Verbose` mode pub struct VerboseDisplay(pub T); impl fmt::Display for VerboseDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Verbose, f) } } @@ -446,8 +479,8 @@ impl fmt::Display for VerboseDisplay { #[derive(Debug)] pub struct ProjectSchemaDisplay<'a>(pub &'a SchemaRef); -impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl fmt::Display for ProjectSchemaDisplay<'_> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let parts: Vec<_> = self .0 .fields() @@ -458,23 +491,6 @@ impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { } } -/// A wrapper to customize output ordering display. -#[derive(Debug)] -pub struct OutputOrderingDisplay<'a>(pub &'a [PhysicalSortExpr]); - -impl<'a> fmt::Display for OutputOrderingDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "[")?; - for (i, e) in self.0.iter().enumerate() { - if i > 0 { - write!(f, ", ")? - } - write!(f, "{e}")?; - } - write!(f, "]") - } -} - pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::Result { if let Some(ordering) = orderings.first() { if !ordering.is_empty() { @@ -488,8 +504,8 @@ pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::R orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) { match idx { - 0 => write!(f, "{}", OutputOrderingDisplay(ordering))?, - _ => write!(f, ", {}", OutputOrderingDisplay(ordering))?, + 0 => write!(f, "[{}]", ordering)?, + _ => write!(f, ", [{}]", ordering)?, } } let end = if orderings.len() == 1 { "" } else { "]" }; diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 4bacea48c3473..5168c3cc101f2 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -20,11 +20,12 @@ use std::any::Any; use std::sync::Arc; -use super::{ - common, DisplayAs, ExecutionMode, PlanProperties, SendableRecordBatchStream, - Statistics, +use super::{common, DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; +use crate::{ + execution_plan::{Boundedness, EmissionType}, + memory::MemoryStream, + DisplayFormatType, ExecutionPlan, Partitioning, }; -use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -35,7 +36,7 @@ use datafusion_physical_expr::EquivalenceProperties; use log::trace; /// Execution plan for empty relation with produce_one_row=false -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct EmptyExec { /// The schema for the produced row schema: SchemaRef, @@ -74,14 +75,11 @@ impl EmptyExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef, n_partitions: usize) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - let output_partitioning = Self::output_partitioning_helper(n_partitions); PlanProperties::new( - eq_properties, - // Output Partitioning - output_partitioning, - // Execution Mode - ExecutionMode::Bounded, + EquivalenceProperties::new(schema), + Self::output_partitioning_helper(n_partitions), + EmissionType::Incremental, + Boundedness::Bounded, ) } } @@ -173,7 +171,7 @@ mod tests { let empty = EmptyExec::new(Arc::clone(&schema)); assert_eq!(empty.schema(), schema); - // we should have no results + // We should have no results let iter = empty.execute(0, task_ctx)?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index b14021f4a99ba..6d79355531160 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -21,13 +21,14 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use arrow_array::Array; use futures::stream::{StreamExt, TryStreamExt}; use tokio::task::JoinSet; use datafusion_common::config::ConfigOptions; pub use datafusion_common::hash_utils; pub use datafusion_common::utils::project_schema; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, Constraints, Result}; pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; use datafusion_execution::TaskContext; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -36,7 +37,7 @@ pub use datafusion_physical_expr::window::WindowExpr; pub use datafusion_physical_expr::{ expressions, udf, Distribution, Partitioning, PhysicalExpr, }; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use datafusion_physical_expr_common::sort_expr::LexRequirement; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -416,6 +417,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { fn fetch(&self) -> Option { None } + + /// Gets the effect on cardinality, if known + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Unknown + } } /// Extension trait provides an easy API to fetch various properties of @@ -425,11 +431,6 @@ pub trait ExecutionPlanProperties { /// partitions. fn output_partitioning(&self) -> &Partitioning; - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns [`ExecutionMode::PipelineBreaking`] to indicate this. - fn execution_mode(&self) -> ExecutionMode; - /// If the output of this `ExecutionPlan` within each partition is sorted, /// returns `Some(keys)` describing the ordering. A `None` return value /// indicates no assumptions should be made on the output ordering. @@ -437,7 +438,15 @@ pub trait ExecutionPlanProperties { /// For example, `SortExec` (obviously) produces sorted output as does /// `SortPreservingMergeStream`. Less obviously, `Projection` produces sorted /// output if its input is sorted as it does not reorder the input rows. - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; + fn output_ordering(&self) -> Option<&LexOrdering>; + + /// Boundedness information of the stream corresponding to this `ExecutionPlan`. + /// For more details, see [`Boundedness`]. + fn boundedness(&self) -> Boundedness; + + /// Indicates how the stream of this `ExecutionPlan` emits its results. + /// For more details, see [`EmissionType`]. + fn pipeline_behavior(&self) -> EmissionType; /// Get the [`EquivalenceProperties`] within the plan. /// @@ -464,12 +473,16 @@ impl ExecutionPlanProperties for Arc { self.properties().output_partitioning() } - fn execution_mode(&self) -> ExecutionMode { - self.properties().execution_mode() + fn output_ordering(&self) -> Option<&LexOrdering> { + self.properties().output_ordering() } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.properties().output_ordering() + fn boundedness(&self) -> Boundedness { + self.properties().boundedness + } + + fn pipeline_behavior(&self) -> EmissionType { + self.properties().emission_type } fn equivalence_properties(&self) -> &EquivalenceProperties { @@ -482,12 +495,16 @@ impl ExecutionPlanProperties for &dyn ExecutionPlan { self.properties().output_partitioning() } - fn execution_mode(&self) -> ExecutionMode { - self.properties().execution_mode() + fn output_ordering(&self) -> Option<&LexOrdering> { + self.properties().output_ordering() + } + + fn boundedness(&self) -> Boundedness { + self.properties().boundedness } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.properties().output_ordering() + fn pipeline_behavior(&self) -> EmissionType { + self.properties().emission_type } fn equivalence_properties(&self) -> &EquivalenceProperties { @@ -495,82 +512,142 @@ impl ExecutionPlanProperties for &dyn ExecutionPlan { } } -/// Describes the execution mode of the result of calling -/// [`ExecutionPlan::execute`] with respect to its size and behavior. +/// Represents whether a stream of data **generated** by an operator is bounded (finite) +/// or unbounded (infinite). /// -/// The mode of the execution plan is determined by the mode of its input -/// execution plans and the details of the operator itself. For example, a -/// `FilterExec` operator will have the same execution mode as its input, but a -/// `SortExec` operator may have a different execution mode than its input, -/// depending on how the input stream is sorted. +/// This is used to determine whether an execution plan will eventually complete +/// processing all its data (bounded) or could potentially run forever (unbounded). /// -/// There are three possible execution modes: `Bounded`, `Unbounded` and -/// `PipelineBreaking`. -#[derive(Clone, Copy, PartialEq, Debug)] -pub enum ExecutionMode { - /// The stream is bounded / finite. - /// - /// In this case the stream will eventually return `None` to indicate that - /// there are no more records to process. +/// For unbounded streams, it also tracks whether the operator requires finite memory +/// to process the stream or if memory usage could grow unbounded. +/// +/// Boundedness of the output stream is based on the the boundedness of the input stream and the nature of +/// the operator. For example, limit or topk with fetch operator can convert an unbounded stream to a bounded stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Boundedness { + /// The data stream is bounded (finite) and will eventually complete Bounded, - /// The stream is unbounded / infinite. - /// - /// In this case, the stream will never be done (never return `None`), - /// except in case of error. - /// - /// This mode is often used in "Steaming" use cases where data is - /// incrementally processed as it arrives. - /// - /// Note that even though the operator generates an unbounded stream of - /// results, it can execute with bounded memory and incrementally produces - /// output. - Unbounded, - /// Some of the operator's input stream(s) are unbounded, but the operator - /// cannot generate streaming results from these streaming inputs. - /// - /// In this case, the execution mode will be pipeline breaking, e.g. the - /// operator requires unbounded memory to generate results. This - /// information is used by the planner when performing sanity checks - /// on plans processings unbounded data sources. - PipelineBreaking, + /// The data stream is unbounded (infinite) and could run forever + Unbounded { + /// Whether this operator requires infinite memory to process the unbounded stream. + /// If false, the operator can process an infinite stream with bounded memory. + /// If true, memory usage may grow unbounded while processing the stream. + /// + /// For example, `Median` requires infinite memory to compute the median of an unbounded stream. + /// `Min/Max` requires infinite memory if the stream is unordered, but can be computed with bounded memory if the stream is ordered. + requires_infinite_memory: bool, + }, } -impl ExecutionMode { - /// Check whether the execution mode is unbounded or not. +impl Boundedness { pub fn is_unbounded(&self) -> bool { - matches!(self, ExecutionMode::Unbounded) + matches!(self, Boundedness::Unbounded { .. }) } +} - /// Check whether the execution is pipeline friendly. If so, operator can - /// execute safely. - pub fn pipeline_friendly(&self) -> bool { - matches!(self, ExecutionMode::Bounded | ExecutionMode::Unbounded) - } +/// Represents how an operator emits its output records. +/// +/// This is used to determine whether an operator emits records incrementally as they arrive, +/// only emits a final result at the end, or can do both. Note that it generates the output -- record batch with `batch_size` rows +/// but it may still buffer data internally until it has enough data to emit a record batch or the source is exhausted. +/// +/// For example, in the following plan: +/// ```text +/// SortExec [EmissionType::Final] +/// |_ on: [col1 ASC] +/// FilterExec [EmissionType::Incremental] +/// |_ pred: col2 > 100 +/// CsvExec [EmissionType::Incremental] +/// |_ file: "data.csv" +/// ``` +/// - CsvExec emits records incrementally as it reads from the file +/// - FilterExec processes and emits filtered records incrementally as they arrive +/// - SortExec must wait for all input records before it can emit the sorted result, +/// since it needs to see all values to determine their final order +/// +/// Left joins can emit both incrementally and finally: +/// - Incrementally emit matches as they are found +/// - Finally emit non-matches after all input is processed +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EmissionType { + /// Records are emitted incrementally as they arrive and are processed + Incremental, + /// Records are only emitted once all input has been processed + Final, + /// Records can be emitted both incrementally and as a final result + Both, } -/// Conservatively "combines" execution modes of a given collection of operators. -pub(crate) fn execution_mode_from_children<'a>( +/// Utility to determine an operator's boundedness based on its children's boundedness. +/// +/// Assumes boundedness can be inferred from child operators: +/// - Unbounded (requires_infinite_memory: true) takes precedence. +/// - Unbounded (requires_infinite_memory: false) is considered next. +/// - Otherwise, the operator is bounded. +/// +/// **Note:** This is a general-purpose utility and may not apply to +/// all multi-child operators. Ensure your operator's behavior aligns +/// with these assumptions before using. +pub(crate) fn boundedness_from_children<'a>( children: impl IntoIterator>, -) -> ExecutionMode { - let mut result = ExecutionMode::Bounded; - for mode in children.into_iter().map(|child| child.execution_mode()) { - match (mode, result) { - (ExecutionMode::PipelineBreaking, _) - | (_, ExecutionMode::PipelineBreaking) => { - // If any of the modes is `PipelineBreaking`, so is the result: - return ExecutionMode::PipelineBreaking; - } - (ExecutionMode::Unbounded, _) | (_, ExecutionMode::Unbounded) => { - // Unbounded mode eats up bounded mode: - result = ExecutionMode::Unbounded; +) -> Boundedness { + let mut unbounded_with_finite_mem = false; + + for child in children { + match child.boundedness() { + Boundedness::Unbounded { + requires_infinite_memory: true, + } => { + return Boundedness::Unbounded { + requires_infinite_memory: true, + } } - (ExecutionMode::Bounded, ExecutionMode::Bounded) => { - // When both modes are bounded, so is the result: - result = ExecutionMode::Bounded; + Boundedness::Unbounded { + requires_infinite_memory: false, + } => { + unbounded_with_finite_mem = true; } + Boundedness::Bounded => {} + } + } + + if unbounded_with_finite_mem { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + } +} + +/// Determines the emission type of an operator based on its children's pipeline behavior. +/// +/// The precedence of emission types is: +/// - `Final` has the highest precedence. +/// - `Both` is next: if any child emits both incremental and final results, the parent inherits this behavior unless a `Final` is present. +/// - `Incremental` is the default if all children emit incremental results. +/// +/// **Note:** This is a general-purpose utility and may not apply to +/// all multi-child operators. Verify your operator's behavior aligns +/// with these assumptions. +pub(crate) fn emission_type_from_children<'a>( + children: impl IntoIterator>, +) -> EmissionType { + let mut inc_and_final = false; + + for child in children { + match child.pipeline_behavior() { + EmissionType::Final => return EmissionType::Final, + EmissionType::Both => inc_and_final = true, + EmissionType::Incremental => continue, } } - result + + if inc_and_final { + EmissionType::Both + } else { + EmissionType::Incremental + } } /// Stores certain, often expensive to compute, plan properties used in query @@ -585,8 +662,10 @@ pub struct PlanProperties { pub eq_properties: EquivalenceProperties, /// See [ExecutionPlanProperties::output_partitioning] pub partitioning: Partitioning, - /// See [ExecutionPlanProperties::execution_mode] - pub execution_mode: ExecutionMode, + /// See [ExecutionPlanProperties::pipeline_behavior] + pub emission_type: EmissionType, + /// See [ExecutionPlanProperties::boundedness] + pub boundedness: Boundedness, /// See [ExecutionPlanProperties::output_ordering] output_ordering: Option, } @@ -596,14 +675,16 @@ impl PlanProperties { pub fn new( eq_properties: EquivalenceProperties, partitioning: Partitioning, - execution_mode: ExecutionMode, + emission_type: EmissionType, + boundedness: Boundedness, ) -> Self { // Output ordering can be derived from `eq_properties`. let output_ordering = eq_properties.output_ordering(); Self { eq_properties, partitioning, - execution_mode, + emission_type, + boundedness, output_ordering, } } @@ -614,12 +695,6 @@ impl PlanProperties { self } - /// Overwrite the execution Mode with its new value. - pub fn with_execution_mode(mut self, execution_mode: ExecutionMode) -> Self { - self.execution_mode = execution_mode; - self - } - /// Overwrite equivalence properties with its new value. pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { // Changing equivalence properties also changes output ordering, so @@ -629,6 +704,24 @@ impl PlanProperties { self } + /// Overwrite boundedness with its new value. + pub fn with_boundedness(mut self, boundedness: Boundedness) -> Self { + self.boundedness = boundedness; + self + } + + /// Overwrite emission type with its new value. + pub fn with_emission_type(mut self, emission_type: EmissionType) -> Self { + self.emission_type = emission_type; + self + } + + /// Overwrite constraints with its new value. + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.eq_properties = self.eq_properties.with_constraints(constraints); + self + } + pub fn equivalence_properties(&self) -> &EquivalenceProperties { &self.eq_properties } @@ -637,12 +730,8 @@ impl PlanProperties { &self.partitioning } - pub fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.output_ordering.as_deref() - } - - pub fn execution_mode(&self) -> ExecutionMode { - self.execution_mode + pub fn output_ordering(&self) -> Option<&LexOrdering> { + self.output_ordering.as_ref() } /// Get schema of the node. @@ -700,9 +789,11 @@ pub fn with_new_children_if_necessary( } } -/// Return a [wrapper](DisplayableExecutionPlan) around an +/// Return a [`DisplayableExecutionPlan`] wrapper around an /// [`ExecutionPlan`] which can be displayed in various easier to /// understand ways. +/// +/// See examples on [`DisplayableExecutionPlan`] pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { DisplayableExecutionPlan::new(plan) } @@ -818,7 +909,7 @@ pub fn execute_stream_partitioned( /// and context. It then checks if there are any columns in the input that might /// violate the `not null` constraints specified in the `sink_schema`. If there are /// such columns, it wraps the resulting stream to enforce the `not null` constraints -/// by invoking the `check_not_null_contraits` function on each batch of the stream. +/// by invoking the [`check_not_null_constraints`] function on each batch of the stream. pub fn execute_input_stream( input: Arc, sink_schema: SchemaRef, @@ -847,7 +938,7 @@ pub fn execute_input_stream( Ok(Box::pin(RecordBatchStreamAdapter::new( sink_schema, input_stream - .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), + .map(move |batch| check_not_null_constraints(batch?, &risky_columns)), ))) } } @@ -867,7 +958,7 @@ pub fn execute_input_stream( /// This function iterates over the specified column indices and ensures that none /// of the columns contain null values. If any column contains null values, an error /// is returned. -pub fn check_not_null_contraits( +pub fn check_not_null_constraints( batch: RecordBatch, column_indices: &Vec, ) -> Result { @@ -880,7 +971,13 @@ pub fn check_not_null_contraits( ); } - if batch.column(index).null_count() > 0 { + if batch + .column(index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default() + > 0 + { return exec_err!( "Invalid batch column at '{}' has null but schema specifies non-nullable", index @@ -898,14 +995,28 @@ pub fn get_plan_string(plan: &Arc) -> Vec { actual.iter().map(|elem| elem.to_string()).collect() } +/// Indicates the effect an execution plan operator will have on the cardinality +/// of its input stream +pub enum CardinalityEffect { + /// Unknown effect. This is the default + Unknown, + /// The operator is guaranteed to produce exactly one row for + /// each input row + Equal, + /// The operator may produce fewer output rows than it receives input rows + LowerEqual, + /// The operator may produce more output rows than it receives input rows + GreaterEqual, +} + #[cfg(test)] mod tests { use super::*; + use arrow_array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; - use arrow_schema::{Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -1049,6 +1160,125 @@ mod tests { fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { let _ = plan.name(); } -} -// pub mod test; + #[test] + fn test_check_not_null_constraints_accept_non_null() -> Result<()> { + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_reject_null() -> Result<()> { + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().strip_backtrace(), + "Execution error: Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_run_end_array() -> Result<()> { + // some null value inside REE array + let run_ends = Int32Array::from(vec![1, 2, 3, 4]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let run_end_array = RunArray::try_new(&run_ends, &values)?; + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + run_end_array.data_type().to_owned(), + true, + )])), + vec![Arc::new(run_end_array)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().strip_backtrace(), + "Execution error: Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_array_with_null() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])); + let keys = Int32Array::from(vec![0, 1, 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().strip_backtrace(), + "Execution error: Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_masking_null() -> Result<()> { + // some null value marked out by dictionary array + let values = Arc::new(Int32Array::from(vec![ + Some(1), + None, // this null value is masked by dictionary keys + Some(3), + Some(4), + ])); + let keys = Int32Array::from(vec![0, /*1,*/ 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_on_null_type() -> Result<()> { + // null value of Null type + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Null, true)])), + vec![Arc::new(NullArray::new(3))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().strip_backtrace(), + "Execution error: Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 56dc35e8819d5..cb00958cec4cb 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -20,7 +20,8 @@ use std::any::Any; use std::sync::Arc; -use super::{DisplayAs, ExecutionMode, PlanProperties, SendableRecordBatchStream}; +use super::{DisplayAs, PlanProperties, SendableRecordBatchStream}; +use crate::execution_plan::{Boundedness, EmissionType}; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; @@ -67,18 +68,18 @@ impl ExplainExec { &self.stringified_plans } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); PlanProperties::new( - eq_properties, + EquivalenceProperties::new(schema), Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, + EmissionType::Final, + Boundedness::Bounded, ) } } @@ -112,7 +113,7 @@ impl ExecutionPlan for ExplainExec { } fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children + // This is a leaf node and has no children vec![] } @@ -132,7 +133,6 @@ impl ExecutionPlan for ExplainExec { if 0 != partition { return internal_err!("ExplainExec invalid partition {partition}"); } - let mut type_builder = StringBuilder::with_capacity(self.stringified_plans.len(), 1024); let mut plan_builder = diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 8d032aaad7191..7e6c4a98e2340 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -45,15 +45,17 @@ use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, + analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr, + ExprBoundaries, PhysicalExpr, }; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to /// include in its output batches. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FilterExec { /// The expression to filter on. This expression must evaluate to a boolean value. predicate: Arc, @@ -114,7 +116,7 @@ impl FilterExec { /// Return new instance of [FilterExec] with the given projection. pub fn with_projection(&self, projection: Option>) -> Result { - // check if the projection is valid + // Check if the projection is valid can_project(&self.schema(), projection.as_ref())?; let projection = match projection { @@ -156,7 +158,7 @@ impl FilterExec { self.default_selectivity } - /// projection + /// Projection pub fn projection(&self) -> Option<&Vec> { self.projection.as_ref() } @@ -217,13 +219,23 @@ impl FilterExec { if binary.op() == &Operator::Eq { // Filter evaluates to single value for all partitions if input_eqs.is_expr_constant(binary.left()) { + let (expr, across_parts) = ( + binary.right(), + input_eqs.get_expr_constant_value(binary.right()), + ); res_constants.push( - ConstExpr::from(binary.right()).with_across_partitions(true), - ) + ConstExpr::new(Arc::clone(expr)) + .with_across_partitions(across_parts), + ); } else if input_eqs.is_expr_constant(binary.right()) { + let (expr, across_parts) = ( + binary.left(), + input_eqs.get_expr_constant_value(binary.left()), + ); res_constants.push( - ConstExpr::from(binary.left()).with_across_partitions(true), - ) + ConstExpr::new(Arc::clone(expr)) + .with_across_partitions(across_parts), + ); } } } @@ -251,12 +263,16 @@ impl FilterExec { .into_iter() .filter(|column| stats.column_statistics[column.index()].is_singleton()) .map(|column| { + let value = stats.column_statistics[column.index()] + .min_value + .get_value(); let expr = Arc::new(column) as _; - ConstExpr::new(expr).with_across_partitions(true) + ConstExpr::new(expr) + .with_across_partitions(AcrossPartitions::Uniform(value.cloned())) }); - // this is for statistics + // This is for statistics eq_properties = eq_properties.with_constants(constants); - // this is for logical constant (for example: a = '1', then a could be marked as a constant) + // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) eq_properties = eq_properties.with_constants(Self::extend_constants(input, predicate)); @@ -271,10 +287,12 @@ impl FilterExec { output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); } + Ok(PlanProperties::new( eq_properties, output_partitioning, - input.execution_mode(), + input.pipeline_behavior(), + input.boundedness(), )) } } @@ -330,7 +348,7 @@ impl ExecutionPlan for FilterExec { } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -370,7 +388,16 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. fn statistics(&self) -> Result { - Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity) + let stats = Self::statistics_helper( + &self.input, + self.predicate(), + self.default_selectivity, + )?; + Ok(stats.project(self.projection.as_ref())) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual } } @@ -420,7 +447,7 @@ struct FilterExecStream { predicate: Arc, /// The input partition to filter. input: SendableRecordBatchStream, - /// runtime metrics recording + /// Runtime metrics recording baseline_metrics: BaselineMetrics, /// The projection indices of the columns in the input schema projection: Option>, @@ -444,7 +471,7 @@ fn filter_and_project( .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { Ok(match (as_boolean_array(&array), projection) { - // apply filter array to record batch + // Apply filter array to record batch (Ok(filter_array), None) => filter_record_batch(batch, filter_array)?, (Ok(filter_array), Some(projection)) => { let projected_columns = projection @@ -485,7 +512,7 @@ impl Stream for FilterExecStream { &self.schema, )?; timer.done(); - // skip entirely filtered batches + // Skip entirely filtered batches if filtered_batch.num_rows() == 0 { continue; } @@ -502,7 +529,7 @@ impl Stream for FilterExecStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 5dc27bc239d26..bfb1e9d53df51 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -23,11 +23,12 @@ use std::fmt::Debug; use std::sync::Arc; use super::{ - execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, - ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, + execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PlanProperties, SendableRecordBatchStream, }; use crate::metrics::MetricsSet; use crate::stream::RecordBatchStreamAdapter; +use crate::ExecutionPlanProperties; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -56,7 +57,12 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync { /// [DataSink]. /// /// See [ExecutionPlan::metrics()] for more details - fn metrics(&self) -> Option; + fn metrics(&self) -> Option { + None + } + + /// Returns the sink schema + fn schema(&self) -> &SchemaRef; // TODO add desired input ordering // How does this sink want its input ordered? @@ -73,19 +79,15 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync { ) -> Result; } -#[deprecated(since = "38.0.0", note = "Use [`DataSinkExec`] instead")] -pub type FileSinkExec = DataSinkExec; - /// Execution plan for writing record batches to a [`DataSink`] /// /// Returns a single row with the number of values written +#[derive(Clone)] pub struct DataSinkExec { /// Input plan that produces the record batches to be written. input: Arc, /// Sink to which to write sink: Arc, - /// Schema of the sink for validating the input data - sink_schema: SchemaRef, /// Schema describing the structure of the output data. count_schema: SchemaRef, /// Optional required sort order for output data. @@ -93,7 +95,7 @@ pub struct DataSinkExec { cache: PlanProperties, } -impl fmt::Debug for DataSinkExec { +impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DataSinkExec schema: {:?}", self.count_schema) } @@ -104,7 +106,6 @@ impl DataSinkExec { pub fn new( input: Arc, sink: Arc, - sink_schema: SchemaRef, sort_order: Option, ) -> Self { let count_schema = make_count_schema(); @@ -112,7 +113,6 @@ impl DataSinkExec { Self { input, sink, - sink_schema, count_schema: make_count_schema(), sort_order, cache, @@ -142,17 +142,14 @@ impl DataSinkExec { PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(1), - input.execution_mode(), + input.pipeline_behavior(), + input.boundedness(), ) } } impl DisplayAs for DataSinkExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "DataSinkExec: sink=")?; @@ -213,7 +210,6 @@ impl ExecutionPlan for DataSinkExec { Ok(Arc::new(Self::new( Arc::clone(&children[0]), Arc::clone(&self.sink), - Arc::clone(&self.sink_schema), self.sort_order.clone(), ))) } @@ -230,7 +226,7 @@ impl ExecutionPlan for DataSinkExec { } let data = execute_input_stream( Arc::clone(&self.input), - Arc::clone(&self.sink_schema), + Arc::clone(self.sink.schema()), 0, Arc::clone(&context), )?; @@ -271,7 +267,7 @@ fn make_count_batch(count: u64) -> RecordBatch { } fn make_count_schema() -> SchemaRef { - // define a schema. + // Define a schema. Arc::new(Schema::new(vec![Field::new( "count", DataType::UInt64, diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index a70645f3d6c0c..69300fce77454 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -19,15 +19,16 @@ //! and producing batches in parallel for the right partitions use super::utils::{ - adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter, + BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs, - DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, - ExecutionPlanProperties, PlanProperties, RecordBatchStream, + handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use arrow::compute::concat_batches; @@ -45,11 +46,31 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use async_trait::async_trait; use futures::{ready, Stream, StreamExt, TryStreamExt}; -/// Data of the left side -type JoinLeftData = (RecordBatch, MemoryReservation); +/// Data of the left side that is buffered into memory +#[derive(Debug)] +struct JoinLeftData { + /// Single RecordBatch with all rows from the left side + merged_batch: RecordBatch, + /// Track memory reservation for merged_batch. Relies on drop + /// semantics to release reservation when JoinLeftData is dropped. + _reservation: MemoryReservation, +} -/// executes partitions in parallel and combines them into a set of -/// partitions by combining all values from the left with all values on the right +#[allow(rustdoc::private_intra_doc_links)] +/// Cross Join Execution Plan +/// +/// This operator is used when there are no predicates between two tables and +/// returns the Cartesian product of the two tables. +/// +/// Buffers the left input into memory and then streams batches from each +/// partition on the right input combining them with the buffered left input +/// to generate the output. +/// +/// # Clone / Shared State +/// +/// Note this structure includes a [`OnceAsync`] that is used to coordinate the +/// loading of the left side with the processing in each output stream. +/// Therefore it can not be [`Clone`] #[derive(Debug)] pub struct CrossJoinExec { /// left (build) side which gets loaded in memory @@ -58,10 +79,16 @@ pub struct CrossJoinExec { pub right: Arc, /// The schema once the join is applied schema: SchemaRef, - /// Build-side data + /// Buffered copy of left (build) side in memory. + /// + /// This structure is *shared* across all output streams. + /// + /// Each output stream waits on the `OnceAsync` to signal the completion of + /// the left side loading. left_fut: OnceAsync, /// Execution plan metrics metrics: ExecutionPlanMetricsSet, + /// Properties such as schema, equivalence properties, ordering, partitioning, etc. cache: PlanProperties, } @@ -86,6 +113,7 @@ impl CrossJoinExec { let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + CrossJoinExec { left, right, @@ -133,14 +161,25 @@ impl CrossJoinExec { left.schema().fields.len(), ); - // Determine the execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - // If any of the inputs is unbounded, cross join breaks the pipeline. - mode = ExecutionMode::PipelineBreaking; - } + PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Final, + boundedness_from_children([left, right]), + ) + } - PlanProperties::new(eq_properties, output_partitioning, mode) + /// Returns a new `ExecutionPlan` that computes the same join as this one, + /// with the left and right inputs swapped using the specified + /// `partition_mode`. + pub fn swap_inputs(&self) -> Result> { + let new_join = + CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left)); + reorder_output_after_swap( + Arc::new(new_join), + &self.left.schema(), + &self.right.schema(), + ) } } @@ -162,23 +201,29 @@ async fn load_left_input( // Load all batches and count the rows let (batches, _metrics, reservation) = stream - .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); - // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; - // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); - // Push batch to output - acc.0.push(batch); - Ok(acc) - }) + .try_fold( + (Vec::new(), metrics, reservation), + |(mut batches, metrics, mut reservation), batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + reservation.try_grow(batch_size)?; + // Update metrics + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); + // Push batch to output + batches.push(batch); + Ok((batches, metrics, reservation)) + }, + ) .await?; let merged_batch = concat_batches(&left_schema, &batches)?; - Ok((merged_batch, reservation)) + Ok(JoinLeftData { + merged_batch, + _reservation: reservation, + }) } impl DisplayAs for CrossJoinExec { @@ -246,6 +291,10 @@ impl ExecutionPlan for CrossJoinExec { let reservation = MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let left_fut = self.left_fut.once(|| { load_left_input( Arc::clone(&self.left), @@ -255,15 +304,29 @@ impl ExecutionPlan for CrossJoinExec { ) }); - Ok(Box::pin(CrossJoinStream { - schema: Arc::clone(&self.schema), - left_fut, - right: stream, - left_index: 0, - join_metrics, - state: CrossJoinStreamState::WaitBuildSide, - left_data: RecordBatch::new_empty(self.left().schema()), - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: NoopBatchTransformer::new(), + })) + } } fn statistics(&self) -> Result { @@ -319,7 +382,7 @@ fn stats_cartesian_product( } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct CrossJoinStream { +struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side @@ -332,11 +395,13 @@ struct CrossJoinStream { join_metrics: BuildProbeJoinMetrics, /// State of the stream state: CrossJoinStreamState, - /// Left data + /// Left data (copy of the entire buffered left side) left_data: RecordBatch, + /// Batch transformer + batch_transformer: T, } -impl RecordBatchStream for CrossJoinStream { +impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -390,24 +455,24 @@ fn build_batch( } #[async_trait] -impl Stream for CrossJoinStream { +impl Stream for CrossJoinStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } -impl CrossJoinStream { +impl CrossJoinStream { /// Separate implementation function that unpins the [`CrossJoinStream`] so /// that partial borrows work correctly fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { + ) -> Poll>> { loop { return match self.state { CrossJoinStreamState::WaitBuildSide => { @@ -430,16 +495,17 @@ impl CrossJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let (left_data, _) = match ready!(self.left_fut.get(cx)) { + let left_data = match ready!(self.left_fut.get(cx)) { Ok(left_data) => left_data, Err(e) => return Poll::Ready(Err(e)), }; build_timer.done(); + let left_data = left_data.merged_batch.clone(); let result = if left_data.num_rows() == 0 { StatefulStreamResult::Ready(None) } else { - self.left_data = left_data.clone(); + self.left_data = left_data; self.state = CrossJoinStreamState::FetchProbeBatch; StatefulStreamResult::Continue }; @@ -470,21 +536,33 @@ impl CrossJoinStream { fn build_batches(&mut self) -> Result>> { let right_batch = self.state.try_as_record_batch()?; if self.left_index < self.left_data.num_rows() { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, right_batch, &self.left_data, &self.schema); - join_timer.done(); - - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + match self.batch_transformer.next() { + None => { + let join_timer = self.join_metrics.join_time.timer(); + let result = build_batch( + self.left_index, + right_batch, + &self.left_data, + &self.schema, + ); + join_timer.done(); + + self.batch_transformer.set_batch(result?); + } + Some((batch, last)) => { + if last { + self.left_index += 1; + } + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } - self.left_index += 1; - result.map(|r| StatefulStreamResult::Ready(Some(r))) } else { self.state = CrossJoinStreamState::FetchProbeBatch; - Ok(StatefulStreamResult::Continue) } + Ok(StatefulStreamResult::Continue) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index f6446548ca40e..81f2e9312013b 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -18,34 +18,38 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator use std::fmt; +use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; use std::{any::Any, vec}; -use super::utils::asymmetric_join_output_partitioning; +use super::utils::{ + asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, + reorder_output_after_swap, swap_join_projection, +}; use super::{ utils::{OnceAsync, OnceFut}, - PartitionMode, + PartitionMode, SharedBitmapBuilder, }; +use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::ExecutionPlanProperties; use crate::{ coalesce_partitions::CoalescePartitionsExec, common::can_project, - execution_mode_from_children, handle_state, + handle_state, hash_utils::create_hashes, joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, - estimate_join_statistics, get_final_indices_from_bit_map, - need_produce_result_in_final, symmetric_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMap, JoinHashMapOffset, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + estimate_join_statistics, need_produce_result_in_final, + symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, + JoinFilter, JoinHashMap, JoinHashMapOffset, JoinHashMapType, JoinOn, JoinOnRef, + StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, - Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, - Statistics, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use arrow::array::{ @@ -70,29 +74,31 @@ use datafusion_physical_expr::equivalence::{ }; use datafusion_physical_expr::PhysicalExprRef; +use crate::spill::get_record_batch_memory_size; use ahash::RandomState; use datafusion_expr::Operator; use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; -type SharedBitmapBuilder = Mutex; - /// HashTable and input data for the left (build side) of a join struct JoinLeftData { /// The hash table with indices into `batch` hash_map: JoinHashMap, /// The input rows for the build side batch: RecordBatch, + /// The build side on expressions values + values: Vec, /// Shared bitmap builder for visited left indices - visited_indices_bitmap: Mutex, + visited_indices_bitmap: SharedBitmapBuilder, /// Counter of running probe-threads, potentially /// able to update `visited_indices_bitmap` probe_threads_counter: AtomicUsize, - /// Memory reservation that tracks memory used by `hash_map` hash table - /// `batch`. Cleared on drop. - #[allow(dead_code)] - reservation: MemoryReservation, + /// We need to keep this field to maintain accurate memory accounting, even though we don't directly use it. + /// Without holding onto this reservation, the recorded memory usage would become inconsistent with actual usage. + /// This could hide potential out-of-memory issues, especially when upstream operators increase their memory consumption. + /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle. + _reservation: MemoryReservation, } impl JoinLeftData { @@ -100,6 +106,7 @@ impl JoinLeftData { fn new( hash_map: JoinHashMap, batch: RecordBatch, + values: Vec, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, @@ -107,9 +114,10 @@ impl JoinLeftData { Self { hash_map, batch, + values, visited_indices_bitmap, probe_threads_counter, - reservation, + _reservation: reservation, } } @@ -123,6 +131,11 @@ impl JoinLeftData { &self.batch } + /// returns a reference to the build side expressions values + fn values(&self) -> &[ArrayRef] { + &self.values + } + /// returns a reference to the visited indices bitmap fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { &self.visited_indices_bitmap @@ -135,13 +148,14 @@ impl JoinLeftData { } } -/// Join execution plan: Evaluates eqijoin predicates in parallel on multiple +#[allow(rustdoc::private_intra_doc_links)] +/// Join execution plan: Evaluates equijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post /// join. /// /// # Join Expressions /// -/// This implementation is optimized for evaluating eqijoin predicates ( +/// This implementation is optimized for evaluating equijoin predicates ( /// ` = `) expressions, which are represented as a list of `Columns` /// in [`Self::on`]. /// @@ -195,7 +209,7 @@ impl JoinLeftData { /// /// Original build-side data Inserting build-side values into hashmap Concatenated build-side batch /// ┌───────────────────────────┐ -/// hasmap.insert(row-hash, row-idx + offset) │ idx │ +/// hashmap.insert(row-hash, row-idx + offset) │ idx │ /// ┌───────┐ │ ┌───────┐ │ /// │ Row 1 │ 1) update_hash for batch 3 with offset 0 │ │ Row 6 │ 0 │ /// Batch 1 │ │ - hashmap.insert(Row 7, idx 1) │ Batch 3 │ │ │ @@ -292,6 +306,12 @@ impl JoinLeftData { /// │ "dimension" │ │ "fact" │ /// └───────────────┘ └───────────────┘ /// ``` +/// +/// # Clone / Shared State +/// +/// Note this structure includes a [`OnceAsync`] that is used to coordinate the +/// loading of the left side with the processing in each output stream. +/// Therefore it can not be [`Clone`] #[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed @@ -308,6 +328,11 @@ pub struct HashJoinExec { /// if there is a projection, the schema isn't the same as the output schema. join_schema: SchemaRef, /// Future that consumes left input and builds the hash table + /// + /// For CollectLeft partition mode, this structure is *shared* across all output streams. + /// + /// Each output stream waits on the `OnceAsync` to signal the completion of + /// the hash table creation. left_fut: OnceAsync, /// Shared the `RandomState` for the hashing algorithm random_state: RandomState, @@ -415,6 +440,12 @@ impl HashJoinExec { &self.join_type } + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + pub fn join_schema(&self) -> &SchemaRef { + &self.join_schema + } + /// The partitioning mode of this hash join pub fn partition_mode(&self) -> &PartitionMode { &self.mode @@ -446,7 +477,7 @@ impl HashJoinExec { } /// Return whether the join contains a projection - pub fn contain_projection(&self) -> bool { + pub fn contains_projection(&self) -> bool { self.projection.is_some() } @@ -506,23 +537,26 @@ impl HashJoinExec { } }; - // Determine execution mode by checking whether this join is pipeline - // breaking. This happens when the left side is unbounded, or the right - // side is unbounded with `Left`, `Full`, `LeftAnti` or `LeftSemi` join types. - let pipeline_breaking = left.execution_mode().is_unbounded() - || (right.execution_mode().is_unbounded() - && matches!( - join_type, - JoinType::Left - | JoinType::Full - | JoinType::LeftAnti - | JoinType::LeftSemi - )); - - let mode = if pipeline_breaking { - ExecutionMode::PipelineBreaking + let emission_type = if left.boundedness().is_unbounded() { + EmissionType::Final + } else if right.pipeline_behavior() == EmissionType::Incremental { + match join_type { + // If we only need to generate matched rows from the probe side, + // we can emit rows incrementally. + JoinType::Inner + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::Right + | JoinType::RightAnti => EmissionType::Incremental, + // If we need to generate unmatched rows from the *build side*, + // we need to emit them at the end. + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::Full => EmissionType::Both, + } } else { - execution_mode_from_children([left, right]) + right.pipeline_behavior() }; // If contains projection, update the PlanProperties. @@ -535,12 +569,61 @@ impl HashJoinExec { output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); } + Ok(PlanProperties::new( eq_properties, output_partitioning, - mode, + emission_type, + boundedness_from_children([left, right]), )) } + + /// Returns a new `ExecutionPlan` that computes the same join as this one, + /// with the left and right inputs swapped using the specified + /// `partition_mode`. + /// + /// # Notes: + /// + /// This function is public so other downstream projects can use it to + /// construct `HashJoinExec` with right side as the build side. + pub fn swap_inputs( + &self, + partition_mode: PartitionMode, + ) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = HashJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect(), + self.filter().map(JoinFilter::swap), + &self.join_type().swap(), + swap_join_projection( + left.schema().fields().len(), + right.schema().fields().len(), + self.projection.as_ref(), + self.join_type(), + ), + partition_mode, + self.null_equals_null(), + )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) || self.projection.is_some() + { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } + } } impl DisplayAs for HashJoinExec { @@ -551,7 +634,7 @@ impl DisplayAs for HashJoinExec { || "".to_string(), |f| format!(", filter={}", f.expression()), ); - let display_projections = if self.contain_projection() { + let display_projections = if self.contains_projection() { format!( ", projection=[{}]", self.projection @@ -747,7 +830,6 @@ impl ExecutionPlan for HashJoinExec { Ok(Box::pin(HashJoinStream { schema: self.schema(), - on_left, on_right, filter: self.filter.clone(), join_type: self.join_type, @@ -772,7 +854,7 @@ impl ExecutionPlan for HashJoinExec { // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` - let mut stats = estimate_join_statistics( + let stats = estimate_join_statistics( Arc::clone(&self.left), Arc::clone(&self.right), self.on.clone(), @@ -780,16 +862,7 @@ impl ExecutionPlan for HashJoinExec { &self.join_schema, )?; // Project statistics if there is a projection - if let Some(projection) = &self.projection { - stats.column_statistics = stats - .column_statistics - .into_iter() - .enumerate() - .filter(|(i, _)| projection.contains(i)) - .map(|(_, s)| s) - .collect(); - } - Ok(stats) + Ok(stats.project(self.projection.as_ref())) } } @@ -826,14 +899,14 @@ async fn collect_left_input( let initial = (Vec::new(), 0, metrics, reservation); let (batches, num_rows, metrics, mut reservation) = stream .try_fold(initial, |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); + let batch_size = get_record_batch_memory_size(&batch); // Reserve memory for incoming batch acc.3.try_grow(batch_size)?; // Update metrics acc.2.build_mem_used.add(batch_size); acc.2.build_input_batches.add(1); acc.2.build_input_rows.add(batch.num_rows()); - // Update rowcount + // Update row count acc.1 += batch.num_rows(); // Push batch to output acc.0.push(batch); @@ -843,7 +916,7 @@ async fn collect_left_input( // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` - let fixed_size = std::mem::size_of::(); + let fixed_size = size_of::(); let estimated_hashtable_size = estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?; @@ -887,9 +960,18 @@ async fn collect_left_input( BooleanBufferBuilder::new(0) }; + let left_values = on_left + .iter() + .map(|c| { + c.evaluate(&single_batch)? + .into_array(single_batch.num_rows()) + }) + .collect::>>()?; + let data = JoinLeftData::new( hashmap, single_batch, + left_values, Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), reservation, @@ -1009,6 +1091,7 @@ impl BuildSide { /// └─ ProcessProbeBatch /// /// ``` +#[derive(Debug, Clone)] enum HashJoinStreamState { /// Initial state for HashJoinStream indicating that build-side data not collected yet WaitBuildSide, @@ -1034,9 +1117,12 @@ impl HashJoinStreamState { } /// Container for HashJoinStreamState::ProcessProbeBatch related data +#[derive(Debug, Clone)] struct ProcessProbeBatchState { /// Current probe-side batch batch: RecordBatch, + /// Probe-side on expressions values + values: Vec, /// Starting offset for JoinHashMap lookups offset: JoinHashMapOffset, /// Max joined probe-side index from current batch @@ -1063,8 +1149,6 @@ impl ProcessProbeBatchState { struct HashJoinStream { /// Input schema schema: Arc, - /// equijoin columns from the left (build side) - on_left: Vec, /// equijoin columns from the right (probe side) on_right: Vec, /// optional join filter @@ -1150,27 +1234,13 @@ impl RecordBatchStream for HashJoinStream { #[allow(clippy::too_many_arguments)] fn lookup_join_hashmap( build_hashmap: &JoinHashMap, - build_input_buffer: &RecordBatch, - probe_batch: &RecordBatch, - build_on: &[PhysicalExprRef], - probe_on: &[PhysicalExprRef], + build_side_values: &[ArrayRef], + probe_side_values: &[ArrayRef], null_equals_null: bool, hashes_buffer: &[u64], limit: usize, offset: JoinHashMapOffset, ) -> Result<(UInt64Array, UInt32Array, Option)> { - let keys_values = probe_on - .iter() - .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) - .collect::>>()?; - let build_join_values = build_on - .iter() - .map(|c| { - c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows()) - }) - .collect::>>()?; - let (probe_indices, build_indices, next_offset) = build_hashmap .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); @@ -1180,8 +1250,8 @@ fn lookup_join_hashmap( let (build_indices, probe_indices) = equal_rows_arr( &build_indices, &probe_indices, - &build_join_values, - &keys_values, + build_side_values, + probe_side_values, null_equals_null, )?; @@ -1253,14 +1323,6 @@ pub fn equal_rows_arr( )) } -fn get_final_indices_from_shared_bitmap( - shared_bitmap: &SharedBitmapBuilder, - join_type: JoinType, -) -> (UInt64Array, UInt32Array) { - let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) -} - impl HashJoinStream { /// Separate implementation function that unpins the [`HashJoinStream`] so /// that partial borrows work correctly @@ -1339,6 +1401,7 @@ impl HashJoinStream { self.state = HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { batch, + values: keys_values, offset: (0, None), joined_probe_idx: None, }); @@ -1363,10 +1426,8 @@ impl HashJoinStream { // get the matched by join keys indices let (left_indices, right_indices, next_offset) = lookup_join_hashmap( build_side.left_data.hash_map(), - build_side.left_data.batch(), - &state.batch, - &self.on_left, - &self.on_right, + build_side.left_data.values(), + &state.values, self.null_equals_null, &self.hashes_buffer, self.batch_size, @@ -1432,7 +1493,7 @@ impl HashJoinStream { index_alignment_range_start..index_alignment_range_end, self.join_type, self.right_side_ordered, - ); + )?; let result = build_batch_from_indices( &self.schema, @@ -1518,7 +1579,7 @@ impl Stream for HashJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } @@ -1550,7 +1611,7 @@ mod tests { use rstest_reuse::*; fn div_ceil(a: usize, b: usize) -> usize { - (a + b - 1) / b + a.div_ceil(b) } #[template] @@ -3084,6 +3145,94 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn join_left_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[test] fn join_with_hash_collision() -> Result<()> { let mut hashmap_left = RawTable::with_capacity(2); @@ -3118,17 +3267,20 @@ mod tests { let join_hash_map = JoinHashMap::new(hashmap_left, next); + let left_keys_values = key_column.evaluate(&left)?.into_array(left.num_rows())?; let right_keys_values = key_column.evaluate(&right)?.into_array(right.num_rows())?; let mut hashes_buffer = vec![0; right.num_rows()]; - create_hashes(&[right_keys_values], &random_state, &mut hashes_buffer)?; + create_hashes( + &[Arc::clone(&right_keys_values)], + &random_state, + &mut hashes_buffer, + )?; let (l, r, _) = lookup_join_hashmap( &join_hash_map, - &left, - &right, - &[Arc::clone(&key_column)], - &[key_column], + &[left_keys_values], + &[right_keys_values], false, &hashes_buffer, 8192, @@ -3384,7 +3536,7 @@ mod tests { Ok(()) } - /// Test for parallelised HashJoinExec with PartitionMode::CollectLeft + /// Test for parallelized HashJoinExec with PartitionMode::CollectLeft #[tokio::test] async fn test_collect_left_multiple_partitions_join() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -3469,6 +3621,15 @@ mod tests { "| 30 | 6 | 90 |", "+----+----+----+", ]; + let expected_left_mark = vec![ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; let test_cases = vec![ (JoinType::Inner, expected_inner), @@ -3479,6 +3640,7 @@ mod tests { (JoinType::LeftAnti, expected_left_anti), (JoinType::RightSemi, expected_right_semi), (JoinType::RightAnti, expected_right_anti), + (JoinType::LeftMark, expected_left_mark), ]; for (join_type, expected) in test_cases { @@ -3588,10 +3750,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::common::collect(stream) - .await - .unwrap_err() - .to_string(); + let result_string = common::collect(stream).await.unwrap_err().to_string(); assert!( result_string.contains("bad data error"), "actual: {result_string}" @@ -3764,6 +3923,7 @@ mod tests { JoinType::LeftAnti, JoinType::RightSemi, JoinType::RightAnti, + JoinType::LeftMark, ]; for join_type in join_types { @@ -3789,6 +3949,11 @@ mod tests { err.to_string(), "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput" ); + + assert_contains!( + err.to_string(), + "Failed to allocate additional 120 bytes for HashJoinInput" + ); } Ok(()) @@ -3870,6 +4035,11 @@ mod tests { "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]" ); + + assert_contains!( + err.to_string(), + "Failed to allocate additional 120 bytes for HashJoinInput[1]" + ); } Ok(()) diff --git a/datafusion/physical-plan/src/joins/join_filter.rs b/datafusion/physical-plan/src/joins/join_filter.rs new file mode 100644 index 0000000000000..b99afd87c92ac --- /dev/null +++ b/datafusion/physical-plan/src/joins/join_filter.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::joins::utils::ColumnIndex; +use arrow_schema::Schema; +use datafusion_common::JoinSide; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. +#[derive(Debug, Clone)] +pub struct JoinFilter { + /// Filter expression + pub(crate) expression: Arc, + /// Column indices required to construct intermediate batch for filtering + pub(crate) column_indices: Vec, + /// Physical schema of intermediate batch + pub(crate) schema: Schema, +} + +impl JoinFilter { + /// Creates new JoinFilter + pub fn new( + expression: Arc, + column_indices: Vec, + schema: Schema, + ) -> JoinFilter { + JoinFilter { + expression, + column_indices, + schema, + } + } + + /// Helper for building ColumnIndex vector from left and right indices + pub fn build_column_indices( + left_indices: Vec, + right_indices: Vec, + ) -> Vec { + left_indices + .into_iter() + .map(|i| ColumnIndex { + index: i, + side: JoinSide::Left, + }) + .chain(right_indices.into_iter().map(|i| ColumnIndex { + index: i, + side: JoinSide::Right, + })) + .collect() + } + + /// Filter expression + pub fn expression(&self) -> &Arc { + &self.expression + } + + /// Column indices for intermediate batch creation + pub fn column_indices(&self) -> &[ColumnIndex] { + &self.column_indices + } + + /// Intermediate batch schema + pub fn schema(&self) -> &Schema { + &self.schema + } + + /// Rewrites the join filter if the inputs to the join are rewritten + pub fn swap(&self) -> JoinFilter { + let column_indices = self + .column_indices() + .iter() + .map(|idx| ColumnIndex { + index: idx.index, + side: idx.side.negate(), + }) + .collect(); + + JoinFilter::new( + Arc::clone(self.expression()), + column_indices, + self.schema().clone(), + ) + } +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 6ddf19c511933..bfdeb2fd6e27b 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -17,9 +17,11 @@ //! DataFusion Join implementations +use arrow_buffer::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; +use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; @@ -31,18 +33,20 @@ mod stream_join_utils; mod symmetric_hash_join; pub mod utils; +mod join_filter; #[cfg(test)] pub mod test_utils; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -/// Partitioning mode to use for hash join +/// Hash join Partitioning mode pub enum PartitionMode { /// Left/right children are partitioned using the left and right keys Partitioned, /// Left side will collected into one partition CollectLeft, - /// When set to Auto, DataFusion optimizer will decide which PartitionMode mode(Partitioned/CollectLeft) is optimal based on statistics. - /// It will also consider swapping the left and right inputs for the Join + /// DataFusion optimizer decides which PartitionMode + /// mode(Partitioned/CollectLeft) is optimal based on statistics. It will + /// also consider swapping the left and right inputs for the Join Auto, } @@ -54,3 +58,6 @@ pub enum StreamJoinPartitionMode { /// Both sides will collected into one partition SinglePartition, } + +/// Shared bitmap for visited left-side indices +type SharedBitmapBuilder = Mutex; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index ebc8a7338d9a7..c97ca6dd2db67 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines the nested loop join plan, it supports all [`JoinType`]. -//! The nested loop join can execute in parallel by partitions and it is -//! determined by the [`JoinType`]. +//! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates). use std::any::Any; use std::fmt::Formatter; @@ -25,37 +23,44 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; +use super::utils::{ + asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, + need_produce_result_in_final, reorder_output_after_swap, swap_join_projection, + BatchSplitter, BatchTransformer, NoopBatchTransformer, StatefulStreamResult, +}; use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::common::can_project; +use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, estimate_join_statistics, - get_final_indices_from_bit_map, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - OnceAsync, OnceFut, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, + handle_state, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + ExecutionPlanProperties, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, }; use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow::util::bit_util; -use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics}; +use datafusion_common::{ + exec_datafusion_err, internal_err, project_schema, JoinSide, Result, Statistics, +}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; -use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::equivalence::{ + join_equivalence_properties, ProjectionMapping, +}; +use crate::joins::SharedBitmapBuilder; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; -/// Shared bitmap for visited left-side indices -type SharedBitmapBuilder = Mutex; /// Left (build-side) data struct JoinLeftData { /// Build-side data collected to single batch @@ -66,8 +71,7 @@ struct JoinLeftData { probe_threads_counter: AtomicUsize, /// Memory reservation for tracking batch and bitmap /// Cleared on `JoinLeftData` drop - #[allow(dead_code)] - reservation: MemoryReservation, + _reservation: MemoryReservation, } impl JoinLeftData { @@ -75,13 +79,13 @@ impl JoinLeftData { batch: RecordBatch, bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, - reservation: MemoryReservation, + _reservation: MemoryReservation, ) -> Self { Self { batch, bitmap, probe_threads_counter, - reservation, + _reservation, } } @@ -100,6 +104,7 @@ impl JoinLeftData { } } +#[allow(rustdoc::private_intra_doc_links)] /// NestedLoopJoinExec is build-probe join operator, whose main task is to /// perform joins without any equijoin conditions in `ON` clause. /// @@ -135,6 +140,11 @@ impl JoinLeftData { /// "reports" about probe phase completion (which means that "visited" bitmap won't be /// updated anymore), and only the last thread, reporting about completion, will return output. /// +/// # Clone / Shared State +/// +/// Note this structure includes a [`OnceAsync`] that is used to coordinate the +/// loading of the left side with the processing in each output stream. +/// Therefore it can not be [`Clone`] #[derive(Debug)] pub struct NestedLoopJoinExec { /// left side @@ -146,11 +156,19 @@ pub struct NestedLoopJoinExec { /// How the join is performed pub(crate) join_type: JoinType, /// The schema once the join is applied - schema: SchemaRef, - /// Build-side data + join_schema: SchemaRef, + /// Future that consumes left input and buffers it in memory + /// + /// This structure is *shared* across all output streams. + /// + /// Each output stream waits on the `OnceAsync` to signal the completion of + /// the hash table creation. inner_table: OnceAsync, /// Information of index and left / right placement of columns column_indices: Vec, + /// Projection to apply to the output of the join + projection: Option>, + /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. @@ -164,24 +182,31 @@ impl NestedLoopJoinExec { right: Arc, filter: Option, join_type: &JoinType, + projection: Option>, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &[])?; - let (schema, column_indices) = + let (join_schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); - let schema = Arc::new(schema); - let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type); + let join_schema = Arc::new(join_schema); + let cache = Self::compute_properties( + &left, + &right, + Arc::clone(&join_schema), + *join_type, + projection.as_ref(), + )?; Ok(NestedLoopJoinExec { left, right, filter, join_type: *join_type, - schema, + join_schema, inner_table: Default::default(), column_indices, + projection, metrics: Default::default(), cache, }) @@ -207,35 +232,71 @@ impl NestedLoopJoinExec { &self.join_type } + pub fn projection(&self) -> Option<&Vec> { + self.projection.as_ref() + } + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( left: &Arc, right: &Arc, schema: SchemaRef, join_type: JoinType, - ) -> PlanProperties { + projection: Option<&Vec>, + ) -> Result { // Calculate equivalence properties: - let eq_properties = join_equivalence_properties( + let mut eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), right.equivalence_properties().clone(), &join_type, - schema, + Arc::clone(&schema), &Self::maintains_input_order(join_type), None, // No on columns in nested loop join &[], ); - let output_partitioning = + let mut output_partitioning = asymmetric_join_output_partitioning(left, right, &join_type); - // Determine execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - mode = ExecutionMode::PipelineBreaking; + let emission_type = if left.boundedness().is_unbounded() { + EmissionType::Final + } else if right.pipeline_behavior() == EmissionType::Incremental { + match join_type { + // If we only need to generate matched rows from the probe side, + // we can emit rows incrementally. + JoinType::Inner + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::Right + | JoinType::RightAnti => EmissionType::Incremental, + // If we need to generate unmatched rows from the *build side*, + // we need to emit them at the end. + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::Full => EmissionType::Both, + } + } else { + right.pipeline_behavior() + }; + + if let Some(projection) = projection { + // construct a map from the input expressions to the output expression of the Projection + let projection_mapping = + ProjectionMapping::from_indices(projection, &schema)?; + let out_schema = project_schema(&schema, Some(projection))?; + output_partitioning = + output_partitioning.project(&projection_mapping, &eq_properties); + eq_properties = eq_properties.project(&projection_mapping, out_schema); } - PlanProperties::new(eq_properties, output_partitioning, mode) + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + emission_type, + boundedness_from_children([left, right]), + )) } /// Returns a vector indicating whether the left and right inputs maintain their order. @@ -262,6 +323,69 @@ impl NestedLoopJoinExec { ), ] } + + pub fn contains_projection(&self) -> bool { + self.projection.is_some() + } + + pub fn with_projection(&self, projection: Option>) -> Result { + // check if the projection is valid + can_project(&self.schema(), projection.as_ref())?; + let projection = match projection { + Some(projection) => match &self.projection { + Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), + None => Some(projection), + }, + None => None, + }; + Self::try_new( + Arc::clone(&self.left), + Arc::clone(&self.right), + self.filter.clone(), + &self.join_type, + projection, + ) + } + + /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left + /// and right inputs swapped. + pub fn swap_inputs(&self) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = NestedLoopJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.filter().map(JoinFilter::swap), + &self.join_type().swap(), + swap_join_projection( + left.schema().fields().len(), + right.schema().fields().len(), + self.projection.as_ref(), + self.join_type(), + ), + )?; + + // For Semi/Anti joins, swap result will produce same output schema, + // no need to wrap them into additional projection + let plan: Arc = if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) || self.projection.is_some() + { + Arc::new(new_join) + } else { + reorder_output_after_swap( + Arc::new(new_join), + &self.left().schema(), + &self.right().schema(), + )? + }; + + Ok(plan) + } } impl DisplayAs for NestedLoopJoinExec { @@ -272,10 +396,28 @@ impl DisplayAs for NestedLoopJoinExec { || "".to_string(), |f| format!(", filter={}", f.expression()), ); + let display_projections = if self.contains_projection() { + format!( + ", projection=[{}]", + self.projection + .as_ref() + .unwrap() + .iter() + .map(|index| format!( + "{}@{}", + self.join_schema.fields().get(*index).unwrap().name(), + index + )) + .collect::>() + .join(", ") + ) + } else { + "".to_string() + }; write!( f, - "NestedLoopJoinExec: join_type={:?}{}", - self.join_type, display_filter + "NestedLoopJoinExec: join_type={:?}{}{}", + self.join_type, display_filter, display_projections ) } } @@ -319,6 +461,7 @@ impl ExecutionPlan for NestedLoopJoinExec { Arc::clone(&children[1]), self.filter.clone(), &self.join_type, + self.projection.clone(), )?)) } @@ -345,6 +488,10 @@ impl ExecutionPlan for NestedLoopJoinExec { ) }); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -352,18 +499,47 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); - Ok(Box::pin(NestedLoopJoinStream { - schema: Arc::clone(&self.schema), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - is_exhausted: false, - column_indices: self.column_indices.clone(), - join_metrics, - indices_cache, - right_side_ordered, - })) + + // update column indices to reflect the projection + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + if enforce_batch_size_in_joins { + Ok(Box::pin(NestedLoopJoinStream { + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: column_indices_after_projection, + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: BatchSplitter::new(batch_size), + left_data: None, + })) + } else { + Ok(Box::pin(NestedLoopJoinStream { + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: column_indices_after_projection, + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: NoopBatchTransformer::new(), + left_data: None, + })) + } } fn metrics(&self) -> Option { @@ -376,7 +552,7 @@ impl ExecutionPlan for NestedLoopJoinExec { Arc::clone(&self.right), vec![], &self.join_type, - &self.schema, + &self.join_schema, ) } } @@ -402,17 +578,17 @@ async fn collect_left_input( let (batches, metrics, mut reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |mut acc, batch| async { + |(mut batches, metrics, mut reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; + reservation.try_grow(batch_size)?; // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); // Push batch to output - acc.0.push(batch); - Ok(acc) + batches.push(batch); + Ok((batches, metrics, reservation)) }, ) .await?; @@ -421,14 +597,13 @@ async fn collect_left_input( // Reserve memory for visited_left_side bitmap if required by join type let visited_left_side = if with_visited_left_side { - // TODO: Replace `ceil` wrapper with stable `div_cell` after - // https://github.com/rust-lang/rust/issues/88581 - let buffer_size = bit_util::ceil(merged_batch.num_rows(), 8); + let n_rows = merged_batch.num_rows(); + let buffer_size = n_rows.div_ceil(8); reservation.try_grow(buffer_size)?; metrics.build_mem_used.add(buffer_size); - let mut buffer = BooleanBufferBuilder::new(merged_batch.num_rows()); - buffer.append_n(merged_batch.num_rows(), false); + let mut buffer = BooleanBufferBuilder::new(n_rows); + buffer.append_n(n_rows, false); buffer } else { BooleanBufferBuilder::new(0) @@ -442,8 +617,37 @@ async fn collect_left_input( )) } +/// This enumeration represents various states of the nested loop join algorithm. +#[derive(Debug, Clone)] +enum NestedLoopJoinStreamState { + /// The initial state, indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for + /// fetching probe-side + FetchProbeBatch, + /// Indicates that a non-empty batch has been fetched from probe-side, and + /// is ready to be processed + ProcessProbeBatch(RecordBatch), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that NestedLoopJoinStream execution is completed + Completed, +} + +impl NestedLoopJoinStreamState { + /// Tries to extract a `ProcessProbeBatchState` from the + /// `NestedLoopJoinStreamState` enum. Returns an error if state is not + /// `ProcessProbeBatchState`. + fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { + match self { + NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected join stream in ProcessProbeBatch state"), + } + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { +struct NestedLoopJoinStream { /// Input schema schema: Arc, /// join filter @@ -454,8 +658,6 @@ struct NestedLoopJoinStream { outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join inner_table: OnceFut, - /// There is nothing to process anymore and left side is processed in case of full join - is_exhausted: bool, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal @@ -466,6 +668,12 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// Current state of the stream + state: NestedLoopJoinStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, + /// Result of the left data future + left_data: Option>, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -544,107 +752,164 @@ fn build_join_indices( } } -impl NestedLoopJoinStream { +impl NestedLoopJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - // all left row + loop { + return match self.state { + NestedLoopJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + NestedLoopJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + NestedLoopJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + NestedLoopJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + NestedLoopJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.inner_table.get_shared(cx)) { - Ok(data) => data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); build_timer.done(); - // Get or initialize visited_left_side bitmap if required by join type + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If a non-empty batch has been fetched, updates state to + /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.outer_table.poll_next_unpin(cx)) { + None => { + self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(right_batch)) => { + self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with + /// matched output, updates state to `FetchProbeBatch`. + fn process_probe_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ProcessProbeBatch state" + ); + }; let visited_left_side = left_data.bitmap(); + let batch = self.state.try_as_process_probe_batch()?; + + match self.batch_transformer.next() { + None => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + timer.done(); + + self.batch_transformer.set_batch(result?); + Ok(StatefulStreamResult::Continue) + } + Some((batch, last)) => { + if last { + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + } - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Ok(StatefulStreamResult::Ready(Some(batch))) + } } + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return None; - }; - - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } else { - // end of the join loop - None - } - } - }) + /// Processes unmatched build-side rows for certain join types and produces + /// output batch, updates state to `Completed`. + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ExhaustedProbeSide state" + ); + }; + let visited_left_side = left_data.bitmap(); + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` / returning None will prevent from + // multiple calls of `report_probe_completed()` + if !left_data.report_probe_completed() { + self.state = NestedLoopJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.state = NestedLoopJoinStreamState::Completed; + + // Recording time + if result.is_ok() { + timer.done(); + } + + Ok(StatefulStreamResult::Ready(Some(result?))) + } else { + // end of the join loop + self.state = NestedLoopJoinStreamState::Completed; + Ok(StatefulStreamResult::Ready(None)) + } } } @@ -684,7 +949,7 @@ fn join_left_and_right_batch( 0..right_batch.num_rows(), join_type, right_side_ordered, - ); + )?; build_batch_from_indices( schema, @@ -697,15 +962,7 @@ fn join_left_and_right_batch( ) } -fn get_final_indices_from_shared_bitmap( - shared_bitmap: &SharedBitmapBuilder, - join_type: JoinType, -) -> (UInt64Array, UInt32Array) { - let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) -} - -impl Stream for NestedLoopJoinStream { +impl Stream for NestedLoopJoinStream { type Item = Result; fn poll_next( @@ -716,14 +973,14 @@ impl Stream for NestedLoopJoinStream { } } -impl RecordBatchStream for NestedLoopJoinStream { +impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -738,7 +995,7 @@ mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{Partitioning, PhysicalExpr}; - use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use rstest::rstest; @@ -768,7 +1025,7 @@ mod tests { let mut exec = MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); if !sorted_column_names.is_empty() { - let mut sort_info = Vec::new(); + let mut sort_info = LexOrdering::default(); for name in sorted_column_names { let index = schema.index_of(name).unwrap(); let sort_expr = PhysicalSortExpr { @@ -780,7 +1037,7 @@ mod tests { }; sort_info.push(sort_expr); } - exec = exec.with_sort_information(vec![sort_info]); + exec = exec.try_with_sort_information(vec![sort_info]).unwrap(); } Arc::new(exec) @@ -850,7 +1107,7 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( + pub(crate) async fn multi_partitioned_join_collect( left: Arc, right: Arc, join_type: &JoinType, @@ -867,7 +1124,7 @@ mod tests { // Use the required distribution for nested loop join to test partition data let nested_loop_join = - NestedLoopJoinExec::try_new(left, right, join_filter, join_type)?; + NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?; let columns = columns(&nested_loop_join.schema()); let mut batches = vec![]; for i in 0..partition_count { @@ -1124,6 +1381,37 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_left_mark_with_filter() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_left_table(); + let right = build_right_table(); + + let filter = prepare_join_filter(); + let (columns, batches) = multi_partitioned_join_collect( + left, + right, + &JoinType::LeftMark, + Some(filter), + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + let expected = [ + "+----+----+-----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+-----+-------+", + "| 11 | 8 | 110 | false |", + "| 5 | 5 | 50 | true |", + "| 9 | 8 | 90 | false |", + "+----+----+-----+-------+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[tokio::test] async fn test_overallocation() -> Result<()> { let left = build_table( @@ -1149,6 +1437,7 @@ mod tests { JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, ]; @@ -1271,6 +1560,7 @@ mod tests { Arc::clone(&right), Some(filter), &join_type, + None, )?) as Arc; assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2118c1a5266fb..bcacc7dcae0fc 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -26,49 +26,100 @@ use std::collections::{HashMap, VecDeque}; use std::fmt::Formatter; use std::fs::File; use std::io::BufReader; -use std::mem; +use std::mem::size_of; use std::ops::Range; use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::task::{Context, Poll}; use arrow::array::*; -use arrow::compute::{self, concat_batches, take, SortOptions}; +use arrow::compute::{ + self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, +}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; - use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, - Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide, + JoinType, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ build_join_schema, check_join_is_valid, estimate_join_statistics, - symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, + reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn, + JoinOnRef, }; use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::spill::spill_record_batches; use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, + metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; -/// join execution plan executes partitions in parallel and combines them into a set of -/// partitions. -#[derive(Debug)] +use futures::{Stream, StreamExt}; + +/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge +/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large +/// inputs where one or both of the inputs don't fit in the available memory. +/// +/// # Join Expressions +/// +/// Equi-join predicate (e.g. ` = `) expressions are represented by [`Self::on`]. +/// +/// Non-equality predicates, which can not be pushed down to join inputs (e.g. +/// ` != `) are known as "filter expressions" and are evaluated +/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional +/// expressions. +/// +/// # Sorting +/// +/// Assumes that both the left and right input to the join are pre-sorted. It is not the +/// responsibility of this execution plan to sort the inputs. +/// +/// # "Streamed" vs "Buffered" +/// +/// The number of record batches of streamed input currently present in the memory will depend +/// on the output batch size of the execution plan. There is no spilling support for streamed input. +/// The comparisons are performed from values of join keys in streamed input with the values of +/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in +/// buffered input batches. The streamed input is managed through the states in `StreamedState` +/// and streamed input batches are represented by `StreamedBatch`. +/// +/// Buffered input is buffered for all record batches having the same value of join key. +/// If the memory limit increases beyond the specified value and spilling is enabled, +/// buffered batches could be spilled to disk. If spilling is disabled, the execution +/// will fail under the same conditions. Multiple record batches of buffered could currently reside +/// in memory/disk during the execution. The number of buffered batches residing in +/// memory/disk depends on the number of rows of buffered input having the same value +/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs, +/// the algorithm understands when it is not needed anymore, and releases the buffered batches +/// from memory/disk. The buffered input is managed through the states in `BufferedState` +/// and buffered input batches are represented by `BufferedBatch`. +/// +/// Depending on the type of join, left or right input may be selected as streamed or buffered +/// respectively. For example, in a left-outer join, the left execution plan will be selected as +/// streamed input while in a right-outer join, the right execution plan will be selected as the +/// streamed input. +/// +/// Reference for the algorithm: +/// . +/// +/// Helpful short video demonstration: +/// . +#[derive(Debug, Clone)] pub struct SortMergeJoinExec { /// Left sorted joining execution plan pub left: Arc, @@ -85,9 +136,9 @@ pub struct SortMergeJoinExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// The left SortExpr - left_sort_exprs: Vec, + left_sort_exprs: LexOrdering, /// The right SortExpr - right_sort_exprs: Vec, + right_sort_exprs: LexOrdering, /// Sort options of join columns used in sorting left and right execution plans pub sort_options: Vec, /// If null_equals_null is true, null == null else null != null @@ -156,8 +207,8 @@ impl SortMergeJoinExec { join_type, schema, metrics: ExecutionPlanMetricsSet::new(), - left_sort_exprs, - right_sort_exprs, + left_sort_exprs: LexOrdering::new(left_sort_exprs), + right_sort_exprs: LexOrdering::new(right_sort_exprs), sort_options, null_equals_null, cache, @@ -177,7 +228,8 @@ impl SortMergeJoinExec { | JoinType::Left | JoinType::Full | JoinType::LeftAnti - | JoinType::LeftSemi => JoinSide::Left, + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, } } @@ -185,7 +237,10 @@ impl SortMergeJoinExec { fn maintains_input_order(join_type: JoinType) -> Vec { match join_type { JoinType::Inner => vec![true, false], - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => vec![true, false], JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { vec![false, true] } @@ -198,18 +253,36 @@ impl SortMergeJoinExec { &self.on } + /// Ref to right execution plan pub fn right(&self) -> &Arc { &self.right } + /// Join type pub fn join_type(&self) -> JoinType { self.join_type } + /// Ref to left execution plan pub fn left(&self) -> &Arc { &self.left } + /// Ref to join filter + pub fn filter(&self) -> &Option { + &self.filter + } + + /// Ref to sort options + pub fn sort_options(&self) -> &[SortOptions] { + &self.sort_options + } + + /// Null equals null + pub fn null_equals_null(&self) -> bool { + self.null_equals_null + } + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( left: &Arc, @@ -232,10 +305,43 @@ impl SortMergeJoinExec { let output_partitioning = symmetric_join_output_partitioning(left, right, &join_type); - // Determine execution mode: - let mode = execution_mode_from_children([left, right]); + PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([left, right]), + ) + } + + pub fn swap_inputs(&self) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = SortMergeJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect::>(), + self.filter().as_ref().map(JoinFilter::swap), + self.join_type().swap(), + self.sort_options.clone(), + self.null_equals_null, + )?; - PlanProperties::new(eq_properties, output_partitioning, mode) + // TODO: OR this condition with having a built-in projection (like + // ordinary hash join) when we support it. + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } } } @@ -291,12 +397,8 @@ impl ExecutionPlan for SortMergeJoinExec { fn required_input_ordering(&self) -> Vec> { vec![ - Some(PhysicalSortRequirement::from_sort_exprs( - &self.left_sort_exprs, - )), - Some(PhysicalSortRequirement::from_sort_exprs( - &self.right_sort_exprs, - )), + Some(LexRequirement::from(self.left_sort_exprs.clone())), + Some(LexRequirement::from(self.right_sort_exprs.clone())), ] } @@ -369,7 +471,7 @@ impl ExecutionPlan for SortMergeJoinExec { .register(context.memory_pool()); // create join stream - Ok(Box::pin(SMJStream::try_new( + Ok(Box::pin(SortMergeJoinStream::try_new( Arc::clone(&self.schema), self.sort_options.clone(), self.null_equals_null, @@ -410,13 +512,13 @@ struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches join_time: metrics::Time, /// Number of batches consumed by this operator - input_batches: metrics::Count, + input_batches: Count, /// Number of rows consumed by this operator - input_rows: metrics::Count, + input_rows: Count, /// Number of batches produced by this operator - output_batches: metrics::Count, + output_batches: Count, /// Number of rows produced by this operator - output_rows: metrics::Count, + output_rows: Count, /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, @@ -459,7 +561,7 @@ impl SortMergeJoinMetrics { /// State of SMJ stream #[derive(Debug, PartialEq, Eq)] -enum SMJState { +enum SortMergeJoinState { /// Init joining with a new streamed row or a new buffered batches Init, /// Polling one streamed row or one buffered batch, or both @@ -509,6 +611,9 @@ struct StreamedJoinedChunk { buffered_indices: UInt64Builder, } +/// Represents a record batch from streamed input. +/// +/// Also stores information of matching rows from buffered batches. struct StreamedBatch { /// The streamed record batch pub batch: RecordBatch, @@ -595,11 +700,11 @@ struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, - /// The indices of buffered batch that failed the join filter. - /// This is a map between buffered row index and a boolean value indicating whether all joined row - /// of the buffered row failed the join filter. + /// The indices of buffered batch that the join filter doesn't satisfy. + /// This is a map between right row index and a boolean value indicating whether all joined row + /// of the right row does not satisfy the filter . /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. - pub join_filter_failed_map: HashMap, + pub join_filter_not_matched_map: HashMap, /// Current buffered batch number of rows. Equal to batch.num_rows() /// but if batch is spilled to disk this property is preferable /// and less expensive @@ -629,9 +734,9 @@ impl BufferedBatch { .iter() .map(|arr| arr.get_array_memory_size()) .sum::() - + batch.num_rows().next_power_of_two() * mem::size_of::() - + mem::size_of::>() - + mem::size_of::(); + + batch.num_rows().next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); let num_rows = batch.num_rows(); BufferedBatch { @@ -640,18 +745,18 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, - join_filter_failed_map: HashMap::new(), + join_filter_not_matched_map: HashMap::new(), num_rows, spill_file: None, } } } -/// Sort-merge join stream that consumes streamed and buffered data stream -/// and produces joined output -struct SMJStream { +/// Sort-Merge join stream that consumes streamed and buffered data streams +/// and produces joined output stream. +struct SortMergeJoinStream { /// Current state of the stream - pub state: SMJState, + pub state: SortMergeJoinState, /// Output schema pub schema: SchemaRef, /// Sort options of join columns used to sort streamed and buffered data stream @@ -687,7 +792,7 @@ struct SMJStream { /// optional join filter pub filter: Option, /// Staging output array builders - pub output_record_batches: Vec, + pub output_record_batches: JoinedRecordBatches, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready @@ -702,15 +807,217 @@ struct SMJStream { pub reservation: MemoryReservation, /// Runtime env pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, } -impl RecordBatchStream for SMJStream { +/// Joined batches with attached join filter information +struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) + pub filter_mask: BooleanBuilder, + /// Row indices to glue together rows in `batches` and `filter_mask` + pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches + pub batch_ids: Vec, +} + +impl JoinedRecordBatches { + fn clear(&mut self) { + self.batches.clear(); + self.batch_ids.clear(); + self.filter_mask = BooleanBuilder::new(); + self.row_indices = UInt64Builder::new(); + } +} +impl RecordBatchStream for SortMergeJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl Stream for SMJStream { +/// True if next index refers to either: +/// - another batch id +/// - another row index within same batch id +/// - end of row indices +#[inline(always)] +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + batch_ids: &[usize], + indices_len: usize, +) -> bool { + row_index == indices_len - 1 + || batch_ids[row_index] != batch_ids[row_index + 1] + || indices.value(row_index) != indices.value(row_index + 1) +} + +// Returns a corrected boolean bitmask for the given join type +// Values in the corrected bitmask can be: true, false, null +// `true` - the row found its match and sent to the output +// `null` - the row ignored, no output +// `false` - the row sent as NULL joined row +fn get_corrected_filter_mask( + join_type: JoinType, + row_indices: &UInt64Array, + batch_ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let row_indices_length = row_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(row_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left | JoinType::Right => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftMark => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + JoinType::LeftAnti | JoinType::RightAnti => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + // Generate null joined rows for records which have no matching join key, + // for LeftAnti non-matched considered as true + corrected_mask.append_n(expected_size - corrected_mask.len(), true); + Some(corrected_mask.finish()) + } + JoinType::Full => { + let mut mask: Vec> = vec![Some(true); row_indices_length]; + let mut last_true_idx = 0; + let mut first_row_idx = 0; + let mut seen_false = false; + + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + let val = filter_mask.value(i); + let is_null = filter_mask.is_null(i); + + if val { + // memoize the first seen matched row + if !seen_true { + last_true_idx = i; + } + seen_true = true; + } + + if is_null || val { + mask[i] = Some(true); + } else if !is_null && !val && (seen_true || seen_false) { + mask[i] = None; + } else { + mask[i] = Some(false); + } + + if !is_null && !val { + seen_false = true; + } + + if last_index { + // If the left row seen as true its needed to output it once + // To do that we mark all other matches for same row as null to avoid the output + if seen_true { + #[allow(clippy::needless_range_loop)] + for j in first_row_idx..last_true_idx { + mask[j] = None; + } + } + + seen_true = false; + seen_false = false; + last_true_idx = 0; + first_row_idx = i + 1; + } + } + + Some(BooleanArray::from(mask)) + } + // Only outer joins needs to keep track of processed rows and apply corrected filter mask + _ => None, + } +} + +impl Stream for SortMergeJoinStream { type Item = Result; fn poll_next( @@ -719,20 +1026,43 @@ impl Stream for SMJStream { ) -> Poll> { let join_time = self.join_metrics.join_time.clone(); let _timer = join_time.timer(); - loop { match &self.state { - SMJState::Init => { + SortMergeJoinState::Init => { let streamed_exhausted = self.streamed_state == StreamedState::Exhausted; let buffered_exhausted = self.buffered_state == BufferedState::Exhausted; self.state = if streamed_exhausted && buffered_exhausted { - SMJState::Exhausted + SortMergeJoinState::Exhausted } else { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Right + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + ) + { + self.freeze_all()?; + + if !self.output_record_batches.batches.is_empty() + { + let out_filtered_batch = + self.filter_joined_batch()?; + return Poll::Ready(Some(Ok( + out_filtered_batch, + ))); + } + } + self.streamed_joined = false; self.streamed_state = StreamedState::Init; } @@ -744,10 +1074,10 @@ impl Stream for SMJStream { } } } - SMJState::Polling + SortMergeJoinState::Polling }; } - SMJState::Polling => { + SortMergeJoinState::Polling => { if ![StreamedState::Exhausted, StreamedState::Ready] .contains(&self.streamed_state) { @@ -770,43 +1100,81 @@ impl Stream for SMJStream { let buffered_exhausted = self.buffered_state == BufferedState::Exhausted; if streamed_exhausted && buffered_exhausted { - self.state = SMJState::Exhausted; + self.state = SortMergeJoinState::Exhausted; continue; } self.current_ordering = self.compare_streamed_buffered()?; - self.state = SMJState::JoinOutput; + self.state = SortMergeJoinState::JoinOutput; } - SMJState::JoinOutput => { + SortMergeJoinState::JoinOutput => { self.join_partial()?; if self.output_size < self.batch_size { if self.buffered_data.scanning_finished() { self.buffered_data.scanning_reset(); - self.state = SMJState::Init; + self.state = SortMergeJoinState::Init; } } else { self.freeze_all()?; - if !self.output_record_batches.is_empty() { + if !self.output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::Full + ) + { + continue; + } + return Poll::Ready(Some(Ok(record_batch))); } return Poll::Pending; } } - SMJState::Exhausted => { + SortMergeJoinState::Exhausted => { self.freeze_all()?; - if !self.output_record_batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); + + if !self.output_record_batches.batches.is_empty() { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + | JoinType::LeftMark + ) + { + let out = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out))); + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } else { + return Poll::Ready(None); } - return Poll::Ready(None); } } } } } -impl SMJStream { +impl SortMergeJoinStream { #[allow(clippy::too_many_arguments)] pub fn try_new( schema: SchemaRef, @@ -826,7 +1194,7 @@ impl SMJStream { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); Ok(Self { - state: SMJState::Init, + state: SortMergeJoinState::Init, sort_options, null_equals_null, schema, @@ -844,13 +1212,19 @@ impl SMJStream { on_streamed, on_buffered, filter, - output_record_batches: vec![], + output_record_batches: JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }, output_size: 0, batch_size, join_type, join_metrics, reservation, runtime_env, + streamed_batch_counter: AtomicUsize::new(0), }) } @@ -882,6 +1256,10 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.streamed_state = StreamedState::Ready; } } @@ -959,9 +1337,10 @@ impl SMJStream { // If the head batch is fully processed, dequeue it and produce output of it. if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; - if let Some(buffered_batch) = + if let Some(mut buffered_batch) = self.buffered_data.batches.pop_front() { + self.produce_buffered_not_matched(&mut buffered_batch)?; self.free_reservation(buffered_batch)?; } } else { @@ -1062,14 +1441,14 @@ impl SMJStream { return Ok(Ordering::Less); } - return compare_join_arrays( + compare_join_arrays( &self.streamed_batch.join_arrays, self.streamed_batch.idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, self.null_equals_null, - ); + ) } /// Produce join and fill output buffer until reaching target batch size @@ -1079,6 +1458,8 @@ impl SMJStream { let mut join_streamed = false; // Whether to join buffered rows let mut join_buffered = false; + // For Mark join we store a dummy id to indicate the the row has a match + let mut mark_row_as_match = false; // determine whether we need to join streamed/buffered rows match self.current_ordering { @@ -1090,12 +1471,15 @@ impl SMJStream { | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark ) { join_streamed = !self.streamed_joined; } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi) { + if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { + mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); // if the join filter is specified then its needed to output the streamed index // only if it has not been emitted before // the `join_filter_matched_idxs` keeps track on if streamed index has a successful @@ -1121,12 +1505,10 @@ impl SMJStream { join_buffered = true; }; - if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; + if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) + && self.filter.is_some() + { + join_streamed = !self.streamed_joined; join_buffered = join_streamed; } } @@ -1176,9 +1558,11 @@ impl SMJStream { } else { Some(self.buffered_data.scanning_batch_idx) }; + // For Mark join we store a dummy id to indicate the the row has a match + let scanning_idx = mark_row_as_match.then_some(0); self.streamed_batch - .append_output_pair(scanning_batch_idx, None); + .append_output_pair(scanning_batch_idx, scanning_idx); self.output_size += 1; self.buffered_data.scanning_finish(); self.streamed_joined = true; @@ -1187,8 +1571,8 @@ impl SMJStream { } fn freeze_all(&mut self) -> Result<()> { + self.freeze_buffered(self.buffered_data.batches.len())?; self.freeze_streamed()?; - self.freeze_buffered(self.buffered_data.batches.len(), false)?; Ok(()) } @@ -1199,7 +1583,7 @@ impl SMJStream { fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; // Only freeze and produce the first batch in buffered_data as the batch is fully processed - self.freeze_buffered(1, true)?; + self.freeze_buffered(1)?; Ok(()) } @@ -1208,13 +1592,7 @@ impl SMJStream { // // Applicable only in case of Full join. // - // If `output_not_matched_filter` is true, this will also produce record batches - // for buffered rows which are joined with streamed side but don't match join filter. - fn freeze_buffered( - &mut self, - batch_count: usize, - output_not_matched_filter: bool, - ) -> Result<()> { + fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { if !matches!(self.join_type, JoinType::Full) { return Ok(()); } @@ -1228,34 +1606,64 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + let num_rows = record_batch.num_rows(); + self.output_record_batches + .filter_mask + .append_nulls(num_rows); + self.output_record_batches + .row_indices + .append_nulls(num_rows); + self.output_record_batches + .batch_ids + .resize(self.output_record_batches.batch_ids.len() + num_rows, 0); + + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); + } + Ok(()) + } - // For buffered row which is joined with streamed side rows but all joined rows - // don't satisfy the join filter - if output_not_matched_filter { - let not_matched_buffered_indices = buffered_batch - .join_filter_failed_map - .iter() - .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) - .collect::>(); + fn produce_buffered_not_matched( + &mut self, + buffered_batch: &mut BufferedBatch, + ) -> Result<()> { + if !matches!(self.join_type, JoinType::Full) { + return Ok(()); + } - let buffered_indices = UInt64Array::from_iter_values( - not_matched_buffered_indices.iter().copied(), - ); + // For buffered row which is joined with streamed side rows but all joined rows + // don't satisfy the join filter + let not_matched_buffered_indices = buffered_batch + .join_filter_not_matched_map + .iter() + .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) + .collect::>(); - if let Some(record_batch) = produce_buffered_null_batch( - &self.schema, - &self.streamed_schema, - &buffered_indices, - buffered_batch, - )? { - self.output_record_batches.push(record_batch); - } - buffered_batch.join_filter_failed_map.clear(); - } + let buffered_indices = + UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied()); + + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + let num_rows = record_batch.num_rows(); + + self.output_record_batches + .filter_mask + .append_nulls(num_rows); + self.output_record_batches + .row_indices + .append_nulls(num_rows); + self.output_record_batches + .batch_ids + .resize(self.output_record_batches.batch_ids.len() + num_rows, 0); + self.output_record_batches.batches.push(record_batch); } + buffered_batch.join_filter_not_matched_map.clear(); + Ok(()) } @@ -1264,63 +1672,73 @@ impl SMJStream { fn freeze_streamed(&mut self) -> Result<()> { for chunk in self.streamed_batch.output_indices.iter_mut() { // The row indices of joined streamed batch - let streamed_indices = chunk.streamed_indices.finish(); + let left_indices = chunk.streamed_indices.finish(); - if streamed_indices.is_empty() { + if left_indices.is_empty() { continue; } - let mut streamed_columns = self + let mut left_columns = self .streamed_batch .batch .columns() .iter() - .map(|column| take(column, &streamed_indices, None)) + .map(|column| take(column, &left_indices, None)) .collect::, ArrowError>>()?; // The row indices of joined buffered batch - let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut buffered_columns = - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - vec![] - } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - get_buffered_columns( - &self.buffered_data, - buffered_idx, - &buffered_indices, - )? - } else { - // If buffered batch none, meaning it is null joined batch. - // We need to create null arrays for buffered columns to join with streamed rows. - self.buffered_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>() - }; - - let streamed_columns_length = streamed_columns.len(); - let buffered_columns_length = buffered_columns.len(); + let right_indices: UInt64Array = chunk.buffered_indices.finish(); + let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) { + vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti + ) { + vec![] + } else if let Some(buffered_idx) = chunk.buffered_batch_idx { + fetch_right_columns_by_idxs( + &self.buffered_data, + buffered_idx, + &right_indices, + )? + } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. + create_unmatched_columns( + self.join_type, + &self.buffered_schema, + right_indices.len(), + ) + }; // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. let filter_columns = if chunk.buffered_batch_idx.is_some() { - if matches!(self.join_type, JoinType::Right) { - get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti - ) { - // unwrap is safe here as we check is_some on top of if statement - let buffered_columns = get_buffered_columns( - &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), - &buffered_indices, - )?; + if !matches!(self.join_type, JoinType::Right) { + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ) { + let right_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; + + get_filter_column(&self.filter, &left_columns, &right_cols) + } else if matches!(self.join_type, JoinType::RightAnti) { + let right_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + get_filter_column(&self.filter, &right_cols, &left_columns) + } else { + get_filter_column(&self.filter, &left_columns, &right_columns) + } } else { - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + get_filter_column(&self.filter, &right_columns, &left_columns) } } else { // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. @@ -1328,17 +1746,15 @@ impl SMJStream { vec![] }; - let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); - buffered_columns + let columns = if !matches!(self.join_type, JoinType::Right) { + left_columns.extend(right_columns); + left_columns } else { - streamed_columns.extend(buffered_columns); - streamed_columns + right_columns.extend(left_columns); + right_columns }; - let output_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; - + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; // Apply join filter if any if !filter_columns.is_empty() { if let Some(f) = &self.filter { @@ -1367,137 +1783,66 @@ impl SMJStream { pre_mask.clone() }; - // For certain join types, we need to adjust the initial mask to handle the join filter. - let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask( - self.join_type, - &streamed_indices, - &mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ); - - let mask = - if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - self.streamed_batch - .join_filter_matched_idxs - .extend(&filtered_join_mask.1); - &filtered_join_mask.0 - } else { - &mask - }; - // Push the filtered batch which contains rows passing join filter to the output - let filtered_batch = - compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); + if matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::Full + ) { + self.output_record_batches.batches.push(output_batch); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.output_record_batches.batches.push(filtered_batch); + } + + if !matches!(self.join_type, JoinType::Full) { + self.output_record_batches.filter_mask.extend(&mask); + } else { + self.output_record_batches.filter_mask.extend(pre_mask); + } + self.output_record_batches.row_indices.extend(&left_indices); + self.output_record_batches.batch_ids.resize( + self.output_record_batches.batch_ids.len() + left_indices.len(), + self.streamed_batch_counter.load(Relaxed), + ); // For outer joins, we need to push the null joined rows to the output if // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!( - self.join_type, - JoinType::Left | JoinType::Right | JoinType::Full - ) { - // We need to get the mask for row indices that the joined rows are failed - // on the join filter. I.e., for a row in streamed side, if all joined rows - // between it and all buffered rows are failed on the join filter, we need to - // output it with null columns from buffered side. For the mask here, it - // behaves like LeftAnti join. - let null_mask: BooleanArray = get_filtered_join_mask( - // Set a mask slot as true only if all joined rows of same streamed index - // are failed on the join filter. - // The masking behavior is like LeftAnti join. - JoinType::LeftAnti, - &streamed_indices, - mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ) - .unwrap() - .0; - - let null_joined_batch = - compute::filter_record_batch(&output_batch, &null_mask)?; - - let mut buffered_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch - .columns() - .iter() - .skip(buffered_columns_length) - .cloned() - .collect::>(); - - buffered_columns.extend(streamed_columns); - buffered_columns - } else { - // Left join or full outer join - let mut streamed_columns = null_joined_batch - .columns() - .iter() - .take(streamed_columns_length) - .cloned() - .collect::>(); - - streamed_columns.extend(buffered_columns); - streamed_columns - }; - - // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = RecordBatch::try_new( - Arc::clone(&self.schema), - columns.clone(), - )?; - self.output_record_batches.push(null_joined_streamed_batch); - - // For full join, we also need to output the null joined rows from the buffered side. - // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with - // streamed side, it won't be outputted by `freeze_buffered`. - // We need to check if a buffered row is joined with streamed side and output. - // If it is joined with streamed side, but doesn't match the join filter, - // we need to output it with nulls as streamed side. - if matches!(self.join_type, JoinType::Full) { - let buffered_batch = &mut self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()]; - - for i in 0..pre_mask.len() { - // If the buffered row is not joined with streamed side, - // skip it. - if buffered_indices.is_null(i) { - continue; - } + if matches!(self.join_type, JoinType::Full) { + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; + + for i in 0..pre_mask.len() { + // If the buffered row is not joined with streamed side, + // skip it. + if right_indices.is_null(i) { + continue; + } - let buffered_index = buffered_indices.value(i); + let buffered_index = right_indices.value(i); - buffered_batch.join_filter_failed_map.insert( - buffered_index, - *buffered_batch - .join_filter_failed_map - .get(&buffered_index) - .unwrap_or(&true) - && !pre_mask.value(i), - ); - } + buffered_batch.join_filter_not_matched_map.insert( + buffered_index, + *buffered_batch + .join_filter_not_matched_map + .get(&buffered_index) + .unwrap_or(&true) + && !pre_mask.value(i), + ); } } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } @@ -1507,7 +1852,8 @@ impl SMJStream { } fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; + let record_batch = + concat_batches(&self.schema, &self.output_record_batches.batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); // If join filter exists, `self.output_size` is not accurate as we don't know the exact @@ -1520,9 +1866,206 @@ impl SMJStream { } else { self.output_size -= record_batch.num_rows(); } - self.output_record_batches.clear(); + + if !(self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::Full + )) + { + self.output_record_batches.batches.clear(); + } Ok(record_batch) } + + fn filter_joined_batch(&mut self) -> Result { + let record_batch = self.output_record_batch_and_reset()?; + let mut out_indices = self.output_record_batches.row_indices.finish(); + let mut out_mask = self.output_record_batches.filter_mask.finish(); + let mut batch_ids = &self.output_record_batches.batch_ids; + let default_batch_ids = vec![0; record_batch.num_rows()]; + + // If only nulls come in and indices sizes doesn't match with expected record batch count + // generate missing indices + // Happens for null joined batches for Full Join + if out_indices.null_count() == out_indices.len() + && out_indices.len() != record_batch.num_rows() + { + out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]); + out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]); + batch_ids = &default_batch_ids; + } + + if out_mask.is_empty() { + self.output_record_batches.batches.clear(); + return Ok(record_batch); + } + + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + self.filter_record_batch_by_join_type(record_batch, corrected_mask) + } + + fn filter_record_batch_by_join_type( + &mut self, + record_batch: RecordBatch, + corrected_mask: &BooleanArray, + ) -> Result { + let mut filtered_record_batch = + filter_record_batch(&record_batch, corrected_mask)?; + let left_columns_length = self.streamed_schema.fields.len(); + let right_columns_length = self.buffered_schema.fields.len(); + + if matches!( + self.join_type, + JoinType::Left | JoinType::LeftMark | JoinType::Right + ) { + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; + + let mut right_columns = create_unmatched_columns( + self.join_type, + &self.buffered_schema, + null_joined_batch.num_rows(), + ); + + let columns = if !matches!(self.join_type, JoinType::Right) { + let mut left_columns = null_joined_batch + .columns() + .iter() + .take(right_columns_length) + .cloned() + .collect::>(); + + left_columns.extend(right_columns); + left_columns + } else { + let left_columns = null_joined_batch + .columns() + .iter() + .skip(left_columns_length) + .cloned() + .collect::>(); + + right_columns.extend(left_columns); + right_columns + }; + + // Push the streamed/buffered batch joined nulls to the output + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?; + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + let output_column_indices = (0..left_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::RightAnti) { + let output_column_indices = (0..right_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::Full) + && corrected_mask.false_count() > 0 + { + // Find rows which joined by key but Filter predicate evaluated as false + let joined_filter_not_matched_mask = compute::not(corrected_mask)?; + let joined_filter_not_matched_batch = + filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?; + + // Add left unmatched rows adding the right side as nulls + let right_null_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let mut result_joined = joined_filter_not_matched_batch + .columns() + .iter() + .take(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_null_columns); + + let left_null_joined_batch = + RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?; + + // Add right unmatched rows adding the left side as nulls + let mut result_joined = self + .streamed_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let right_data = joined_filter_not_matched_batch + .columns() + .iter() + .skip(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_data); + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, left_null_joined_batch], + )?; + } + + self.output_record_batches.clear(); + + Ok(filtered_record_batch) + } +} + +fn create_unmatched_columns( + join_type: JoinType, + schema: &SchemaRef, + size: usize, +) -> Vec { + if matches!(join_type, JoinType::LeftMark) { + vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef] + } else { + schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), size)) + .collect::>() + } } /// Gets the arrays which join filters are applied on. @@ -1566,39 +2109,39 @@ fn produce_buffered_null_batch( } // Take buffered (right) columns - let buffered_columns = - get_buffered_columns_from_batch(buffered_batch, buffered_indices)?; + let right_columns = + fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?; // Create null streamed (left) columns - let mut streamed_columns = streamed_schema + let mut left_columns = streamed_schema .fields() .iter() .map(|f| new_null_array(f.data_type(), buffered_indices.len())) .collect::>(); - streamed_columns.extend(buffered_columns); + left_columns.extend(right_columns); Ok(Some(RecordBatch::try_new( Arc::clone(schema), - streamed_columns, + left_columns, )?)) } -/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices #[inline(always)] -fn get_buffered_columns( +fn fetch_right_columns_by_idxs( buffered_data: &BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, ) -> Result> { - get_buffered_columns_from_batch( + fetch_right_columns_from_batch_by_idxs( &buffered_data.batches[buffered_batch_idx], buffered_indices, ) } #[inline(always)] -fn get_buffered_columns_from_batch( +fn fetch_right_columns_from_batch_by_idxs( buffered_batch: &BufferedBatch, buffered_indices: &UInt64Array, ) -> Result> { @@ -1631,101 +2174,6 @@ fn get_buffered_columns_from_batch( } } -/// Calculate join filter bit mask considering join type specifics -/// `streamed_indices` - array of streamed datasource JOINED row indices -/// `mask` - array booleans representing computed join filter expression eval result: -/// true = the row index matches the join filter -/// false = the row index doesn't match the join filter -/// `streamed_indices` have the same length as `mask` -/// `matched_indices` array of streaming indices that already has a join filter match -/// `scanning_buffered_offset` current buffered offset across batches -/// -/// This return a tuple of: -/// - corrected mask with respect to the join type -/// - indices of rows in streamed batch that have a join filter match -fn get_filtered_join_mask( - join_type: JoinType, - streamed_indices: &UInt64Array, - mask: &BooleanArray, - matched_indices: &HashSet, - scanning_buffered_offset: &usize, -) -> Option<(BooleanArray, Vec)> { - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - match join_type { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - JoinType::LeftSemi => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_idx); - } else { - corrected_mask.append_value(false); - } - - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1) - { - seen_as_true = false; - } - } - Some((corrected_mask.finish(), filter_matched_indices)) - } - // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. - // `filter_matched_indices` needs to be set once per streaming index - // to prevent duplicates in the output - JoinType::LeftAnti => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - filter_matched_indices.push(streamed_idx); - } - - // Reset `seen_as_true` flag and calculate mask for the current streaming index - // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) - // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last - if (i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1)) - || (i == streamed_indices_length - 1 - && *scanning_buffered_offset == 0) - { - corrected_mask.append_value( - !matched_indices.contains(&streamed_idx) && !seen_as_true, - ); - seen_as_true = false; - } else { - corrected_mask.append_value(false); - } - } - - Some((corrected_mask.finish(), filter_matched_indices)) - } - _ => None, - } -} - /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1966,13 +2414,14 @@ mod tests { use std::sync::Arc; use arrow::array::{Date32Array, Date64Array, Int32Array}; - use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; - use hashbrown::HashSet; - use datafusion_common::JoinType::{LeftAnti, LeftSemi}; + use datafusion_common::JoinSide; + use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -1980,13 +2429,15 @@ mod tests { use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; use crate::expressions::Column; - use crate::joins::sort_merge_join::get_filtered_join_mask; - use crate::joins::utils::JoinOn; + use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; + use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; - use crate::test::build_table_i32; + use crate::test::{build_table_i32, build_table_i32_two_cols}; use crate::{common, ExecutionPlan}; fn build_table( @@ -2077,6 +2528,15 @@ mod tests { Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } + pub fn build_table_two_cols( + a: (&str, &Vec), + b: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32_two_cols(a, b); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + fn join( left: Arc, right: Arc, @@ -2106,6 +2566,26 @@ mod tests { ) } + fn join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + sort_options: Vec, + null_equals_null: bool, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + Some(filter), + join_type, + sort_options, + null_equals_null, + ) + } + async fn join_collect( left: Arc, right: Arc, @@ -2116,6 +2596,25 @@ mod tests { join_collect_with_options(left, right, on, join_type, sort_options, false).await } + async fn join_collect_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + + let task_ctx = Arc::new(TaskContext::default()); + let join = + join_with_filter(left, right, on, filter, join_type, sort_options, false)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + async fn join_collect_with_options( left: Arc, right: Arc, @@ -2175,7 +2674,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", @@ -2214,7 +2713,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2252,7 +2751,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2291,7 +2790,7 @@ mod tests { ), ]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2332,7 +2831,7 @@ mod tests { left, right, on, - JoinType::Inner, + Inner, vec![ SortOptions { descending: true, @@ -2382,7 +2881,7 @@ mod tests { ]; let (_, batches) = - join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; + join_collect_batch_size_equals_two(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2417,7 +2916,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2449,7 +2948,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2481,7 +2980,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2497,7 +2996,7 @@ mod tests { } #[tokio::test] - async fn join_anti() -> Result<()> { + async fn join_left_anti() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2, 3, 5]), ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right @@ -2513,7 +3012,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?; + let (_, batches) = join_collect(left, right, on, LeftAnti).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2527,6 +3026,310 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_anti_one_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = + build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+", + "| a2 | b1 |", + "+----+----+", + "| 30 | 6 |", + "+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + let left2 = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right2 = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _, + )]; + + let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?; + let expected2 = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected2, &batches2); + + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_two_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = + build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+", + "| a2 | b1 |", + "+----+----+", + "| 10 | 4 |", + "| 20 | 5 |", + "| 30 | 6 |", + "+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 10 | 4 | 70 |", + "| 20 | 5 | 80 |", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Gt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 2 | | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]), + ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightAnti, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + true, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_semi() -> Result<()> { let left = build_table( @@ -2544,7 +3347,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?; + let (_, batches) = join_collect(left, right, on, LeftSemi).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2559,6 +3362,39 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_left_mark() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, LeftMark).await?; + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { let left = build_table( @@ -2577,7 +3413,7 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+---+---+---+----+---+----+", "| a | b | c | a | b | c |", @@ -2609,7 +3445,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+------------+------------+------------+------------+------------+------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2641,7 +3477,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2672,7 +3508,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2708,7 +3544,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2752,7 +3588,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2801,7 +3637,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2850,7 +3686,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2890,14 +3726,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -2975,14 +3804,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3038,14 +3860,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3146,14 +3961,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3213,171 +4021,679 @@ mod tests { Ok(()) } + fn build_joined_record_batches() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut batches = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + // Insert already prejoined non-filtered rows + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![1; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + batches.batch_ids.extend(vec![2; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![3; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + batches + .filter_mask + .extend(&BooleanArray::from(vec![true, false])); + batches.filter_mask.extend(&BooleanArray::from(vec![true])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, true])); + batches.filter_mask.extend(&BooleanArray::from(vec![false])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + Ok(batches) + } + #[tokio::test] - async fn left_semi_join_filtered_mask() -> Result<()> { + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, false, false, false, false, false, false, false + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + false, false, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, true]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, true, false, false, false, false, false, false + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![1])) + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![0])) + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, true, false, true, false, false]), - vec![0, 1] - )) + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, true]), - vec![1] - )) + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![true, false, false, false, false, true]), - &HashSet::from_iter(vec![1]), - &0, - ), - Some(( - BooleanArray::from(vec![true, false, false, false, false, false]), - vec![0] - )) + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); + let corrected_mask = get_corrected_filter_mask( + Left, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[null_joined_batch] + ); Ok(()) } #[tokio::test] - async fn left_anti_join_filtered_mask() -> Result<()> { + async fn test_left_semi_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![1])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![0])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, false]), - vec![0, 1] - )) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, true, false, false, false]), - vec![1] - )) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftSemi, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] ); + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + Ok(()) + } + + #[tokio::test] + async fn test_anti_join_filtered_mask() -> Result<()> { + for join_type in [LeftAnti, RightAnti] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); + + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + } Ok(()) } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index ba9384aef1a65..cea04ccad3fce 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -19,6 +19,7 @@ //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; +use std::mem::size_of; use std::sync::Arc; use crate::joins::utils::{JoinFilter, JoinHashMapType}; @@ -31,8 +32,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, - ScalarValue, + arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -40,8 +40,8 @@ use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::raw::RawTable; -use hashbrown::HashSet; /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. impl JoinHashMapType for PruningJoinHashMap { @@ -154,8 +154,7 @@ impl PruningJoinHashMap { /// # Returns /// The size of the hash map in bytes. pub(crate) fn size(&self) -> usize { - self.map.allocation_info().1.size() - + self.next.capacity() * std::mem::size_of::() + self.map.allocation_info().1.size() + self.next.capacity() * size_of::() } /// Removes hash values from the map and the list based on the given pruning @@ -369,34 +368,40 @@ impl SortedFilterExpr { filter_expr: Arc, filter_schema: &Schema, ) -> Result { - let dt = &filter_expr.data_type(filter_schema)?; + let dt = filter_expr.data_type(filter_schema)?; Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::make_unbounded(dt)?, + interval: Interval::make_unbounded(&dt)?, node_index: 0, }) } + /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { &self.origin_sorted_expr } + /// Get filter expr information pub fn filter_expr(&self) -> &Arc { &self.filter_expr } + /// Get interval information pub fn interval(&self) -> &Interval { &self.interval } + /// Sets interval pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph pub fn node_index(&self) -> usize { self.node_index } + /// Node index setter in ExprIntervalGraph pub fn set_node_index(&mut self, node_index: usize) { self.node_index = node_index; @@ -409,41 +414,45 @@ impl SortedFilterExpr { /// on the first or the last value of the expression in `build_input_buffer` /// and `probe_batch`. /// -/// # Arguments +/// # Parameters /// /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. /// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// -/// ### Note -/// ```text +/// ## Note /// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. +/// Utilizing interval arithmetic, this function computes feasible join intervals +/// on the pruning side by evaluating the prospective value ranges that might +/// emerge in subsequent data batches from the enforcer side. This is done by +/// first creating an interval for join filter values in the pruning side of the +/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/ +/// ascending) of the filter expression. Here, `FV` denotes the first value on the +/// pruning side. This range is then compared with the enforcer side interval, +/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/ +/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer +/// side. /// /// As a concrete example, consider the following query: /// +/// ```text /// SELECT * FROM left_table, right_table /// WHERE /// left_key = right_key AND /// a > b - 3 AND /// a < b + 10 +/// ``` /// -/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// where columns `a` and `b` come from tables `left_table` and `right_table`, /// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left +/// condition `a > b - 3` will possibly indicate a prunable range for the left /// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// condition `a < b + 10` will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new `RecordBatch` arrives at the right /// side (i.e. when the left side is the build side): /// +/// ```text /// Build Probe /// +-------+ +-------+ /// | a | z | | b | y | @@ -456,13 +465,13 @@ impl SortedFilterExpr { /// |+--|--+| |+--|--+| /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ +/// ``` /// /// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// column `a` is `[1, ∞]`, and the interval representing possible future values +/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate /// intervals for the whole filter expression and propagate join constraint by /// traversing the expression graph. -/// ``` pub fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices( } } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// Prepares and sorts expressions based on a given filter, left and right schemas, +/// and sort expressions. /// -/// # Arguments +/// This function prepares sorted filter expressions for both the left and right +/// sides of a join operation. It first builds the filter order for each side +/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter +/// expressions, the function then constructs an expression interval graph and +/// updates the sorted expressions with node indices. The final sorted filter +/// expressions for both sides are then returned. +/// +/// # Parameters /// /// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. +/// * `left` - The `ExecutionPlan` for the left side of the join. +/// * `right` - The `ExecutionPlan` for the right side of the join. /// * `left_sort_exprs` - The expressions to sort on the left side. /// * `right_sort_exprs` - The expressions to sort on the right side. /// @@ -727,12 +744,14 @@ pub fn prepare_sorted_exprs( filter: &JoinFilter, left: &Arc, right: &Arc, - left_sort_exprs: &[PhysicalSortExpr], - right_sort_exprs: &[PhysicalSortExpr], + left_sort_exprs: &LexOrdering, + right_sort_exprs: &LexOrdering, ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); + let err = || { + datafusion_common::plan_datafusion_err!("Filter does not include the child order") + }; + // Build the filter order for the left side: let left_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Left, filter, @@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs( )? .ok_or_else(err)?; - // Build the filter order for the right side + // Build the filter order for the right side: let right_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Right, filter, @@ -952,15 +971,15 @@ pub mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index ac718a95e9f4c..72fd5a0feb1a3 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -27,12 +27,13 @@ use std::any::Any; use std::fmt::{self, Debug}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; -use crate::handle_state; +use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, @@ -42,12 +43,11 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, + BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, + NoopBatchTransformer, StatefulStreamResult, }; use crate::{ - execution_mode_from_children, - expressions::PhysicalSortExpr, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, @@ -61,20 +61,20 @@ use arrow::array::{ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_buffer::ArrowNativeType; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; -use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result}; +use datafusion_common::{internal_err, plan_err, HashSet, JoinSide, JoinType, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; -use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; +use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use futures::{ready, Stream, StreamExt}; -use hashbrown::HashSet; use parking_lot::Mutex; const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; @@ -164,7 +164,7 @@ const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; /// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) /// than that can be dropped from the inner buffer. /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SymmetricHashJoinExec { /// Left side stream pub(crate) left: Arc, @@ -185,9 +185,9 @@ pub struct SymmetricHashJoinExec { /// If null_equals_null is true, null == null else null != null pub(crate) null_equals_null: bool, /// Left side sort expression(s) - pub(crate) left_sort_exprs: Option>, + pub(crate) left_sort_exprs: Option, /// Right side sort expression(s) - pub(crate) right_sort_exprs: Option>, + pub(crate) right_sort_exprs: Option, /// Partition Mode mode: StreamJoinPartitionMode, /// Cache holding plan properties like equivalences, output partitioning etc. @@ -209,8 +209,8 @@ impl SymmetricHashJoinExec { filter: Option, join_type: &JoinType, null_equals_null: bool, - left_sort_exprs: Option>, - right_sort_exprs: Option>, + left_sort_exprs: Option, + right_sort_exprs: Option, mode: StreamJoinPartitionMode, ) -> Result { let left_schema = left.schema(); @@ -275,10 +275,12 @@ impl SymmetricHashJoinExec { let output_partitioning = symmetric_join_output_partitioning(left, right, &join_type); - // Determine execution mode: - let mode = execution_mode_from_children([left, right]); - - PlanProperties::new(eq_properties, output_partitioning, mode) + PlanProperties::new( + eq_properties, + output_partitioning, + emission_type_from_children([left, right]), + boundedness_from_children([left, right]), + ) } /// left stream @@ -317,13 +319,13 @@ impl SymmetricHashJoinExec { } /// Get left_sort_exprs - pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { - self.left_sort_exprs.as_deref() + pub fn left_sort_exprs(&self) -> Option<&LexOrdering> { + self.left_sort_exprs.as_ref() } /// Get right_sort_exprs - pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { - self.right_sort_exprs.as_deref() + pub fn right_sort_exprs(&self) -> Option<&LexOrdering> { + self.right_sort_exprs.as_ref() } /// Check if order information covers every column in the filter expression. @@ -415,10 +417,12 @@ impl ExecutionPlan for SymmetricHashJoinExec { vec![ self.left_sort_exprs .as_ref() - .map(PhysicalSortRequirement::from_sort_exprs), + .cloned() + .map(LexRequirement::from), self.right_sort_exprs .as_ref() - .map(PhysicalSortRequirement::from_sort_exprs), + .cloned() + .map(LexRequirement::from), ] } @@ -465,23 +469,27 @@ impl ExecutionPlan for SymmetricHashJoinExec { consider using RepartitionExec" ); } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + // If `filter_state` and `filter` are both present, then calculate sorted + // filter expressions for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left_sort_exprs(), + self.right_sort_exprs(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None + // for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); @@ -494,6 +502,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) .register(context.memory_pool()), @@ -502,29 +514,52 @@ impl ExecutionPlan for SymmetricHashJoinExec { reservation.lock().try_grow(g.size())?; } - Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: StreamJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - state: SHJStreamState::PullRight, - reservation, - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: NoopBatchTransformer::new(), + })) + } } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { +struct SymmetricHashJoinStream { /// Input streams left_stream: SendableRecordBatchStream, right_stream: SendableRecordBatchStream, @@ -556,20 +591,24 @@ struct SymmetricHashJoinStream { reservation: SharedMemoryReservation, /// State machine for input execution state: SHJStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, } -impl RecordBatchStream for SymmetricHashJoinStream { +impl RecordBatchStream + for SymmetricHashJoinStream +{ fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl Stream for SymmetricHashJoinStream { +impl Stream for SymmetricHashJoinStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { self.poll_next_impl(cx) } @@ -634,7 +673,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> if build_side == JoinSide::Left { matches!( join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi + JoinType::Left + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftSemi + | JoinType::LeftMark ) } else { matches!( @@ -673,6 +716,20 @@ where { // Store the result in a tuple let result = match (build_side, join_type) { + (JoinSide::Left, JoinType::LeftMark) => { + let build_indices = (0..prune_length) + .map(L::Native::from_usize) + .collect::>(); + let probe_indices = (0..prune_length) + .map(|idx| { + // For mark join we output a dummy index 0 to indicate the row had a match + visited_rows + .contains(&(idx + deleted_offset)) + .then_some(R::Native::from_usize(0).unwrap()) + }) + .collect(); + (build_indices, probe_indices) + } // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) @@ -836,6 +893,7 @@ pub(crate) fn join_with_probe_batch( JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi + | JoinType::LeftMark | JoinType::RightSemi ) { Ok(None) @@ -969,15 +1027,15 @@ pub struct OneSideHashJoiner { impl OneSideHashJoiner { pub fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(self); - size += std::mem::size_of_val(&self.build_side); + size += size_of_val(self); + size += size_of_val(&self.build_side); size += self.input_buffer.get_array_memory_size(); - size += std::mem::size_of_val(&self.on); + size += size_of_val(&self.on); size += self.hashmap.size(); - size += self.hashes_buffer.capacity() * std::mem::size_of::(); - size += self.visited_rows.capacity() * std::mem::size_of::(); - size += std::mem::size_of_val(&self.offset); - size += std::mem::size_of_val(&self.deleted_offset); + size += self.hashes_buffer.capacity() * size_of::(); + size += self.visited_rows.capacity() * size_of::(); + size += size_of_val(&self.offset); + size += size_of_val(&self.deleted_offset); size } pub fn new( @@ -1140,7 +1198,7 @@ impl OneSideHashJoiner { /// - Transition to `BothExhausted { final_result: true }`: /// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are /// exhausted, indicating completion of processing and availability of final results. -impl SymmetricHashJoinStream { +impl SymmetricHashJoinStream { /// Implements the main polling logic for the join stream. /// /// This method continuously checks the state of the join stream and @@ -1159,26 +1217,45 @@ impl SymmetricHashJoinStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - return match self.state() { - SHJStreamState::PullRight => { - handle_state!(ready!(self.fetch_next_from_right_stream(cx))) - } - SHJStreamState::PullLeft => { - handle_state!(ready!(self.fetch_next_from_left_stream(cx))) + match self.batch_transformer.next() { + None => { + let result = match self.state() { + SHJStreamState::PullRight => { + ready!(self.fetch_next_from_right_stream(cx)) + } + SHJStreamState::PullLeft => { + ready!(self.fetch_next_from_left_stream(cx)) + } + SHJStreamState::RightExhausted => { + ready!(self.handle_right_stream_end(cx)) + } + SHJStreamState::LeftExhausted => { + ready!(self.handle_left_stream_end(cx)) + } + SHJStreamState::BothExhausted { + final_result: false, + } => self.prepare_for_final_results_after_exhaustion(), + SHJStreamState::BothExhausted { final_result: true } => { + return Poll::Ready(None); + } + }; + + match result? { + StatefulStreamResult::Ready(None) => { + return Poll::Ready(None); + } + StatefulStreamResult::Ready(Some(batch)) => { + self.batch_transformer.set_batch(batch); + } + _ => {} + } } - SHJStreamState::RightExhausted => { - handle_state!(ready!(self.handle_right_stream_end(cx))) + Some((batch, _)) => { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); } - SHJStreamState::LeftExhausted => { - handle_state!(ready!(self.handle_left_stream_end(cx))) - } - SHJStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) - } - SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None), - }; + } } } /// Asynchronously pulls the next batch from the right stream. @@ -1384,11 +1461,8 @@ impl SymmetricHashJoinStream { // Combine the left and right results: let result = combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); + // Return the result: + if result.is_some() { return Ok(StatefulStreamResult::Ready(result)); } Ok(StatefulStreamResult::Continue) @@ -1412,18 +1486,18 @@ impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.schema); - size += std::mem::size_of_val(&self.filter); - size += std::mem::size_of_val(&self.join_type); + size += size_of_val(&self.schema); + size += size_of_val(&self.filter); + size += size_of_val(&self.join_type); size += self.left.size(); size += self.right.size(); - size += std::mem::size_of_val(&self.column_indices); + size += size_of_val(&self.column_indices); size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0); - size += std::mem::size_of_val(&self.left_sorted_filter_expr); - size += std::mem::size_of_val(&self.right_sorted_filter_expr); - size += std::mem::size_of_val(&self.random_state); - size += std::mem::size_of_val(&self.null_equals_null); - size += std::mem::size_of_val(&self.metrics); + size += size_of_val(&self.left_sorted_filter_expr); + size += size_of_val(&self.right_sorted_filter_expr); + size += size_of_val(&self.random_state); + size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.metrics); size } @@ -1523,11 +1597,6 @@ impl SymmetricHashJoinStream { let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } Ok(result) } } @@ -1579,6 +1648,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use once_cell::sync::Lazy; use rstest::*; @@ -1660,6 +1730,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1678,7 +1749,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: binary( col("la1", left_schema)?, Operator::Plus, @@ -1686,11 +1757,11 @@ mod tests { left_schema, )?, options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1716,15 +1787,15 @@ mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; @@ -1744,6 +1815,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1756,14 +1828,14 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1771,10 +1843,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1811,6 +1880,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1825,10 +1895,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1865,6 +1932,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1877,10 +1945,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) } @@ -1895,6 +1960,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1905,20 +1971,20 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1926,10 +1992,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1966,20 +2029,20 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_asc_null_first", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_asc_null_first", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1987,10 +2050,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2027,20 +2087,20 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_asc_null_last", left_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_asc_null_last", right_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2048,10 +2108,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2090,20 +2147,20 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_desc_null_first", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_desc_null_first", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2111,10 +2168,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2154,15 +2208,15 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]; + }]); - let right_sorted = vec![PhysicalSortExpr { + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2170,10 +2224,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2215,20 +2266,20 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let left_sorted = vec![ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }], - vec![PhysicalSortExpr { + }]), + LexOrdering::new(vec![PhysicalSortExpr { expr: col("la2", left_schema)?, options: SortOptions::default(), - }], + }]), ]; - let right_sorted = vec![PhysicalSortExpr { + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, @@ -2237,10 +2288,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2278,6 +2326,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -2296,24 +2345,21 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; - let left_sorted = vec![PhysicalSortExpr { + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("rt1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2363,6 +2409,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -2380,24 +2427,21 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; - let left_sorted = vec![PhysicalSortExpr { + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ri1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2440,6 +2484,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -2458,14 +2503,14 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_float", left_schema)?, options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_float", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2473,10 +2518,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Float64, true), diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 264f297ffb4c4..37d6c0aff8503 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -47,21 +47,23 @@ use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { + let left_row_num: usize = collected_1.iter().map(|batch| batch.num_rows()).sum(); + let right_row_num: usize = collected_2.iter().map(|batch| batch.num_rows()).sum(); + if left_row_num == 0 && right_row_num == 0 { + return; + } // compare let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); - let mut first_formatted_sorted: Vec<&str> = first_formatted.trim().lines().collect(); - first_formatted_sorted.sort_unstable(); + let mut first_lines: Vec<&str> = first_formatted.trim().lines().collect(); + first_lines.sort_unstable(); - let mut second_formatted_sorted: Vec<&str> = - second_formatted.trim().lines().collect(); - second_formatted_sorted.sort_unstable(); + let mut second_lines: Vec<&str> = second_formatted.trim().lines().collect(); + second_lines.sort_unstable(); - for (i, (first_line, second_line)) in first_formatted_sorted - .iter() - .zip(&second_formatted_sorted) - .enumerate() + for (i, (first_line, second_line)) in + first_lines.iter().zip(&second_lines).enumerate() { assert_eq!((i, first_line), (i, second_line)); } @@ -101,8 +103,10 @@ pub async fn partitioned_sym_join_with_filter( filter, join_type, null_equals_null, - left.output_ordering().map(|p| p.to_vec()), - right.output_ordering().map(|p| p.to_vec()), + left.output_ordering().map(|p| LexOrdering::new(p.to_vec())), + right + .output_ordering() + .map(|p| LexOrdering::new(p.to_vec())), StreamJoinPartitionMode::Partitioned, )?; @@ -289,7 +293,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(10 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + // left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 15 1 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -300,9 +304,9 @@ macro_rules! join_expr_tests { Operator::Plus, ), ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(15 as $type)), (Operator::Gt, Operator::Lt), ), // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 @@ -353,7 +357,8 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + // left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col - 3 + // (filters all input rows) 5 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -369,7 +374,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::GtEq, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + // left_col + 28 >= right_col - 11 AND left_col + 21 <= right_col + 39 6 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -385,7 +390,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(39 as $type)), (Operator::Gt, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 + // left_col + 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 7 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -526,10 +531,10 @@ pub fn create_memory_table( ) -> Result<(Arc, Arc)> { let left_schema = left_partition[0].schema(); let left = MemoryExec::try_new(&[left_partition], left_schema, None)? - .with_sort_information(left_sorted); + .try_with_sort_information(left_sorted)?; let right_schema = right_partition[0].schema(); let right = MemoryExec::try_new(&[right_partition], right_schema, None)? - .with_sort_information(right_sorted); + .try_with_sort_information(right_sorted)?; Ok((Arc::new(left), Arc::new(right))) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 89f3feaf07be6..dea4305fa6a11 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; +use std::iter::once; use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; @@ -28,6 +29,8 @@ use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, }; +// compatibility +pub use super::join_filter::JoinFilter; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, @@ -50,9 +53,11 @@ use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::{collect_columns, merge_vectors}; use datafusion_physical_expr::{ - LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + LexOrdering, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; +use crate::joins::SharedBitmapBuilder; +use crate::projection::ProjectionExec; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -369,7 +374,7 @@ impl JoinHashMapType for JoinHashMap { } } -impl fmt::Debug for JoinHashMap { +impl Debug for JoinHashMap { fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } @@ -448,10 +453,10 @@ pub fn adjust_right_output_partitioning( /// the left column (zeroth index in the tuple) inside `right_ordering`. fn replace_on_columns_of_right_ordering( on_columns: &[(PhysicalExprRef, PhysicalExprRef)], - right_ordering: &mut [PhysicalSortExpr], + right_ordering: &mut LexOrdering, ) -> Result<()> { for (left_col, right_col) in on_columns { - for item in right_ordering.iter_mut() { + right_ordering.transform(|item| { let new_expr = Arc::clone(&item.expr) .transform(|e| { if e.eq(right_col) { @@ -460,18 +465,19 @@ fn replace_on_columns_of_right_ordering( Ok(Transformed::no(e)) } }) - .data()?; + .data() + .expect("closure is infallible"); item.expr = new_expr; - } + }); } Ok(()) } fn offset_ordering( - ordering: LexOrderingRef, + ordering: &LexOrdering, join_type: &JoinType, offset: usize, -) -> Vec { +) -> LexOrdering { match join_type { // In the case below, right ordering should be offsetted with the left // side length, since we append the right table to the left table. @@ -482,14 +488,14 @@ fn offset_ordering( options: sort_expr.options, }) .collect(), - _ => ordering.to_vec(), + _ => ordering.clone(), } } /// Calculate the output ordering of a given join operation. pub fn calculate_join_output_ordering( - left_ordering: LexOrderingRef, - right_ordering: LexOrderingRef, + left_ordering: &LexOrdering, + right_ordering: &LexOrdering, join_type: JoinType, on_columns: &[(PhysicalExprRef, PhysicalExprRef)], left_columns_len: usize, @@ -502,15 +508,16 @@ pub fn calculate_join_output_ordering( if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { replace_on_columns_of_right_ordering( on_columns, - &mut right_ordering.to_vec(), + &mut right_ordering.clone(), ) .ok()?; merge_vectors( left_ordering, - &offset_ordering(right_ordering, &join_type, left_columns_len), + offset_ordering(right_ordering, &join_type, left_columns_len) + .as_ref(), ) } else { - left_ordering.to_vec() + left_ordering.clone() } } [false, true] => { @@ -518,11 +525,12 @@ pub fn calculate_join_output_ordering( if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { replace_on_columns_of_right_ordering( on_columns, - &mut right_ordering.to_vec(), + &mut right_ordering.clone(), ) .ok()?; merge_vectors( - &offset_ordering(right_ordering, &join_type, left_columns_len), + offset_ordering(right_ordering, &join_type, left_columns_len) + .as_ref(), left_ordering, ) } else { @@ -546,65 +554,6 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output -#[derive(Debug, Clone)] -pub struct JoinFilter { - /// Filter expression - expression: Arc, - /// Column indices required to construct intermediate batch for filtering - column_indices: Vec, - /// Physical schema of intermediate batch - schema: Schema, -} - -impl JoinFilter { - /// Creates new JoinFilter - pub fn new( - expression: Arc, - column_indices: Vec, - schema: Schema, - ) -> JoinFilter { - JoinFilter { - expression, - column_indices, - schema, - } - } - - /// Helper for building ColumnIndex vector from left and right indices - pub fn build_column_indices( - left_indices: Vec, - right_indices: Vec, - ) -> Vec { - left_indices - .into_iter() - .map(|i| ColumnIndex { - index: i, - side: JoinSide::Left, - }) - .chain(right_indices.into_iter().map(|i| ColumnIndex { - index: i, - side: JoinSide::Right, - })) - .collect() - } - - /// Filter expression - pub fn expression(&self) -> &Arc { - &self.expression - } - - /// Column indices for intermediate batch creation - pub fn column_indices(&self) -> &[ColumnIndex] { - &self.column_indices - } - - /// Intermediate batch schema - pub fn schema(&self) -> &Schema { - &self.schema - } -} - /// Returns the output field given the input field. Outer joins may /// insert nulls even if the input was not null /// @@ -618,6 +567,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> JoinType::RightSemi => false, // doesn't introduce nulls JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??) JoinType::RightAnti => false, // doesn't introduce nulls (or can it??) + JoinType::LeftMark => false, }; if force_nullable { @@ -634,44 +584,10 @@ pub fn build_join_schema( right: &Schema, join_type: &JoinType, ) -> (Schema, Vec) { - let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = left - .fields() - .iter() - .map(|f| output_join_field(f, join_type, true)) - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Left, - }, - ) - }); - let right_fields = right - .fields() - .iter() - .map(|f| output_join_field(f, join_type, false)) - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Right, - }, - ) - }); - - // left then right - left_fields.chain(right_fields).unzip() - } - JoinType::LeftSemi | JoinType::LeftAnti => left - .fields() + let left_fields = || { + left.fields() .iter() - .cloned() + .map(|f| output_join_field(f, join_type, true)) .enumerate() .map(|(index, f)| { ( @@ -682,11 +598,13 @@ pub fn build_join_schema( }, ) }) - .unzip(), - JoinType::RightSemi | JoinType::RightAnti => right + }; + + let right_fields = || { + right .fields() .iter() - .cloned() + .map(|f| output_join_field(f, join_type, false)) .enumerate() .map(|(index, f)| { ( @@ -697,17 +615,49 @@ pub fn build_join_schema( }, ) }) - .unzip(), }; - (fields.finish(), column_indices) + let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + // left then right + left_fields().chain(right_fields()).unzip() + } + JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), + JoinType::LeftMark => { + let right_field = once(( + Field::new("mark", arrow_schema::DataType::Boolean, false), + ColumnIndex { + index: 0, + side: JoinSide::None, + }, + )); + left_fields().chain(right_field).unzip() + } + JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), + }; + + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); + (fields.finish().with_metadata(metadata), column_indices) } -/// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls -/// to [`OnceAsync::once`] returning a [`OnceFut`] to the same asynchronous computation +/// A [`OnceAsync`] runs an `async` closure once, where multiple calls to +/// [`OnceAsync::once`] return a [`OnceFut`] that resolves to the result of the +/// same computation. +/// +/// This is useful for joins where the results of one child are needed to proceed +/// with multiple output stream +/// /// -/// This is useful for joins where the results of one child are buffered in memory -/// and shared across potentially multiple output partitions +/// For example, in a hash join, one input is buffered and shared across +/// potentially multiple output partitions. Each output partition must wait for +/// the hash table to be built before proceeding. +/// +/// Each output partition waits on the same `OnceAsync` before proceeding. pub(crate) struct OnceAsync { fut: Mutex>>, } @@ -720,8 +670,8 @@ impl Default for OnceAsync { } } -impl std::fmt::Debug for OnceAsync { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for OnceAsync { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "OnceAsync") } } @@ -895,6 +845,16 @@ fn estimate_join_cardinality( column_statistics: outer_stats.column_statistics, }) } + + JoinType::LeftMark => { + let num_rows = *left_stats.num_rows.get_value()?; + let mut column_statistics = left_stats.column_statistics; + column_statistics.push(ColumnStatistics::new_unknown()); + Some(PartialJoinStatistics { + num_rows, + column_statistics, + }) + } } } @@ -1146,10 +1106,22 @@ impl OnceFut { pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { matches!( join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Full ) } +pub(crate) fn get_final_indices_from_shared_bitmap( + shared_bitmap: &SharedBitmapBuilder, + join_type: JoinType, +) -> (UInt64Array, UInt32Array) { + let bitmap = shared_bitmap.lock(); + get_final_indices_from_bit_map(&bitmap, join_type) +} + /// In the end of join execution, need to use bit map of the matched /// indices to generate the final left and right indices. /// @@ -1164,6 +1136,13 @@ pub(crate) fn get_final_indices_from_bit_map( join_type: JoinType, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); + if join_type == JoinType::LeftMark { + let left_indices = (0..left_size as u64).collect::(); + let right_indices = (0..left_size) + .map(|idx| left_bit_map.get_bit(idx).then_some(0)) + .collect::(); + return (left_indices, right_indices); + } let left_indices = if join_type == JoinType::LeftSemi { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) @@ -1247,7 +1226,10 @@ pub(crate) fn build_batch_from_indices( let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); for column_index in column_indices { - let array = if column_index.side == build_side { + let array = if column_index.side == JoinSide::None { + // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + Arc::new(compute::is_not_null(probe_indices)?) + } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); if array.is_empty() || build_indices.null_count() == build_indices.len() { // Outer join would generate a null index when finding no match at our side. @@ -1280,15 +1262,15 @@ pub(crate) fn adjust_indices_by_join_type( adjust_range: Range, join_type: JoinType, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { match join_type { JoinType::Inner => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::Left => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } JoinType::Right => { @@ -1307,22 +1289,22 @@ pub(crate) fn adjust_indices_by_join_type( // need to remove the duplicated record in the right side let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } - JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - ( + Ok(( UInt64Array::from_iter_values(vec![]), UInt32Array::from_iter_values(vec![]), - ) + )) } } } @@ -1347,27 +1329,64 @@ pub(crate) fn append_right_indices( right_indices: UInt32Array, adjust_range: Range, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { if preserve_order_for_right { - append_probe_indices_in_order(left_indices, right_indices, adjust_range) + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) } else { let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + Ok((left_indices, right_indices)) } else { - let unmatched_size = right_unmatched_indices.len(); + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect(); - (new_left_indices, new_right_indices) + let mut new_right_indices_builder = right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) } } } @@ -1579,7 +1598,7 @@ macro_rules! handle_state { /// Represents the result of a stateful operation. /// -/// This enumueration indicates whether the state produced a result that is +/// This enumeration indicates whether the state produced a result that is /// ready for use (`Ready`) or if the operation requires continuation (`Continue`). /// /// Variants: @@ -1602,7 +1621,7 @@ pub(crate) fn symmetric_join_output_partitioning( let left_partitioning = left.output_partitioning(); let right_partitioning = right.output_partitioning(); match join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left_partitioning.clone() } JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), @@ -1627,14 +1646,178 @@ pub(crate) fn asymmetric_join_output_partitioning( left.schema().fields().len(), ), JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { - Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ) + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftMark => Partitioning::UnknownPartitioning( + right.output_partitioning().partition_count(), + ), + } +} + +/// Trait for incrementally generating Join output. +/// +/// This trait is used to limit some join outputs +/// so it does not produce single large batches +pub(crate) trait BatchTransformer: Debug + Clone { + /// Sets the next `RecordBatch` to be processed. + fn set_batch(&mut self, batch: RecordBatch); + + /// Retrieves the next `RecordBatch` from the transformer. + /// Returns `None` if all batches have been produced. + /// The boolean flag indicates whether the batch is the last one. + fn next(&mut self) -> Option<(RecordBatch, bool)>; +} + +#[derive(Debug, Clone)] +/// A batch transformer that does nothing. +pub(crate) struct NoopBatchTransformer { + /// RecordBatch to be processed + batch: Option, +} + +impl NoopBatchTransformer { + pub fn new() -> Self { + Self { batch: None } + } +} + +impl BatchTransformer for NoopBatchTransformer { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + self.batch.take().map(|batch| (batch, true)) + } +} + +#[derive(Debug, Clone)] +/// Splits large batches into smaller batches with a maximum number of rows. +pub(crate) struct BatchSplitter { + /// RecordBatch to be split + batch: Option, + /// Maximum number of rows in a split batch + batch_size: usize, + /// Current row index + row_index: usize, +} + +impl BatchSplitter { + /// Creates a new `BatchSplitter` with the specified batch size. + pub(crate) fn new(batch_size: usize) -> Self { + Self { + batch: None, + batch_size, + row_index: 0, } } } +impl BatchTransformer for BatchSplitter { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + self.row_index = 0; + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + let Some(batch) = &self.batch else { + return None; + }; + + let remaining_rows = batch.num_rows() - self.row_index; + let rows_to_slice = remaining_rows.min(self.batch_size); + let sliced_batch = batch.slice(self.row_index, rows_to_slice); + self.row_index += rows_to_slice; + + let mut last = false; + if self.row_index >= batch.num_rows() { + self.batch = None; + last = true; + } + + Some((sliced_batch, last)) + } +} + +/// When the order of the join inputs are changed, the output order of columns +/// must remain the same. +/// +/// Joins output columns from their left input followed by their right input. +/// Thus if the inputs are reordered, the output columns must be reordered to +/// match the original order. +pub(crate) fn reorder_output_after_swap( + plan: Arc, + left_schema: &Schema, + right_schema: &Schema, +) -> Result> { + let proj = ProjectionExec::try_new( + swap_reverting_projection(left_schema, right_schema), + plan, + )?; + Ok(Arc::new(proj)) +} + +/// When the order of the join is changed, the output order of columns must +/// remain the same. +/// +/// Returns the expressions that will allow to swap back the values from the +/// original left as the first columns and those on the right next. +fn swap_reverting_projection( + left_schema: &Schema, + right_schema: &Schema, +) -> Vec<(Arc, String)> { + let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), i)) as Arc, + f.name().to_owned(), + ) + }); + let right_len = right_cols.len(); + let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), right_len + i)) as Arc, + f.name().to_owned(), + ) + }); + + left_cols.chain(right_cols).collect() +} + +/// This function swaps the given join's projection. +pub(super) fn swap_join_projection( + left_schema_len: usize, + right_schema_len: usize, + projection: Option<&Vec>, + join_type: &JoinType, +) -> Option> { + match join_type { + // For Anti/Semi join types, projection should remain unmodified, + // since these joins output schema remains the same after swap + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi => projection.cloned(), + + _ => projection.map(|p| { + p.iter() + .map(|i| { + // If the index is less than the left schema length, it is from + // the left schema, so we add the right schema length to it. + // Otherwise, it is from the right schema, so we subtract the left + // schema length from it. + if *i < left_schema_len { + *i + right_schema_len + } else { + *i - left_schema_len + } + }) + .collect() + }), + } +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -1643,11 +1826,13 @@ mod tests { use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_array::Int32Array; use arrow_schema::SortOptions; - use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use rstest::rstest; + fn check( left: &[Column], right: &[Column], @@ -1821,13 +2006,13 @@ mod tests { ) -> Statistics { Statistics { num_rows: if is_exact { - num_rows.map(Precision::Exact) + num_rows.map(Exact) } else { - num_rows.map(Precision::Inexact) + num_rows.map(Inexact) } - .unwrap_or(Precision::Absent), + .unwrap_or(Absent), column_statistics: column_stats, - total_byte_size: Precision::Absent, + total_byte_size: Absent, } } @@ -2073,17 +2258,17 @@ mod tests { assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact((400 * 400) / 200)) + Some(Inexact((400 * 400) / 200)) ); Ok(()) } @@ -2091,33 +2276,33 @@ mod tests { #[test] fn test_inner_join_cardinality_decimal_range() -> Result<()> { let left_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), ..Default::default() }]; let right_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), ..Default::default() }]; assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact(100)) + Some(Inexact(100)) ); Ok(()) } @@ -2455,7 +2640,7 @@ mod tests { #[test] fn test_calculate_join_output_ordering() -> Result<()> { let options = SortOptions::default(); - let left_ordering = vec![ + let left_ordering = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options, @@ -2468,8 +2653,8 @@ mod tests { expr: Arc::new(Column::new("d", 3)), options, }, - ]; - let right_ordering = vec![ + ]); + let right_ordering = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("z", 2)), options, @@ -2478,7 +2663,7 @@ mod tests { expr: Arc::new(Column::new("y", 1)), options, }, - ]; + ]); let join_type = JoinType::Inner; let on_columns = [( Arc::new(Column::new("b", 1)) as _, @@ -2489,7 +2674,7 @@ mod tests { let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; let expected = [ - Some(vec![ + Some(LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options, @@ -2510,8 +2695,8 @@ mod tests { expr: Arc::new(Column::new("y", 6)), options, }, - ]), - Some(vec![ + ])), + Some(LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("z", 7)), options, @@ -2532,7 +2717,7 @@ mod tests { expr: Arc::new(Column::new("d", 3)), options, }, - ]), + ])), ]; for (i, (maintains_input_order, probe_side)) in @@ -2540,8 +2725,8 @@ mod tests { { assert_eq!( calculate_join_output_ordering( - &left_ordering, - &right_ordering, + left_ordering.as_ref(), + right_ordering.as_ref(), join_type, &on_columns, left_columns_len, @@ -2554,4 +2739,84 @@ mod tests { Ok(()) } + + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32)); + RecordBatch::try_new(schema, vec![data]).unwrap() + } + + fn assert_split_batches( + batches: Vec<(RecordBatch, bool)>, + batch_size: usize, + num_rows: usize, + ) { + let mut row_count = 0; + for (batch, last) in batches.into_iter() { + assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size)); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + assert_eq!(column.value(i), i as i32 + row_count as i32); + } + row_count += batch.num_rows(); + assert_eq!(last, row_count == num_rows); + } + } + + #[rstest] + #[test] + fn test_batch_splitter( + #[values(1, 3, 11)] batch_size: usize, + #[values(1, 6, 50)] num_rows: usize, + ) { + let mut splitter = BatchSplitter::new(batch_size); + splitter.set_batch(create_test_batch(num_rows)); + + let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size)); + while let Some(batch) = splitter.next() { + batches.push(batch); + } + + assert!(splitter.next().is_none()); + assert_split_batches(batches, batch_size, num_rows); + } + + #[tokio::test] + async fn test_swap_reverting_projection() { + let left_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]); + + let proj = swap_reverting_projection(&left_schema, &right_schema); + + assert_eq!(proj.len(), 3); + + let (col, name) = &proj[0]; + assert_eq!(name, "a"); + assert_col_expr(col, "a", 1); + + let (col, name) = &proj[1]; + assert_eq!(name, "b"); + assert_col_expr(col, "b", 2); + + let (col, name) = &proj[2]; + assert_eq!(name, "c"); + assert_col_expr(col, "c", 0); + } + + fn assert_col_expr(expr: &Arc, name: &str, index: usize) { + let col = expr + .as_any() + .downcast_ref::() + .expect("Projection items should be Column expression"); + assert_eq!(col.name(), name); + assert_eq!(col.index(), index); + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 9b41ebed763fc..c30d5f5b085c4 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -35,11 +35,10 @@ pub use datafusion_physical_expr::{ }; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; -pub(crate) use crate::execution_plan::execution_mode_from_children; pub use crate::execution_plan::{ collect, collect_partitioned, displayable, execute_input_stream, execute_stream, execute_stream_partitioned, get_plan_string, with_new_children_if_necessary, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 360e942226d24..9665a09e42c98 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -24,9 +24,10 @@ use std::task::{Context, Poll}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ - DisplayAs, ExecutionMode, ExecutionPlanProperties, PlanProperties, RecordBatchStream, + DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::execution_plan::{Boundedness, CardinalityEffect}; use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning}; use arrow::datatypes::SchemaRef; @@ -38,7 +39,7 @@ use futures::stream::{Stream, StreamExt}; use log::trace; /// Limit execution plan -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct GlobalLimitExec { /// Input execution plan input: Arc, @@ -85,7 +86,9 @@ impl GlobalLimitExec { PlanProperties::new( input.equivalence_properties().clone(), // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + input.pipeline_behavior(), + // Limit operations are always bounded since they output a finite number of rows + Boundedness::Bounded, ) } } @@ -241,7 +244,9 @@ impl LocalLimitExec { PlanProperties::new( input.equivalence_properties().clone(), // Equivalence Properties input.output_partitioning().clone(), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + input.pipeline_behavior(), + // Limit operations are always bounded since they output a finite number of rows + Boundedness::Bounded, ) } } @@ -336,6 +341,10 @@ impl ExecutionPlan for LocalLimitExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } /// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows. @@ -393,7 +402,7 @@ impl LimitStream { if batch.num_rows() > 0 { break poll; } else { - // continue to poll input stream + // Continue to poll input stream } } Poll::Ready(Some(Err(_e))) => break poll, @@ -403,12 +412,12 @@ impl LimitStream { } } - /// fetches from the batch + /// Fetches from the batch fn stream_limit(&mut self, batch: RecordBatch) -> Option { // records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); if self.fetch == 0 { - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early None } else if batch.num_rows() < self.fetch { // @@ -417,7 +426,7 @@ impl LimitStream { } else if batch.num_rows() >= self.fetch { let batch_rows = self.fetch; self.fetch = 0; - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early // It is guaranteed that batch_rows is <= batch.num_rows Some(batch.slice(0, batch_rows)) @@ -448,7 +457,7 @@ impl Stream for LimitStream { other => other, }) } - // input has been cleared + // Input has been cleared None => Poll::Ready(None), }; @@ -468,7 +477,7 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common::collect; - use crate::{common, test}; + use crate::test; use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use arrow_array::RecordBatchOptions; @@ -484,17 +493,17 @@ mod tests { let num_partitions = 4; let csv = test::scan_partitioned(num_partitions); - // input should have 4 partitions + // Input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let limit = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7)); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = limit.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; - // there should be a total of 100 rows + // There should be a total of 100 rows let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); assert_eq!(row_count, 7); @@ -515,7 +524,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (5 rows) and 1 row from the second (1 row) let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -545,7 +554,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -575,7 +584,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -593,7 +602,7 @@ mod tests { Ok(()) } - // test cases for "skip" + // Test cases for "skip" async fn skip_and_fetch(skip: usize, fetch: Option) -> Result { let task_ctx = Arc::new(TaskContext::default()); @@ -606,9 +615,9 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = offset.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; Ok(batches.iter().map(|batch| batch.num_rows()).sum()) } @@ -628,7 +637,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 3 rows (offset = 3) + // There are total of 400 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, None).await?; assert_eq!(row_count, 397); Ok(()) @@ -636,7 +645,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_10_stats() -> Result<()> { - // there are total of 100 rows, we skipped 3 rows (offset = 3) + // There are total of 100 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, Some(10)).await?; assert_eq!(row_count, 10); Ok(()) @@ -651,7 +660,7 @@ mod tests { #[tokio::test] async fn skip_400_fetch_1() -> Result<()> { - // there are a total of 400 rows + // There are a total of 400 rows let row_count = skip_and_fetch(400, Some(1)).await?; assert_eq!(row_count, 0); Ok(()) @@ -659,7 +668,7 @@ mod tests { #[tokio::test] async fn skip_401_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 401 rows (offset = 3) + // There are total of 400 rows, we skipped 401 rows (offset = 3) let row_count = skip_and_fetch(401, None).await?; assert_eq!(row_count, 0); Ok(()) diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 3aa445d295cb0..58cb842ce91d7 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -17,27 +17,36 @@ //! Execution plan for reading in-memory batches of data +use parking_lot::RwLock; use std::any::Any; use std::fmt; use std::sync::Arc; use std::task::{Context, Poll}; -use super::expressions::PhysicalSortExpr; use super::{ - common, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + common, ColumnarValue, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + Statistics, }; +use crate::execution_plan::{Boundedness, EmissionType}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, project_schema, Result}; +use arrow_array::RecordBatchOptions; +use arrow_schema::Schema; +use datafusion_common::{internal_err, plan_err, project_schema, Constraints, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use datafusion_expr::Scalar; use futures::Stream; /// Execution plan for reading in-memory batches of data +#[derive(Clone)] pub struct MemoryExec { /// The partitions to query partitions: Vec>, @@ -66,11 +75,7 @@ impl fmt::Debug for MemoryExec { } impl DisplayAs for MemoryExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let partition_sizes: Vec<_> = @@ -80,21 +85,29 @@ impl DisplayAs for MemoryExec { .sort_information .first() .map(|output_ordering| { - format!( - ", output_ordering={}", - PhysicalSortExpr::format_list(output_ordering) - ) + format!(", output_ordering={}", output_ordering) }) .unwrap_or_default(); + let constraints = self.cache.equivalence_properties().constraints(); + let constraints = if constraints.is_empty() { + String::new() + } else { + format!(", {}", constraints) + }; + if self.show_sizes { write!( f, - "MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{output_ordering}", + "MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{output_ordering}{constraints}", partition_sizes.len(), ) } else { - write!(f, "MemoryExec: partitions={}", partition_sizes.len(),) + write!( + f, + "MemoryExec: partitions={}{output_ordering}{constraints}", + partition_sizes.len(), + ) } } } @@ -116,7 +129,7 @@ impl ExecutionPlan for MemoryExec { } fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children + // This is a leaf node and has no children vec![] } @@ -163,8 +176,13 @@ impl MemoryExec { projection: Option>, ) -> Result { let projected_schema = project_schema(&schema, projection.as_ref())?; - let cache = - Self::compute_properties(Arc::clone(&projected_schema), &[], partitions); + let constraints = Constraints::empty(); + let cache = Self::compute_properties( + Arc::clone(&projected_schema), + &[], + constraints, + partitions, + ); Ok(Self { partitions: partitions.to_vec(), schema, @@ -176,20 +194,137 @@ impl MemoryExec { }) } - /// set `show_sizes` to determine whether to display partition sizes + /// Create a new execution plan from a list of constant values (`ValuesExec`) + pub fn try_new_as_values( + schema: SchemaRef, + data: Vec>>, + ) -> Result { + if data.is_empty() { + return plan_err!("Values list cannot be empty"); + } + + let n_row = data.len(); + let n_col = schema.fields().len(); + + // We have this single row batch as a placeholder to satisfy evaluation argument + // and generate a single output row + let placeholder_schema = Arc::new(Schema::empty()); + let placeholder_batch = RecordBatch::try_new_with_options( + Arc::clone(&placeholder_schema), + vec![], + &RecordBatchOptions::new().with_row_count(Some(1)), + )?; + + // Evaluate each column + let arrays = (0..n_col) + .map(|j| { + (0..n_row) + .map(|i| { + let expr = &data[i][j]; + let result = expr.evaluate(&placeholder_batch)?; + + match result { + ColumnarValue::Scalar(scalar) => Ok(scalar), + ColumnarValue::Array(array) if array.len() == 1 => { + Scalar::try_from_array(&array, 0) + } + ColumnarValue::Array(_) => { + plan_err!("Cannot have array values in a values list") + } + } + }) + .collect::>>() + .and_then(Scalar::iter_to_array) + }) + .collect::>>()?; + + let batch = RecordBatch::try_new_with_options( + Arc::clone(&schema), + arrays, + &RecordBatchOptions::new().with_row_count(Some(n_row)), + )?; + + let partitions = vec![batch]; + Self::try_new_from_batches(Arc::clone(&schema), partitions) + } + + /// Create a new plan using the provided schema and batches. + /// + /// Errors if any of the batches don't match the provided schema, or if no + /// batches are provided. + pub fn try_new_from_batches( + schema: SchemaRef, + batches: Vec, + ) -> Result { + if batches.is_empty() { + return plan_err!("Values list cannot be empty"); + } + + for batch in &batches { + let batch_schema = batch.schema(); + if batch_schema != schema { + return plan_err!( + "Batch has invalid schema. Expected: {}, got: {}", + schema, + batch_schema + ); + } + } + + let partitions = vec![batches]; + let cache = Self::compute_properties( + Arc::clone(&schema), + &[], + Constraints::empty(), + &partitions, + ); + Ok(Self { + partitions, + schema: Arc::clone(&schema), + projected_schema: Arc::clone(&schema), + projection: None, + sort_information: vec![], + cache, + show_sizes: true, + }) + } + + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.cache = self.cache.with_constraints(constraints); + self + } + + /// Set `show_sizes` to determine whether to display partition sizes pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { self.show_sizes = show_sizes; self } + /// Ref to constraints + pub fn constraints(&self) -> &Constraints { + self.cache.equivalence_properties().constraints() + } + + /// Ref to partitions pub fn partitions(&self) -> &[Vec] { &self.partitions } + /// Ref to projection pub fn projection(&self) -> &Option> { &self.projection } + /// Show sizes + pub fn show_sizes(&self) -> bool { + self.show_sizes + } + + /// Ref to sort information + pub fn sort_information(&self) -> &[LexOrdering] { + &self.sort_information + } + /// A memory table can be ordered by multiple expressions simultaneously. /// [`EquivalenceProperties`] keeps track of expressions that describe the /// global ordering of the schema. These columns are not necessarily same; e.g. @@ -206,18 +341,66 @@ impl MemoryExec { /// where both `a ASC` and `b DESC` can describe the table ordering. With /// [`EquivalenceProperties`], we can keep track of these equivalences /// and treat `a ASC` and `b DESC` as the same ordering requirement. - pub fn with_sort_information(mut self, sort_information: Vec) -> Self { - self.sort_information = sort_information; + /// + /// Note that if there is an internal projection, that projection will be + /// also applied to the given `sort_information`. + pub fn try_with_sort_information( + mut self, + mut sort_information: Vec, + ) -> Result { + // All sort expressions must refer to the original schema + let fields = self.schema.fields(); + let ambiguous_column = sort_information + .iter() + .flat_map(|ordering| ordering.clone()) + .flat_map(|expr| collect_columns(&expr.expr)) + .find(|col| { + fields + .get(col.index()) + .map(|field| field.name() != col.name()) + .unwrap_or(true) + }); + if let Some(col) = ambiguous_column { + return internal_err!( + "Column {:?} is not found in the original schema of the MemoryExec", + col + ); + } + + // If there is a projection on the source, we also need to project orderings + if let Some(projection) = &self.projection { + let base_eqp = EquivalenceProperties::new_with_orderings( + self.original_schema(), + &sort_information, + ); + let proj_exprs = projection + .iter() + .map(|idx| { + let base_schema = self.original_schema(); + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }) + .collect::>(); + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; + sort_information = base_eqp + .project(&projection_mapping, self.schema()) + .into_oeq_class() + .into_inner(); + } + self.sort_information = sort_information; // We need to update equivalence properties when updating sort information. let eq_properties = EquivalenceProperties::new_with_orderings( self.schema(), &self.sort_information, ); self.cache = self.cache.with_eq_properties(eq_properties); - self + + Ok(self) } + /// Arc clone of ref to original schema pub fn original_schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -226,13 +409,15 @@ impl MemoryExec { fn compute_properties( schema: SchemaRef, orderings: &[LexOrdering], + constraints: Constraints, partitions: &[Vec], ) -> PlanProperties { - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings); PlanProperties::new( - eq_properties, // Equivalence Properties - Partitioning::UnknownPartitioning(partitions.len()), // Output Partitioning - ExecutionMode::Bounded, // Execution Mode + EquivalenceProperties::new_with_orderings(schema, orderings) + .with_constraints(constraints), + Partitioning::UnknownPartitioning(partitions.len()), + EmissionType::Incremental, + Boundedness::Bounded, ) } } @@ -309,8 +494,166 @@ impl RecordBatchStream for MemoryStream { } } +pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display { + /// Generate the next batch, return `None` when no more batches are available + fn generate_next_batch(&mut self) -> Result>; +} + +/// Execution plan for lazy in-memory batches of data +/// +/// This plan generates output batches lazily, it doesn't have to buffer all batches +/// in memory up front (compared to `MemoryExec`), thus consuming constant memory. +pub struct LazyMemoryExec { + /// Schema representing the data + schema: SchemaRef, + /// Functions to generate batches for each partition + batch_generators: Vec>>, + /// Plan properties cache storing equivalence properties, partitioning, and execution mode + cache: PlanProperties, +} + +impl LazyMemoryExec { + /// Create a new lazy memory execution plan + pub fn try_new( + schema: SchemaRef, + generators: Vec>>, + ) -> Result { + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::RoundRobinBatch(generators.len()), + EmissionType::Incremental, + Boundedness::Bounded, + ); + Ok(Self { + schema, + batch_generators: generators, + cache, + }) + } +} + +impl fmt::Debug for LazyMemoryExec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("LazyMemoryExec") + .field("schema", &self.schema) + .field("batch_generators", &self.batch_generators) + .finish() + } +} + +impl DisplayAs for LazyMemoryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "LazyMemoryExec: partitions={}, batch_generators=[{}]", + self.batch_generators.len(), + self.batch_generators + .iter() + .map(|g| g.read().to_string()) + .collect::>() + .join(", ") + ) + } + } + } +} + +impl ExecutionPlan for LazyMemoryExec { + fn name(&self) -> &'static str { + "LazyMemoryExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in LazyMemoryExec") + } + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + if partition >= self.batch_generators.len() { + return internal_err!( + "Invalid partition {} for LazyMemoryExec with {} partitions", + partition, + self.batch_generators.len() + ); + } + + Ok(Box::pin(LazyMemoryStream { + schema: Arc::clone(&self.schema), + generator: Arc::clone(&self.batch_generators[partition]), + })) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema)) + } +} + +/// Stream that generates record batches on demand +pub struct LazyMemoryStream { + schema: SchemaRef, + /// Generator to produce batches + /// + /// Note: Idiomatically, DataFusion uses plan-time parallelism - each stream + /// should have a unique `LazyBatchGenerator`. Use RepartitionExec or + /// construct multiple `LazyMemoryStream`s during planning to enable + /// parallel execution. + /// Sharing generators between streams should be used with caution. + generator: Arc>, +} + +impl Stream for LazyMemoryStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + let batch = self.generator.write().generate_next_batch(); + + match batch { + Ok(Some(batch)) => Poll::Ready(Some(Ok(batch))), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(e))), + } + } +} + +impl RecordBatchStream for LazyMemoryStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + #[cfg(test)] -mod tests { +mod memory_exec_tests { use std::sync::Arc; use crate::memory::MemoryExec; @@ -319,6 +662,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SortOptions}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_memory_order_eq() -> datafusion_common::Result<()> { @@ -327,7 +671,7 @@ mod tests { Field::new("b", DataType::Int64, false), Field::new("c", DataType::Int64, false), ])); - let sort1 = vec![ + let sort1 = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), @@ -336,22 +680,22 @@ mod tests { expr: col("b", &schema)?, options: SortOptions::default(), }, - ]; - let sort2 = vec![PhysicalSortExpr { + ]); + let sort2 = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema)?, options: SortOptions::default(), - }]; - let mut expected_output_order = vec![]; + }]); + let mut expected_output_order = LexOrdering::default(); expected_output_order.extend(sort1.clone()); expected_output_order.extend(sort2.clone()); let sort_information = vec![sort1.clone(), sort2.clone()]; let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)? - .with_sort_information(sort_information); + .try_with_sort_information(sort_information)?; assert_eq!( mem_exec.properties().output_ordering().unwrap(), - expected_output_order + &expected_output_order ); let eq_properties = mem_exec.properties().equivalence_properties(); assert!(eq_properties.oeq_class().contains(&sort1)); @@ -359,3 +703,217 @@ mod tests { Ok(()) } } + +#[cfg(test)] +mod lazy_memory_tests { + use super::*; + use arrow::array::Int64Array; + use arrow::datatypes::{DataType, Field, Schema}; + use futures::StreamExt; + + #[derive(Debug, Clone)] + struct TestGenerator { + counter: i64, + max_batches: i64, + batch_size: usize, + schema: SchemaRef, + } + + impl fmt::Display for TestGenerator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "TestGenerator: counter={}, max_batches={}, batch_size={}", + self.counter, self.max_batches, self.batch_size + ) + } + } + + impl LazyBatchGenerator for TestGenerator { + fn generate_next_batch(&mut self) -> Result> { + if self.counter >= self.max_batches { + return Ok(None); + } + + let array = Int64Array::from_iter_values( + (self.counter * self.batch_size as i64) + ..(self.counter * self.batch_size as i64 + self.batch_size as i64), + ); + self.counter += 1; + Ok(Some(RecordBatch::try_new( + Arc::clone(&self.schema), + vec![Arc::new(array)], + )?)) + } + } + + #[tokio::test] + async fn test_lazy_memory_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: 3, + batch_size: 2, + schema: Arc::clone(&schema), + }; + + let exec = + LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?; + + // Test schema + assert_eq!(exec.schema().fields().len(), 1); + assert_eq!(exec.schema().field(0).name(), "a"); + + // Test execution + let stream = exec.execute(0, Arc::new(TaskContext::default()))?; + let batches: Vec<_> = stream.collect::>().await; + + assert_eq!(batches.len(), 3); + + // Verify batch contents + let batch0 = batches[0].as_ref().unwrap(); + let array0 = batch0 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(array0.values(), &[0, 1]); + + let batch1 = batches[1].as_ref().unwrap(); + let array1 = batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(array1.values(), &[2, 3]); + + let batch2 = batches[2].as_ref().unwrap(); + let array2 = batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(array2.values(), &[4, 5]); + + Ok(()) + } + + #[tokio::test] + async fn test_lazy_memory_exec_invalid_partition() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: 1, + batch_size: 1, + schema: Arc::clone(&schema), + }; + + let exec = + LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?; + + // Test invalid partition + let result = exec.execute(1, Arc::new(TaskContext::default())); + + // partition is 0-indexed, so there only should be partition 0 + assert!(matches!( + result, + Err(e) if e.to_string().contains("Invalid partition 1 for LazyMemoryExec with 1 partitions") + )); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::lit; + use crate::test::{self, make_partition}; + + use arrow_schema::{DataType, Field}; + use datafusion_common::stats::{ColumnStatistics, Precision}; + use datafusion_common::ScalarValue; + + #[tokio::test] + async fn values_empty_case() -> Result<()> { + let schema = test::aggr_test_schema(); + let empty = MemoryExec::try_new_as_values(schema, vec![]); + assert!(empty.is_err()); + Ok(()) + } + + #[test] + fn new_exec_with_batches() { + let batch = make_partition(7); + let schema = batch.schema(); + let batches = vec![batch.clone(), batch]; + let _exec = MemoryExec::try_new_from_batches(schema, batches).unwrap(); + } + + #[test] + fn new_exec_with_batches_empty() { + let batch = make_partition(7); + let schema = batch.schema(); + let _ = MemoryExec::try_new_from_batches(schema, Vec::new()).unwrap_err(); + } + + #[test] + fn new_exec_with_batches_invalid_schema() { + let batch = make_partition(7); + let batches = vec![batch.clone(), batch]; + + let invalid_schema = Arc::new(Schema::new(vec![ + Field::new("col0", DataType::UInt32, false), + Field::new("col1", DataType::Utf8, false), + ])); + let _ = MemoryExec::try_new_from_batches(invalid_schema, batches).unwrap_err(); + } + + // Test issue: https://github.com/apache/datafusion/issues/8763 + #[test] + fn new_exec_with_non_nullable_schema() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col0", + DataType::UInt32, + false, + )])); + let _ = MemoryExec::try_new_as_values(Arc::clone(&schema), vec![vec![lit(1u32)]]) + .unwrap(); + // Test that a null value is rejected + let _ = MemoryExec::try_new_as_values( + schema, + vec![vec![lit(ScalarValue::UInt32(None))]], + ) + .unwrap_err(); + } + + #[test] + fn values_stats_with_nulls_only() -> Result<()> { + let data = vec![ + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + ]; + let rows = data.len(); + let values = MemoryExec::try_new_as_values( + Arc::new(Schema::new(vec![Field::new("col0", DataType::Null, true)])), + data, + )?; + + assert_eq!( + values.statistics()?, + Statistics { + num_rows: Precision::Exact(rows), + total_byte_size: Precision::Exact(8), // not important + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(rows), // there are only nulls + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + },], + } + ); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/metrics/builder.rs b/datafusion/physical-plan/src/metrics/builder.rs index 2037ddb70c2d0..dbda0a310ce52 100644 --- a/datafusion/physical-plan/src/metrics/builder.rs +++ b/datafusion/physical-plan/src/metrics/builder.rs @@ -50,7 +50,7 @@ pub struct MetricBuilder<'a> { /// optional partition number partition: Option, - /// arbitrary name=value pairs identifiying this metric + /// arbitrary name=value pairs identifying this metric labels: Vec